Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,10 +15,11 @@ try:
|
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
import numpy as np
|
| 17 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 18 |
RAG_DEPENDENCIES_AVAILABLE = True
|
| 19 |
except ImportError as e:
|
| 20 |
print(f"RAG dependencies not available: {e}")
|
| 21 |
-
print("Please install: pip install sentence-transformers scikit-learn")
|
| 22 |
RAG_DEPENDENCIES_AVAILABLE = False
|
| 23 |
SentenceTransformer = None
|
| 24 |
import os
|
|
@@ -320,21 +321,32 @@ hf_token = os.getenv('HF_TOKEN')
|
|
| 320 |
# Don't load models initially - load them on demand
|
| 321 |
model_status = "β
Models ready (Dynamic loading)"
|
| 322 |
|
| 323 |
-
# Initialize embedding model
|
| 324 |
if RAG_DEPENDENCIES_AVAILABLE:
|
| 325 |
try:
|
| 326 |
print("Loading embedding model for RAG...")
|
| 327 |
-
# Use CPU for embedding model to save GPU memory for main models
|
| 328 |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 329 |
print("β
Embedding model loaded successfully (CPU)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
except Exception as e:
|
| 331 |
-
print(f"β Error loading
|
| 332 |
import traceback
|
| 333 |
traceback.print_exc()
|
| 334 |
embedding_model = None
|
|
|
|
| 335 |
else:
|
| 336 |
print("β RAG dependencies not available")
|
| 337 |
embedding_model = None
|
|
|
|
| 338 |
|
| 339 |
# Model management functions
|
| 340 |
def load_dolphin_model():
|
|
@@ -371,59 +383,29 @@ def unload_dolphin_model():
|
|
| 371 |
torch.cuda.empty_cache()
|
| 372 |
print("β
DOLPHIN model unloaded")
|
| 373 |
|
| 374 |
-
def
|
| 375 |
-
"""
|
| 376 |
-
global
|
| 377 |
-
|
| 378 |
-
if current_model == "chatbot":
|
| 379 |
-
return chatbot_model, chatbot_processor
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
|
| 384 |
try:
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
chatbot_processor = AutoProcessor.from_pretrained(
|
| 397 |
-
"google/gemma-3n-e4b-it",
|
| 398 |
-
token=hf_token
|
| 399 |
-
)
|
| 400 |
-
|
| 401 |
-
current_model = "chatbot"
|
| 402 |
-
print("β
Gemma chatbot model loaded")
|
| 403 |
-
return chatbot_model, chatbot_processor
|
| 404 |
-
else:
|
| 405 |
-
print("β No HF_TOKEN found")
|
| 406 |
-
return None, None
|
| 407 |
except Exception as e:
|
| 408 |
-
print(f"β Error
|
| 409 |
import traceback
|
| 410 |
traceback.print_exc()
|
| 411 |
-
return None
|
| 412 |
-
|
| 413 |
-
def unload_chatbot_model():
|
| 414 |
-
"""Unload chatbot model to free memory"""
|
| 415 |
-
global chatbot_model, chatbot_processor, current_model
|
| 416 |
-
|
| 417 |
-
if chatbot_model is not None:
|
| 418 |
-
print("Unloading Gemma chatbot model...")
|
| 419 |
-
del chatbot_model, chatbot_processor
|
| 420 |
-
chatbot_model = None
|
| 421 |
-
chatbot_processor = None
|
| 422 |
-
if current_model == "chatbot":
|
| 423 |
-
current_model = None
|
| 424 |
-
if torch.cuda.is_available():
|
| 425 |
-
torch.cuda.empty_cache()
|
| 426 |
-
print("β
Gemma chatbot model unloaded")
|
| 427 |
|
| 428 |
|
| 429 |
# Global state for managing tabs
|
|
@@ -431,12 +413,10 @@ processed_markdown = ""
|
|
| 431 |
show_results_tab = False
|
| 432 |
document_chunks = []
|
| 433 |
document_embeddings = None
|
| 434 |
-
embedding_model = None
|
| 435 |
|
| 436 |
-
# Global model state
|
| 437 |
dolphin_model = None
|
| 438 |
-
|
| 439 |
-
chatbot_processor = None
|
| 440 |
current_model = None # Track which model is currently loaded
|
| 441 |
|
| 442 |
|
|
@@ -518,9 +498,8 @@ def process_uploaded_pdf(pdf_file, progress=gr.Progress()):
|
|
| 518 |
document_embeddings = create_embeddings(document_chunks)
|
| 519 |
print(f"Created {len(document_chunks)} chunks")
|
| 520 |
|
| 521 |
-
#
|
| 522 |
progress(0.95, desc="Preparing chatbot...")
|
| 523 |
-
unload_dolphin_model()
|
| 524 |
|
| 525 |
show_results_tab = True
|
| 526 |
progress(1.0, desc="PDF processed successfully!")
|
|
@@ -549,11 +528,10 @@ def clear_all():
|
|
| 549 |
document_chunks = []
|
| 550 |
document_embeddings = None
|
| 551 |
|
| 552 |
-
# Unload
|
| 553 |
unload_dolphin_model()
|
| 554 |
-
unload_chatbot_model()
|
| 555 |
|
| 556 |
-
return None, "
|
| 557 |
|
| 558 |
|
| 559 |
# Create Gradio interface
|
|
@@ -608,12 +586,14 @@ with gr.Blocks(
|
|
| 608 |
# Home Tab
|
| 609 |
with gr.TabItem("π Home", id="home"):
|
| 610 |
embedding_status = "β
RAG ready" if embedding_model else "β RAG not loaded"
|
|
|
|
| 611 |
current_status = f"Currently loaded: {current_model or 'None'}"
|
| 612 |
gr.Markdown(
|
| 613 |
"# Scholar Express\n"
|
| 614 |
-
"### Upload a research paper to get a web-friendly version
|
| 615 |
f"**System:** {model_status}\n"
|
| 616 |
f"**RAG System:** {embedding_status}\n"
|
|
|
|
| 617 |
f"**Status:** {current_status}"
|
| 618 |
)
|
| 619 |
|
|
@@ -648,7 +628,7 @@ with gr.Blocks(
|
|
| 648 |
|
| 649 |
# Status output (hidden during processing)
|
| 650 |
status_output = gr.Markdown(
|
| 651 |
-
"
|
| 652 |
elem_classes="status-message"
|
| 653 |
)
|
| 654 |
|
|
@@ -685,7 +665,7 @@ with gr.Blocks(
|
|
| 685 |
send_btn = gr.Button("Send", variant="primary", scale=1)
|
| 686 |
|
| 687 |
gr.Markdown(
|
| 688 |
-
"*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) to find relevant sections and provide accurate answers.*",
|
| 689 |
elem_id="chat-notice"
|
| 690 |
)
|
| 691 |
|
|
@@ -714,7 +694,7 @@ with gr.Blocks(
|
|
| 714 |
outputs=[chat_tab]
|
| 715 |
)
|
| 716 |
|
| 717 |
-
# Chatbot functionality
|
| 718 |
def chatbot_response(message, history):
|
| 719 |
if not message.strip():
|
| 720 |
return history
|
|
@@ -723,61 +703,42 @@ with gr.Blocks(
|
|
| 723 |
return history + [[message, "β Please process a PDF document first before asking questions."]]
|
| 724 |
|
| 725 |
try:
|
| 726 |
-
#
|
| 727 |
-
model
|
| 728 |
|
| 729 |
-
if model is None
|
| 730 |
-
return history + [[message, "β Failed to
|
| 731 |
|
| 732 |
-
# Use RAG to get relevant chunks
|
| 733 |
if document_chunks and len(document_chunks) > 0:
|
| 734 |
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
|
| 735 |
context = "\n\n".join(relevant_chunks)
|
| 736 |
else:
|
| 737 |
# Fallback to truncated document if RAG fails
|
| 738 |
-
context = processed_markdown[:
|
| 739 |
-
|
| 740 |
-
# Create chat messages with shorter context
|
| 741 |
-
messages = [
|
| 742 |
-
{
|
| 743 |
-
"role": "system",
|
| 744 |
-
"content": [{"type": "text", "text": "You are a helpful assistant. Answer questions about the document concisely."}]
|
| 745 |
-
},
|
| 746 |
-
{
|
| 747 |
-
"role": "user",
|
| 748 |
-
"content": [{"type": "text", "text": f"Context:\n{context}\n\nQ: {message}"}]
|
| 749 |
-
}
|
| 750 |
-
]
|
| 751 |
|
| 752 |
-
#
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
input_len = inputs["input_ids"].shape[-1]
|
| 762 |
|
| 763 |
-
|
| 764 |
-
|
| 765 |
-
**inputs,
|
| 766 |
-
max_new_tokens=300, # Can be higher now with single model
|
| 767 |
-
do_sample=False,
|
| 768 |
-
temperature=0.7,
|
| 769 |
-
pad_token_id=processor.tokenizer.pad_token_id,
|
| 770 |
-
use_cache=True, # Can enable cache with single model
|
| 771 |
-
num_beams=1
|
| 772 |
-
)
|
| 773 |
-
generation = generation[0][input_len:]
|
| 774 |
|
| 775 |
-
|
| 776 |
|
| 777 |
-
return history + [[message,
|
| 778 |
|
| 779 |
except Exception as e:
|
| 780 |
error_msg = f"β Error generating response: {str(e)}"
|
|
|
|
|
|
|
|
|
|
| 781 |
return history + [[message, error_msg]]
|
| 782 |
|
| 783 |
send_btn.click(
|
|
|
|
| 15 |
from sentence_transformers import SentenceTransformer
|
| 16 |
import numpy as np
|
| 17 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 18 |
+
import google.generativeai as genai
|
| 19 |
RAG_DEPENDENCIES_AVAILABLE = True
|
| 20 |
except ImportError as e:
|
| 21 |
print(f"RAG dependencies not available: {e}")
|
| 22 |
+
print("Please install: pip install sentence-transformers scikit-learn google-generativeai")
|
| 23 |
RAG_DEPENDENCIES_AVAILABLE = False
|
| 24 |
SentenceTransformer = None
|
| 25 |
import os
|
|
|
|
| 321 |
# Don't load models initially - load them on demand
|
| 322 |
model_status = "β
Models ready (Dynamic loading)"
|
| 323 |
|
| 324 |
+
# Initialize embedding model and Gemini API
|
| 325 |
if RAG_DEPENDENCIES_AVAILABLE:
|
| 326 |
try:
|
| 327 |
print("Loading embedding model for RAG...")
|
|
|
|
| 328 |
embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
|
| 329 |
print("β
Embedding model loaded successfully (CPU)")
|
| 330 |
+
|
| 331 |
+
# Initialize Gemini API
|
| 332 |
+
gemini_api_key = os.getenv('GEMINI_API_KEY')
|
| 333 |
+
if gemini_api_key:
|
| 334 |
+
genai.configure(api_key=gemini_api_key)
|
| 335 |
+
gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
|
| 336 |
+
print("β
Gemini API configured successfully")
|
| 337 |
+
else:
|
| 338 |
+
print("β GEMINI_API_KEY not found in environment")
|
| 339 |
+
gemini_model = None
|
| 340 |
except Exception as e:
|
| 341 |
+
print(f"β Error loading models: {e}")
|
| 342 |
import traceback
|
| 343 |
traceback.print_exc()
|
| 344 |
embedding_model = None
|
| 345 |
+
gemini_model = None
|
| 346 |
else:
|
| 347 |
print("β RAG dependencies not available")
|
| 348 |
embedding_model = None
|
| 349 |
+
gemini_model = None
|
| 350 |
|
| 351 |
# Model management functions
|
| 352 |
def load_dolphin_model():
|
|
|
|
| 383 |
torch.cuda.empty_cache()
|
| 384 |
print("β
DOLPHIN model unloaded")
|
| 385 |
|
| 386 |
+
def initialize_gemini_model():
|
| 387 |
+
"""Initialize Gemini API model"""
|
| 388 |
+
global gemini_model
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
+
if gemini_model is not None:
|
| 391 |
+
return gemini_model
|
| 392 |
|
| 393 |
try:
|
| 394 |
+
gemini_api_key = os.getenv('GEMINI_API_KEY')
|
| 395 |
+
if not gemini_api_key:
|
| 396 |
+
print("β GEMINI_API_KEY not found in environment")
|
| 397 |
+
return None
|
| 398 |
+
|
| 399 |
+
print("Initializing Gemini API...")
|
| 400 |
+
genai.configure(api_key=gemini_api_key)
|
| 401 |
+
gemini_model = genai.GenerativeModel('gemma-3n-e4b-it')
|
| 402 |
+
print("β
Gemini API model ready")
|
| 403 |
+
return gemini_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
except Exception as e:
|
| 405 |
+
print(f"β Error initializing Gemini model: {e}")
|
| 406 |
import traceback
|
| 407 |
traceback.print_exc()
|
| 408 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
|
| 411 |
# Global state for managing tabs
|
|
|
|
| 413 |
show_results_tab = False
|
| 414 |
document_chunks = []
|
| 415 |
document_embeddings = None
|
|
|
|
| 416 |
|
| 417 |
+
# Global model state
|
| 418 |
dolphin_model = None
|
| 419 |
+
gemini_model = None
|
|
|
|
| 420 |
current_model = None # Track which model is currently loaded
|
| 421 |
|
| 422 |
|
|
|
|
| 498 |
document_embeddings = create_embeddings(document_chunks)
|
| 499 |
print(f"Created {len(document_chunks)} chunks")
|
| 500 |
|
| 501 |
+
# Keep DOLPHIN model loaded for GPU usage
|
| 502 |
progress(0.95, desc="Preparing chatbot...")
|
|
|
|
| 503 |
|
| 504 |
show_results_tab = True
|
| 505 |
progress(1.0, desc="PDF processed successfully!")
|
|
|
|
| 528 |
document_chunks = []
|
| 529 |
document_embeddings = None
|
| 530 |
|
| 531 |
+
# Unload DOLPHIN model
|
| 532 |
unload_dolphin_model()
|
|
|
|
| 533 |
|
| 534 |
+
return None, "", gr.Tabs(visible=False)
|
| 535 |
|
| 536 |
|
| 537 |
# Create Gradio interface
|
|
|
|
| 586 |
# Home Tab
|
| 587 |
with gr.TabItem("π Home", id="home"):
|
| 588 |
embedding_status = "β
RAG ready" if embedding_model else "β RAG not loaded"
|
| 589 |
+
gemini_status = "β
Gemini API ready" if gemini_model else "β Gemini API not configured"
|
| 590 |
current_status = f"Currently loaded: {current_model or 'None'}"
|
| 591 |
gr.Markdown(
|
| 592 |
"# Scholar Express\n"
|
| 593 |
+
"### Upload a research paper to get a web-friendly version and an AI chatbot powered by Gemini API. DOLPHIN model runs on GPU for optimal performance.\n"
|
| 594 |
f"**System:** {model_status}\n"
|
| 595 |
f"**RAG System:** {embedding_status}\n"
|
| 596 |
+
f"**Gemini API:** {gemini_status}\n"
|
| 597 |
f"**Status:** {current_status}"
|
| 598 |
)
|
| 599 |
|
|
|
|
| 628 |
|
| 629 |
# Status output (hidden during processing)
|
| 630 |
status_output = gr.Markdown(
|
| 631 |
+
"",
|
| 632 |
elem_classes="status-message"
|
| 633 |
)
|
| 634 |
|
|
|
|
| 665 |
send_btn = gr.Button("Send", variant="primary", scale=1)
|
| 666 |
|
| 667 |
gr.Markdown(
|
| 668 |
+
"*Ask questions about your processed document. The AI uses RAG (Retrieval-Augmented Generation) with Gemini API to find relevant sections and provide accurate answers.*",
|
| 669 |
elem_id="chat-notice"
|
| 670 |
)
|
| 671 |
|
|
|
|
| 694 |
outputs=[chat_tab]
|
| 695 |
)
|
| 696 |
|
| 697 |
+
# Chatbot functionality with Gemini API
|
| 698 |
def chatbot_response(message, history):
|
| 699 |
if not message.strip():
|
| 700 |
return history
|
|
|
|
| 703 |
return history + [[message, "β Please process a PDF document first before asking questions."]]
|
| 704 |
|
| 705 |
try:
|
| 706 |
+
# Initialize Gemini model
|
| 707 |
+
model = initialize_gemini_model()
|
| 708 |
|
| 709 |
+
if model is None:
|
| 710 |
+
return history + [[message, "β Failed to initialize Gemini model. Please check your GEMINI_API_KEY."]]
|
| 711 |
|
| 712 |
+
# Use RAG to get relevant chunks from markdown
|
| 713 |
if document_chunks and len(document_chunks) > 0:
|
| 714 |
relevant_chunks = retrieve_relevant_chunks(message, document_chunks, document_embeddings)
|
| 715 |
context = "\n\n".join(relevant_chunks)
|
| 716 |
else:
|
| 717 |
# Fallback to truncated document if RAG fails
|
| 718 |
+
context = processed_markdown[:2000] + "..." if len(processed_markdown) > 2000 else processed_markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 719 |
|
| 720 |
+
# Create prompt for Gemini
|
| 721 |
+
prompt = f"""You are a helpful assistant that answers questions about documents. Use the provided context to answer questions accurately and concisely.
|
| 722 |
+
|
| 723 |
+
Context from the document:
|
| 724 |
+
{context}
|
| 725 |
+
|
| 726 |
+
Question: {message}
|
| 727 |
+
|
| 728 |
+
Please provide a clear and helpful answer based on the context provided."""
|
|
|
|
| 729 |
|
| 730 |
+
# Generate response using Gemini API
|
| 731 |
+
response = model.generate_content(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 732 |
|
| 733 |
+
response_text = response.text if hasattr(response, 'text') else str(response)
|
| 734 |
|
| 735 |
+
return history + [[message, response_text]]
|
| 736 |
|
| 737 |
except Exception as e:
|
| 738 |
error_msg = f"β Error generating response: {str(e)}"
|
| 739 |
+
print(f"Full error: {e}")
|
| 740 |
+
import traceback
|
| 741 |
+
traceback.print_exc()
|
| 742 |
return history + [[message, error_msg]]
|
| 743 |
|
| 744 |
send_btn.click(
|