brahmanarisetty commited on
Commit
0c61b1f
·
verified ·
1 Parent(s): d6da279

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -85
app.py CHANGED
@@ -2,12 +2,11 @@
2
  """
3
  IT Support Chatbot Application
4
  - Converts the original Colab notebook into a deployable Gradio app.
5
- - Loads data from a local CSV file.
6
  - Uses environment variables for API keys.
7
  - Implements a RAG pipeline with LLaMA 3.1, Qdrant, and Hybrid Retrieval.
8
  """
9
 
10
-
11
  # --- CELL 1: Imports, Logging & Reproducibility ---
12
  import os
13
  import random
@@ -41,8 +40,7 @@ logging.basicConfig(
41
  )
42
  logger = logging.getLogger(__name__)
43
 
44
- # Apply nest_asyncio for environments like notebooks
45
- nest_asyncio.apply()
46
 
47
  # Reproducibility
48
  SEED = 42
@@ -50,12 +48,10 @@ random.seed(SEED)
50
  np.random.seed(SEED)
51
  torch.manual_seed(SEED)
52
 
53
- # --- CELL 0: load secrets from env vars ---
54
- QDRANT_HOST = os.getenv("QDRANT_HOST")
55
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
56
- HF_TOKEN = os.getenv("HF_TOKEN")
57
-
58
- # --- CELL 2: Environment & Qdrant Connection Setup ---
59
 
60
  if not all([QDRANT_HOST, QDRANT_API_KEY, HF_TOKEN]):
61
  raise EnvironmentError(
@@ -73,80 +69,48 @@ qdrant = qdrant_client.QdrantClient(
73
  )
74
  COLLECTION_NAME = "it_support_rag"
75
 
76
-
77
  # --- CELL 3: Load Dataset & Build Documents ---
78
- # Load data from a local CSV file.
79
- # Make sure this CSV file is in the same directory as app.py when deploying.
80
- CSV_PATH = "data.csv" # Or whatever you name your CSV file
81
  if not os.path.exists(CSV_PATH):
82
  raise FileNotFoundError(
83
- f"The data file was not found at {CSV_PATH}. "
84
- "Please upload your data CSV and name it correctly."
85
  )
86
 
87
  df = pd.read_csv(CSV_PATH, encoding="ISO-8859-1")
88
-
89
  case_docs: List[Document] = []
90
  for _, row in df.iterrows():
91
  text = str(row.get("text_chunk", ""))
92
  meta = {
93
  "source_dataset": str(row.get("source_dataset", ""))[:50],
94
- "category": str(row.get("category", ""))[:100],
95
- "orig_query": str(row.get("original_query", ""))[:200],
96
- "orig_solution": str(row.get("original_solution", ""))[:200]
97
  }
98
  case_docs.append(Document(text=text, metadata=meta))
99
  logger.info(f"Loaded {len(case_docs)} documents from {CSV_PATH}.")
100
 
101
-
102
- # --- CELL 4: Create Vector Index ---
103
- # Embedding model
104
- device = "cuda" if torch.cuda.is_available() else "cpu"
105
- logger.info(f"Using device: {device}")
106
- embed_model = HuggingFaceEmbedding(
107
- model_name="BAAI/bge-large-en-v1.5",
108
- device=device
109
- )
110
-
111
- # Node parser for chunking
112
- node_parser = SentenceSplitter(
113
- chunk_size=1024,
114
- chunk_overlap=100,
115
- paragraph_separator="\n\n"
116
- )
117
-
118
- # Qdrant-backed vector store
119
  vector_store = QdrantVectorStore(
120
  client=qdrant,
121
  collection_name=COLLECTION_NAME,
122
  prefer_grpc=False
123
  )
124
-
125
- # Build the index (will upload to Qdrant if collection doesn't exist)
126
- # Note: This step can be slow the first time it's run.
127
- logger.info("Initializing VectorStoreIndex...")
128
- index = VectorStoreIndex.from_documents(
129
- documents=case_docs,
130
- storage_context=StorageContext.from_defaults(vector_store=vector_store),
131
- embed_model=embed_model,
132
- node_parser=node_parser,
133
- show_progress=True
134
- )
135
- logger.info("VectorStoreIndex initialized successfully.")
136
-
137
 
138
  # --- CELL 5: Define Hybrid Retriever & Reranker ---
139
- Settings.llm = None # We will use our own LLM pipeline
140
 
141
  class HybridRetriever(BaseRetriever):
142
  def __init__(self, dense, bm25):
143
  super().__init__()
144
  self.dense = dense
145
  self.bm25 = bm25
 
146
  def _retrieve(self, query_bundle: QueryBundle) -> List[Document]:
147
  dense_hits = self.dense.retrieve(query_bundle)
148
  bm25_hits = self.bm25.retrieve(query_bundle)
149
-
150
  combined = dense_hits + bm25_hits
151
  unique = []
152
  seen = set()
@@ -159,7 +123,7 @@ class HybridRetriever(BaseRetriever):
159
 
160
  # Instantiate retrievers
161
  dense_retriever = index.as_retriever(similarity_top_k=10)
162
- bm25_nodes = node_parser.get_nodes_from_documents(case_docs)
163
  bm25_retriever = BM25Retriever.from_defaults(
164
  nodes=bm25_nodes,
165
  similarity_top_k=10,
@@ -169,7 +133,7 @@ hybrid_retriever = HybridRetriever(dense=dense_retriever, bm25=bm25_retriever)
169
  reranker = SentenceTransformerRerank(
170
  model="cross-encoder/ms-marco-MiniLM-L-2-v2",
171
  top_n=4,
172
- device=device
173
  )
174
 
175
  query_engine = index.as_query_engine(
@@ -178,7 +142,6 @@ query_engine = index.as_query_engine(
178
  llm=None
179
  )
180
 
181
-
182
  # --- CELL 6: Load & Quantize LLaMA Model ---
183
  quant_config = BitsAndBytesConfig(
184
  load_in_4bit=True,
@@ -204,7 +167,6 @@ generator = pipeline(
204
  device_map="auto"
205
  )
206
 
207
-
208
  # --- CELL 7: Chat Logic and Prompting ---
209
  SYSTEM_PROMPT = (
210
  "You are a friendly and helpful Level 0 IT Support Assistant. "
@@ -230,26 +192,22 @@ GREETINGS = {"hello", "hi", "hey", "good morning", "good afternoon", "good eveni
230
 
231
  def format_history(history):
232
  return "".join(
233
- f"{HDR['usr']}\n{u}{HDR['eot']}{HDR['ast']}\n{a}{HDR['eot']}"
234
- for u, a in history
235
  )
236
 
237
  def build_prompt(query, context, history):
238
  if query.lower().strip() in GREETINGS:
239
  return None, "greeting"
240
-
241
  words = query.strip().split()
242
  if len(words) < 3:
243
  return (
244
  "Could you provide more detail about what you're experiencing? "
245
  "Any error messages or steps you've tried will help me assist you."
246
  ), "clarify"
247
-
248
  context_str = "\n---\n".join(node.text for node in context) if context else "No context provided."
249
  hist_str = format_history(history[-3:])
250
-
251
  prompt = (
252
- f"<|begin_of_text|>"
253
  f"{HDR['sys']}\n{SYSTEM_PROMPT}{HDR['eot']}"
254
  f"{hist_str}"
255
  f"{HDR['usr']}\nContext:\n{context_str}\n\nQuestion: {query}{HDR['eot']}"
@@ -260,22 +218,17 @@ def build_prompt(query, context, history):
260
  def chat(query, temperature=0.7, top_p=0.9):
261
  global chat_history
262
  prompt, mode = build_prompt(query, [], chat_history)
263
-
264
  if mode == "greeting":
265
  reply = "Hello there! How can I help with your IT support question today?"
266
  chat_history.append((query, reply))
267
  return reply
268
-
269
  if mode == "clarify":
270
  reply = prompt
271
  chat_history.append((query, reply))
272
  return reply
273
-
274
  response = query_engine.query(query)
275
  context_nodes = response.source_nodes
276
-
277
  prompt, _ = build_prompt(query, context_nodes, chat_history)
278
-
279
  gen_args = {
280
  "do_sample": True,
281
  "max_new_tokens": 350,
@@ -283,19 +236,15 @@ def chat(query, temperature=0.7, top_p=0.9):
283
  "top_p": top_p,
284
  "eos_token_id": tokenizer.eos_token_id
285
  }
286
-
287
  output = generator(prompt, **gen_args)
288
  text = output[0]["generated_text"]
289
  answer = text.split(HDR["ast"])[-1].strip()
290
-
291
  chat_history.append((query, answer))
292
  return answer, context_nodes
293
 
294
-
295
  # --- CELL 8: Gradio Interface ---
296
  with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot") as demo:
297
  gr.Markdown("### 🤖 Level 0 IT Support Chatbot (RAG + Qdrant + LLaMA3)")
298
-
299
  with gr.Row():
300
  with gr.Column(scale=3):
301
  chatbot = gr.Chatbot(label="Chat", height=500, bubble_full_width=False)
@@ -310,35 +259,25 @@ with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot")
310
  top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
311
  with gr.Accordion("Show Retrieved Context", open=False):
312
  context_display = gr.Textbox(label="Retrieved Context", interactive=False, lines=10)
313
-
314
  def respond(message, history, k, temp, top_p):
315
  global chat_history
316
- # Update retriever k value
317
  dense_retriever.similarity_top_k = k
318
  bm25_retriever.similarity_top_k = k
319
-
320
- # Get response and context
321
  reply, context_nodes = chat(message, temperature=temp, top_p=top_p)
322
-
323
- # Format context for display
324
- ctx_text = "\n\n---\n\n".join([f"**Source {i+1} (Score: {node.score:.4f})**\n{node.text}" for i, node in enumerate(context_nodes)])
325
-
326
  history.append([message, reply])
327
  return "", history, ctx_text
328
-
329
  def clear_chat():
330
  global chat_history
331
  chat_history = []
332
  return [], None
333
-
334
- # Event Listeners
335
  inp.submit(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
336
  send_btn.click(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
337
  clear_btn.click(clear_chat, None, [chatbot, context_display], queue=False)
338
 
339
- # --- Main execution block ---
340
  if __name__ == "__main__":
341
- # The launch() command will start a web server that serves the interface.
342
- # It will block the script from exiting.
343
  logger.info("Launching Gradio interface...")
344
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  """
3
  IT Support Chatbot Application
4
  - Converts the original Colab notebook into a deployable Gradio app.
5
+ - Connects to a prebuilt Qdrant index instead of rebuilding it on startup.
6
  - Uses environment variables for API keys.
7
  - Implements a RAG pipeline with LLaMA 3.1, Qdrant, and Hybrid Retrieval.
8
  """
9
 
 
10
  # --- CELL 1: Imports, Logging & Reproducibility ---
11
  import os
12
  import random
 
40
  )
41
  logger = logging.getLogger(__name__)
42
 
43
+ # Apply nest_asyncio for environments like notebooks\ nnest_asyncio.apply()
 
44
 
45
  # Reproducibility
46
  SEED = 42
 
48
  np.random.seed(SEED)
49
  torch.manual_seed(SEED)
50
 
51
+ # --- CELL 0: Load secrets from environment variables ---
52
+ QDRANT_HOST = os.getenv("QDRANT_HOST")
53
  QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
54
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
55
 
56
  if not all([QDRANT_HOST, QDRANT_API_KEY, HF_TOKEN]):
57
  raise EnvironmentError(
 
69
  )
70
  COLLECTION_NAME = "it_support_rag"
71
 
 
72
  # --- CELL 3: Load Dataset & Build Documents ---
73
+ CSV_PATH = "data.csv"
 
 
74
  if not os.path.exists(CSV_PATH):
75
  raise FileNotFoundError(
76
+ f"The data file was not found at {CSV_PATH}. Please upload your data CSV and name it correctly."
 
77
  )
78
 
79
  df = pd.read_csv(CSV_PATH, encoding="ISO-8859-1")
 
80
  case_docs: List[Document] = []
81
  for _, row in df.iterrows():
82
  text = str(row.get("text_chunk", ""))
83
  meta = {
84
  "source_dataset": str(row.get("source_dataset", ""))[:50],
85
+ "category": str(row.get("category", ""))[:100],
86
+ "orig_query": str(row.get("original_query", ""))[:200],
87
+ "orig_solution": str(row.get("original_solution", ""))[:200],
88
  }
89
  case_docs.append(Document(text=text, metadata=meta))
90
  logger.info(f"Loaded {len(case_docs)} documents from {CSV_PATH}.")
91
 
92
+ # --- CELL 4: Load prebuilt Vector Index ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  vector_store = QdrantVectorStore(
94
  client=qdrant,
95
  collection_name=COLLECTION_NAME,
96
  prefer_grpc=False
97
  )
98
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
99
+ index = VectorStoreIndex.load_from_storage(storage_context)
100
+ logger.info("✅ Loaded existing VectorStoreIndex from Qdrant")
 
 
 
 
 
 
 
 
 
 
101
 
102
  # --- CELL 5: Define Hybrid Retriever & Reranker ---
103
+ Settings.llm = None # We will use our own LLM pipeline
104
 
105
  class HybridRetriever(BaseRetriever):
106
  def __init__(self, dense, bm25):
107
  super().__init__()
108
  self.dense = dense
109
  self.bm25 = bm25
110
+
111
  def _retrieve(self, query_bundle: QueryBundle) -> List[Document]:
112
  dense_hits = self.dense.retrieve(query_bundle)
113
  bm25_hits = self.bm25.retrieve(query_bundle)
 
114
  combined = dense_hits + bm25_hits
115
  unique = []
116
  seen = set()
 
123
 
124
  # Instantiate retrievers
125
  dense_retriever = index.as_retriever(similarity_top_k=10)
126
+ bm25_nodes = SentenceSplitter(chunk_size=1024, chunk_overlap=100).get_nodes_from_documents(case_docs)
127
  bm25_retriever = BM25Retriever.from_defaults(
128
  nodes=bm25_nodes,
129
  similarity_top_k=10,
 
133
  reranker = SentenceTransformerRerank(
134
  model="cross-encoder/ms-marco-MiniLM-L-2-v2",
135
  top_n=4,
136
+ device="cuda" if torch.cuda.is_available() else "cpu"
137
  )
138
 
139
  query_engine = index.as_query_engine(
 
142
  llm=None
143
  )
144
 
 
145
  # --- CELL 6: Load & Quantize LLaMA Model ---
146
  quant_config = BitsAndBytesConfig(
147
  load_in_4bit=True,
 
167
  device_map="auto"
168
  )
169
 
 
170
  # --- CELL 7: Chat Logic and Prompting ---
171
  SYSTEM_PROMPT = (
172
  "You are a friendly and helpful Level 0 IT Support Assistant. "
 
192
 
193
  def format_history(history):
194
  return "".join(
195
+ f"{HDR['usr']}\n{u}{HDR['eot']}{HDR['ast']}\n{a}{HDR['eot']}" for u, a in history
 
196
  )
197
 
198
  def build_prompt(query, context, history):
199
  if query.lower().strip() in GREETINGS:
200
  return None, "greeting"
 
201
  words = query.strip().split()
202
  if len(words) < 3:
203
  return (
204
  "Could you provide more detail about what you're experiencing? "
205
  "Any error messages or steps you've tried will help me assist you."
206
  ), "clarify"
 
207
  context_str = "\n---\n".join(node.text for node in context) if context else "No context provided."
208
  hist_str = format_history(history[-3:])
 
209
  prompt = (
210
+ "<|begin_of_text|>"
211
  f"{HDR['sys']}\n{SYSTEM_PROMPT}{HDR['eot']}"
212
  f"{hist_str}"
213
  f"{HDR['usr']}\nContext:\n{context_str}\n\nQuestion: {query}{HDR['eot']}"
 
218
  def chat(query, temperature=0.7, top_p=0.9):
219
  global chat_history
220
  prompt, mode = build_prompt(query, [], chat_history)
 
221
  if mode == "greeting":
222
  reply = "Hello there! How can I help with your IT support question today?"
223
  chat_history.append((query, reply))
224
  return reply
 
225
  if mode == "clarify":
226
  reply = prompt
227
  chat_history.append((query, reply))
228
  return reply
 
229
  response = query_engine.query(query)
230
  context_nodes = response.source_nodes
 
231
  prompt, _ = build_prompt(query, context_nodes, chat_history)
 
232
  gen_args = {
233
  "do_sample": True,
234
  "max_new_tokens": 350,
 
236
  "top_p": top_p,
237
  "eos_token_id": tokenizer.eos_token_id
238
  }
 
239
  output = generator(prompt, **gen_args)
240
  text = output[0]["generated_text"]
241
  answer = text.split(HDR["ast"])[-1].strip()
 
242
  chat_history.append((query, answer))
243
  return answer, context_nodes
244
 
 
245
  # --- CELL 8: Gradio Interface ---
246
  with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot") as demo:
247
  gr.Markdown("### 🤖 Level 0 IT Support Chatbot (RAG + Qdrant + LLaMA3)")
 
248
  with gr.Row():
249
  with gr.Column(scale=3):
250
  chatbot = gr.Chatbot(label="Chat", height=500, bubble_full_width=False)
 
259
  top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
260
  with gr.Accordion("Show Retrieved Context", open=False):
261
  context_display = gr.Textbox(label="Retrieved Context", interactive=False, lines=10)
 
262
  def respond(message, history, k, temp, top_p):
263
  global chat_history
 
264
  dense_retriever.similarity_top_k = k
265
  bm25_retriever.similarity_top_k = k
 
 
266
  reply, context_nodes = chat(message, temperature=temp, top_p=top_p)
267
+ ctx_text = "\n\n---\n\n".join([
268
+ f"**Source {i+1} (Score: {node.score:.4f})**\n{node.text}"
269
+ for i,node in enumerate(context_nodes)
270
+ ])
271
  history.append([message, reply])
272
  return "", history, ctx_text
 
273
  def clear_chat():
274
  global chat_history
275
  chat_history = []
276
  return [], None
 
 
277
  inp.submit(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
278
  send_btn.click(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
279
  clear_btn.click(clear_chat, None, [chatbot, context_display], queue=False)
280
 
 
281
  if __name__ == "__main__":
 
 
282
  logger.info("Launching Gradio interface...")
283
  demo.launch(server_name="0.0.0.0", server_port=7860)