Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
import uvicorn
|
| 3 |
import numpy as np
|
| 4 |
import fitz # PyMuPDF
|
|
@@ -111,6 +111,137 @@ with gr.Blocks() as demo:
|
|
| 111 |
# β
Mount Gradio with FastAPI
|
| 112 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
@app.get("/")
|
| 115 |
def home():
|
| 116 |
return RedirectResponse(url="/")
|
|
|
|
| 1 |
+
"""import gradio as gr
|
| 2 |
import uvicorn
|
| 3 |
import numpy as np
|
| 4 |
import fitz # PyMuPDF
|
|
|
|
| 111 |
# β
Mount Gradio with FastAPI
|
| 112 |
app = gr.mount_gradio_app(app, demo, path="/")
|
| 113 |
|
| 114 |
+
@app.get("/")
|
| 115 |
+
def home():
|
| 116 |
+
return RedirectResponse(url="/")
|
| 117 |
+
|
| 118 |
+
# β
Run FastAPI + Gradio
|
| 119 |
+
if __name__ == "__main__":
|
| 120 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 121 |
+
"""
|
| 122 |
+
import gradio as gr
|
| 123 |
+
import uvicorn
|
| 124 |
+
import numpy as np
|
| 125 |
+
import fitz # PyMuPDF
|
| 126 |
+
import tika
|
| 127 |
+
import torch
|
| 128 |
+
from fastapi import FastAPI
|
| 129 |
+
from transformers import pipeline
|
| 130 |
+
from PIL import Image
|
| 131 |
+
from io import BytesIO
|
| 132 |
+
from starlette.responses import RedirectResponse
|
| 133 |
+
from tika import parser
|
| 134 |
+
from openpyxl import load_workbook
|
| 135 |
+
import os
|
| 136 |
+
|
| 137 |
+
# Initialize Tika for DOCX & PPTX parsing
|
| 138 |
+
tika.initVM()
|
| 139 |
+
|
| 140 |
+
# Initialize FastAPI
|
| 141 |
+
app = FastAPI()
|
| 142 |
+
|
| 143 |
+
# Load models
|
| 144 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 145 |
+
qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)
|
| 146 |
+
image_captioning_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
| 147 |
+
|
| 148 |
+
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
|
| 149 |
+
|
| 150 |
+
# β
Function to Validate File Type
|
| 151 |
+
def validate_file_type(file):
|
| 152 |
+
if isinstance(file, str): # If it's text input (NamedString)
|
| 153 |
+
return None
|
| 154 |
+
if hasattr(file, "name") and file.name:
|
| 155 |
+
ext = file.name.split(".")[-1].lower()
|
| 156 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 157 |
+
return f"β Unsupported file format: {ext}"
|
| 158 |
+
return None
|
| 159 |
+
return "β Invalid file format!"
|
| 160 |
+
|
| 161 |
+
# β
Extract Text from PDF
|
| 162 |
+
def extract_text_from_pdf(pdf_bytes):
|
| 163 |
+
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
| 164 |
+
return "\n".join([page.get_text() for page in doc])
|
| 165 |
+
|
| 166 |
+
# β
Extract Text from DOCX & PPTX using Tika
|
| 167 |
+
def extract_text_with_tika(file_bytes):
|
| 168 |
+
parsed = parser.from_buffer(file_bytes)
|
| 169 |
+
return parsed.get("content", "").strip() if parsed else ""
|
| 170 |
+
|
| 171 |
+
# β
Extract Text from Excel
|
| 172 |
+
def extract_text_from_excel(file_path):
|
| 173 |
+
wb = load_workbook(file_path, data_only=True)
|
| 174 |
+
text = []
|
| 175 |
+
for sheet in wb.worksheets:
|
| 176 |
+
for row in sheet.iter_rows(values_only=True):
|
| 177 |
+
text.append(" ".join(str(cell) for cell in row if cell))
|
| 178 |
+
return "\n".join(text)
|
| 179 |
+
|
| 180 |
+
# β
Truncate Long Text for Model
|
| 181 |
+
def truncate_text(text, max_length=2048):
|
| 182 |
+
return text[:max_length] if len(text) > max_length else text
|
| 183 |
+
|
| 184 |
+
# β
Answer Questions from Image or Document
|
| 185 |
+
def answer_question(file, question: str):
|
| 186 |
+
# πΌοΈ Handle Image Input (Gradio sends NumPy arrays)
|
| 187 |
+
if isinstance(file, np.ndarray):
|
| 188 |
+
image = Image.fromarray(file)
|
| 189 |
+
caption = image_captioning_pipeline(image)[0]['generated_text']
|
| 190 |
+
response = qa_pipeline(f"Question: {question}\nContext: {caption}")
|
| 191 |
+
return response[0]["generated_text"]
|
| 192 |
+
|
| 193 |
+
# Validate File
|
| 194 |
+
validation_error = validate_file_type(file)
|
| 195 |
+
if validation_error:
|
| 196 |
+
return validation_error
|
| 197 |
+
|
| 198 |
+
file_ext = file.name.split(".")[-1].lower() if hasattr(file, "name") else None
|
| 199 |
+
|
| 200 |
+
# π οΈ Fix: Read File Bytes Correctly (Gradio Provides File Path)
|
| 201 |
+
try:
|
| 202 |
+
with open(file.name, "rb") as f:
|
| 203 |
+
file_bytes = f.read()
|
| 204 |
+
except Exception as e:
|
| 205 |
+
return f"β Error reading file: {str(e)}"
|
| 206 |
+
|
| 207 |
+
if not file_bytes:
|
| 208 |
+
return "β Could not read file content!"
|
| 209 |
+
|
| 210 |
+
# π Extract Text from Supported Documents
|
| 211 |
+
if file_ext == "pdf":
|
| 212 |
+
text = extract_text_from_pdf(file_bytes)
|
| 213 |
+
elif file_ext in ["docx", "pptx"]:
|
| 214 |
+
text = extract_text_with_tika(file_bytes)
|
| 215 |
+
elif file_ext == "xlsx":
|
| 216 |
+
text = extract_text_from_excel(file.name)
|
| 217 |
+
else:
|
| 218 |
+
return "β Unsupported file format!"
|
| 219 |
+
|
| 220 |
+
if not text.strip():
|
| 221 |
+
return "β οΈ No text extracted from the document."
|
| 222 |
+
|
| 223 |
+
# π₯ Run Model on Extracted Text
|
| 224 |
+
truncated_text = truncate_text(text)
|
| 225 |
+
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
|
| 226 |
+
|
| 227 |
+
return response[0]["generated_text"]
|
| 228 |
+
|
| 229 |
+
# β
Gradio Interface (Unified for Images & Documents)
|
| 230 |
+
with gr.Blocks() as demo:
|
| 231 |
+
gr.Markdown("## π AI-Powered Document & Image QA")
|
| 232 |
+
|
| 233 |
+
with gr.Row():
|
| 234 |
+
file_input = gr.File(label="Upload Document / Image")
|
| 235 |
+
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
|
| 236 |
+
|
| 237 |
+
answer_output = gr.Textbox(label="Answer")
|
| 238 |
+
|
| 239 |
+
submit_btn = gr.Button("Get Answer")
|
| 240 |
+
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
|
| 241 |
+
|
| 242 |
+
# β
Mount Gradio with FastAPI
|
| 243 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
| 244 |
+
|
| 245 |
@app.get("/")
|
| 246 |
def home():
|
| 247 |
return RedirectResponse(url="/")
|