Jawad commited on
Commit
b0e8abd
·
1 Parent(s): 1678afb

add cross validation

Browse files
Files changed (3) hide show
  1. model.py +33 -16
  2. requirements.txt +19 -1
  3. stream_app.py +11 -3
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import pandas as pd
3
  import xgboost as xgb
@@ -29,24 +30,40 @@ def prepare_data(monetary_decision):
29
 
30
  def split(predictors, target):
31
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
32
- target,
33
- test_size=0.2,
34
- random_state=42)
35
  return predictors_train, predictors_test, target_train, target_test
36
 
37
 
38
- def run_training(predictors_train, predictors_test):
39
- data_train = xgb.DMatrix(predictors_train, label=predictors_test, enable_categorical=True)
40
- params = {'max_depth': 4,
41
- 'learning_rate': 0.05,
42
- 'colsample_bytree': 0.3,
43
- 'subsample': 0.8,
44
- 'gamma': 0.5,
45
- 'objective': 'reg:squarederror'}
46
- num_round = 1000
47
- #xgb_cv = cv(dtrain=data_train, params=params, nfold=3,
48
- # num_boost_round=1000, early_stopping_rounds=10, metrics="rmse", as_pandas=True, seed=123)
49
- return xgb.train(params, data_train, num_round)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def predict(model, predictors):
@@ -55,4 +72,4 @@ def predict(model, predictors):
55
 
56
 
57
  def features_importance(model):
58
- return pd.Series(model.get_score(importance_type='gain')).sort_values()
 
1
+ import itertools
2
  import numpy as np
3
  import pandas as pd
4
  import xgboost as xgb
 
30
 
31
  def split(predictors, target):
32
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
33
+ target,
34
+ test_size=0.2,
35
+ random_state=42)
36
  return predictors_train, predictors_test, target_train, target_test
37
 
38
 
39
+ def run_cv_training(predictors_train, target_train):
40
+ data_train = xgb.DMatrix(predictors_train, label=target_train, enable_categorical=True)
41
+ xgb_csv = []
42
+ best_params = (100, {}, 10)
43
+ for eta, max_depth, col_num in itertools.product([0.05, 0.01], [10, 15], [0.3, 0.8]):
44
+ prefix = f"{str(eta)}_{str(max_depth)}_{str(col_num)}"
45
+ params = {
46
+ 'learning_rate': eta,
47
+ 'max_depth': max_depth,
48
+ 'colsample_bytree': col_num,
49
+ #'gamma': 0.5,
50
+ 'subsample': 0.8,
51
+ 'objective': 'reg:squarederror'}
52
+ cv_results = cv(dtrain=data_train, params=params, nfold=2,
53
+ num_boost_round=1000, early_stopping_rounds=3, metrics="rmse", as_pandas=True, seed=123)
54
+ best_value = cv_results['test-rmse-mean'].values[-1]
55
+ best_round = cv_results.index[-1]
56
+ xgb_csv.append(
57
+ cv_results.rename(columns={col: f'{prefix}_{col}' for col in cv_results.columns}).tail(10).reset_index())
58
+ if best_value < best_params[0]:
59
+ best_params = (best_value, params, best_round)
60
+
61
+ return pd.concat(xgb_csv, axis=1), best_params
62
+
63
+
64
+ def run_training(predictors_train, target_train, params, num_rounds):
65
+ data_train = xgb.DMatrix(predictors_train, label=target_train, enable_categorical=True)
66
+ return xgb.train(params, data_train, num_rounds)
67
 
68
 
69
  def predict(model, predictors):
 
72
 
73
 
74
  def features_importance(model):
75
+ return pd.Series(model.get_score(importance_type='gain')).sort_values()
requirements.txt CHANGED
@@ -12,6 +12,7 @@ certifi==2021.5.30
12
  cffi==1.14.6
13
  charset-normalizer==2.0.6
14
  click==7.1.2
 
15
  cycler==0.10.0
16
  debugpy==1.5.0
17
  decorator==5.1.0
@@ -20,12 +21,14 @@ entrypoints==0.3
20
  gitdb==4.0.7
21
  GitPython==3.1.24
22
  idna==3.2
 
23
  ipykernel==6.4.1
24
  ipython==7.28.0
25
  ipython-genutils==0.2.0
26
  ipywidgets==7.6.5
27
  jedi==0.18.0
28
  Jinja2==3.0.2
 
29
  jsonschema==4.0.1
30
  jupyter-client==7.0.6
31
  jupyter-core==4.8.1
@@ -51,31 +54,45 @@ pexpect==4.8.0
51
  pickleshare==0.7.5
52
  Pillow==8.3.2
53
  plotly==5.3.1
 
 
54
  prometheus-client==0.11.0
55
  prompt-toolkit==3.0.20
56
  protobuf==3.18.1
57
  ptyprocess==0.7.0
 
58
  pyarrow==5.0.0
 
 
59
  pycparser==2.20
60
  pydeck==0.7.0
61
  Pygments==2.10.0
 
62
  pyparsing==2.4.7
63
  pyrsistent==0.18.0
 
 
 
64
  python-dateutil==2.8.2
65
  pytz==2021.3
66
  pyzmq==22.3.0
 
67
  requests==2.26.0
 
68
  scipy==1.7.1
69
  seaborn==0.11.2
70
  Send2Trash==1.8.0
71
  six==1.16.0
 
72
  smmap==4.0.0
73
  statsmodels==0.13.0
74
- streamlit==0.89.0
75
  tenacity==8.0.1
76
  terminado==0.12.1
77
  testpath==0.5.0
 
78
  toml==0.10.2
 
79
  toolz==0.11.1
80
  tornado==6.1
81
  traitlets==5.1.0
@@ -87,3 +104,4 @@ watchdog==2.1.6
87
  wcwidth==0.2.5
88
  webencodings==0.5.1
89
  widgetsnbextension==3.5.1
 
 
12
  cffi==1.14.6
13
  charset-normalizer==2.0.6
14
  click==7.1.2
15
+ coverage==6.1.1
16
  cycler==0.10.0
17
  debugpy==1.5.0
18
  decorator==5.1.0
 
21
  gitdb==4.0.7
22
  GitPython==3.1.24
23
  idna==3.2
24
+ iniconfig==1.1.1
25
  ipykernel==6.4.1
26
  ipython==7.28.0
27
  ipython-genutils==0.2.0
28
  ipywidgets==7.6.5
29
  jedi==0.18.0
30
  Jinja2==3.0.2
31
+ joblib==1.1.0
32
  jsonschema==4.0.1
33
  jupyter-client==7.0.6
34
  jupyter-core==4.8.1
 
54
  pickleshare==0.7.5
55
  Pillow==8.3.2
56
  plotly==5.3.1
57
+ pluggy==1.0.0
58
+ pprintpp==0.4.0
59
  prometheus-client==0.11.0
60
  prompt-toolkit==3.0.20
61
  protobuf==3.18.1
62
  ptyprocess==0.7.0
63
+ py==1.10.0
64
  pyarrow==5.0.0
65
+ pycountry==20.7.3
66
+ pycountry-convert==0.7.2
67
  pycparser==2.20
68
  pydeck==0.7.0
69
  Pygments==2.10.0
70
+ Pympler==0.9
71
  pyparsing==2.4.7
72
  pyrsistent==0.18.0
73
+ pytest==6.2.5
74
+ pytest-cov==3.0.0
75
+ pytest-mock==3.6.1
76
  python-dateutil==2.8.2
77
  pytz==2021.3
78
  pyzmq==22.3.0
79
+ repoze.lru==0.7
80
  requests==2.26.0
81
+ scikit-learn==1.0.1
82
  scipy==1.7.1
83
  seaborn==0.11.2
84
  Send2Trash==1.8.0
85
  six==1.16.0
86
+ sklearn==0.0
87
  smmap==4.0.0
88
  statsmodels==0.13.0
89
+ streamlit==1.2.0
90
  tenacity==8.0.1
91
  terminado==0.12.1
92
  testpath==0.5.0
93
+ threadpoolctl==3.0.0
94
  toml==0.10.2
95
+ tomli==1.2.2
96
  toolz==0.11.1
97
  tornado==6.1
98
  traitlets==5.1.0
 
104
  wcwidth==0.2.5
105
  webencodings==0.5.1
106
  widgetsnbextension==3.5.1
107
+ xgboost==1.5.0
stream_app.py CHANGED
@@ -8,7 +8,7 @@ import plotly.figure_factory as ff
8
  import scipy
9
  import numpy as np
10
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
11
- from model import prepare_data, run_training, split, predict, features_importance
12
 
13
 
14
  def _max_width_():
@@ -57,7 +57,7 @@ decision_scope = decisions[authority_filter & year_filter]
57
 
58
  st.subheader("Dataset Description")
59
 
60
- st.metric('Number of validated decisions liked to organisations (and not individuals)', decision_scope.shape[0])
61
 
62
  st.metric('Decisions with monetary sanctions',
63
  decision_scope[decision_scope.monetary_sanction > 0].shape[0])
@@ -184,7 +184,15 @@ if st.button('Run training'):
184
  st.metric(label="Training size", value=predictors_train.shape[0])
185
  st.metric(label="Test size", value=predictors_test.shape[0])
186
 
187
- xgb_model = run_training(predictors_train, target_train)
 
 
 
 
 
 
 
 
188
 
189
  # evaluate model error
190
  target_train_predicted = predict(xgb_model, predictors_train)
 
8
  import scipy
9
  import numpy as np
10
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
11
+ from model import prepare_data, run_training, split, predict, features_importance, run_cv_training
12
 
13
 
14
  def _max_width_():
 
57
 
58
  st.subheader("Dataset Description")
59
 
60
+ st.metric('Number of validated decisions linked to organisations (and not individuals)', decision_scope.shape[0])
61
 
62
  st.metric('Decisions with monetary sanctions',
63
  decision_scope[decision_scope.monetary_sanction > 0].shape[0])
 
184
  st.metric(label="Training size", value=predictors_train.shape[0])
185
  st.metric(label="Test size", value=predictors_test.shape[0])
186
 
187
+ #run cross validation
188
+ st.subheader("Cross validation error")
189
+ xgb_cv, best_params = run_cv_training(predictors_train, target_train)
190
+ print(best_params)
191
+ st.json(best_params)
192
+ xgb_cv.to_csv('cv_results.csv')
193
+ st.line_chart(xgb_cv[[col for col in xgb_cv.columns if "mean" in col]])
194
+
195
+ xgb_model = run_training(predictors_train, target_train, best_params[1], best_params[2])
196
 
197
  # evaluate model error
198
  target_train_predicted = predict(xgb_model, predictors_train)