Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,9 +11,16 @@ import numpy as np
|
|
| 11 |
from PIL import Image
|
| 12 |
from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
|
| 13 |
import torch
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
import os
|
| 18 |
import tempfile
|
| 19 |
import uuid
|
|
@@ -316,13 +323,20 @@ except Exception as e:
|
|
| 316 |
model_status = f"β Model failed to load: {str(e)}"
|
| 317 |
|
| 318 |
# Initialize embedding model for RAG
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
embedding_model = None
|
| 327 |
|
| 328 |
# Initialize chatbot model
|
|
@@ -369,7 +383,7 @@ embedding_model = None
|
|
| 369 |
# chatbot_model is initialized above
|
| 370 |
|
| 371 |
|
| 372 |
-
def chunk_document(text, chunk_size=
|
| 373 |
"""Split document into overlapping chunks for RAG"""
|
| 374 |
words = text.split()
|
| 375 |
chunks = []
|
|
@@ -387,8 +401,8 @@ def create_embeddings(chunks):
|
|
| 387 |
return None
|
| 388 |
|
| 389 |
try:
|
| 390 |
-
# Process in
|
| 391 |
-
batch_size =
|
| 392 |
embeddings = []
|
| 393 |
|
| 394 |
for i in range(0, len(chunks), batch_size):
|
|
@@ -401,10 +415,10 @@ def create_embeddings(chunks):
|
|
| 401 |
print(f"Error creating embeddings: {e}")
|
| 402 |
return None
|
| 403 |
|
| 404 |
-
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=
|
| 405 |
"""Retrieve most relevant chunks for a question"""
|
| 406 |
if embedding_model is None or embeddings is None:
|
| 407 |
-
return chunks[:
|
| 408 |
|
| 409 |
try:
|
| 410 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
|
@@ -417,7 +431,7 @@ def retrieve_relevant_chunks(question, chunks, embeddings, top_k=2):
|
|
| 417 |
return relevant_chunks
|
| 418 |
except Exception as e:
|
| 419 |
print(f"Error retrieving chunks: {e}")
|
| 420 |
-
return chunks[:
|
| 421 |
|
| 422 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
| 423 |
"""Main processing function for uploaded PDF"""
|
|
@@ -467,10 +481,6 @@ def clear_all():
|
|
| 467 |
document_chunks = []
|
| 468 |
document_embeddings = None
|
| 469 |
|
| 470 |
-
# Clear GPU memory
|
| 471 |
-
if torch.cuda.is_available():
|
| 472 |
-
torch.cuda.empty_cache()
|
| 473 |
-
|
| 474 |
return None, "", gr.Tabs(visible=False)
|
| 475 |
|
| 476 |
|
|
@@ -676,23 +686,15 @@ with gr.Blocks(
|
|
| 676 |
input_len = inputs["input_ids"].shape[-1]
|
| 677 |
|
| 678 |
with torch.inference_mode():
|
| 679 |
-
# Clear cache before generation
|
| 680 |
-
if torch.cuda.is_available():
|
| 681 |
-
torch.cuda.empty_cache()
|
| 682 |
-
|
| 683 |
generation = chatbot_model.generate(
|
| 684 |
**inputs,
|
| 685 |
-
max_new_tokens=
|
| 686 |
do_sample=False,
|
| 687 |
temperature=0.7,
|
| 688 |
pad_token_id=chatbot_processor.tokenizer.pad_token_id,
|
| 689 |
-
use_cache=
|
| 690 |
)
|
| 691 |
generation = generation[0][input_len:]
|
| 692 |
-
|
| 693 |
-
# Clear cache after generation
|
| 694 |
-
if torch.cuda.is_available():
|
| 695 |
-
torch.cuda.empty_cache()
|
| 696 |
|
| 697 |
response = chatbot_processor.decode(generation, skip_special_tokens=True)
|
| 698 |
|
|
|
|
| 11 |
from PIL import Image
|
| 12 |
from transformers import AutoProcessor, VisionEncoderDecoderModel, Gemma3nForConditionalGeneration, pipeline
|
| 13 |
import torch
|
| 14 |
+
try:
|
| 15 |
+
from sentence_transformers import SentenceTransformer
|
| 16 |
+
import numpy as np
|
| 17 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 18 |
+
RAG_DEPENDENCIES_AVAILABLE = True
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(f"RAG dependencies not available: {e}")
|
| 21 |
+
print("Please install: pip install sentence-transformers scikit-learn")
|
| 22 |
+
RAG_DEPENDENCIES_AVAILABLE = False
|
| 23 |
+
SentenceTransformer = None
|
| 24 |
import os
|
| 25 |
import tempfile
|
| 26 |
import uuid
|
|
|
|
| 323 |
model_status = f"β Model failed to load: {str(e)}"
|
| 324 |
|
| 325 |
# Initialize embedding model for RAG
|
| 326 |
+
if RAG_DEPENDENCIES_AVAILABLE:
|
| 327 |
+
try:
|
| 328 |
+
print("Loading embedding model for RAG...")
|
| 329 |
+
# Use GPU for embedding model with 24GB VRAM
|
| 330 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 331 |
+
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 332 |
+
print(f"β
Embedding model loaded successfully ({device})")
|
| 333 |
+
except Exception as e:
|
| 334 |
+
print(f"β Error loading embedding model: {e}")
|
| 335 |
+
import traceback
|
| 336 |
+
traceback.print_exc()
|
| 337 |
+
embedding_model = None
|
| 338 |
+
else:
|
| 339 |
+
print("β RAG dependencies not available")
|
| 340 |
embedding_model = None
|
| 341 |
|
| 342 |
# Initialize chatbot model
|
|
|
|
| 383 |
# chatbot_model is initialized above
|
| 384 |
|
| 385 |
|
| 386 |
+
def chunk_document(text, chunk_size=500, overlap=50):
|
| 387 |
"""Split document into overlapping chunks for RAG"""
|
| 388 |
words = text.split()
|
| 389 |
chunks = []
|
|
|
|
| 401 |
return None
|
| 402 |
|
| 403 |
try:
|
| 404 |
+
# Process in larger batches with 24GB GPU
|
| 405 |
+
batch_size = 64
|
| 406 |
embeddings = []
|
| 407 |
|
| 408 |
for i in range(0, len(chunks), batch_size):
|
|
|
|
| 415 |
print(f"Error creating embeddings: {e}")
|
| 416 |
return None
|
| 417 |
|
| 418 |
+
def retrieve_relevant_chunks(question, chunks, embeddings, top_k=3):
|
| 419 |
"""Retrieve most relevant chunks for a question"""
|
| 420 |
if embedding_model is None or embeddings is None:
|
| 421 |
+
return chunks[:3] # Fallback to first 3 chunks
|
| 422 |
|
| 423 |
try:
|
| 424 |
question_embedding = embedding_model.encode([question], show_progress_bar=False)
|
|
|
|
| 431 |
return relevant_chunks
|
| 432 |
except Exception as e:
|
| 433 |
print(f"Error retrieving chunks: {e}")
|
| 434 |
+
return chunks[:3] # Fallback
|
| 435 |
|
| 436 |
def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
| 437 |
"""Main processing function for uploaded PDF"""
|
|
|
|
| 481 |
document_chunks = []
|
| 482 |
document_embeddings = None
|
| 483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
return None, "", gr.Tabs(visible=False)
|
| 485 |
|
| 486 |
|
|
|
|
| 686 |
input_len = inputs["input_ids"].shape[-1]
|
| 687 |
|
| 688 |
with torch.inference_mode():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
generation = chatbot_model.generate(
|
| 690 |
**inputs,
|
| 691 |
+
max_new_tokens=400, # Increased for 24GB GPU
|
| 692 |
do_sample=False,
|
| 693 |
temperature=0.7,
|
| 694 |
pad_token_id=chatbot_processor.tokenizer.pad_token_id,
|
| 695 |
+
use_cache=True # Enable KV cache with more VRAM
|
| 696 |
)
|
| 697 |
generation = generation[0][input_len:]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
|
| 699 |
response = chatbot_processor.decode(generation, skip_special_tokens=True)
|
| 700 |
|