Jawad commited on
Commit
fe0f3db
·
1 Parent(s): 2715942

add traning

Browse files
Files changed (3) hide show
  1. data_processing.py +2 -0
  2. model.py +53 -0
  3. stream_app.py +120 -4
data_processing.py CHANGED
@@ -82,6 +82,8 @@ def get_monetary_dataframe(decision_scope):
82
  monetary_decision['log10_org_revenues'] = monetary_decision.org_revenues.apply(np.log10)
83
  monetary_decision['log10_monetary_sanction'] = monetary_decision.monetary_sanction.apply(np.log10)
84
  monetary_decision['same_country'] = (monetary_decision.org_country == monetary_decision.authorities_country)
 
 
85
  return monetary_decision
86
 
87
 
 
82
  monetary_decision['log10_org_revenues'] = monetary_decision.org_revenues.apply(np.log10)
83
  monetary_decision['log10_monetary_sanction'] = monetary_decision.monetary_sanction.apply(np.log10)
84
  monetary_decision['same_country'] = (monetary_decision.org_country == monetary_decision.authorities_country)
85
+ monetary_decision['monetary_sanction_rate'] = monetary_decision.monetary_sanction/monetary_decision.org_revenues
86
+ monetary_decision['log10_monetary_sanction_rate'] = monetary_decision.monetary_sanction_rate.apply(np.log10)
87
  return monetary_decision
88
 
89
 
model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import xgboost as xgb
4
+ from sklearn.model_selection import train_test_split
5
+
6
+
7
+ def prepare_data(monetary_decision):
8
+ monetary_decision = monetary_decision.reset_index(drop=True)
9
+ time = round((monetary_decision.decision_date - monetary_decision.decision_date.min()) / np.timedelta64(1, "M"))
10
+ monetary_decision.loc[:, ('time')] = time
11
+ col_num = ['log10_org_revenues',
12
+ 'time']
13
+ col_cat = ['authorities_country',
14
+ 'type',
15
+ 'violation_theme',
16
+ 'justice_type',
17
+ 'org_country',
18
+ 'org_currency',
19
+ 'org_continent',
20
+ 'same_country',
21
+ 'org_company_type']
22
+ predictors = monetary_decision[col_num + col_cat]
23
+ target = monetary_decision.log10_monetary_sanction
24
+ for col in col_cat:
25
+ predictors[col] = predictors[col].astype("category")
26
+ return predictors, target
27
+
28
+
29
+ def split(predictors, target):
30
+ predictors_train, predictors_test, target_train, target_test = train_test_split(predictors,
31
+ target,
32
+ test_size=0.2,
33
+ random_state=42)
34
+ return predictors_train, predictors_test, target_train, target_test
35
+
36
+
37
+ def run_training(predictors_train, predictors_test):
38
+ data_train = xgb.DMatrix(predictors_train, label=predictors_test, enable_categorical=True)
39
+ param = {'max_depth': 5,
40
+ 'learning_rate': .2,
41
+ 'colsample_bytree': 0.3,
42
+ 'objective': 'reg:squarederror'}
43
+ num_round = 50
44
+ return xgb.train(param, data_train, num_round)
45
+
46
+
47
+ def predict(model, predictors):
48
+ data = xgb.DMatrix(predictors, enable_categorical=True)
49
+ return model.predict(data)
50
+
51
+
52
+ def features_importance(model):
53
+ return pd.Series(model.get_score(importance_type='gain')).sort_values()
stream_app.py CHANGED
@@ -1,9 +1,14 @@
1
  # -*- coding: utf-8 -*-
 
2
  import streamlit as st
3
 
4
  import plotly.express as px
 
5
 
 
 
6
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
 
7
 
8
 
9
  def _max_width_():
@@ -30,12 +35,15 @@ st.write("by [Teolex](https://www.theolex.io/)")
30
  data = load_data()
31
  decisions, organizations, authorities = process_data(data)
32
 
33
- st.sidebar.title("Parameters")
34
- authorities_country = st.sidebar.selectbox('Authority country', authorities.country.unique())
35
 
36
- select_auth = authorities[authorities.country == authorities_country].name.sort_values()
37
- authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
 
 
38
 
 
39
  min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2010, 2021))
40
 
41
  # apply filters
@@ -90,6 +98,43 @@ fig = px.scatter(monetary_decision,
90
  width=1000, height=600)
91
  st.plotly_chart(fig)
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  st.subheader("Sum of monetary sanctions over time ")
94
  st.markdown("The graph shows the cumulated monetary sanction per year for each violation theme")
95
  chart_data = get_themes_per_year(monetary_decision)
@@ -101,3 +146,74 @@ fig = px.area(chart_data, x="year",
101
  line_group="violation_theme",
102
  width=1000, height=600)
103
  st.plotly_chart(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
+ import pandas as pd
3
  import streamlit as st
4
 
5
  import plotly.express as px
6
+ import plotly.figure_factory as ff
7
 
8
+ import scipy
9
+ import numpy as np
10
  from data_processing import load_data, process_data, get_monetary_dataframe, get_themes_per_year
11
+ from model import prepare_data, run_training, split, predict, features_importance
12
 
13
 
14
  def _max_width_():
 
35
  data = load_data()
36
  decisions, organizations, authorities = process_data(data)
37
 
38
+ st.sidebar.title("Authorities parameters")
39
+ authorities_country = st.sidebar.selectbox('Authority country', ['All', *authorities.country.unique()])
40
 
41
+ if authorities_country != 'All':
42
+ select_auth = authorities[authorities.country == authorities_country].name.sort_values()
43
+ else:
44
+ select_auth = authorities.name.sort_values()
45
 
46
+ authority = st.sidebar.selectbox('Authority', ['All', *select_auth])
47
  min_year, max_year = st.sidebar.slider('Decisions year', min_value=2001, max_value=2021, value=(2010, 2021))
48
 
49
  # apply filters
 
98
  width=1000, height=600)
99
  st.plotly_chart(fig)
100
 
101
+ fig = px.scatter(monetary_decision[~monetary_decision.org_revenues.isnull()],
102
+ x="decision_date",
103
+ size="log10_monetary_sanction",
104
+ y="org_revenues",
105
+ log_y=True,
106
+ template="simple_white",
107
+ color="same_country",
108
+ hover_name="monetary_sanction",
109
+ width=1000, height=600)
110
+ st.plotly_chart(fig)
111
+
112
+ fig = px.histogram(monetary_decision, x="log10_monetary_sanction",
113
+ # y="log10_org_revenues",
114
+ color="same_country",
115
+ marginal="box", # or violin, rug
116
+ template="simple_white",
117
+ width=1000, height=600, nbins=40, opacity=0.5,
118
+ hover_data=monetary_decision.columns)
119
+
120
+ st.plotly_chart(fig)
121
+
122
+ fig = px.histogram(monetary_decision, x="log10_monetary_sanction_rate",
123
+ # y="log10_org_revenues",
124
+ color="same_country",
125
+ marginal="box", # or violin, rug
126
+ template="simple_white",
127
+ width=1000, height=600, nbins=40, opacity=0.5,
128
+ hover_data=monetary_decision.columns)
129
+
130
+ st.plotly_chart(fig)
131
+
132
+ p = scipy.stats.ks_2samp(monetary_decision[monetary_decision.same_country]['log10_monetary_sanction_rate'],
133
+ monetary_decision[~monetary_decision.same_country]['log10_monetary_sanction_rate']
134
+ , alternative='two-sided', mode='auto')
135
+
136
+ st.metric(label="p-value", value=f"{round(p.pvalue, 2)}%")
137
+
138
  st.subheader("Sum of monetary sanctions over time ")
139
  st.markdown("The graph shows the cumulated monetary sanction per year for each violation theme")
140
  chart_data = get_themes_per_year(monetary_decision)
 
146
  line_group="violation_theme",
147
  width=1000, height=600)
148
  st.plotly_chart(fig)
149
+
150
+ st.sidebar.title("Organizations view")
151
+
152
+ col_x = ['log10_org_revenues', 'authorities_country', 'violation_theme', 'org_country', 'org_company_type']
153
+
154
+ predictors, target = prepare_data(monetary_decision)
155
+
156
+ st.title("Training phase")
157
+ st.markdown("Plot taget distribution: log 10 of monetary sanctions")
158
+ fig = ff.create_distplot([target], [' log 10 of monetary sanctions'], bin_size=0.1)
159
+ fig.update_layout(width=1000,
160
+ template="simple_white",
161
+ height=600,
162
+ bargap=0.01)
163
+ st.plotly_chart(fig)
164
+
165
+ # split data set
166
+ predictors_train, predictors_test, target_train, target_test = split(predictors, target)
167
+
168
+ # train the model
169
+ xgb_model = run_training(predictors_train, target_train)
170
+
171
+ # evaluate model error
172
+ target_train_predicted = predict(xgb_model, predictors_train)
173
+ training_bias = np.mean(target_train_predicted - target_train)
174
+ st.metric(label="Training bias", value=training_bias)
175
+
176
+ target_test_predicted = predict(xgb_model, predictors_test)
177
+ test_errors = target_test_predicted - target_test
178
+ test_bias = np.mean(test_errors)
179
+ st.metric(label="Test bias", value=test_bias)
180
+
181
+ fig = ff.create_distplot([test_errors], ['errors distribution'], bin_size=0.1)
182
+ fig.update_layout(width=1000,
183
+ template="simple_white",
184
+ height=600,
185
+ bargap=0.01)
186
+ st.plotly_chart(fig)
187
+
188
+ st.subheader("Plot features importance for the trained model")
189
+ xgb_features_importance = features_importance(xgb_model)
190
+
191
+ fig = px.bar(xgb_features_importance,
192
+ orientation='h',
193
+ width=1000,
194
+ template="simple_white",
195
+ height=600,
196
+ )
197
+ st.plotly_chart(fig)
198
+
199
+ st.subheader("Plot predicted vs real")
200
+ import plotly.graph_objs as go
201
+
202
+ compare = pd.concat([pd.DataFrame({'target': target_test, 'predicted': target_test_predicted, 'sample': 'test'}),
203
+ pd.DataFrame({'target': target_train, 'predicted': target_train_predicted, 'sample': 'train'})])
204
+ fig = px.scatter(
205
+ compare,
206
+ x='predicted',
207
+ y='target',
208
+ color='sample',
209
+ marginal_y="violin",
210
+ width=1000,
211
+ template="simple_white",
212
+ height=600,
213
+ trendline="ols")
214
+
215
+ st.plotly_chart(fig)
216
+
217
+ sample_revenues = st.sidebar.number_input('Yearly revenues', value=1000000)
218
+ authority = st.sidebar.selectbox('Organization country', predictors.org_country.cat.categories)
219
+ authority = st.sidebar.selectbox('Organization activity', predictors.org_company_type.cat.categories)