import gradio as gr import numpy as np import fitz # PyMuPDF import tika import torch from fastapi import FastAPI from transformers import pipeline from PIL import Image from io import BytesIO from starlette.responses import RedirectResponse from tika import parser from openpyxl import load_workbook # Initialize Tika for DOCX & PPTX parsing (Ensure Java is installed) tika.initVM() # Initialize FastAPI app = FastAPI() # Load models device = "cuda" if torch.cuda.is_available() else "cpu" qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device) image_captioning_pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} # ✅ Function to Validate File Type def validate_file_type(file): if hasattr(file, "name"): ext = file.name.split(".")[-1].lower() if ext not in ALLOWED_EXTENSIONS: return f"❌ Unsupported file format: {ext}" return None return "❌ Invalid file format!" # ✅ Extract Text from PDF def extract_text_from_pdf(file): with fitz.open(file.name) as doc: return "\n".join([page.get_text() for page in doc]) # ✅ Extract Text from DOCX & PPTX using Tika def extract_text_with_tika(file): return parser.from_file(file.name)["content"] # ✅ Extract Text from Excel def extract_text_from_excel(file): wb = load_workbook(file.name, data_only=True) text = [] for sheet in wb.worksheets: for row in sheet.iter_rows(values_only=True): text.append(" ".join(str(cell) for cell in row if cell)) return "\n".join(text) # ✅ Truncate Long Text for Model def truncate_text(text, max_length=2048): return text[:max_length] if len(text) > max_length else text # ✅ Answer Questions from Image or Document def answer_question(file, question: str): if isinstance(file, np.ndarray): # Image Processing image = Image.fromarray(file) caption = image_captioning_pipeline(image)[0]['generated_text'] response = qa_pipeline(f"Question: {question}\nContext: {caption}") return response[0]["generated_text"] validation_error = validate_file_type(file) if validation_error: return validation_error file_ext = file.name.split(".")[-1].lower() # Extract Text from Supported Documents if file_ext == "pdf": text = extract_text_from_pdf(file) elif file_ext in ["docx", "pptx"]: text = extract_text_with_tika(file) elif file_ext == "xlsx": text = extract_text_from_excel(file) else: return "❌ Unsupported file format!" if not text: return "⚠️ No text extracted from the document." truncated_text = truncate_text(text) response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}") return response[0]["generated_text"] # ✅ Gradio Interface (Separate File & Image Inputs) with gr.Blocks() as demo: gr.Markdown("## 📄 AI-Powered Document & Image QA") with gr.Row(): file_input = gr.File(label="Upload Document") image_input = gr.Image(label="Upload Image") question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?") answer_output = gr.Textbox(label="Answer") submit_btn = gr.Button("Get Answer") submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output) # ✅ Mount Gradio with FastAPI app = gr.mount_gradio_app(app, demo, path="/") @app.get("/") def home(): return RedirectResponse(url="/")