Ipshitaa commited on
Commit
ec2628d
·
1 Parent(s): 24c00be

updated main.py

Browse files
Files changed (3) hide show
  1. endpoint.py +50 -216
  2. main.py +17 -1
  3. requirements.txt +2 -0
endpoint.py CHANGED
@@ -1,231 +1,65 @@
1
- from flask import Flask, request, jsonify
2
- import os
3
- from dotenv import load_dotenv
4
- from llama_index.core import VectorStoreIndex, Settings, StorageContext, load_index_from_storage
5
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
  from llama_index.llms.groq import Groq
7
- import pandas as pd
8
- from llama_index.core import Document
9
-
10
- app = Flask(__name__)
11
-
12
- # --- Configuration ---
13
- PERSIST_DIR = "./storage"
14
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
15
- LLM_MODEL = "llama3-8b-8192"
16
- CSV_FILE_PATH = "shl_assessments.csv"
17
-
18
- # --- Root Route (for health check) ---
19
- @app.route("/", methods=["GET"])
20
- def home():
21
- return "🧠 SHL Chatbot API is running!", 200
22
-
23
- # --- Utility Functions ---
24
- def load_groq_llm():
25
- load_dotenv()
26
- api_key = os.getenv("GROQ_API_KEY")
27
- if not api_key:
28
- raise ValueError("GROQ_API_KEY not found in .env file or environment variables")
29
- return Groq(model=LLM_MODEL, api_key=api_key, temperature=0.1)
30
-
31
- def load_embeddings():
32
- return HuggingFaceEmbedding(model_name=EMBED_MODEL)
33
-
34
- def load_data_from_csv(csv_path):
35
- try:
36
- df = pd.read_csv(csv_path)
37
- required_columns = ["Assessment Name", "URL", "Remote Testing Support",
38
- "Adaptive/IRT Support", "Duration (min)", "Test Type"]
39
- if not all(col in df.columns for col in required_columns):
40
- raise ValueError(f"CSV file must contain columns: {', '.join(required_columns)}")
41
- return df.to_dict(orient="records")
42
- except FileNotFoundError:
43
- raise FileNotFoundError(f"CSV file not found at {csv_path}")
44
- except Exception as e:
45
- raise Exception(f"Error reading CSV: {e}")
46
-
47
- def build_index(data):
48
- Settings.embed_model = load_embeddings()
49
- Settings.llm = load_groq_llm()
50
- documents = [
51
- Document(text=f"Name: {item['Assessment Name']}, URL: {item['URL']}, Remote Testing: {item['Remote Testing Support']}, Adaptive/IRT: {item['Adaptive/IRT Support']}, Duration: {item['Duration (min)']}, Type: {item['Test Type']}")
52
- for item in data
53
- ]
54
- index = VectorStoreIndex.from_documents(documents)
55
- index.storage_context.persist(persist_dir=PERSIST_DIR)
56
- return index
57
-
58
- def load_chat_engine():
59
- if not os.path.exists(PERSIST_DIR):
60
- return None
61
- Settings.embed_model = load_embeddings()
62
- Settings.llm = load_groq_llm()
63
- storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
64
- index = load_index_from_storage(storage_context)
65
- return index.as_chat_engine(chat_mode="context", verbose=True)
66
-
67
- # --- Load or Build Index ---
68
- try:
69
- chat_engine = load_chat_engine()
70
- if chat_engine is None:
71
- assessment_data = load_data_from_csv(CSV_FILE_PATH)
72
- build_index(assessment_data)
73
- chat_engine = load_chat_engine()
74
- except Exception as e:
75
- print(f"❌ Error initializing chat engine: {e}")
76
- chat_engine = None
77
-
78
- # --- Endpoint ---
79
- @app.route("/assessments", methods=["POST"])
80
- def get_assessments():
81
- data = request.get_json()
82
- query = data.get("query")
83
-
84
- if not query:
85
- return jsonify({"error": "No query provided"}), 400
86
-
87
- if chat_engine:
88
- try:
89
- response = chat_engine.chat(query)
90
- results = []
91
-
92
- for node in response.source_nodes:
93
- try:
94
- parts = node.node.text.split(", ")
95
- results.append({
96
- "assessment_name": parts[0].split(": ")[1] if len(parts) > 0 else "N/A",
97
- "assessment_url": parts[1].split(": ")[1] if len(parts) > 1 else "N/A",
98
- "remote_testing_support": parts[2].split(": ")[1] if len(parts) > 2 else "N/A",
99
- "adaptive_irt_support": parts[3].split(": ")[1] if len(parts) > 3 else "N/A",
100
- "duration": parts[4].split(": ")[1] if len(parts) > 4 else "N/A",
101
- "test_type": parts[5].split(": ")[1] if len(parts) > 5 else "N/A"
102
- })
103
- except:
104
- results.append({"error": "Error parsing assessment info"})
105
-
106
- return jsonify({"query": query, "response": results}), 200
107
-
108
- except Exception as e:
109
- return jsonify({"error": f"Chat processing error: {e}"}), 500
110
- else:
111
- return jsonify({"error": "Chat engine not initialized"}), 500
112
-
113
- # --- Entry Point for Local Debugging ---
114
- if __name__ == "__main__":
115
- app.run(host="0.0.0.0", port=10000)
116
-
117
- from flask import Flask, request, jsonify
118
- import os
119
- from dotenv import load_dotenv
120
- from llama_index.core import VectorStoreIndex, Settings, StorageContext, load_index_from_storage
121
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
122
- from llama_index.llms.groq import Groq
123
- import pandas as pd
124
- from llama_index.core import Document
125
-
126
- app = Flask(__name__)
127
-
128
- # --- Configuration ---
129
- PERSIST_DIR = "./storage"
130
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
131
- LLM_MODEL = "llama3-8b-8192"
132
- CSV_FILE_PATH = "shl_assessments.csv"
133
-
134
- # --- Root Route (for health check) ---
135
- @app.route("/", methods=["GET"])
136
- def home():
137
- return "🧠 SHL Chatbot API is running!", 200
138
 
139
- # --- Utility Functions ---
140
- def load_groq_llm():
141
- load_dotenv()
142
- api_key = os.getenv("GROQ_API_KEY")
143
- if not api_key:
144
- raise ValueError("GROQ_API_KEY not found in .env file or environment variables")
145
- return Groq(model=LLM_MODEL, api_key=api_key, temperature=0.1)
146
 
147
- def load_embeddings():
148
- return HuggingFaceEmbedding(model_name=EMBED_MODEL)
 
149
 
150
- def load_data_from_csv(csv_path):
151
- try:
152
- df = pd.read_csv(csv_path)
153
- required_columns = ["Assessment Name", "URL", "Remote Testing Support",
154
- "Adaptive/IRT Support", "Duration (min)", "Test Type"]
155
- if not all(col in df.columns for col in required_columns):
156
- raise ValueError(f"CSV file must contain columns: {', '.join(required_columns)}")
157
- return df.to_dict(orient="records")
158
- except FileNotFoundError:
159
- raise FileNotFoundError(f"CSV file not found at {csv_path}")
160
- except Exception as e:
161
- raise Exception(f"Error reading CSV: {e}")
162
 
163
- def build_index(data):
164
- Settings.embed_model = load_embeddings()
165
- Settings.llm = load_groq_llm()
166
- documents = [
167
- Document(text=f"Name: {item['Assessment Name']}, URL: {item['URL']}, Remote Testing: {item['Remote Testing Support']}, Adaptive/IRT: {item['Adaptive/IRT Support']}, Duration: {item['Duration (min)']}, Type: {item['Test Type']}")
168
- for item in data
169
- ]
170
- index = VectorStoreIndex.from_documents(documents)
171
- index.storage_context.persist(persist_dir=PERSIST_DIR)
172
- return index
173
 
174
- def load_chat_engine():
175
- if not os.path.exists(PERSIST_DIR):
176
- return None
177
- Settings.embed_model = load_embeddings()
178
- Settings.llm = load_groq_llm()
179
- storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
180
- index = load_index_from_storage(storage_context)
181
- return index.as_chat_engine(chat_mode="context", verbose=True)
182
 
183
- # --- Load or Build Index ---
184
- try:
185
- chat_engine = load_chat_engine()
186
- if chat_engine is None:
187
- assessment_data = load_data_from_csv(CSV_FILE_PATH)
188
- build_index(assessment_data)
189
- chat_engine = load_chat_engine()
190
- except Exception as e:
191
- print(f"❌ Error initializing chat engine: {e}")
192
- chat_engine = None
193
 
194
- # --- Endpoint ---
195
- @app.route("/assessments", methods=["POST"])
196
- def get_assessments():
197
- data = request.get_json()
198
- query = data.get("query")
199
 
200
- if not query:
201
- return jsonify({"error": "No query provided"}), 400
 
 
202
 
203
- if chat_engine:
204
- try:
205
- response = chat_engine.chat(query)
206
- results = []
207
 
208
- for node in response.source_nodes:
209
- try:
210
- parts = node.node.text.split(", ")
211
- results.append({
212
- "assessment_name": parts[0].split(": ")[1] if len(parts) > 0 else "N/A",
213
- "assessment_url": parts[1].split(": ")[1] if len(parts) > 1 else "N/A",
214
- "remote_testing_support": parts[2].split(": ")[1] if len(parts) > 2 else "N/A",
215
- "adaptive_irt_support": parts[3].split(": ")[1] if len(parts) > 3 else "N/A",
216
- "duration": parts[4].split(": ")[1] if len(parts) > 4 else "N/A",
217
- "test_type": parts[5].split(": ")[1] if len(parts) > 5 else "N/A"
218
- })
219
- except:
220
- results.append({"error": "Error parsing assessment info"})
221
 
222
- return jsonify({"query": query, "response": results}), 200
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- except Exception as e:
225
- return jsonify({"error": f"Chat processing error: {e}"}), 500
226
- else:
227
- return jsonify({"error": "Chat engine not initialized"}), 500
228
 
229
- # --- Entry Point for Local Debugging ---
230
- if __name__ == "__main__":
231
- app.run(host="0.0.0.0", port=10000)
 
 
 
1
+ # endpoint.py
2
+ from fastapi import FastAPI, Request
3
+ from pydantic import BaseModel
4
+ from llama_index.core import Settings, StorageContext, load_index_from_storage
 
5
  from llama_index.llms.groq import Groq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ import os,json
8
+ from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Load secrets
11
+ load_dotenv()
12
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
 
 
 
 
13
 
14
+ # Init LLM and Embedding model
15
+ Settings.llm = Groq(model="llama3-8b-8192", api_key=GROQ_API_KEY)
16
+ Settings.embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
17
 
18
+ # Load index
19
+ PERSIST_DIR = "./storage"
20
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
21
+ index = load_index_from_storage(storage_context)
22
+ chat_engine = index.as_chat_engine(chat_mode="context", verbose=False)
 
 
 
 
 
 
 
23
 
24
+ app = FastAPI()
 
 
 
 
 
 
 
 
 
25
 
26
+ class QueryRequest(BaseModel):
27
+ question: str
 
 
 
 
 
 
28
 
29
+ class RecommendRequest(BaseModel):
30
+ query: str
 
 
 
 
 
 
 
 
31
 
32
+ @app.get("/health")
33
+ def health_check():
34
+ return {"status": "healthy"}
 
 
35
 
36
+ @app.post("/recommend")
37
+ async def recommend(request: RecommendRequest):
38
+ prompt = f"""
39
+ You are an intelligent assistant that recommends SHL assessments based on user queries.
40
 
41
+ Using the query: "{request.query}", return **all relevant and matching** SHL assessments (at least 1 and up to 10).
 
 
 
42
 
43
+ Only respond in this exact JSON format:
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ {{
46
+ "recommended_assessments": [
47
+ {{
48
+ "url": "Valid URL in string",
49
+ "adaptive_support": "Yes/No",
50
+ "description": "Description in string",
51
+ "duration": 60,
52
+ "remote_support": "Yes/No",
53
+ "test_type": ["List of string"]
54
+ }}
55
+ ]
56
+ }}
57
 
58
+ Do not include any explanations or extra text. Only return pure JSON. Respond with as many matching assessments as possible (up to 10).
59
+ """
 
 
60
 
61
+ response = chat_engine.chat(prompt)
62
+ try:
63
+ return json.loads(response.response)
64
+ except Exception:
65
+ return {"error": "Model response was not valid JSON", "raw": response.response}
main.py CHANGED
@@ -172,7 +172,23 @@ def main():
172
  with st.chat_message("assistant"):
173
  try:
174
  # Add formatting instructions to the prompt
175
- formatted_prompt = f"{prompt}. Please provide a list of all matching SHL assessments (minimum 1, maximum 10). For each assessment, include the following details: Assessment Name: [Name], URL: [URL], Remote Testing Support: [Yes/No], Adaptive/IRT Support: [Yes/No], Duration: [Duration], Test Type: [Type]. If there are no matching assessments, please state that."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  response = chat_engine.chat(formatted_prompt)
177
  st.markdown(f"<span style='color: white;'>🤖 {response.response}</span>", unsafe_allow_html=True)
178
  st.session_state.messages.append({"role": "assistant", "content": response.response})
 
172
  with st.chat_message("assistant"):
173
  try:
174
  # Add formatting instructions to the prompt
175
+ formatted_prompt = f"""
176
+ {prompt}
177
+
178
+ Please provide a list of all matching SHL assessments (minimum 1, maximum 10).
179
+
180
+ For each matching assessment, follow this exact format:
181
+
182
+ • Assessment Name: [Name]
183
+ URL: [URL]
184
+ Remote Testing Support: [Yes/No]
185
+ Adaptive/IRT Support: [Yes/No]
186
+ Duration: [Duration in minutes]
187
+ Test Type: [Test Type]
188
+
189
+ If there are no matches, clearly state that. Respond in a clean, readable bullet-point format.Do not use any "+" signs. Do not return JSON or markdown tables. Do not bold anything.
190
+
191
+ """
192
  response = chat_engine.chat(formatted_prompt)
193
  st.markdown(f"<span style='color: white;'>🤖 {response.response}</span>", unsafe_allow_html=True)
194
  st.session_state.messages.append({"role": "assistant", "content": response.response})
requirements.txt CHANGED
@@ -10,3 +10,5 @@ groq==0.22.0
10
  streamlit
11
  Flask
12
  gunicorn
 
 
 
10
  streamlit
11
  Flask
12
  gunicorn
13
+ fastapi
14
+ uvicorn