MagicDash commited on
Commit
2f69b10
·
verified ·
1 Parent(s): a3e812d

Update webapp.py

Browse files
Files changed (1) hide show
  1. webapp.py +854 -856
webapp.py CHANGED
@@ -1,856 +1,854 @@
1
- import pandas as pd
2
- import seaborn as sns
3
- import matplotlib
4
- import matplotlib.pyplot as plt
5
- matplotlib.use('Agg')
6
- import numpy as np
7
- import google.generativeai as genai
8
- from PIL import Image
9
- from werkzeug.utils import secure_filename
10
- import os
11
- import json
12
- from fpdf import FPDF
13
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
- from fastapi.responses import HTMLResponse, FileResponse
15
- from fastapi.staticfiles import StaticFiles
16
- from fastapi.templating import Jinja2Templates
17
- from starlette.requests import Request
18
- from typing import List
19
- import textwrap
20
- from IPython.display import display, Markdown
21
- from PIL import Image
22
- import shutil
23
- from werkzeug.utils import secure_filename
24
- import urllib.parse
25
- import re
26
- from langchain_google_genai import ChatGoogleGenerativeAI
27
- from langchain_community.document_loaders import PyPDFLoader, UnstructuredCSVLoader, UnstructuredExcelLoader, Docx2txtLoader, UnstructuredPowerPointLoader
28
- from langchain.chains import StuffDocumentsChain
29
- from langchain.chains.llm import LLMChain
30
- from langchain.prompts import PromptTemplate
31
- from langchain.vectorstores import FAISS
32
- from langchain_google_genai import GoogleGenerativeAIEmbeddings
33
- from langchain.text_splitter import CharacterTextSplitter
34
-
35
- app = FastAPI()
36
- app.mount("/static", StaticFiles(directory="static"), name="static")
37
- templates = Jinja2Templates(directory="templates")
38
-
39
- sns.set_theme(color_codes=True)
40
- uploaded_df = None
41
- document_analyzed = False
42
- question_responses = []
43
-
44
-
45
- def format_text(text):
46
- # Replace **text** with <b>text</b>
47
- text = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', text)
48
- # Replace any remaining * with <br>
49
- text = text.replace('*', '<br>')
50
- return text
51
-
52
- def clean_data(df):
53
- # Step 1: Clean currency-related columns
54
- for col in df.columns:
55
- if any(x in col.lower() for x in ['value', 'price', 'cost', 'amount']):
56
- if df[col].dtype == 'object':
57
- df[col] = df[col].str.replace('$', '').str.replace('£', '').str.replace('€', '').replace('[^\d.-]', '', regex=True).astype(float)
58
-
59
- # Step 2: Drop columns with more than 25% missing values
60
- null_percentage = df.isnull().sum() / len(df)
61
- columns_to_drop = null_percentage[null_percentage > 0.25].index
62
- df.drop(columns=columns_to_drop, inplace=True)
63
-
64
- # Step 3: Fill missing values for remaining columns
65
- for col in df.columns:
66
- if df[col].isnull().sum() > 0:
67
- if null_percentage[col] <= 0.25:
68
- if df[col].dtype in ['float64', 'int64']:
69
- median_value = df[col].median()
70
- df[col].fillna(median_value, inplace=True)
71
-
72
- # Step 4: Convert object-type columns to lowercase
73
- for col in df.columns:
74
- if df[col].dtype == 'object':
75
- df[col] = df[col].str.lower()
76
-
77
- # Step 5: Drop columns with only one unique value
78
- unique_value_columns = [col for col in df.columns if df[col].nunique() == 1]
79
- df.drop(columns=unique_value_columns, inplace=True)
80
-
81
- return df
82
-
83
-
84
-
85
-
86
- def clean_data2(df):
87
- for col in df.columns:
88
- if 'value' in col or 'price' in col or 'cost' in col or 'amount' in col or 'Value' in col or 'Price' in col or 'Cost' in col or 'Amount' in col:
89
- if df[col].dtype == 'object':
90
- df[col] = df[col].str.replace('$', '')
91
- df[col] = df[col].str.replace('£', '')
92
- df[col] = df[col].str.replace('€', '')
93
- df[col] = df[col].replace('[^\d.-]', '', regex=True).astype(float)
94
-
95
- null_percentage = df.isnull().sum() / len(df)
96
-
97
- for col in df.columns:
98
- if df[col].isnull().sum() > 0:
99
- if null_percentage[col] <= 0.25:
100
- if df[col].dtype in ['float64', 'int64']:
101
- median_value = df[col].median()
102
- df[col].fillna(median_value, inplace=True)
103
-
104
- for col in df.columns:
105
- if df[col].dtype == 'object':
106
- df[col] = df[col].str.lower()
107
-
108
- return df
109
-
110
-
111
-
112
- def generate_plot(df, plot_path, plot_type):
113
- df = clean_data(df)
114
- excluded_words = ["name", "postal", "date", "phone", "address", "code", "id"]
115
-
116
- if plot_type == 'countplot':
117
- cat_vars = [col for col in df.select_dtypes(include='object').columns
118
- if all(word not in col.lower() for word in excluded_words) and df[col].nunique() > 1]
119
-
120
- for col in cat_vars:
121
- if df[col].nunique() > 10:
122
- top_categories = df[col].value_counts().index[:10]
123
- df[col] = df[col].apply(lambda x: x if x in top_categories else 'Other')
124
-
125
- num_cols = len(cat_vars)
126
- num_rows = (num_cols + 1) // 2
127
- fig, axs = plt.subplots(nrows=num_rows, ncols=2, figsize=(15, 5*num_rows))
128
- axs = axs.flatten()
129
-
130
- for i, var in enumerate(cat_vars):
131
- category_counts = df[var].value_counts()
132
- top_values = category_counts.index[:10][::-1]
133
- filtered_df = df.copy()
134
- filtered_df[var] = pd.Categorical(filtered_df[var], categories=top_values, ordered=True)
135
- sns.countplot(x=var, data=filtered_df, order=top_values, ax=axs[i])
136
- axs[i].set_title(var)
137
- axs[i].tick_params(axis='x', rotation=30)
138
-
139
- total = len(filtered_df[var])
140
- for p in axs[i].patches:
141
- height = p.get_height()
142
- axs[i].annotate(f'{height/total:.1%}', (p.get_x() + p.get_width() / 2., height), ha='center', va='bottom')
143
-
144
- sample_size = filtered_df.shape[0]
145
-
146
-
147
- for i in range(num_cols, len(axs)):
148
- fig.delaxes(axs[i])
149
-
150
- elif plot_type == 'histplot':
151
- num_vars = [col for col in df.select_dtypes(include=['int', 'float']).columns
152
- if all(word not in col.lower() for word in excluded_words)]
153
- num_cols = len(num_vars)
154
- num_rows = (num_cols + 2) // 3
155
- fig, axs = plt.subplots(nrows=num_rows, ncols=min(3, num_cols), figsize=(15, 5*num_rows))
156
- axs = axs.flatten()
157
-
158
- plot_index = 0
159
-
160
- for i, var in enumerate(num_vars):
161
- if len(df[var].unique()) == len(df):
162
- fig.delaxes(axs[plot_index])
163
- else:
164
- sns.histplot(df[var], ax=axs[plot_index], kde=True, stat="percent")
165
- axs[plot_index].set_title(var)
166
- axs[plot_index].set_xlabel('')
167
-
168
- sample_size = df.shape[0]
169
-
170
-
171
- plot_index += 1
172
-
173
- for i in range(plot_index, len(axs)):
174
- fig.delaxes(axs[i])
175
-
176
- fig.tight_layout()
177
- fig.savefig(plot_path)
178
- plt.close(fig)
179
- return plot_path
180
-
181
- @app.get("/", response_class=HTMLResponse)
182
- async def read_form(request: Request):
183
- return templates.TemplateResponse("upload.html", {"request": request})
184
-
185
- @app.post("/process/", response_class=HTMLResponse)
186
- async def process_file(request: Request, file: UploadFile = File(...)):
187
- global df, uploaded_file, document_analyzed, file_path, file_extension
188
- uploaded_file = file
189
- file_location = f"static/{file.filename}"
190
-
191
- # Save the uploaded file to the server
192
- with open(file_location, "wb") as buffer:
193
- shutil.copyfileobj(file.file, buffer)
194
-
195
- # Load DataFrame based on file type
196
- file_extension = os.path.splitext(file.filename)[1]
197
- if file_extension == '.csv':
198
- file_path = 'dataset.csv'
199
- df = pd.read_csv(file_location, delimiter=",")
200
- df.to_csv(file_path, index=False) # Save as dataset.csv
201
- elif file_extension == '.xlsx':
202
- file_path = 'dataset.xlsx'
203
- df = pd.read_excel(file_location)
204
- df.to_excel(file_path, index=False) # Save as dataset.xlsx
205
- else:
206
- raise HTTPException(status_code=415, detail="Unsupported file format")
207
-
208
- # Get columns of the DataFrame
209
- columns = df.columns.tolist()
210
-
211
- return templates.TemplateResponse("upload.html", {"request": request, "columns": columns})
212
-
213
-
214
- @app.post("/result")
215
- async def result(request: Request,
216
- target: str = Form(...),
217
- algorithm: str = Form(...)):
218
- global df, api
219
- global plot1_path, plot2_path, plot3_path, plot4_path, plot5_path, plot6_path, plot7_path, plot8_path, plot9_path, plot10_path, plot11_path
220
- global response1, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11
221
-
222
-
223
- api = "AIzaSyCFI6cTqFdS-mpZBfi7kxwygewtnuF7PfA"
224
- excluded_words = ["name", "postal", "date", "phone", "address", "id"]
225
-
226
- if df[target].dtype in ['float64', 'int64']:
227
- unique_values = df[target].nunique()
228
-
229
- # If unique values > 20, treat it as regression, else classification
230
- if unique_values > 20:
231
- method = "Regression"
232
- else:
233
- method = "Classification"
234
- else:
235
- # If the target is not numeric, treat it as classification
236
- method = "Classification"
237
-
238
-
239
-
240
- # Initialize response3 and plot3_path to None
241
- response3 = None
242
- plot3_path = None
243
- response4 = None
244
- plot4_path = None
245
- response6 = None
246
- plot6_path = None
247
- response8 = None # Initialize response8
248
- plot8_path = None # Initialize plot8_path
249
- response9 = None # Initialize response9
250
- plot9_path = None # Initialize plot9_path
251
- response10 = None # Initialize response8
252
- plot10_path = None # Initialize plot8_path
253
- response11 = None # Initialize response9
254
- plot11_path = None # Initialize plot9_path
255
-
256
- if method == "Classification":
257
- cat_vars = [col for col in df.select_dtypes(include=['object']).columns
258
- if all(word not in col.lower() for word in excluded_words)]
259
-
260
- # Exclude the target variable from the list if it exists in cat_vars
261
- if target in cat_vars:
262
- cat_vars.remove(target)
263
-
264
- # Create a figure with subplots, but only include the required number of subplots
265
- num_cols = len(cat_vars)
266
- num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
267
- fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
268
- axs = axs.flatten()
269
-
270
- # Create a count plot for each categorical variable
271
- for i, var in enumerate(cat_vars):
272
- top_categories = df[var].value_counts().nlargest(5).index
273
- filtered_df = df[df[var].notnull() & df[var].isin(top_categories)] # Exclude rows with NaN values in the variable
274
-
275
- # Replace less frequent categories with "Other" if there are more than 5 unique values
276
- if df[var].nunique() > 5:
277
- other_categories = df[var].value_counts().index[5:]
278
- filtered_df[var] = filtered_df[var].apply(lambda x: x if x in top_categories else 'Other')
279
-
280
- sns.countplot(x=var, hue=target, stat="percent", data=filtered_df, ax=axs[i])
281
- axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=45)
282
-
283
- # Change y-axis label to represent percentage
284
- axs[i].set_ylabel('Percentage')
285
-
286
- # Annotate the subplot with sample size
287
- sample_size = df.shape[0]
288
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
289
-
290
- # Remove any remaining blank subplots
291
- for i in range(num_cols, len(axs)):
292
- fig.delaxes(axs[i])
293
-
294
- plt.xticks(rotation=45)
295
- plt.tight_layout()
296
- plot3_path = "static/multiclass_barplot.png"
297
- plt.savefig(plot3_path)
298
- plt.close(fig)
299
-
300
- #response 3
301
- def to_markdown(text):
302
- text = text.replace('•', ' *')
303
- return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
304
-
305
- genai.configure(api_key=api)
306
-
307
- import PIL.Image
308
-
309
- img = PIL.Image.open("static/multiclass_barplot.png")
310
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
311
- #response = model.generate_content(img)
312
- response = model.generate_content(["As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
313
- response.resolve()
314
- response3 = format_text(response.text)
315
-
316
-
317
- if method == "Classification":
318
- # Generate Multiclass Pairplot
319
- pairplot_fig = sns.pairplot(df, hue=target)
320
- plot6_path = "static/pair1.png" # Use plot6_path
321
- pairplot_fig.savefig(plot6_path) # Save the pairplot as a PNG file
322
-
323
-
324
- # Google Gemini Integration
325
- genai.configure(api_key=api)
326
- img = PIL.Image.open(plot6_path)
327
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
328
-
329
- # Generate response based on the pairplot
330
- response = model.generate_content([
331
- "You are a professional Data Analyst, write the complete conclusion and actionable insight based on the image. Explain it by points.",
332
- img
333
- ])
334
- response.resolve()
335
-
336
- # Assign the response to response6
337
- response6 = format_text(response.text)
338
-
339
- # Include response6 and plot6_path in the data dictionary to be passed to the template
340
-
341
-
342
- if method == "Classification":
343
- # Multiclass Histplot
344
- # Get the names of all columns with data type 'object' (categorical columns)
345
- cat_cols = df.columns.tolist()
346
-
347
- # Get the names of all columns with data type 'int'
348
- int_vars = df.select_dtypes(include=['int', 'float']).columns.tolist()
349
- int_vars = [col for col in int_vars if col != target]
350
-
351
- # Create a figure with subplots
352
- num_cols = len(int_vars)
353
- num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
354
- fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
355
- axs = axs.flatten()
356
-
357
- # Create a histogram for each integer variable with hue='Attrition'
358
- for i, var in enumerate(int_vars):
359
- top_categories = df[var].value_counts().nlargest(10).index
360
- filtered_df = df[df[var].notnull() & df[var].isin(top_categories)]
361
- sns.histplot(data=df, x=var, hue=target, kde=True, ax=axs[i], stat="percent")
362
- axs[i].set_title(var)
363
-
364
- # Annotate the subplot with sample size
365
- sample_size = df.shape[0]
366
- axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
367
-
368
- # Remove any extra empty subplots if needed
369
- if num_cols < len(axs):
370
- for i in range(num_cols, len(axs)):
371
- fig.delaxes(axs[i])
372
-
373
- # Adjust spacing between subplots
374
- fig.tight_layout()
375
- plt.xticks(rotation=45)
376
- plot4_path = "static/multiclass_histplot.png"
377
- plt.savefig(plot4_path)
378
- plt.close(fig)
379
-
380
- #response 4
381
- def to_markdown(text):
382
- text = text.replace('•', ' *')
383
- return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
384
-
385
- genai.configure(api_key=api)
386
-
387
- import PIL.Image
388
-
389
- img = PIL.Image.open("static/multiclass_histplot.png")
390
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
391
- response4 = model.generate_content(img)
392
- response4 = model.generate_content(["As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
393
- response4.resolve()
394
- response4 = format_text(response4.text)
395
-
396
-
397
-
398
-
399
-
400
- # Generate Pairplot
401
- pairplot_fig = sns.pairplot(df)
402
- plot5_path = "static/pair2.png"
403
- pairplot_fig.savefig(plot5_path) # Save the pairplot as a PNG file
404
-
405
- # Google Gemini Integration
406
- genai.configure(api_key=api)
407
- img = PIL.Image.open(plot5_path)
408
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
409
-
410
- # Generate response based on the pairplot
411
- response = model.generate_content([
412
- "You are a professional Data Analyst, write the complete conclusion and actionable insight based on the image. Explain it by points.",
413
- img
414
- ])
415
- response.resolve()
416
-
417
- # Assign the response to response5
418
- response5 = format_text(response.text)
419
-
420
- def generate_gemini_response(plot_path):
421
-
422
-
423
- genai.configure(api_key=api)
424
- img = Image.open(plot_path)
425
- model = genai.GenerativeModel('gemini-1.5-flash-latest')
426
- response = model.generate_content([
427
- " As a marketing consultant, I want to understand consumer insights based on the chart and the market context so I can use the key findings to formulate actionable insights",
428
- img
429
- ])
430
- response.resolve()
431
- return response.text
432
-
433
- plot1_path = generate_plot(df, 'static/plot1.png', 'countplot')
434
- plot2_path = generate_plot(df, 'static/plot2.png', 'histplot')
435
-
436
- response1 = format_text((generate_gemini_response(plot1_path)))
437
- response2 = format_text((generate_gemini_response(plot2_path)))
438
-
439
- from sklearn import preprocessing
440
- for col in df.select_dtypes(include=['object']).columns:
441
-
442
- # Initialize a LabelEncoder object
443
- label_encoder = preprocessing.LabelEncoder()
444
-
445
- # Fit the encoder to the unique values in the column
446
- label_encoder.fit(df[col].unique())
447
-
448
- # Transform the column using the encoder
449
- df[col] = label_encoder.transform(df[col])
450
-
451
-
452
- # Display Correlation Heatmap
453
- plot7_path = "static/correlation_matrix.png"
454
- fig, ax = plt.subplots(figsize=(30, 24))
455
- correlation_matrix = df.corr()
456
- sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
457
- plt.savefig(plot7_path)
458
- plt.close(fig)
459
-
460
- img = PIL.Image.open(plot7_path)
461
- response7 = format_text((generate_gemini_response(plot7_path)))
462
-
463
-
464
-
465
-
466
-
467
- X = df.drop(target, axis=1)
468
- y = df[target]
469
- from sklearn.model_selection import train_test_split
470
- from sklearn.metrics import accuracy_score
471
- X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2,random_state=0)
472
-
473
- from scipy import stats
474
- threshold = 3
475
-
476
- for col in X_train.columns:
477
- if X_train[col].nunique() > 20:
478
- # Calculate Z-scores for the column
479
- z_scores = np.abs(stats.zscore(X_train[col]))
480
- # Find and remove outliers based on the threshold
481
- outlier_indices = np.where(z_scores > threshold)[0]
482
- X_train = X_train.drop(X_train.index[outlier_indices])
483
- y_train = y_train.drop(y_train.index[outlier_indices])
484
-
485
-
486
-
487
-
488
- from sklearn.tree import DecisionTreeRegressor
489
- from sklearn.tree import DecisionTreeClassifier
490
- from sklearn.model_selection import GridSearchCV
491
- from sklearn import metrics
492
- from sklearn.metrics import mean_absolute_percentage_error
493
- import math
494
-
495
-
496
- if algorithm == "Decision Tree":
497
-
498
- if method == "Regression":
499
- dtree = DecisionTreeRegressor()
500
- param_grid = {
501
- 'max_depth': [4, 6, 8],
502
- 'min_samples_split': [4, 6, 8],
503
- 'min_samples_leaf': [1, 2, 3, 4],
504
- 'random_state': [0, 42],
505
- 'max_features': ['auto', 'sqrt', 'log2']
506
- }
507
- grid_search = GridSearchCV(dtree, param_grid, cv=5, scoring='neg_mean_squared_error')
508
- grid_search.fit(X_train, y_train)
509
- best_params = grid_search.best_params_
510
- dtree = DecisionTreeRegressor(**best_params)
511
- dtree.fit(X_train, y_train)
512
-
513
- y_pred = dtree.predict(X_test)
514
- mae = metrics.mean_absolute_error(y_test, y_pred)
515
- mse = metrics.mean_squared_error(y_test, y_pred)
516
- r2 = metrics.r2_score(y_test, y_pred)
517
- rmse = np.sqrt(mse)
518
-
519
- # Feature importance visualization
520
- imp_df = pd.DataFrame({
521
- "Feature Name": X_train.columns,
522
- "Importance": dtree.feature_importances_
523
- })
524
- fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
525
- fig, ax = plt.subplots(figsize=(10, 8))
526
- sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
527
- ax.set_title('Top 10 Feature Importance (Decision Tree Regressor)', fontsize=18)
528
- plot8_path = "static/dtree_regressor.png"
529
- plt.savefig(plot8_path)
530
- img = PIL.Image.open(plot8_path)
531
- response8 = format_text((generate_gemini_response(plot8_path)))
532
-
533
-
534
- elif method == "Classification":
535
- dtree = DecisionTreeClassifier()
536
- param_grid = {
537
- 'max_depth': [3, 4, 5, 6, 7],
538
- 'min_samples_split': [2, 3, 4],
539
- 'min_samples_leaf': [1, 2, 3],
540
- 'random_state': [0, 42]
541
- }
542
- grid_search = GridSearchCV(dtree, param_grid, cv=5)
543
- grid_search.fit(X_train, y_train)
544
- best_params = grid_search.best_params_
545
- dtree = DecisionTreeClassifier(**best_params)
546
- dtree.fit(X_train, y_train)
547
-
548
- y_pred = dtree.predict(X_test)
549
- acc = metrics.accuracy_score(y_test, y_pred)
550
- f1 = metrics.f1_score(y_test, y_pred, average='micro')
551
- prec = metrics.precision_score(y_test, y_pred, average='micro')
552
- recall = metrics.recall_score(y_test, y_pred, average='micro')
553
-
554
- # Feature importance visualization
555
- imp_df = pd.DataFrame({
556
- "Feature Name": X_train.columns,
557
- "Importance": dtree.feature_importances_
558
- })
559
- fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
560
- fig, ax = plt.subplots(figsize=(10, 8))
561
- sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
562
- ax.set_title('Top 10 Feature Importance (Decision Tree Classifier)', fontsize=18)
563
- plot9_path = "static/dtree_classifier.png"
564
- plt.savefig(plot9_path)
565
- img = PIL.Image.open(plot9_path)
566
- response9 = format_text((generate_gemini_response(plot9_path)))
567
-
568
-
569
-
570
- from sklearn.ensemble import RandomForestRegressor
571
- from sklearn.ensemble import RandomForestClassifier
572
-
573
- if algorithm == "Random Forest":
574
-
575
- if method == "Regression":
576
- rf = RandomForestRegressor()
577
- param_grid = {
578
- 'max_depth': [4, 6, 8],
579
- 'random_state': [0, 42],
580
- 'max_features': ['auto', 'sqrt', 'log2']
581
- }
582
- grid_search = GridSearchCV(rf, param_grid, cv=5, scoring='neg_mean_squared_error')
583
- grid_search.fit(X_train, y_train)
584
- best_params = grid_search.best_params_
585
- rf = RandomForestRegressor(**best_params)
586
- rf.fit(X_train, y_train)
587
-
588
- y_pred = rf.predict(X_test)
589
- mae = metrics.mean_absolute_error(y_test, y_pred)
590
- mse = metrics.mean_squared_error(y_test, y_pred)
591
- r2 = metrics.r2_score(y_test, y_pred)
592
- rmse = np.sqrt(mse)
593
-
594
- # Feature importance visualization
595
- imp_df = pd.DataFrame({
596
- "Feature Name": X_train.columns,
597
- "Importance": rf.feature_importances_
598
- })
599
- fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
600
- fig, ax = plt.subplots(figsize=(10, 8))
601
- sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
602
- ax.set_title('Top 10 Feature Importance (Random Forest Regressor)', fontsize=18)
603
- plot10_path = "static/rf_regressor.png"
604
- plt.savefig(plot10_path)
605
- img = PIL.Image.open(plot10_path)
606
- response10 = format_text((generate_gemini_response(plot10_path)))
607
-
608
- elif method == "Classification":
609
- rf = RandomForestClassifier()
610
- param_grid = {
611
- 'max_depth': [3, 4, 5, 6],
612
- 'random_state': [0, 42]
613
- }
614
- grid_search = GridSearchCV(rf, param_grid, cv=5)
615
- grid_search.fit(X_train, y_train)
616
- best_params = grid_search.best_params_
617
- rf = RandomForestClassifier(**best_params)
618
- rf.fit(X_train, y_train)
619
-
620
- y_pred = rf.predict(X_test)
621
- acc = metrics.accuracy_score(y_test, y_pred)
622
- f1 = metrics.f1_score(y_test, y_pred, average='micro')
623
- prec = metrics.precision_score(y_test, y_pred, average='micro')
624
- recall = metrics.recall_score(y_test, y_pred, average='micro')
625
-
626
- # Feature importance visualization
627
- imp_df = pd.DataFrame({
628
- "Feature Name": X_train.columns,
629
- "Importance": rf.feature_importances_
630
- })
631
- fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
632
- fig, ax = plt.subplots(figsize=(10, 8))
633
- sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
634
- ax.set_title('Top 10 Feature Importance (Random Forest Classifier)', fontsize=18)
635
- plot11_path = "static/rf_classifier.png"
636
- plt.savefig(plot11_path)
637
- img = PIL.Image.open(plot11_path)
638
- response11 = format_text((generate_gemini_response(plot11_path)))
639
-
640
-
641
-
642
- document_analyzed = True
643
-
644
-
645
-
646
- data = {
647
- "request": request,
648
- "response1": response1,
649
- "response2": response2,
650
- "response5": response5,
651
- "response7": response7,
652
- "plot1_path": plot1_path,
653
- "plot2_path": plot2_path,
654
- "plot5_path": plot5_path,
655
- "plot7_path": plot7_path,
656
- "show_conversation": document_analyzed,
657
- "question_responses": question_responses
658
- }
659
-
660
- # Conditionally include response3 and plot3_path if they exist
661
- if response3:
662
- data["response3"] = response3
663
- if plot3_path:
664
- data["plot3_path"] = plot3_path
665
- if response4:
666
- data["response4"] = response3
667
- if plot4_path:
668
- data["plot4_path"] = plot4_path
669
- if response6:
670
- data["response6"] = response6
671
- if plot6_path:
672
- data["plot6_path"] = plot6_path
673
- if response8:
674
- data["response8"] = response8
675
- if plot8_path:
676
- data["plot8_path"] = plot8_path
677
- if response9:
678
- data["response9"] = response9
679
- if plot9_path:
680
- data["plot9_path"] = plot9_path
681
- if response10:
682
- data["response10"] = response10
683
- if plot10_path:
684
- data["plot10_path"] = plot10_path
685
- if response11:
686
- data["response11"] = response11
687
- if plot11_path:
688
- data["plot11_path"] = plot11_path
689
-
690
- return templates.TemplateResponse("upload.html", data)
691
-
692
-
693
-
694
-
695
- # Route for asking questions
696
- @app.post("/ask", response_class=HTMLResponse)
697
- async def ask_question(request: Request, question: str = Form(...)):
698
- global file_extension, question_responses, api
699
- global plot1_path, plot2_path, plot3_path, plot4_path, plot5_path, plot6_path, plot7_path, plot8_path, plot9_path, plot10_path, plot11_path
700
- global response1, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11
701
- global document_analyzed
702
-
703
- # Check if a file has been uploaded
704
- if not file_extension:
705
- raise HTTPException(status_code=400, detail="No file has been uploaded yet.")
706
-
707
- # Initialize the LLM model
708
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=api)
709
-
710
- # Determine the file extension and select the appropriate loader
711
- file_path = ''
712
- loader = None
713
-
714
- if file_extension.endswith('.csv'):
715
- file_path = 'dataset.csv'
716
- loader = UnstructuredCSVLoader(file_path, mode="elements")
717
- elif file_extension.endswith('.xlsx'):
718
- file_path = 'dataset.xlsx'
719
- loader = UnstructuredExcelLoader(file_path, mode="elements")
720
- else:
721
- raise HTTPException(status_code=400, detail="Unsupported file format")
722
-
723
- # Load and process the document
724
- try:
725
- docs = loader.load()
726
- except Exception as e:
727
- raise HTTPException(status_code=500, detail=f"Error loading document: {str(e)}")
728
-
729
- # Combine document text
730
- text = "\n".join([doc.page_content for doc in docs])
731
- os.environ["GOOGLE_API_KEY"] = api
732
-
733
- # Initialize embeddings and create FAISS vector store
734
- embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
735
- text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
736
- chunks = text_splitter.split_text(text)
737
- document_search = FAISS.from_texts(chunks, embeddings)
738
-
739
- # Generate query embedding and perform similarity search
740
- query_embedding = embeddings.embed_query(question)
741
- results = document_search.similarity_search_by_vector(query_embedding, k=3)
742
-
743
- if results:
744
- retrieved_texts = " ".join([result.page_content for result in results])
745
-
746
- # Define the Summarize Chain for the question
747
- latest_response = "" if not question_responses else question_responses[-1][1]
748
- template1 = (
749
- f"{question} Answer the question based on the following:\n\"{text}\"\n:" +
750
- (f" Answer the Question with only 3 sentences. Latest conversation: {latest_response}" if latest_response else "")
751
- )
752
- prompt1 = PromptTemplate.from_template(template1)
753
-
754
- # Initialize the LLMChain with the prompt
755
- llm_chain1 = LLMChain(llm=llm, prompt=prompt1)
756
-
757
- # Invoke the chain to get the summary
758
- try:
759
- response_chain = llm_chain1.invoke({"text": text})
760
- summary1 = response_chain["text"]
761
- except Exception as e:
762
- raise HTTPException(status_code=500, detail=f"Error invoking LLMChain: {str(e)}")
763
-
764
- # Generate embeddings for the summary
765
- try:
766
- summary_embedding = embeddings.embed_query(summary1)
767
- document_search = FAISS.from_texts([summary1], embeddings)
768
- except Exception as e:
769
- raise HTTPException(status_code=500, detail=f"Error generating embeddings: {str(e)}")
770
-
771
- # Perform a search on the FAISS vector database
772
- try:
773
- if document_search:
774
- query_embedding = embeddings.embed_query(question)
775
- results = document_search.similarity_search_by_vector(query_embedding, k=1)
776
-
777
- if results:
778
- current_response = format_text(results[0].page_content)
779
- else:
780
- current_response = "No matching document found in the database."
781
- else:
782
- current_response = "Vector database not initialized."
783
- except Exception as e:
784
- raise HTTPException(status_code=500, detail=f"Error during similarity search: {str(e)}")
785
- else:
786
- current_response = "No relevant results found."
787
-
788
- # Append the question and response from FAISS search
789
- current_question = f"You asked: {question}"
790
- question_responses.append((current_question, current_response))
791
-
792
- # Save all results to output_summary.json
793
- save_to_json(question_responses)
794
-
795
-
796
-
797
- data = {
798
- "request": request,
799
- "response1": response1,
800
- "response2": response2,
801
- "response5": response5,
802
- "response7": response7,
803
- "plot1_path": plot1_path,
804
- "plot2_path": plot2_path,
805
- "plot5_path": plot5_path,
806
- "plot7_path": plot7_path,
807
- "show_conversation": True,
808
- "question_responses": question_responses
809
- }
810
-
811
- # Conditionally include response3 and plot3_path if they exist
812
- if response3:
813
- data["response3"] = response3
814
- if plot3_path:
815
- data["plot3_path"] = plot3_path
816
- if response4:
817
- data["response4"] = response3
818
- if plot4_path:
819
- data["plot4_path"] = plot4_path
820
- if response6:
821
- data["response6"] = response6
822
- if plot6_path:
823
- data["plot6_path"] = plot6_path
824
- if response8:
825
- data["response8"] = response8
826
- if plot8_path:
827
- data["plot8_path"] = plot8_path
828
- if response9:
829
- data["response9"] = response9
830
- if plot9_path:
831
- data["plot9_path"] = plot9_path
832
- if response10:
833
- data["response10"] = response10
834
- if plot10_path:
835
- data["plot10_path"] = plot10_path
836
- if response11:
837
- data["response11"] = response11
838
- if plot11_path:
839
- data["plot11_path"] = plot11_path
840
-
841
- return templates.TemplateResponse("upload.html", data)
842
-
843
-
844
-
845
- def save_to_json(question_responses):
846
- outputs = {
847
- "question_responses": question_responses
848
- }
849
- with open("output_summary.json", "w") as outfile:
850
- json.dump(outputs, outfile)
851
-
852
-
853
-
854
- if __name__ == "__main__":
855
- import uvicorn
856
- uvicorn.run(app, host="127.0.0.1", port=8000)
 
1
+ import pandas as pd
2
+ import seaborn as sns
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ matplotlib.use('Agg')
6
+ import numpy as np
7
+ import google.generativeai as genai
8
+ from PIL import Image
9
+ from werkzeug.utils import secure_filename
10
+ import os
11
+ import json
12
+ from fpdf import FPDF
13
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
14
+ from fastapi.responses import HTMLResponse, FileResponse
15
+ from fastapi.staticfiles import StaticFiles
16
+ from fastapi.templating import Jinja2Templates
17
+ from starlette.requests import Request
18
+ from typing import List
19
+ import textwrap
20
+ from IPython.display import display, Markdown
21
+ from PIL import Image
22
+ import shutil
23
+ from werkzeug.utils import secure_filename
24
+ import urllib.parse
25
+ import re
26
+ from langchain_google_genai import ChatGoogleGenerativeAI
27
+ from langchain_community.document_loaders import PyPDFLoader, UnstructuredCSVLoader, UnstructuredExcelLoader, Docx2txtLoader, UnstructuredPowerPointLoader
28
+ from langchain.chains import StuffDocumentsChain
29
+ from langchain.chains.llm import LLMChain
30
+ from langchain.prompts import PromptTemplate
31
+ from langchain.vectorstores import FAISS
32
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
33
+ from langchain.text_splitter import CharacterTextSplitter
34
+
35
+ app = FastAPI()
36
+ app.mount("/static", StaticFiles(directory="static"), name="static")
37
+ templates = Jinja2Templates(directory="templates")
38
+
39
+ sns.set_theme(color_codes=True)
40
+ uploaded_df = None
41
+ document_analyzed = False
42
+ question_responses = []
43
+
44
+
45
+ def format_text(text):
46
+ # Replace **text** with <b>text</b>
47
+ text = re.sub(r'\*\*(.*?)\*\*', r'<b>\1</b>', text)
48
+ # Replace any remaining * with <br>
49
+ text = text.replace('*', '<br>')
50
+ return text
51
+
52
+ def clean_data(df):
53
+ # Step 1: Clean currency-related columns
54
+ for col in df.columns:
55
+ if any(x in col.lower() for x in ['value', 'price', 'cost', 'amount']):
56
+ if df[col].dtype == 'object':
57
+ df[col] = df[col].str.replace('$', '').str.replace('£', '').str.replace('€', '').replace('[^\d.-]', '', regex=True).astype(float)
58
+
59
+ # Step 2: Drop columns with more than 25% missing values
60
+ null_percentage = df.isnull().sum() / len(df)
61
+ columns_to_drop = null_percentage[null_percentage > 0.25].index
62
+ df.drop(columns=columns_to_drop, inplace=True)
63
+
64
+ # Step 3: Fill missing values for remaining columns
65
+ for col in df.columns:
66
+ if df[col].isnull().sum() > 0:
67
+ if null_percentage[col] <= 0.25:
68
+ if df[col].dtype in ['float64', 'int64']:
69
+ median_value = df[col].median()
70
+ df[col].fillna(median_value, inplace=True)
71
+
72
+ # Step 4: Convert object-type columns to lowercase
73
+ for col in df.columns:
74
+ if df[col].dtype == 'object':
75
+ df[col] = df[col].str.lower()
76
+
77
+ # Step 5: Drop columns with only one unique value
78
+ unique_value_columns = [col for col in df.columns if df[col].nunique() == 1]
79
+ df.drop(columns=unique_value_columns, inplace=True)
80
+
81
+ return df
82
+
83
+
84
+
85
+
86
+ def clean_data2(df):
87
+ for col in df.columns:
88
+ if 'value' in col or 'price' in col or 'cost' in col or 'amount' in col or 'Value' in col or 'Price' in col or 'Cost' in col or 'Amount' in col:
89
+ if df[col].dtype == 'object':
90
+ df[col] = df[col].str.replace('$', '')
91
+ df[col] = df[col].str.replace('£', '')
92
+ df[col] = df[col].str.replace('€', '')
93
+ df[col] = df[col].replace('[^\d.-]', '', regex=True).astype(float)
94
+
95
+ null_percentage = df.isnull().sum() / len(df)
96
+
97
+ for col in df.columns:
98
+ if df[col].isnull().sum() > 0:
99
+ if null_percentage[col] <= 0.25:
100
+ if df[col].dtype in ['float64', 'int64']:
101
+ median_value = df[col].median()
102
+ df[col].fillna(median_value, inplace=True)
103
+
104
+ for col in df.columns:
105
+ if df[col].dtype == 'object':
106
+ df[col] = df[col].str.lower()
107
+
108
+ return df
109
+
110
+
111
+
112
+ def generate_plot(df, plot_path, plot_type):
113
+ df = clean_data(df)
114
+ excluded_words = ["name", "postal", "date", "phone", "address", "code", "id"]
115
+
116
+ if plot_type == 'countplot':
117
+ cat_vars = [col for col in df.select_dtypes(include='object').columns
118
+ if all(word not in col.lower() for word in excluded_words) and df[col].nunique() > 1]
119
+
120
+ for col in cat_vars:
121
+ if df[col].nunique() > 10:
122
+ top_categories = df[col].value_counts().index[:10]
123
+ df[col] = df[col].apply(lambda x: x if x in top_categories else 'Other')
124
+
125
+ num_cols = len(cat_vars)
126
+ num_rows = (num_cols + 1) // 2
127
+ fig, axs = plt.subplots(nrows=num_rows, ncols=2, figsize=(15, 5*num_rows))
128
+ axs = axs.flatten()
129
+
130
+ for i, var in enumerate(cat_vars):
131
+ category_counts = df[var].value_counts()
132
+ top_values = category_counts.index[:10][::-1]
133
+ filtered_df = df.copy()
134
+ filtered_df[var] = pd.Categorical(filtered_df[var], categories=top_values, ordered=True)
135
+ sns.countplot(x=var, data=filtered_df, order=top_values, ax=axs[i])
136
+ axs[i].set_title(var)
137
+ axs[i].tick_params(axis='x', rotation=30)
138
+
139
+ total = len(filtered_df[var])
140
+ for p in axs[i].patches:
141
+ height = p.get_height()
142
+ axs[i].annotate(f'{height/total:.1%}', (p.get_x() + p.get_width() / 2., height), ha='center', va='bottom')
143
+
144
+ sample_size = filtered_df.shape[0]
145
+
146
+
147
+ for i in range(num_cols, len(axs)):
148
+ fig.delaxes(axs[i])
149
+
150
+ elif plot_type == 'histplot':
151
+ num_vars = [col for col in df.select_dtypes(include=['int', 'float']).columns
152
+ if all(word not in col.lower() for word in excluded_words)]
153
+ num_cols = len(num_vars)
154
+ num_rows = (num_cols + 2) // 3
155
+ fig, axs = plt.subplots(nrows=num_rows, ncols=min(3, num_cols), figsize=(15, 5*num_rows))
156
+ axs = axs.flatten()
157
+
158
+ plot_index = 0
159
+
160
+ for i, var in enumerate(num_vars):
161
+ if len(df[var].unique()) == len(df):
162
+ fig.delaxes(axs[plot_index])
163
+ else:
164
+ sns.histplot(df[var], ax=axs[plot_index], kde=True, stat="percent")
165
+ axs[plot_index].set_title(var)
166
+ axs[plot_index].set_xlabel('')
167
+
168
+ sample_size = df.shape[0]
169
+
170
+
171
+ plot_index += 1
172
+
173
+ for i in range(plot_index, len(axs)):
174
+ fig.delaxes(axs[i])
175
+
176
+ fig.tight_layout()
177
+ fig.savefig(plot_path)
178
+ plt.close(fig)
179
+ return plot_path
180
+
181
+ @app.get("/", response_class=HTMLResponse)
182
+ async def read_form(request: Request):
183
+ return templates.TemplateResponse("upload.html", {"request": request})
184
+
185
+ @app.post("/process/", response_class=HTMLResponse)
186
+ async def process_file(request: Request, file: UploadFile = File(...)):
187
+ global df, uploaded_file, document_analyzed, file_path, file_extension
188
+ uploaded_file = file
189
+ file_location = f"static/{file.filename}"
190
+
191
+ # Save the uploaded file to the server
192
+ with open(file_location, "wb") as buffer:
193
+ shutil.copyfileobj(file.file, buffer)
194
+
195
+ # Load DataFrame based on file type
196
+ file_extension = os.path.splitext(file.filename)[1]
197
+ if file_extension == '.csv':
198
+ file_path = 'dataset.csv'
199
+ df = pd.read_csv(file_location, delimiter=",")
200
+ df.to_csv(file_path, index=False) # Save as dataset.csv
201
+ elif file_extension == '.xlsx':
202
+ file_path = 'dataset.xlsx'
203
+ df = pd.read_excel(file_location)
204
+ df.to_excel(file_path, index=False) # Save as dataset.xlsx
205
+ else:
206
+ raise HTTPException(status_code=415, detail="Unsupported file format")
207
+
208
+ # Get columns of the DataFrame
209
+ columns = df.columns.tolist()
210
+
211
+ return templates.TemplateResponse("upload.html", {"request": request, "columns": columns})
212
+
213
+
214
+ @app.post("/result")
215
+ async def result(request: Request,
216
+ target: str = Form(...),
217
+ algorithm: str = Form(...)):
218
+ global df, api
219
+ global plot1_path, plot2_path, plot3_path, plot4_path, plot5_path, plot6_path, plot7_path, plot8_path, plot9_path, plot10_path, plot11_path
220
+ global response1, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11
221
+
222
+
223
+ api = "AIzaSyCFI6cTqFdS-mpZBfi7kxwygewtnuF7PfA"
224
+ excluded_words = ["name", "postal", "date", "phone", "address", "id"]
225
+
226
+ if df[target].dtype in ['float64', 'int64']:
227
+ unique_values = df[target].nunique()
228
+
229
+ # If unique values > 20, treat it as regression, else classification
230
+ if unique_values > 20:
231
+ method = "Regression"
232
+ else:
233
+ method = "Classification"
234
+ else:
235
+ # If the target is not numeric, treat it as classification
236
+ method = "Classification"
237
+
238
+
239
+
240
+ # Initialize response3 and plot3_path to None
241
+ response3 = None
242
+ plot3_path = None
243
+ response4 = None
244
+ plot4_path = None
245
+ response6 = None
246
+ plot6_path = None
247
+ response8 = None # Initialize response8
248
+ plot8_path = None # Initialize plot8_path
249
+ response9 = None # Initialize response9
250
+ plot9_path = None # Initialize plot9_path
251
+ response10 = None # Initialize response8
252
+ plot10_path = None # Initialize plot8_path
253
+ response11 = None # Initialize response9
254
+ plot11_path = None # Initialize plot9_path
255
+
256
+ if method == "Classification":
257
+ cat_vars = [col for col in df.select_dtypes(include=['object']).columns
258
+ if all(word not in col.lower() for word in excluded_words)]
259
+
260
+ # Exclude the target variable from the list if it exists in cat_vars
261
+ if target in cat_vars:
262
+ cat_vars.remove(target)
263
+
264
+ # Create a figure with subplots, but only include the required number of subplots
265
+ num_cols = len(cat_vars)
266
+ num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
267
+ fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
268
+ axs = axs.flatten()
269
+
270
+ # Create a count plot for each categorical variable
271
+ for i, var in enumerate(cat_vars):
272
+ top_categories = df[var].value_counts().nlargest(5).index
273
+ filtered_df = df[df[var].notnull() & df[var].isin(top_categories)] # Exclude rows with NaN values in the variable
274
+
275
+ # Replace less frequent categories with "Other" if there are more than 5 unique values
276
+ if df[var].nunique() > 5:
277
+ other_categories = df[var].value_counts().index[5:]
278
+ filtered_df[var] = filtered_df[var].apply(lambda x: x if x in top_categories else 'Other')
279
+
280
+ sns.countplot(x=var, hue=target, stat="percent", data=filtered_df, ax=axs[i])
281
+ axs[i].set_xticklabels(axs[i].get_xticklabels(), rotation=45)
282
+
283
+ # Change y-axis label to represent percentage
284
+ axs[i].set_ylabel('Percentage')
285
+
286
+ # Annotate the subplot with sample size
287
+ sample_size = df.shape[0]
288
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
289
+
290
+ # Remove any remaining blank subplots
291
+ for i in range(num_cols, len(axs)):
292
+ fig.delaxes(axs[i])
293
+
294
+ plt.xticks(rotation=45)
295
+ plt.tight_layout()
296
+ plot3_path = "static/multiclass_barplot.png"
297
+ plt.savefig(plot3_path)
298
+ plt.close(fig)
299
+
300
+ #response 3
301
+ def to_markdown(text):
302
+ text = text.replace('•', ' *')
303
+ return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
304
+
305
+ genai.configure(api_key=api)
306
+
307
+ import PIL.Image
308
+
309
+ img = PIL.Image.open("static/multiclass_barplot.png")
310
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
311
+ #response = model.generate_content(img)
312
+ response = model.generate_content(["As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
313
+ response.resolve()
314
+ response3 = format_text(response.text)
315
+
316
+
317
+ if method == "Classification":
318
+ # Generate Multiclass Pairplot
319
+ pairplot_fig = sns.pairplot(df, hue=target)
320
+ plot6_path = "static/pair1.png" # Use plot6_path
321
+ pairplot_fig.savefig(plot6_path) # Save the pairplot as a PNG file
322
+
323
+
324
+ # Google Gemini Integration
325
+ genai.configure(api_key=api)
326
+ img = PIL.Image.open(plot6_path)
327
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
328
+
329
+ # Generate response based on the pairplot
330
+ response = model.generate_content([
331
+ "You are a professional Data Analyst, write the complete conclusion and actionable insight based on the image. Explain it by points.",
332
+ img
333
+ ])
334
+ response.resolve()
335
+
336
+ # Assign the response to response6
337
+ response6 = format_text(response.text)
338
+
339
+ # Include response6 and plot6_path in the data dictionary to be passed to the template
340
+
341
+
342
+ if method == "Classification":
343
+ # Multiclass Histplot
344
+ # Get the names of all columns with data type 'object' (categorical columns)
345
+ cat_cols = df.columns.tolist()
346
+
347
+ # Get the names of all columns with data type 'int'
348
+ int_vars = df.select_dtypes(include=['int', 'float']).columns.tolist()
349
+ int_vars = [col for col in int_vars if col != target]
350
+
351
+ # Create a figure with subplots
352
+ num_cols = len(int_vars)
353
+ num_rows = (num_cols + 2) // 3 # To make sure there are enough rows for the subplots
354
+ fig, axs = plt.subplots(nrows=num_rows, ncols=3, figsize=(15, 5*num_rows))
355
+ axs = axs.flatten()
356
+
357
+ # Create a histogram for each integer variable with hue='Attrition'
358
+ for i, var in enumerate(int_vars):
359
+ top_categories = df[var].value_counts().nlargest(10).index
360
+ filtered_df = df[df[var].notnull() & df[var].isin(top_categories)]
361
+ sns.histplot(data=df, x=var, hue=target, kde=True, ax=axs[i], stat="percent")
362
+ axs[i].set_title(var)
363
+
364
+ # Annotate the subplot with sample size
365
+ sample_size = df.shape[0]
366
+ axs[i].annotate(f'Sample Size = {sample_size}', xy=(0.5, 0.9), xycoords='axes fraction', ha='center', va='center')
367
+
368
+ # Remove any extra empty subplots if needed
369
+ if num_cols < len(axs):
370
+ for i in range(num_cols, len(axs)):
371
+ fig.delaxes(axs[i])
372
+
373
+ # Adjust spacing between subplots
374
+ fig.tight_layout()
375
+ plt.xticks(rotation=45)
376
+ plot4_path = "static/multiclass_histplot.png"
377
+ plt.savefig(plot4_path)
378
+ plt.close(fig)
379
+
380
+ #response 4
381
+ def to_markdown(text):
382
+ text = text.replace('•', ' *')
383
+ return Markdown(textwrap.indent(text, '> ', predicate=lambda _: True))
384
+
385
+ genai.configure(api_key=api)
386
+
387
+ import PIL.Image
388
+
389
+ img = PIL.Image.open("static/multiclass_histplot.png")
390
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
391
+ response4 = model.generate_content(img)
392
+ response4 = model.generate_content(["As a marketing consulant, I want to understand consumer insighst based on the chart and the market context so I can use the key findings to formulate actionable insights", img])
393
+ response4.resolve()
394
+ response4 = format_text(response4.text)
395
+
396
+
397
+
398
+
399
+
400
+ # Generate Pairplot
401
+ pairplot_fig = sns.pairplot(df)
402
+ plot5_path = "static/pair2.png"
403
+ pairplot_fig.savefig(plot5_path) # Save the pairplot as a PNG file
404
+
405
+ # Google Gemini Integration
406
+ genai.configure(api_key=api)
407
+ img = PIL.Image.open(plot5_path)
408
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
409
+
410
+ # Generate response based on the pairplot
411
+ response = model.generate_content([
412
+ "You are a professional Data Analyst, write the complete conclusion and actionable insight based on the image. Explain it by points.",
413
+ img
414
+ ])
415
+ response.resolve()
416
+
417
+ # Assign the response to response5
418
+ response5 = format_text(response.text)
419
+
420
+ def generate_gemini_response(plot_path):
421
+
422
+
423
+ genai.configure(api_key=api)
424
+ img = Image.open(plot_path)
425
+ model = genai.GenerativeModel('gemini-1.5-flash-latest')
426
+ response = model.generate_content([
427
+ " As a marketing consultant, I want to understand consumer insights based on the chart and the market context so I can use the key findings to formulate actionable insights",
428
+ img
429
+ ])
430
+ response.resolve()
431
+ return response.text
432
+
433
+ plot1_path = generate_plot(df, 'static/plot1.png', 'countplot')
434
+ plot2_path = generate_plot(df, 'static/plot2.png', 'histplot')
435
+
436
+ response1 = format_text((generate_gemini_response(plot1_path)))
437
+ response2 = format_text((generate_gemini_response(plot2_path)))
438
+
439
+ from sklearn import preprocessing
440
+ for col in df.select_dtypes(include=['object']).columns:
441
+
442
+ # Initialize a LabelEncoder object
443
+ label_encoder = preprocessing.LabelEncoder()
444
+
445
+ # Fit the encoder to the unique values in the column
446
+ label_encoder.fit(df[col].unique())
447
+
448
+ # Transform the column using the encoder
449
+ df[col] = label_encoder.transform(df[col])
450
+
451
+
452
+ # Display Correlation Heatmap
453
+ plot7_path = "static/correlation_matrix.png"
454
+ fig, ax = plt.subplots(figsize=(30, 24))
455
+ correlation_matrix = df.corr()
456
+ sns.heatmap(correlation_matrix, annot=True, fmt='.2f', cmap='coolwarm', ax=ax)
457
+ plt.savefig(plot7_path)
458
+ plt.close(fig)
459
+
460
+ img = PIL.Image.open(plot7_path)
461
+ response7 = format_text((generate_gemini_response(plot7_path)))
462
+
463
+
464
+
465
+
466
+
467
+ X = df.drop(target, axis=1)
468
+ y = df[target]
469
+ from sklearn.model_selection import train_test_split
470
+ from sklearn.metrics import accuracy_score
471
+ X_train, X_test, y_train, y_test = train_test_split(X,y, test_size=0.2,random_state=0)
472
+
473
+ from scipy import stats
474
+ threshold = 3
475
+
476
+ for col in X_train.columns:
477
+ if X_train[col].nunique() > 20:
478
+ # Calculate Z-scores for the column
479
+ z_scores = np.abs(stats.zscore(X_train[col]))
480
+ # Find and remove outliers based on the threshold
481
+ outlier_indices = np.where(z_scores > threshold)[0]
482
+ X_train = X_train.drop(X_train.index[outlier_indices])
483
+ y_train = y_train.drop(y_train.index[outlier_indices])
484
+
485
+
486
+
487
+
488
+ from sklearn.tree import DecisionTreeRegressor
489
+ from sklearn.tree import DecisionTreeClassifier
490
+ from sklearn.model_selection import GridSearchCV
491
+ from sklearn import metrics
492
+ from sklearn.metrics import mean_absolute_percentage_error
493
+ import math
494
+
495
+
496
+ if algorithm == "Decision Tree":
497
+
498
+ if method == "Regression":
499
+ dtree = DecisionTreeRegressor()
500
+ param_grid = {
501
+ 'max_depth': [4, 6, 8],
502
+ 'min_samples_split': [4, 6, 8],
503
+ 'min_samples_leaf': [1, 2, 3, 4],
504
+ 'random_state': [0, 42],
505
+ 'max_features': ['auto', 'sqrt', 'log2']
506
+ }
507
+ grid_search = GridSearchCV(dtree, param_grid, cv=5, scoring='neg_mean_squared_error')
508
+ grid_search.fit(X_train, y_train)
509
+ best_params = grid_search.best_params_
510
+ dtree = DecisionTreeRegressor(**best_params)
511
+ dtree.fit(X_train, y_train)
512
+
513
+ y_pred = dtree.predict(X_test)
514
+ mae = metrics.mean_absolute_error(y_test, y_pred)
515
+ mse = metrics.mean_squared_error(y_test, y_pred)
516
+ r2 = metrics.r2_score(y_test, y_pred)
517
+ rmse = np.sqrt(mse)
518
+
519
+ # Feature importance visualization
520
+ imp_df = pd.DataFrame({
521
+ "Feature Name": X_train.columns,
522
+ "Importance": dtree.feature_importances_
523
+ })
524
+ fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
525
+ fig, ax = plt.subplots(figsize=(10, 8))
526
+ sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
527
+ ax.set_title('Top 10 Feature Importance (Decision Tree Regressor)', fontsize=18)
528
+ plot8_path = "static/dtree_regressor.png"
529
+ plt.savefig(plot8_path)
530
+ img = PIL.Image.open(plot8_path)
531
+ response8 = format_text((generate_gemini_response(plot8_path)))
532
+
533
+
534
+ elif method == "Classification":
535
+ dtree = DecisionTreeClassifier()
536
+ param_grid = {
537
+ 'max_depth': [3, 4, 5, 6, 7],
538
+ 'min_samples_split': [2, 3, 4],
539
+ 'min_samples_leaf': [1, 2, 3],
540
+ 'random_state': [0, 42]
541
+ }
542
+ grid_search = GridSearchCV(dtree, param_grid, cv=5)
543
+ grid_search.fit(X_train, y_train)
544
+ best_params = grid_search.best_params_
545
+ dtree = DecisionTreeClassifier(**best_params)
546
+ dtree.fit(X_train, y_train)
547
+
548
+ y_pred = dtree.predict(X_test)
549
+ acc = metrics.accuracy_score(y_test, y_pred)
550
+ f1 = metrics.f1_score(y_test, y_pred, average='micro')
551
+ prec = metrics.precision_score(y_test, y_pred, average='micro')
552
+ recall = metrics.recall_score(y_test, y_pred, average='micro')
553
+
554
+ # Feature importance visualization
555
+ imp_df = pd.DataFrame({
556
+ "Feature Name": X_train.columns,
557
+ "Importance": dtree.feature_importances_
558
+ })
559
+ fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
560
+ fig, ax = plt.subplots(figsize=(10, 8))
561
+ sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
562
+ ax.set_title('Top 10 Feature Importance (Decision Tree Classifier)', fontsize=18)
563
+ plot9_path = "static/dtree_classifier.png"
564
+ plt.savefig(plot9_path)
565
+ img = PIL.Image.open(plot9_path)
566
+ response9 = format_text((generate_gemini_response(plot9_path)))
567
+
568
+
569
+
570
+ from sklearn.ensemble import RandomForestRegressor
571
+ from sklearn.ensemble import RandomForestClassifier
572
+
573
+ if algorithm == "Random Forest":
574
+
575
+ if method == "Regression":
576
+ rf = RandomForestRegressor()
577
+ param_grid = {
578
+ 'max_depth': [4, 6, 8],
579
+ 'random_state': [0, 42],
580
+ 'max_features': ['auto', 'sqrt', 'log2']
581
+ }
582
+ grid_search = GridSearchCV(rf, param_grid, cv=5, scoring='neg_mean_squared_error')
583
+ grid_search.fit(X_train, y_train)
584
+ best_params = grid_search.best_params_
585
+ rf = RandomForestRegressor(**best_params)
586
+ rf.fit(X_train, y_train)
587
+
588
+ y_pred = rf.predict(X_test)
589
+ mae = metrics.mean_absolute_error(y_test, y_pred)
590
+ mse = metrics.mean_squared_error(y_test, y_pred)
591
+ r2 = metrics.r2_score(y_test, y_pred)
592
+ rmse = np.sqrt(mse)
593
+
594
+ # Feature importance visualization
595
+ imp_df = pd.DataFrame({
596
+ "Feature Name": X_train.columns,
597
+ "Importance": rf.feature_importances_
598
+ })
599
+ fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
600
+ fig, ax = plt.subplots(figsize=(10, 8))
601
+ sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
602
+ ax.set_title('Top 10 Feature Importance (Random Forest Regressor)', fontsize=18)
603
+ plot10_path = "static/rf_regressor.png"
604
+ plt.savefig(plot10_path)
605
+ img = PIL.Image.open(plot10_path)
606
+ response10 = format_text((generate_gemini_response(plot10_path)))
607
+
608
+ elif method == "Classification":
609
+ rf = RandomForestClassifier()
610
+ param_grid = {
611
+ 'max_depth': [3, 4, 5, 6],
612
+ 'random_state': [0, 42]
613
+ }
614
+ grid_search = GridSearchCV(rf, param_grid, cv=5)
615
+ grid_search.fit(X_train, y_train)
616
+ best_params = grid_search.best_params_
617
+ rf = RandomForestClassifier(**best_params)
618
+ rf.fit(X_train, y_train)
619
+
620
+ y_pred = rf.predict(X_test)
621
+ acc = metrics.accuracy_score(y_test, y_pred)
622
+ f1 = metrics.f1_score(y_test, y_pred, average='micro')
623
+ prec = metrics.precision_score(y_test, y_pred, average='micro')
624
+ recall = metrics.recall_score(y_test, y_pred, average='micro')
625
+
626
+ # Feature importance visualization
627
+ imp_df = pd.DataFrame({
628
+ "Feature Name": X_train.columns,
629
+ "Importance": rf.feature_importances_
630
+ })
631
+ fi = imp_df.sort_values(by="Importance", ascending=False).head(10)
632
+ fig, ax = plt.subplots(figsize=(10, 8))
633
+ sns.barplot(data=fi, x='Importance', y='Feature Name', ax=ax)
634
+ ax.set_title('Top 10 Feature Importance (Random Forest Classifier)', fontsize=18)
635
+ plot11_path = "static/rf_classifier.png"
636
+ plt.savefig(plot11_path)
637
+ img = PIL.Image.open(plot11_path)
638
+ response11 = format_text((generate_gemini_response(plot11_path)))
639
+
640
+
641
+
642
+ document_analyzed = True
643
+
644
+
645
+
646
+ data = {
647
+ "request": request,
648
+ "response1": response1,
649
+ "response2": response2,
650
+ "response5": response5,
651
+ "response7": response7,
652
+ "plot1_path": plot1_path,
653
+ "plot2_path": plot2_path,
654
+ "plot5_path": plot5_path,
655
+ "plot7_path": plot7_path,
656
+ "show_conversation": document_analyzed,
657
+ "question_responses": question_responses
658
+ }
659
+
660
+ # Conditionally include response3 and plot3_path if they exist
661
+ if response3:
662
+ data["response3"] = response3
663
+ if plot3_path:
664
+ data["plot3_path"] = plot3_path
665
+ if response4:
666
+ data["response4"] = response3
667
+ if plot4_path:
668
+ data["plot4_path"] = plot4_path
669
+ if response6:
670
+ data["response6"] = response6
671
+ if plot6_path:
672
+ data["plot6_path"] = plot6_path
673
+ if response8:
674
+ data["response8"] = response8
675
+ if plot8_path:
676
+ data["plot8_path"] = plot8_path
677
+ if response9:
678
+ data["response9"] = response9
679
+ if plot9_path:
680
+ data["plot9_path"] = plot9_path
681
+ if response10:
682
+ data["response10"] = response10
683
+ if plot10_path:
684
+ data["plot10_path"] = plot10_path
685
+ if response11:
686
+ data["response11"] = response11
687
+ if plot11_path:
688
+ data["plot11_path"] = plot11_path
689
+
690
+ return templates.TemplateResponse("upload.html", data)
691
+
692
+
693
+
694
+
695
+ # Route for asking questions
696
+ @app.post("/ask", response_class=HTMLResponse)
697
+ async def ask_question(request: Request, question: str = Form(...)):
698
+ global file_extension, question_responses, api
699
+ global plot1_path, plot2_path, plot3_path, plot4_path, plot5_path, plot6_path, plot7_path, plot8_path, plot9_path, plot10_path, plot11_path
700
+ global response1, response2, response3, response4, response5, response6, response7, response8, response9, response10, response11
701
+ global document_analyzed
702
+
703
+ # Check if a file has been uploaded
704
+ if not file_extension:
705
+ raise HTTPException(status_code=400, detail="No file has been uploaded yet.")
706
+
707
+ # Initialize the LLM model
708
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", google_api_key=api)
709
+
710
+ # Determine the file extension and select the appropriate loader
711
+ file_path = ''
712
+ loader = None
713
+
714
+ if file_extension.endswith('.csv'):
715
+ file_path = 'dataset.csv'
716
+ loader = UnstructuredCSVLoader(file_path, mode="elements")
717
+ elif file_extension.endswith('.xlsx'):
718
+ file_path = 'dataset.xlsx'
719
+ loader = UnstructuredExcelLoader(file_path, mode="elements")
720
+ else:
721
+ raise HTTPException(status_code=400, detail="Unsupported file format")
722
+
723
+ # Load and process the document
724
+ try:
725
+ docs = loader.load()
726
+ except Exception as e:
727
+ raise HTTPException(status_code=500, detail=f"Error loading document: {str(e)}")
728
+
729
+ # Combine document text
730
+ text = "\n".join([doc.page_content for doc in docs])
731
+ os.environ["GOOGLE_API_KEY"] = api
732
+
733
+ # Initialize embeddings and create FAISS vector store
734
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
735
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
736
+ chunks = text_splitter.split_text(text)
737
+ document_search = FAISS.from_texts(chunks, embeddings)
738
+
739
+ # Generate query embedding and perform similarity search
740
+ query_embedding = embeddings.embed_query(question)
741
+ results = document_search.similarity_search_by_vector(query_embedding, k=3)
742
+
743
+ if results:
744
+ retrieved_texts = " ".join([result.page_content for result in results])
745
+
746
+ # Define the Summarize Chain for the question
747
+ latest_response = "" if not question_responses else question_responses[-1][1]
748
+ template1 = (
749
+ f"{question} Answer the question based on the following:\n\"{text}\"\n:" +
750
+ (f" Answer the Question with only 3 sentences. Latest conversation: {latest_response}" if latest_response else "")
751
+ )
752
+ prompt1 = PromptTemplate.from_template(template1)
753
+
754
+ # Initialize the LLMChain with the prompt
755
+ llm_chain1 = LLMChain(llm=llm, prompt=prompt1)
756
+
757
+ # Invoke the chain to get the summary
758
+ try:
759
+ response_chain = llm_chain1.invoke({"text": text})
760
+ summary1 = response_chain["text"]
761
+ except Exception as e:
762
+ raise HTTPException(status_code=500, detail=f"Error invoking LLMChain: {str(e)}")
763
+
764
+ # Generate embeddings for the summary
765
+ try:
766
+ summary_embedding = embeddings.embed_query(summary1)
767
+ document_search = FAISS.from_texts([summary1], embeddings)
768
+ except Exception as e:
769
+ raise HTTPException(status_code=500, detail=f"Error generating embeddings: {str(e)}")
770
+
771
+ # Perform a search on the FAISS vector database
772
+ try:
773
+ if document_search:
774
+ query_embedding = embeddings.embed_query(question)
775
+ results = document_search.similarity_search_by_vector(query_embedding, k=1)
776
+
777
+ if results:
778
+ current_response = format_text(results[0].page_content)
779
+ else:
780
+ current_response = "No matching document found in the database."
781
+ else:
782
+ current_response = "Vector database not initialized."
783
+ except Exception as e:
784
+ raise HTTPException(status_code=500, detail=f"Error during similarity search: {str(e)}")
785
+ else:
786
+ current_response = "No relevant results found."
787
+
788
+ # Append the question and response from FAISS search
789
+ current_question = f"You asked: {question}"
790
+ question_responses.append((current_question, current_response))
791
+
792
+ # Save all results to output_summary.json
793
+ save_to_json(question_responses)
794
+
795
+
796
+
797
+ data = {
798
+ "request": request,
799
+ "response1": response1,
800
+ "response2": response2,
801
+ "response5": response5,
802
+ "response7": response7,
803
+ "plot1_path": plot1_path,
804
+ "plot2_path": plot2_path,
805
+ "plot5_path": plot5_path,
806
+ "plot7_path": plot7_path,
807
+ "show_conversation": True,
808
+ "question_responses": question_responses
809
+ }
810
+
811
+ # Conditionally include response3 and plot3_path if they exist
812
+ if response3:
813
+ data["response3"] = response3
814
+ if plot3_path:
815
+ data["plot3_path"] = plot3_path
816
+ if response4:
817
+ data["response4"] = response3
818
+ if plot4_path:
819
+ data["plot4_path"] = plot4_path
820
+ if response6:
821
+ data["response6"] = response6
822
+ if plot6_path:
823
+ data["plot6_path"] = plot6_path
824
+ if response8:
825
+ data["response8"] = response8
826
+ if plot8_path:
827
+ data["plot8_path"] = plot8_path
828
+ if response9:
829
+ data["response9"] = response9
830
+ if plot9_path:
831
+ data["plot9_path"] = plot9_path
832
+ if response10:
833
+ data["response10"] = response10
834
+ if plot10_path:
835
+ data["plot10_path"] = plot10_path
836
+ if response11:
837
+ data["response11"] = response11
838
+ if plot11_path:
839
+ data["plot11_path"] = plot11_path
840
+
841
+ return templates.TemplateResponse("upload.html", data)
842
+
843
+
844
+
845
+ def save_to_json(question_responses):
846
+ outputs = {
847
+ "question_responses": question_responses
848
+ }
849
+ with open("output_summary.json", "w") as outfile:
850
+ json.dump(outputs, outfile)
851
+
852
+
853
+
854
+