Jawad commited on
Commit
d036d71
·
1 Parent(s): b0e8abd

change default date

Browse files
Files changed (2) hide show
  1. model.py +2 -2
  2. stream_app.py +6 -5
model.py CHANGED
@@ -32,7 +32,7 @@ def split(predictors, target):
32
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
33
  target,
34
  test_size=0.2,
35
- random_state=42)
36
  return predictors_train, predictors_test, target_train, target_test
37
 
38
 
@@ -50,7 +50,7 @@ def run_cv_training(predictors_train, target_train):
50
  'subsample': 0.8,
51
  'objective': 'reg:squarederror'}
52
  cv_results = cv(dtrain=data_train, params=params, nfold=2,
53
- num_boost_round=1000, early_stopping_rounds=3, metrics="rmse", as_pandas=True, seed=123)
54
  best_value = cv_results['test-rmse-mean'].values[-1]
55
  best_round = cv_results.index[-1]
56
  xgb_csv.append(
 
32
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
33
  target,
34
  test_size=0.2,
35
+ random_state=50)
36
  return predictors_train, predictors_test, target_train, target_test
37
 
38
 
 
50
  'subsample': 0.8,
51
  'objective': 'reg:squarederror'}
52
  cv_results = cv(dtrain=data_train, params=params, nfold=2,
53
+ num_boost_round=1000, early_stopping_rounds=3, metrics="rmse", as_pandas=True, seed=50)
54
  best_value = cv_results['test-rmse-mean'].values[-1]
55
  best_round = cv_results.index[-1]
56
  xgb_csv.append(
stream_app.py CHANGED
@@ -44,7 +44,7 @@ else:
44
  select_auth = authorities.name.sort_values()
45
 
46
  authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
47
- min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2001, 2021))
48
 
49
  # apply filters
50
  authority_filter = True
@@ -184,14 +184,15 @@ if st.button('Run training'):
184
  st.metric(label="Training size", value=predictors_train.shape[0])
185
  st.metric(label="Test size", value=predictors_test.shape[0])
186
 
187
- #run cross validation
188
  st.subheader("Cross validation error")
189
  xgb_cv, best_params = run_cv_training(predictors_train, target_train)
190
- print(best_params)
191
- st.json(best_params)
192
- xgb_cv.to_csv('cv_results.csv')
193
  st.line_chart(xgb_cv[[col for col in xgb_cv.columns if "mean" in col]])
 
 
194
 
 
195
  xgb_model = run_training(predictors_train, target_train, best_params[1], best_params[2])
196
 
197
  # evaluate model error
 
44
  select_auth = authorities.name.sort_values()
45
 
46
  authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
47
+ min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2008, 2021))
48
 
49
  # apply filters
50
  authority_filter = True
 
184
  st.metric(label="Training size", value=predictors_train.shape[0])
185
  st.metric(label="Test size", value=predictors_test.shape[0])
186
 
187
+ # run cross validation
188
  st.subheader("Cross validation error")
189
  xgb_cv, best_params = run_cv_training(predictors_train, target_train)
190
+
 
 
191
  st.line_chart(xgb_cv[[col for col in xgb_cv.columns if "mean" in col]])
192
+ st.subheader("Selected variables")
193
+ st.json(best_params)
194
 
195
+ # train final model
196
  xgb_model = run_training(predictors_train, target_train, best_params[1], best_params[2])
197
 
198
  # evaluate model error