JawadRouen commited on
Commit
eb06de4
·
unverified ·
2 Parent(s): 115a68c 81fbf23

Merge pull request #3 from THEOLEX-IO/report

Browse files
Files changed (5) hide show
  1. data_processing.py +4 -3
  2. model.py +40 -31
  3. requirements.txt +20 -3
  4. score_app.py +4 -2
  5. stream_app.py +58 -17
data_processing.py CHANGED
@@ -43,7 +43,7 @@ def load_data():
43
  def process_data(data):
44
  decisions = pd.DataFrame(data['decisions'])
45
  decisions['year'] = pd.to_datetime(decisions['decision_date']).dt.year
46
- decisions.monetary_sanction = decisions.monetary_sanction.astype(float)
47
  # keep validated decisions
48
  decisions = decisions[decisions.status == 'V']
49
  decisions.decision_date = pd.to_datetime(decisions['decision_date']).dt.date
@@ -63,7 +63,6 @@ def process_data(data):
63
  decisions = decisions.merge(organizations, left_on='organizations', right_on='org_id')
64
  # remove Individual
65
  decisions = decisions[decisions.org_company_type != "Individual"]
66
-
67
  # work on authorities
68
  authorities = pd.DataFrame(data['authorities'])
69
  authorities.index = authorities.url.apply(get_id)
@@ -79,12 +78,14 @@ def process_data(data):
79
  def get_monetary_dataframe(decision_scope):
80
  monetary_decision = decision_scope[decision_scope.monetary_sanction > 0]
81
  monetary_decision['has_revenues'] = (monetary_decision.org_revenues != "")
82
- monetary_decision['org_revenues'] = monetary_decision.org_revenues.str.replace('', '0').astype(float)
83
  monetary_decision['log10_org_revenues'] = monetary_decision.org_revenues.apply(lambda x: np.log10(x+1))
84
  monetary_decision['log10_monetary_sanction'] = monetary_decision.monetary_sanction.apply(lambda x: np.log10(x+1))
85
  monetary_decision['same_country'] = (monetary_decision.org_country == monetary_decision.authorities_country)
86
  monetary_decision['monetary_sanction_rate'] = monetary_decision.monetary_sanction/monetary_decision.org_revenues
87
  monetary_decision['log10_monetary_sanction_rate'] = monetary_decision.monetary_sanction_rate.apply(np.log10)
 
 
88
  return monetary_decision
89
 
90
 
 
43
  def process_data(data):
44
  decisions = pd.DataFrame(data['decisions'])
45
  decisions['year'] = pd.to_datetime(decisions['decision_date']).dt.year
46
+ decisions.monetary_sanction = pd.to_numeric(decisions.monetary_sanction, errors='coerce').fillna(0)
47
  # keep validated decisions
48
  decisions = decisions[decisions.status == 'V']
49
  decisions.decision_date = pd.to_datetime(decisions['decision_date']).dt.date
 
63
  decisions = decisions.merge(organizations, left_on='organizations', right_on='org_id')
64
  # remove Individual
65
  decisions = decisions[decisions.org_company_type != "Individual"]
 
66
  # work on authorities
67
  authorities = pd.DataFrame(data['authorities'])
68
  authorities.index = authorities.url.apply(get_id)
 
78
  def get_monetary_dataframe(decision_scope):
79
  monetary_decision = decision_scope[decision_scope.monetary_sanction > 0]
80
  monetary_decision['has_revenues'] = (monetary_decision.org_revenues != "")
81
+ monetary_decision['org_revenues'] = pd.to_numeric(monetary_decision.org_revenues, errors='coerce').fillna(0)
82
  monetary_decision['log10_org_revenues'] = monetary_decision.org_revenues.apply(lambda x: np.log10(x+1))
83
  monetary_decision['log10_monetary_sanction'] = monetary_decision.monetary_sanction.apply(lambda x: np.log10(x+1))
84
  monetary_decision['same_country'] = (monetary_decision.org_country == monetary_decision.authorities_country)
85
  monetary_decision['monetary_sanction_rate'] = monetary_decision.monetary_sanction/monetary_decision.org_revenues
86
  monetary_decision['log10_monetary_sanction_rate'] = monetary_decision.monetary_sanction_rate.apply(np.log10)
87
+ time = round((monetary_decision.decision_date - monetary_decision.decision_date.min()) / np.timedelta64(1, "M"))
88
+ monetary_decision['time'] = time
89
  return monetary_decision
90
 
91
 
model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import pandas as pd
3
  import xgboost as xgb
@@ -5,48 +6,56 @@ from xgboost import cv
5
  from sklearn.model_selection import train_test_split
6
 
7
 
8
- def prepare_data(monetary_decision):
9
  monetary_decision = monetary_decision.reset_index(drop=True)
10
- time = round((monetary_decision.decision_date - monetary_decision.decision_date.min()) / np.timedelta64(1, "M"))
11
- monetary_decision.loc[:, ('time')] = time
12
- col_num = ['log10_org_revenues',
13
- 'time']
14
- col_cat = ['authorities_country',
15
- 'type',
16
- 'violation_theme',
17
- 'justice_type',
18
- 'org_country',
19
- 'org_currency',
20
- 'org_continent',
21
- 'same_country',
22
- 'org_company_type']
23
  predictors = monetary_decision[col_num + col_cat]
24
- target = monetary_decision.log10_monetary_sanction
25
  for col in col_cat:
26
  predictors[col] = predictors[col].astype("category")
 
 
 
 
 
 
27
  return predictors, target
28
 
29
 
30
  def split(predictors, target):
31
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
32
- target,
33
- test_size=0.2,
34
- random_state=42)
35
  return predictors_train, predictors_test, target_train, target_test
36
 
37
 
38
- def run_training(predictors_train, predictors_test):
39
- data_train = xgb.DMatrix(predictors_train, label=predictors_test, enable_categorical=True)
40
- params = {'max_depth': 4,
41
- 'learning_rate': 0.05,
42
- 'colsample_bytree': 0.3,
43
- 'subsample': 0.8,
44
- 'gamma': 0.5,
45
- 'objective': 'reg:squarederror'}
46
- num_round = 1000
47
- #xgb_cv = cv(dtrain=data_train, params=params, nfold=3,
48
- # num_boost_round=1000, early_stopping_rounds=10, metrics="rmse", as_pandas=True, seed=123)
49
- return xgb.train(params, data_train, num_round)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def predict(model, predictors):
@@ -55,4 +64,4 @@ def predict(model, predictors):
55
 
56
 
57
  def features_importance(model):
58
- return pd.Series(model.get_score(importance_type='gain')).sort_values()
 
1
+ import itertools
2
  import numpy as np
3
  import pandas as pd
4
  import xgboost as xgb
 
6
  from sklearn.model_selection import train_test_split
7
 
8
 
9
+ def prepare_predictors(monetary_decision, col_num, col_cat):
10
  monetary_decision = monetary_decision.reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  predictors = monetary_decision[col_num + col_cat]
 
12
  for col in col_cat:
13
  predictors[col] = predictors[col].astype("category")
14
+ return predictors
15
+
16
+
17
+ def prepare_data(monetary_decision, col_num, col_cat):
18
+ predictors = prepare_predictors(monetary_decision, col_num, col_cat)
19
+ target = monetary_decision.log10_monetary_sanction
20
  return predictors, target
21
 
22
 
23
  def split(predictors, target):
24
  predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
25
+ target,
26
+ test_size=0.2,
27
+ random_state=50)
28
  return predictors_train, predictors_test, target_train, target_test
29
 
30
 
31
+ def run_cv_training(predictors_train, target_train):
32
+ data_train = xgb.DMatrix(predictors_train, label=target_train, enable_categorical=True)
33
+ xgb_csv = []
34
+ best_params = (100, {}, 10)
35
+ for eta, max_depth, col_num in itertools.product([0.05, 0.01], [10, 15], [0.3, 0.8]):
36
+ prefix = f"{str(eta)}_{str(max_depth)}_{str(col_num)}"
37
+ params = {
38
+ 'learning_rate': eta,
39
+ 'max_depth': max_depth,
40
+ 'colsample_bytree': col_num,
41
+ # 'gamma': 0.5,
42
+ 'subsample': 0.8,
43
+ 'objective': 'reg:squarederror'}
44
+ cv_results = cv(dtrain=data_train, params=params, nfold=2,
45
+ num_boost_round=1000, early_stopping_rounds=3, metrics="rmse", as_pandas=True, seed=50)
46
+ best_value = cv_results['test-rmse-mean'].values[-1]
47
+ best_round = cv_results.index[-1]
48
+ xgb_csv.append(
49
+ cv_results.rename(columns={col: f'{prefix}_{col}' for col in cv_results.columns}).tail(10).reset_index())
50
+ if best_value < best_params[0]:
51
+ best_params = (best_value, params, best_round)
52
+
53
+ return pd.concat(xgb_csv, axis=1), best_params
54
+
55
+
56
+ def run_training(predictors_train, target_train, params, num_rounds):
57
+ data_train = xgb.DMatrix(predictors_train, label=target_train, enable_categorical=True)
58
+ return xgb.train(params, data_train, num_rounds)
59
 
60
 
61
  def predict(model, predictors):
 
64
 
65
 
66
  def features_importance(model):
67
+ return pd.Series(model.get_score(importance_type='gain')).sort_values()
requirements.txt CHANGED
@@ -12,6 +12,7 @@ certifi==2021.5.30
12
  cffi==1.14.6
13
  charset-normalizer==2.0.6
14
  click==7.1.2
 
15
  cycler==0.10.0
16
  debugpy==1.5.0
17
  decorator==5.1.0
@@ -20,12 +21,14 @@ entrypoints==0.3
20
  gitdb==4.0.7
21
  GitPython==3.1.24
22
  idna==3.2
 
23
  ipykernel==6.4.1
24
  ipython==7.28.0
25
  ipython-genutils==0.2.0
26
  ipywidgets==7.6.5
27
  jedi==0.18.0
28
  Jinja2==3.0.2
 
29
  jsonschema==4.0.1
30
  jupyter-client==7.0.6
31
  jupyter-core==4.8.1
@@ -51,31 +54,45 @@ pexpect==4.8.0
51
  pickleshare==0.7.5
52
  Pillow==8.3.2
53
  plotly==5.3.1
 
 
54
  prometheus-client==0.11.0
55
  prompt-toolkit==3.0.20
56
  protobuf==3.18.1
57
  ptyprocess==0.7.0
 
58
  pyarrow==5.0.0
 
 
59
  pycparser==2.20
60
  pydeck==0.7.0
61
  Pygments==2.10.0
 
62
  pyparsing==2.4.7
63
  pyrsistent==0.18.0
 
 
 
64
  python-dateutil==2.8.2
65
  pytz==2021.3
66
  pyzmq==22.3.0
 
67
  requests==2.26.0
 
68
  scipy==1.7.1
69
  seaborn==0.11.2
70
  Send2Trash==1.8.0
71
  six==1.16.0
 
72
  smmap==4.0.0
73
  statsmodels==0.13.0
74
- streamlit==0.89.0
75
  tenacity==8.0.1
76
  terminado==0.12.1
77
  testpath==0.5.0
 
78
  toml==0.10.2
 
79
  toolz==0.11.1
80
  tornado==6.1
81
  traitlets==5.1.0
@@ -87,5 +104,5 @@ watchdog==2.1.6
87
  wcwidth==0.2.5
88
  webencodings==0.5.1
89
  widgetsnbextension==3.5.1
90
- xgboost
91
- sklearn
 
12
  cffi==1.14.6
13
  charset-normalizer==2.0.6
14
  click==7.1.2
15
+ coverage==6.1.1
16
  cycler==0.10.0
17
  debugpy==1.5.0
18
  decorator==5.1.0
 
21
  gitdb==4.0.7
22
  GitPython==3.1.24
23
  idna==3.2
24
+ iniconfig==1.1.1
25
  ipykernel==6.4.1
26
  ipython==7.28.0
27
  ipython-genutils==0.2.0
28
  ipywidgets==7.6.5
29
  jedi==0.18.0
30
  Jinja2==3.0.2
31
+ joblib==1.1.0
32
  jsonschema==4.0.1
33
  jupyter-client==7.0.6
34
  jupyter-core==4.8.1
 
54
  pickleshare==0.7.5
55
  Pillow==8.3.2
56
  plotly==5.3.1
57
+ pluggy==1.0.0
58
+ pprintpp==0.4.0
59
  prometheus-client==0.11.0
60
  prompt-toolkit==3.0.20
61
  protobuf==3.18.1
62
  ptyprocess==0.7.0
63
+ py==1.10.0
64
  pyarrow==5.0.0
65
+ pycountry==20.7.3
66
+ pycountry-convert==0.7.2
67
  pycparser==2.20
68
  pydeck==0.7.0
69
  Pygments==2.10.0
70
+ Pympler==0.9
71
  pyparsing==2.4.7
72
  pyrsistent==0.18.0
73
+ pytest==6.2.5
74
+ pytest-cov==3.0.0
75
+ pytest-mock==3.6.1
76
  python-dateutil==2.8.2
77
  pytz==2021.3
78
  pyzmq==22.3.0
79
+ repoze.lru==0.7
80
  requests==2.26.0
81
+ scikit-learn==1.0.1
82
  scipy==1.7.1
83
  seaborn==0.11.2
84
  Send2Trash==1.8.0
85
  six==1.16.0
86
+ sklearn==0.0
87
  smmap==4.0.0
88
  statsmodels==0.13.0
89
+ streamlit==1.2.0
90
  tenacity==8.0.1
91
  terminado==0.12.1
92
  testpath==0.5.0
93
+ threadpoolctl==3.0.0
94
  toml==0.10.2
95
+ tomli==1.2.2
96
  toolz==0.11.1
97
  tornado==6.1
98
  traitlets==5.1.0
 
104
  wcwidth==0.2.5
105
  webencodings==0.5.1
106
  widgetsnbextension==3.5.1
107
+ xgboost==1.5.0
108
+ sklearn
score_app.py CHANGED
@@ -2,6 +2,7 @@
2
  import streamlit as st
3
  import requests
4
  import pandas as pd
 
5
  import datetime
6
  from data import headers
7
 
@@ -18,7 +19,7 @@ created_at = st.sidebar.date_input('Date input', value=datetime.date(2021, 1, 1)
18
  @st.cache
19
  def load_data(source_type, start_date):
20
  def get_decision_hist(d_id):
21
- url = f"https://www.theolex.io/data/decisions/{d_id}/return_hist/"
22
  res = requests.get(url, headers=headers)
23
  return res.json()
24
 
@@ -34,7 +35,8 @@ def load_data(source_type, start_date):
34
  data_sources = data_sources[data_sources.created_at >= start_date]
35
 
36
  # get decisions history
37
- data_list = [(_id, get_decision_hist(_id)) for _id in data_sources['decision_id']]
 
38
  return [(_id, pd.DataFrame(pd.DataFrame(data).fields.to_dict()).T)
39
  for _id, data in data_list if len(data) > 0]
40
 
 
2
  import streamlit as st
3
  import requests
4
  import pandas as pd
5
+ import numpy as np
6
  import datetime
7
  from data import headers
8
 
 
19
  @st.cache
20
  def load_data(source_type, start_date):
21
  def get_decision_hist(d_id):
22
+ url = f"https://www.theolex.io/data/decisions/{int(d_id)}/return_hist/"
23
  res = requests.get(url, headers=headers)
24
  return res.json()
25
 
 
35
  data_sources = data_sources[data_sources.created_at >= start_date]
36
 
37
  # get decisions history
38
+ # can be optimized by filtering first on validated decision for decision table
39
+ data_list = [(_id, get_decision_hist(_id)) for _id in data_sources['decision_id'] if not np.isnan(_id)]
40
  return [(_id, pd.DataFrame(pd.DataFrame(data).fields.to_dict()).T)
41
  for _id, data in data_list if len(data) > 0]
42
 
stream_app.py CHANGED
@@ -1,6 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  import pandas as pd
3
  import streamlit as st
 
4
 
5
  import plotly.express as px
6
  import plotly.figure_factory as ff
@@ -8,7 +9,7 @@ import plotly.figure_factory as ff
8
  import scipy
9
  import numpy as np
10
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
11
- from model import prepare_data, run_training, split, predict, features_importance
12
 
13
 
14
  def _max_width_():
@@ -44,7 +45,7 @@ else:
44
  select_auth = authorities.name.sort_values()
45
 
46
  authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
47
- min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2001, 2021))
48
 
49
  # apply filters
50
  authority_filter = True
@@ -57,7 +58,7 @@ decision_scope = decisions[authority_filter & year_filter]
57
 
58
  st.subheader("Dataset Description")
59
 
60
- st.metric('Number of validated decisions liked to organisations (and not individuals)', decision_scope.shape[0])
61
 
62
  st.metric('Decisions with monetary sanctions',
63
  decision_scope[decision_scope.monetary_sanction > 0].shape[0])
@@ -157,36 +158,63 @@ with st.expander("Data exploration"):
157
  width=1000, height=600)
158
  st.plotly_chart(fig)
159
 
160
-
161
  ##############################################
162
  ####
163
  # build ML model
164
  ####
165
  ##############################################
166
  st.title("Training phase")
167
-
168
- predictors, target = prepare_data(monetary_decision)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # train the model
 
170
  if st.button('Run training'):
171
  with st.expander("Training results"):
 
172
  st.write(f"dataset size: {monetary_decision.shape[0]}")
173
- st.markdown("Plot taget distribution: log 10 of monetary sanctions")
174
- fig = ff.create_distplot([target], [' log 10 of monetary sanctions'], bin_size=0.1)
175
  fig.update_layout(width=1000,
176
  template="simple_white",
177
  height=600,
178
  bargap=0.01)
179
  st.plotly_chart(fig)
180
 
181
- # split data set
182
  predictors_train, predictors_test, target_train, target_test = split(predictors, target)
183
  st.subheader("Split dataset between training and test:")
184
  st.metric(label="Training size", value=predictors_train.shape[0])
185
  st.metric(label="Test size", value=predictors_test.shape[0])
186
 
187
- xgb_model = run_training(predictors_train, target_train)
 
 
 
 
 
 
 
 
 
 
188
 
189
- # evaluate model error
190
  target_train_predicted = predict(xgb_model, predictors_train)
191
  training_bias = np.mean(target_train_predicted - target_train)
192
  st.metric(label="Training bias", value=training_bias)
@@ -196,7 +224,7 @@ if st.button('Run training'):
196
  test_bias = np.mean(test_errors)
197
  st.metric(label="Test bias", value=test_bias)
198
 
199
- fig = ff.create_distplot([test_errors], ['errors distribution'], bin_size=0.1)
200
  fig.update_layout(width=1000,
201
  template="simple_white",
202
  height=600,
@@ -251,8 +279,21 @@ if st.button('Run training'):
251
  R_sq = corr_matrix[0, 1] ** 2
252
  st.metric(label="Explained variation thanks to model (R^2)", value=f"{round(100 * R_sq, 2)}%")
253
 
254
- st.sidebar.title("Organizations view")
255
- col_x = ['log10_org_revenues', 'authorities_country', 'violation_theme', 'org_country', 'org_company_type']
256
- sample_revenues = st.sidebar.number_input('Yearly revenues', value=1000000)
257
- authority = st.sidebar.selectbox('Organization country', predictors.org_country.cat.categories)
258
- authority = st.sidebar.selectbox('Organization activity', predictors.org_company_type.cat.categories)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  import pandas as pd
3
  import streamlit as st
4
+ from scipy import stats
5
 
6
  import plotly.express as px
7
  import plotly.figure_factory as ff
 
9
  import scipy
10
  import numpy as np
11
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
12
+ from model import prepare_predictors, prepare_data, run_training, split, predict, features_importance, run_cv_training
13
 
14
 
15
  def _max_width_():
 
45
  select_auth = authorities.name.sort_values()
46
 
47
  authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
48
+ min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2008, 2021))
49
 
50
  # apply filters
51
  authority_filter = True
 
58
 
59
  st.subheader("Dataset Description")
60
 
61
+ st.metric('Number of validated decisions linked to organisations (and not individuals)', decision_scope.shape[0])
62
 
63
  st.metric('Decisions with monetary sanctions',
64
  decision_scope[decision_scope.monetary_sanction > 0].shape[0])
 
158
  width=1000, height=600)
159
  st.plotly_chart(fig)
160
 
 
161
  ##############################################
162
  ####
163
  # build ML model
164
  ####
165
  ##############################################
166
  st.title("Training phase")
167
+ xgb_model = None
168
+ col_num_all = ['log10_org_revenues',
169
+ 'time']
170
+ col_cat_all = ['authorities_country',
171
+ 'type',
172
+ 'violation_theme',
173
+ 'justice_type',
174
+ 'org_country',
175
+ 'org_currency',
176
+ 'org_continent',
177
+ 'same_country',
178
+ 'org_company_type']
179
+
180
+ st.sidebar.title("Training params")
181
+ col_num = st.sidebar.multiselect('Numeric variables',
182
+ col_num_all, col_num_all)
183
+ col_cat = st.sidebar.multiselect('Categorical variables',
184
+ col_cat_all, col_cat_all)
185
  # train the model
186
+ predictors, target = prepare_data(monetary_decision, col_num, col_cat)
187
  if st.button('Run training'):
188
  with st.expander("Training results"):
189
+ # Study distribution
190
  st.write(f"dataset size: {monetary_decision.shape[0]}")
191
+ st.markdown("Plot target distribution: log 10 of monetary sanctions")
192
+ fig = ff.create_distplot([target], ['log 10 of monetary sanctions'], bin_size=0.1)
193
  fig.update_layout(width=1000,
194
  template="simple_white",
195
  height=600,
196
  bargap=0.01)
197
  st.plotly_chart(fig)
198
 
199
+ # Split data set
200
  predictors_train, predictors_test, target_train, target_test = split(predictors, target)
201
  st.subheader("Split dataset between training and test:")
202
  st.metric(label="Training size", value=predictors_train.shape[0])
203
  st.metric(label="Test size", value=predictors_test.shape[0])
204
 
205
+ # Run cross validation
206
+ st.subheader("Cross validation error")
207
+ with st.spinner('Wait for it...'):
208
+ xgb_cv, best_params = run_cv_training(predictors_train, target_train)
209
+
210
+ st.line_chart(xgb_cv[[col for col in xgb_cv.columns if "mean" in col]])
211
+ st.subheader("Selected variables")
212
+ st.json(best_params)
213
+
214
+ # Train final model
215
+ xgb_model = run_training(predictors_train, target_train, best_params[1], best_params[2])
216
 
217
+ # Evaluate model error
218
  target_train_predicted = predict(xgb_model, predictors_train)
219
  training_bias = np.mean(target_train_predicted - target_train)
220
  st.metric(label="Training bias", value=training_bias)
 
224
  test_bias = np.mean(test_errors)
225
  st.metric(label="Test bias", value=test_bias)
226
 
227
+ fig = ff.create_distplot([test_errors], ['errors distribution'], bin_size=0.2)
228
  fig.update_layout(width=1000,
229
  template="simple_white",
230
  height=600,
 
279
  R_sq = corr_matrix[0, 1] ** 2
280
  st.metric(label="Explained variation thanks to model (R^2)", value=f"{round(100 * R_sq, 2)}%")
281
 
282
+ st.subheader("Residuals & homoscedasticity")
283
+ # st.metric(label="Explained variation thanks to model (R^2)", value=f"{round(100 * R_sq, 2)}%")
284
+
285
+ print(stats.pearsonr(test_errors, target_test))
286
+
287
+ st.title("Organizations view")
288
+ col1, col2, col3 = st.columns(3)
289
+ to_predict = {}
290
+ with col1:
291
+ to_predict['log10_org_revenues'] = [np.log10(st.number_input('Yearly revenues', value=100000000))]
292
+ for col in col_cat:
293
+ to_predict[col] = [st.selectbox(f'{col}', predictors[col].cat.categories)]
294
+ print(to_predict)
295
+
296
+ df_to_predict = prepare_predictors(pd.DataFrame.from_dict(to_predict), col_num, col_cat)
297
+ if xgb_model:
298
+ predicted = predict(xgb_model, df_to_predict)
299
+ print(predicted)