ikraamkb commited on
Commit
4f113b7
·
verified ·
1 Parent(s): 76f340e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -83
app.py CHANGED
@@ -103,104 +103,59 @@ async def get_docs(request: Request):
103
  from fastapi import FastAPI
104
  from fastapi.responses import RedirectResponse
105
  import gradio as gr
106
- from transformers import pipeline
107
- import pdfplumber
108
- import docx
109
- from pptx import Presentation
110
  from PIL import Image
111
- import pytesseract
112
- import easyocr
113
- import os
114
- from io import BytesIO
115
 
116
- # Initialize FastAPI app
117
  app = FastAPI()
118
 
119
- # Load models
120
- qa_pipeline = pipeline("question-answering", model="bert-large-uncased-whole-word-masking-finetuned-squad")
121
- image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
122
- reader = easyocr.Reader(['en'])
123
 
124
- # File parsing functions with error handling
125
- def extract_text_from_pdf(file):
126
- try:
127
- with pdfplumber.open(file) as pdf:
128
- return "\n".join(page.extract_text() for page in pdf.pages if page.extract_text())
129
- except Exception as e:
130
- return f"Error reading PDF: {str(e)}"
131
-
132
- def extract_text_from_docx(file):
133
- try:
134
- doc = docx.Document(file)
135
- return "\n".join(para.text for para in doc.paragraphs)
136
- except Exception as e:
137
- return f"Error reading DOCX: {str(e)}"
138
-
139
- def extract_text_from_pptx(file):
140
- try:
141
- prs = Presentation(file)
142
- return "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
143
- except Exception as e:
144
- return f"Error reading PPTX: {str(e)}"
145
-
146
- def extract_text_from_image(file):
147
- try:
148
- # Use easyocr for better image text extraction
149
- return easyocr.Reader(['en']).readtext(file)
150
- except Exception as e:
151
- return f"Error reading image: {str(e)}"
152
-
153
- # Main QA logic for documents and images
154
- def answer_question(question, file):
155
- file_ext = os.path.splitext(file.name)[-1].lower()
156
-
157
- if file_ext == ".pdf":
158
- context = extract_text_from_pdf(file)
159
- elif file_ext == ".docx":
160
- context = extract_text_from_docx(file)
161
- elif file_ext == ".pptx":
162
- context = extract_text_from_pptx(file)
163
- elif file_ext in [".png", ".jpg", ".jpeg", ".bmp"]:
164
- context = extract_text_from_image(file)
165
- else:
166
- return "❌ Unsupported file format."
167
 
168
- if not context.strip():
169
- return "⚠️ No readable text found in the document."
 
 
 
 
 
170
 
171
- result = qa_pipeline(question=question, context=context)
172
- return result["answer"]
 
173
 
174
- # Create Gradio interfaces for both document and image QA
175
  doc_interface = gr.Interface(
176
- fn=answer_question,
177
- inputs=[
178
- gr.Textbox(label="Ask a question"),
179
- gr.File(label="Upload a document (PDF, DOCX, PPTX)")
180
- ],
181
- outputs=gr.Textbox(label="Answer"),
182
- title="Document Question Answering",
183
- description="Upload a document and ask a question. Get answers from the document content.",
184
- )
185
-
186
- img_interface = gr.Interface(
187
- fn=answer_question,
188
- inputs=[
189
- gr.Textbox(label="Ask a question"),
190
- gr.File(label="Upload an image (PNG, JPG, etc.)")
191
- ],
192
- outputs=gr.Textbox(label="Answer"),
193
- title="Image Question Answering",
194
- description="Upload an image and ask a question. Get answers from the text extracted from the image.",
195
  )
196
 
197
- # Create a Tabbed Interface to switch between document and image QA
198
  demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
199
 
200
- # Mount Gradio app in FastAPI
201
  app = gr.mount_gradio_app(app, demo, path="/")
202
 
203
- # Redirect to Gradio interface
204
  @app.get("/")
205
  def home():
206
  return RedirectResponse(url="/")
 
103
  from fastapi import FastAPI
104
  from fastapi.responses import RedirectResponse
105
  import gradio as gr
106
+ from transformers import VilBertForQuestionAnswering, ViltProcessor
 
 
 
107
  from PIL import Image
108
+ import torch
 
 
 
109
 
110
+ # Initialize FastAPI
111
  app = FastAPI()
112
 
113
+ # Load VilBERT model and processor
114
+ model = VilBertForQuestionAnswering.from_pretrained("facebook/vilbert-vqa")
115
+ processor = ViltProcessor.from_pretrained("facebook/vilbert-vqa")
 
116
 
117
+ # Function to handle image question answering
118
+ def answer_question_from_image(image, question):
119
+ if image is None or question.strip() == "":
120
+ return "Please upload an image and enter a question."
121
+
122
+ # Process input
123
+ inputs = processor(images=image, text=question, return_tensors="pt")
124
+ with torch.no_grad():
125
+ outputs = model(**inputs)
126
+ predicted_idx = outputs.logits.argmax(-1).item()
127
+
128
+ # For VilBERT VQA, class index maps to predefined answers (like "yes", "no", etc.)
129
+ # You'd need the VQA label mapping to decode this properly
130
+ # For now, just return the index
131
+ return f"Predicted answer ID: {predicted_idx}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ # Create Image QA interface
134
+ img_interface = gr.Interface(
135
+ fn=answer_question_from_image,
136
+ inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")],
137
+ outputs="text",
138
+ title="AI Image Question Answering"
139
+ )
140
 
141
+ # Dummy doc QA interface (replace with your own implementation)
142
+ def dummy_doc_qa(doc, question):
143
+ return "This is a placeholder for Document QA."
144
 
 
145
  doc_interface = gr.Interface(
146
+ fn=dummy_doc_qa,
147
+ inputs=[gr.File(label="Upload Document"), gr.Textbox(label="Ask a Question")],
148
+ outputs="text",
149
+ title="Document Question Answering"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
 
152
+ # Combine into a tabbed interface
153
  demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
154
 
155
+ # Mount Gradio inside FastAPI at root "/"
156
  app = gr.mount_gradio_app(app, demo, path="/")
157
 
158
+ # Redirect root URL to Gradio UI
159
  @app.get("/")
160
  def home():
161
  return RedirectResponse(url="/")