Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import pandas as pd | |
| import xgboost as xgb | |
| from xgboost import cv | |
| from sklearn.model_selection import train_test_split | |
| def prepare_data(monetary_decision): | |
| monetary_decision = monetary_decision.reset_index(drop=True) | |
| time = round((monetary_decision.decision_date - monetary_decision.decision_date.min()) / np.timedelta64(1, "M")) | |
| monetary_decision.loc[:, ('time')] = time | |
| col_num = ['log10_org_revenues', | |
| 'time'] | |
| col_cat = ['authorities_country', | |
| 'type', | |
| 'violation_theme', | |
| 'justice_type', | |
| 'org_country', | |
| 'org_currency', | |
| 'org_continent', | |
| 'same_country', | |
| 'org_company_type'] | |
| predictors = monetary_decision[col_num + col_cat] | |
| target = monetary_decision.log10_monetary_sanction | |
| for col in col_cat: | |
| predictors[col] = predictors[col].astype("category") | |
| return predictors, target | |
| def split(predictors, target): | |
| predictors_train, predictors_test, target_train, target_test = train_test_split(predictors, | |
| target, | |
| test_size=0.2, | |
| random_state=42) | |
| return predictors_train, predictors_test, target_train, target_test | |
| def run_training(predictors_train, predictors_test): | |
| data_train = xgb.DMatrix(predictors_train, label=predictors_test, enable_categorical=True) | |
| params = {'max_depth': 4, | |
| 'learning_rate': 0.05, | |
| 'colsample_bytree': 0.3, | |
| 'subsample': 0.8, | |
| 'gamma': 0.5, | |
| 'objective': 'reg:squarederror'} | |
| num_round = 1000 | |
| xgb_cv = cv(dtrain=data_train, params=params, nfold=3, | |
| num_boost_round=1000, early_stopping_rounds=10, metrics="rmse", as_pandas=True, seed=123) | |
| print(xgb_cv) | |
| return xgb.train(params, data_train, num_round) | |
| def predict(model, predictors): | |
| data = xgb.DMatrix(predictors, enable_categorical=True) | |
| return model.predict(data) | |
| def features_importance(model): | |
| return pd.Series(model.get_score(importance_type='gain')).sort_values() |