Spaces:
Sleeping
Sleeping
| """import gradio as gr | |
| import numpy as np | |
| import fitz # PyMuPDF | |
| import torch | |
| import asyncio | |
| from fastapi import FastAPI | |
| from transformers import pipeline | |
| from PIL import Image | |
| from starlette.responses import RedirectResponse | |
| from openpyxl import load_workbook | |
| from docx import Document | |
| from pptx import Presentation | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Use GPU if available | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"β Using device: {device}") | |
| # Function to load models lazily | |
| def get_qa_pipeline(): | |
| print("π Loading QA pipeline model...") | |
| return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16) | |
| def get_image_captioning_pipeline(): | |
| print("π Loading Image Captioning model...") | |
| return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
| ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} | |
| MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing | |
| # β Validate File Type | |
| def validate_file_type(file): | |
| if hasattr(file, "name"): | |
| ext = file.name.split(".")[-1].lower() | |
| print(f"π File extension detected: {ext}") | |
| if ext not in ALLOWED_EXTENSIONS: | |
| print(f"β Unsupported file format: {ext}") | |
| return f"β Unsupported file format: {ext}" | |
| return None | |
| print("β Invalid file format!") | |
| return "β Invalid file format!" | |
| # β Extract Text from PDF | |
| async def extract_text_from_pdf(file): | |
| print(f"π Extracting text from PDF: {file.name}") | |
| loop = asyncio.get_event_loop() | |
| text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)])) | |
| print(f"β Extracted {len(text)} characters from PDF") | |
| return text | |
| # β Extract Text from DOCX | |
| async def extract_text_from_docx(file): | |
| print(f"π Extracting text from DOCX: {file.name}") | |
| loop = asyncio.get_event_loop() | |
| text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs])) | |
| print(f"β Extracted {len(text)} characters from DOCX") | |
| return text | |
| # β Extract Text from PPTX | |
| async def extract_text_from_pptx(file): | |
| print(f"π Extracting text from PPTX: {file.name}") | |
| loop = asyncio.get_event_loop() | |
| text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")])) | |
| print(f"β Extracted {len(text)} characters from PPTX") | |
| return text | |
| # β Extract Text from Excel | |
| async def extract_text_from_excel(file): | |
| print(f"π Extracting text from Excel: {file.name}") | |
| loop = asyncio.get_event_loop() | |
| text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)])) | |
| print(f"β Extracted {len(text)} characters from Excel") | |
| return text | |
| # β Truncate Long Text | |
| def truncate_text(text): | |
| print(f"βοΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...") | |
| return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text | |
| # β Answer Questions from Image or Document | |
| async def answer_question(file, question: str): | |
| print(f"β Question received: {question}") | |
| if isinstance(file, np.ndarray): # Image Processing | |
| print("πΌοΈ Processing image for captioning...") | |
| image = Image.fromarray(file) | |
| image_captioning = get_image_captioning_pipeline() | |
| caption = image_captioning(image)[0]['generated_text'] | |
| print(f"π Generated caption: {caption}") | |
| qa = get_qa_pipeline() | |
| print("π€ Running QA model...") | |
| response = qa(f"Question: {question}\nContext: {caption}") | |
| print(f"β Model response: {response[0]['generated_text']}") | |
| return response[0]["generated_text"] | |
| validation_error = validate_file_type(file) | |
| if validation_error: | |
| return validation_error | |
| file_ext = file.name.split(".")[-1].lower() | |
| # Extract text asynchronously | |
| if file_ext == "pdf": | |
| text = await extract_text_from_pdf(file) | |
| elif file_ext == "docx": | |
| text = await extract_text_from_docx(file) | |
| elif file_ext == "pptx": | |
| text = await extract_text_from_pptx(file) | |
| elif file_ext == "xlsx": | |
| text = await extract_text_from_excel(file) | |
| else: | |
| print("β Unsupported file format!") | |
| return "β Unsupported file format!" | |
| if not text: | |
| print("β οΈ No text extracted from the document.") | |
| return "β οΈ No text extracted from the document." | |
| truncated_text = truncate_text(text) | |
| # Run QA model asynchronously | |
| print("π€ Running QA model...") | |
| loop = asyncio.get_event_loop() | |
| qa = get_qa_pipeline() | |
| response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}") | |
| print(f"β Model response: {response[0]['generated_text']}") | |
| return response[0]["generated_text"] | |
| # β Gradio Interface (Separate File & Image Inputs) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## π AI-Powered Document & Image QA") | |
| with gr.Row(): | |
| file_input = gr.File(label="Upload Document") | |
| image_input = gr.Image(label="Upload Image") | |
| question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?") | |
| answer_output = gr.Textbox(label="Answer") | |
| submit_btn = gr.Button("Get Answer") | |
| submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output) | |
| # β Mount Gradio with FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| @app.get("/") | |
| def home(): | |
| return RedirectResponse(url="/") | |
| """ | |
| import torch | |
| print("CUDA Available:", torch.cuda.is_available()) | |
| print("Torch Device Count:", torch.cuda.device_count()) | |
| print("Current Device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU") | |
| print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None") | |