brahmanarisetty commited on
Commit
f8e3778
·
verified ·
1 Parent(s): 3e5b384

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +339 -0
  2. requirements.txt +18 -0
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ IT Support Chatbot Application
4
+ - Converts the original Colab notebook into a deployable Gradio app.
5
+ - Loads data from a local CSV file.
6
+ - Uses environment variables for API keys.
7
+ - Implements a RAG pipeline with LLaMA 3.1, Qdrant, and Hybrid Retrieval.
8
+ """
9
+
10
+ # --- CELL 1: Imports, Logging & Reproducibility ---
11
+ import os
12
+ import random
13
+ import logging
14
+ import numpy as np
15
+ import torch
16
+ import nest_asyncio
17
+ import pandas as pd
18
+ import gradio as gr
19
+ from typing import List
20
+
21
+ # Llama-Index & Transformers
22
+ from llama_index.core import (
23
+ SimpleDirectoryReader, VectorStoreIndex, StorageContext,
24
+ PromptTemplate, Settings, QueryBundle, Document
25
+ )
26
+ from llama_index.core.postprocessor import SentenceTransformerRerank
27
+ from llama_index.core.retrievers import BaseRetriever
28
+ from llama_index.retrievers.bm25 import BM25Retriever
29
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
30
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
31
+ from llama_index.core.node_parser import SentenceSplitter
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
33
+ from huggingface_hub import login
34
+ import qdrant_client
35
+
36
+ # Configure logging
37
+ logging.basicConfig(
38
+ format='%(asctime)s %(levelname)s: %(message)s',
39
+ level=logging.INFO
40
+ )
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # Apply nest_asyncio for environments like notebooks
44
+ nest_asyncio.apply()
45
+
46
+ # Reproducibility
47
+ SEED = 42
48
+ random.seed(SEED)
49
+ np.random.seed(SEED)
50
+ torch.manual_seed(SEED)
51
+
52
+
53
+ # --- CELL 2: Environment & Qdrant Connection Setup ---
54
+
55
+ if not all([QDRANT_HOST, QDRANT_API_KEY, HF_TOKEN]):
56
+ raise EnvironmentError(
57
+ "Please set QDRANT_HOST, QDRANT_API_KEY, and HF_TOKEN environment variables."
58
+ )
59
+
60
+ # Login to Hugging Face
61
+ login(token=HF_TOKEN)
62
+
63
+ # Initialize Qdrant client
64
+ qdrant = qdrant_client.QdrantClient(
65
+ url=QDRANT_HOST,
66
+ api_key=QDRANT_API_KEY,
67
+ prefer_grpc=False
68
+ )
69
+ COLLECTION_NAME = "it_support_rag"
70
+
71
+
72
+ # --- CELL 3: Load Dataset & Build Documents ---
73
+ # Load data from a local CSV file.
74
+ # Make sure this CSV file is in the same directory as app.py when deploying.
75
+ CSV_PATH = "data.csv" # Or whatever you name your CSV file
76
+ if not os.path.exists(CSV_PATH):
77
+ raise FileNotFoundError(
78
+ f"The data file was not found at {CSV_PATH}. "
79
+ "Please upload your data CSV and name it correctly."
80
+ )
81
+
82
+ df = pd.read_csv(CSV_PATH, encoding="ISO-8859-1")
83
+
84
+ case_docs: List[Document] = []
85
+ for _, row in df.iterrows():
86
+ text = str(row.get("text_chunk", ""))
87
+ meta = {
88
+ "source_dataset": str(row.get("source_dataset", ""))[:50],
89
+ "category": str(row.get("category", ""))[:100],
90
+ "orig_query": str(row.get("original_query", ""))[:200],
91
+ "orig_solution": str(row.get("original_solution", ""))[:200]
92
+ }
93
+ case_docs.append(Document(text=text, metadata=meta))
94
+ logger.info(f"Loaded {len(case_docs)} documents from {CSV_PATH}.")
95
+
96
+
97
+ # --- CELL 4: Create Vector Index ---
98
+ # Embedding model
99
+ device = "cuda" if torch.cuda.is_available() else "cpu"
100
+ logger.info(f"Using device: {device}")
101
+ embed_model = HuggingFaceEmbedding(
102
+ model_name="BAAI/bge-large-en-v1.5",
103
+ device=device
104
+ )
105
+
106
+ # Node parser for chunking
107
+ node_parser = SentenceSplitter(
108
+ chunk_size=1024,
109
+ chunk_overlap=100,
110
+ paragraph_separator="\n\n"
111
+ )
112
+
113
+ # Qdrant-backed vector store
114
+ vector_store = QdrantVectorStore(
115
+ client=qdrant,
116
+ collection_name=COLLECTION_NAME,
117
+ prefer_grpc=False
118
+ )
119
+
120
+ # Build the index (will upload to Qdrant if collection doesn't exist)
121
+ # Note: This step can be slow the first time it's run.
122
+ logger.info("Initializing VectorStoreIndex...")
123
+ index = VectorStoreIndex.from_documents(
124
+ documents=case_docs,
125
+ storage_context=StorageContext.from_defaults(vector_store=vector_store),
126
+ embed_model=embed_model,
127
+ node_parser=node_parser,
128
+ show_progress=True
129
+ )
130
+ logger.info("VectorStoreIndex initialized successfully.")
131
+
132
+
133
+ # --- CELL 5: Define Hybrid Retriever & Reranker ---
134
+ Settings.llm = None # We will use our own LLM pipeline
135
+
136
+ class HybridRetriever(BaseRetriever):
137
+ def __init__(self, dense, bm25):
138
+ super().__init__()
139
+ self.dense = dense
140
+ self.bm25 = bm25
141
+ def _retrieve(self, query_bundle: QueryBundle) -> List[Document]:
142
+ dense_hits = self.dense.retrieve(query_bundle)
143
+ bm25_hits = self.bm25.retrieve(query_bundle)
144
+
145
+ combined = dense_hits + bm25_hits
146
+ unique = []
147
+ seen = set()
148
+ for hit in combined:
149
+ nid = hit.node.node_id
150
+ if nid not in seen:
151
+ seen.add(nid)
152
+ unique.append(hit)
153
+ return unique
154
+
155
+ # Instantiate retrievers
156
+ dense_retriever = index.as_retriever(similarity_top_k=10)
157
+ bm25_nodes = node_parser.get_nodes_from_documents(case_docs)
158
+ bm25_retriever = BM25Retriever.from_defaults(
159
+ nodes=bm25_nodes,
160
+ similarity_top_k=10,
161
+ )
162
+ hybrid_retriever = HybridRetriever(dense=dense_retriever, bm25=bm25_retriever)
163
+
164
+ reranker = SentenceTransformerRerank(
165
+ model="cross-encoder/ms-marco-MiniLM-L-2-v2",
166
+ top_n=4,
167
+ device=device
168
+ )
169
+
170
+ query_engine = index.as_query_engine(
171
+ retriever=hybrid_retriever,
172
+ node_postprocessors=[reranker],
173
+ llm=None
174
+ )
175
+
176
+
177
+ # --- CELL 6: Load & Quantize LLaMA Model ---
178
+ quant_config = BitsAndBytesConfig(
179
+ load_in_4bit=True,
180
+ bnb_4bit_quant_type="nf4",
181
+ bnb_4bit_use_double_quant=True,
182
+ bnb_4bit_compute_dtype=torch.bfloat16
183
+ )
184
+
185
+ MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
186
+ logger.info(f"Loading model: {MODEL_ID}")
187
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
188
+ llm = AutoModelForCausalLM.from_pretrained(
189
+ MODEL_ID,
190
+ quantization_config=quant_config,
191
+ device_map="auto"
192
+ )
193
+ logger.info("Model loaded successfully.")
194
+
195
+ generator = pipeline(
196
+ task="text-generation",
197
+ model=llm,
198
+ tokenizer=tokenizer,
199
+ device_map="auto"
200
+ )
201
+
202
+
203
+ # --- CELL 7: Chat Logic and Prompting ---
204
+ SYSTEM_PROMPT = (
205
+ "You are a friendly and helpful Level 0 IT Support Assistant. "
206
+ "Use a conversational tone and guide users step-by-step. "
207
+ "If the user's question lacks details or clarity, ask a concise follow-up question "
208
+ "to gather the information you need before providing a solution. "
209
+ "Once clarified, then:\n"
210
+ "1. Diagnose the problem.\n"
211
+ "2. Provide step-by-step solutions with bullet points.\n"
212
+ "3. Offer additional recommendations or safety warnings.\n"
213
+ "4. End with a polite closing."
214
+ )
215
+
216
+ HDR = {
217
+ "sys": "<|start_header_id|>system<|end_header_id|>",
218
+ "usr": "<|start_header_id|>user<|end_header_id|>",
219
+ "ast": "<|start_header_id|>assistant<|end_header_id|>",
220
+ "eot": "<|eot_id|>"
221
+ }
222
+
223
+ chat_history = []
224
+ GREETINGS = {"hello", "hi", "hey", "good morning", "good afternoon", "good evening"}
225
+
226
+ def format_history(history):
227
+ return "".join(
228
+ f"{HDR['usr']}\n{u}{HDR['eot']}{HDR['ast']}\n{a}{HDR['eot']}"
229
+ for u, a in history
230
+ )
231
+
232
+ def build_prompt(query, context, history):
233
+ if query.lower().strip() in GREETINGS:
234
+ return None, "greeting"
235
+
236
+ words = query.strip().split()
237
+ if len(words) < 3:
238
+ return (
239
+ "Could you provide more detail about what you're experiencing? "
240
+ "Any error messages or steps you've tried will help me assist you."
241
+ ), "clarify"
242
+
243
+ context_str = "\n---\n".join(node.text for node in context) if context else "No context provided."
244
+ hist_str = format_history(history[-3:])
245
+
246
+ prompt = (
247
+ f"<|begin_of_text|>"
248
+ f"{HDR['sys']}\n{SYSTEM_PROMPT}{HDR['eot']}"
249
+ f"{hist_str}"
250
+ f"{HDR['usr']}\nContext:\n{context_str}\n\nQuestion: {query}{HDR['eot']}"
251
+ f"{HDR['ast']}\n"
252
+ )
253
+ return prompt, "rag"
254
+
255
+ def chat(query, temperature=0.7, top_p=0.9):
256
+ global chat_history
257
+ prompt, mode = build_prompt(query, [], chat_history)
258
+
259
+ if mode == "greeting":
260
+ reply = "Hello there! How can I help with your IT support question today?"
261
+ chat_history.append((query, reply))
262
+ return reply
263
+
264
+ if mode == "clarify":
265
+ reply = prompt
266
+ chat_history.append((query, reply))
267
+ return reply
268
+
269
+ response = query_engine.query(query)
270
+ context_nodes = response.source_nodes
271
+
272
+ prompt, _ = build_prompt(query, context_nodes, chat_history)
273
+
274
+ gen_args = {
275
+ "do_sample": True,
276
+ "max_new_tokens": 350,
277
+ "temperature": temperature,
278
+ "top_p": top_p,
279
+ "eos_token_id": tokenizer.eos_token_id
280
+ }
281
+
282
+ output = generator(prompt, **gen_args)
283
+ text = output[0]["generated_text"]
284
+ answer = text.split(HDR["ast"])[-1].strip()
285
+
286
+ chat_history.append((query, answer))
287
+ return answer, context_nodes
288
+
289
+
290
+ # --- CELL 8: Gradio Interface ---
291
+ with gr.Blocks(theme=gr.themes.Soft(), title="💬 Level 0 IT Support Chatbot") as demo:
292
+ gr.Markdown("### 🤖 Level 0 IT Support Chatbot (RAG + Qdrant + LLaMA3)")
293
+
294
+ with gr.Row():
295
+ with gr.Column(scale=3):
296
+ chatbot = gr.Chatbot(label="Chat", height=500, bubble_full_width=False)
297
+ inp = gr.Textbox(placeholder="Ask your IT support question...", label="Your Message", lines=2)
298
+ with gr.Row():
299
+ send_btn = gr.Button("Send", variant="primary")
300
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
301
+ with gr.Column(scale=1):
302
+ gr.Markdown("### ⚙️ Settings")
303
+ k_slider = gr.Slider(1, 20, value=10, step=1, label="Context Hits (k)")
304
+ temp_slider = gr.Slider(0.0, 1.0, value=0.7, step=0.01, label="Temperature")
305
+ top_p_slider = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
306
+ with gr.Accordion("Show Retrieved Context", open=False):
307
+ context_display = gr.Textbox(label="Retrieved Context", interactive=False, lines=10)
308
+
309
+ def respond(message, history, k, temp, top_p):
310
+ global chat_history
311
+ # Update retriever k value
312
+ dense_retriever.similarity_top_k = k
313
+ bm25_retriever.similarity_top_k = k
314
+
315
+ # Get response and context
316
+ reply, context_nodes = chat(message, temperature=temp, top_p=top_p)
317
+
318
+ # Format context for display
319
+ ctx_text = "\n\n---\n\n".join([f"**Source {i+1} (Score: {node.score:.4f})**\n{node.text}" for i, node in enumerate(context_nodes)])
320
+
321
+ history.append([message, reply])
322
+ return "", history, ctx_text
323
+
324
+ def clear_chat():
325
+ global chat_history
326
+ chat_history = []
327
+ return [], None
328
+
329
+ # Event Listeners
330
+ inp.submit(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
331
+ send_btn.click(respond, [inp, chatbot, k_slider, temp_slider, top_p_slider], [inp, chatbot, context_display])
332
+ clear_btn.click(clear_chat, None, [chatbot, context_display], queue=False)
333
+
334
+ # --- Main execution block ---
335
+ if __name__ == "__main__":
336
+ # The launch() command will start a web server that serves the interface.
337
+ # It will block the script from exiting.
338
+ logger.info("Launching Gradio interface...")
339
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ llama-index-core
2
+ llama-index-vector-stores-qdrant
3
+ llama-index-embeddings-huggingface
4
+ llama-index-retrievers-bm25
5
+ llama-index-llms-huggingface
6
+ sentence-transformers
7
+ transformers
8
+ accelerate
9
+ gradio
10
+ qdrant-client
11
+ bitsandbytes
12
+ rouge-score
13
+ bert-score
14
+ evaluate
15
+ nest_asyncio
16
+ torch
17
+ pandas
18
+ numpy