Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -152,8 +152,111 @@ app = gr.mount_gradio_app(app, demo, path="/")
|
|
| 152 |
def home():
|
| 153 |
return RedirectResponse(url="/")
|
| 154 |
"""
|
|
|
|
|
|
|
|
|
|
| 155 |
import torch
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def home():
|
| 153 |
return RedirectResponse(url="/")
|
| 154 |
"""
|
| 155 |
+
import gradio as gr
|
| 156 |
+
import numpy as np
|
| 157 |
+
import fitz # PyMuPDF
|
| 158 |
import torch
|
| 159 |
+
from fastapi import FastAPI
|
| 160 |
+
from transformers import pipeline
|
| 161 |
+
from PIL import Image
|
| 162 |
+
from starlette.responses import RedirectResponse
|
| 163 |
+
from openpyxl import load_workbook
|
| 164 |
+
from docx import Document
|
| 165 |
+
from pptx import Presentation
|
| 166 |
+
|
| 167 |
+
# β
Initialize FastAPI
|
| 168 |
+
app = FastAPI()
|
| 169 |
+
|
| 170 |
+
# β
Check if CUDA is Available (For Debugging)
|
| 171 |
+
device = "cpu"
|
| 172 |
+
print(f"β
Running on: {device}")
|
| 173 |
+
|
| 174 |
+
# β
Lazy Load Model Function (Loads Only When Needed)
|
| 175 |
+
def get_qa_pipeline():
|
| 176 |
+
print("π Loading QA Model on CPU...")
|
| 177 |
+
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1)
|
| 178 |
+
|
| 179 |
+
def get_image_captioning_pipeline():
|
| 180 |
+
print("π Loading Image Captioning Model on CPU...")
|
| 181 |
+
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=-1)
|
| 182 |
+
|
| 183 |
+
# β
File Type Validation
|
| 184 |
+
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
|
| 185 |
+
|
| 186 |
+
def validate_file_type(file):
|
| 187 |
+
print(f"π Validating file: {file.name}")
|
| 188 |
+
ext = file.name.split(".")[-1].lower()
|
| 189 |
+
return None if ext in ALLOWED_EXTENSIONS else f"β Unsupported file format: {ext}"
|
| 190 |
+
|
| 191 |
+
# β
Extract Text Functions (Optimized)
|
| 192 |
+
def extract_text_from_pdf(file):
|
| 193 |
+
print("π Extracting text from PDF...")
|
| 194 |
+
with fitz.open(file.name) as doc:
|
| 195 |
+
return " ".join(page.get_text() for page in doc)
|
| 196 |
+
|
| 197 |
+
def extract_text_from_docx(file):
|
| 198 |
+
print("π Extracting text from DOCX...")
|
| 199 |
+
doc = Document(file.name)
|
| 200 |
+
return " ".join(p.text for p in doc.paragraphs)
|
| 201 |
+
|
| 202 |
+
def extract_text_from_pptx(file):
|
| 203 |
+
print("π Extracting text from PPTX...")
|
| 204 |
+
ppt = Presentation(file.name)
|
| 205 |
+
return " ".join(shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text"))
|
| 206 |
+
|
| 207 |
+
def extract_text_from_excel(file):
|
| 208 |
+
print("π Extracting text from Excel...")
|
| 209 |
+
wb = load_workbook(file.name, data_only=True)
|
| 210 |
+
return " ".join(" ".join(str(cell) for cell in row if cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True))
|
| 211 |
+
|
| 212 |
+
# β
Question Answering Function (Efficient Processing)
|
| 213 |
+
async def answer_question(file, question: str):
|
| 214 |
+
print("π Processing file for QA...")
|
| 215 |
+
|
| 216 |
+
validation_error = validate_file_type(file)
|
| 217 |
+
if validation_error:
|
| 218 |
+
return validation_error
|
| 219 |
+
|
| 220 |
+
file_ext = file.name.split(".")[-1].lower()
|
| 221 |
+
text = ""
|
| 222 |
+
|
| 223 |
+
if file_ext == "pdf":
|
| 224 |
+
text = extract_text_from_pdf(file)
|
| 225 |
+
elif file_ext == "docx":
|
| 226 |
+
text = extract_text_from_docx(file)
|
| 227 |
+
elif file_ext == "pptx":
|
| 228 |
+
text = extract_text_from_pptx(file)
|
| 229 |
+
elif file_ext == "xlsx":
|
| 230 |
+
text = extract_text_from_excel(file)
|
| 231 |
+
|
| 232 |
+
if not text.strip():
|
| 233 |
+
return "β οΈ No text extracted from the document."
|
| 234 |
+
|
| 235 |
+
print("βοΈ Truncating text for faster processing...")
|
| 236 |
+
truncated_text = text[:1024] # Reduce to 1024 characters for better speed
|
| 237 |
+
|
| 238 |
+
qa_pipeline = get_qa_pipeline()
|
| 239 |
+
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
|
| 240 |
+
|
| 241 |
+
return response[0]["generated_text"]
|
| 242 |
+
|
| 243 |
+
# β
Gradio UI
|
| 244 |
+
with gr.Blocks() as demo:
|
| 245 |
+
gr.Markdown("## π AI-Powered Document & Image QA")
|
| 246 |
+
|
| 247 |
+
with gr.Row():
|
| 248 |
+
file_input = gr.File(label="Upload Document")
|
| 249 |
+
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
|
| 250 |
+
|
| 251 |
+
answer_output = gr.Textbox(label="Answer")
|
| 252 |
+
submit_btn = gr.Button("Get Answer")
|
| 253 |
+
|
| 254 |
+
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
|
| 255 |
+
|
| 256 |
+
# β
Mount Gradio with FastAPI
|
| 257 |
+
app = gr.mount_gradio_app(app, demo, path="/demo")
|
| 258 |
+
|
| 259 |
+
@app.get("/")
|
| 260 |
+
def home():
|
| 261 |
+
return RedirectResponse(url="/demo")
|
| 262 |
+
|