File size: 5,882 Bytes
935d12d
ffda1f9
 
 
a768964
ffda1f9
a768964
ffda1f9
 
 
 
 
 
 
 
 
 
a768964
ffda1f9
a768964
 
ffda1f9
 
a768964
ffda1f9
 
da9e0ce
a768964
ffda1f9
 
a768964
ffda1f9
 
a768964
da9e0ce
 
 
 
ffda1f9
 
 
 
 
 
 
a768964
ffda1f9
 
 
da9e0ce
ffda1f9
a768964
ffda1f9
 
 
 
da9e0ce
ffda1f9
a768964
ffda1f9
 
da9e0ce
a768964
ffda1f9
 
 
 
 
 
 
 
da9e0ce
ffda1f9
 
da9e0ce
ffda1f9
 
 
 
 
 
 
 
a768964
ffda1f9
 
a768964
ffda1f9
 
 
 
 
da9e0ce
 
a768964
da9e0ce
a768964
ffda1f9
 
 
1e4a65e
935d12d
9325c19
df1ed5e
9325c19
b36b2d0
 
935d12d
4f113b7
b36b2d0
935d12d
 
 
b36b2d0
 
 
935d12d
b36b2d0
 
 
 
 
 
 
 
 
 
 
 
 
4f113b7
b36b2d0
 
 
9325c19
b36b2d0
 
 
9325c19
b36b2d0
 
 
 
 
 
 
 
9325c19
b36b2d0
df1ed5e
b36b2d0
 
4f113b7
 
9325c19
 
b36b2d0
 
 
 
 
 
df1ed5e
b36b2d0
 
df1ed5e
7a6dca4
df1ed5e
b36b2d0
df1ed5e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""from fastapi import FastAPI, Form, File, UploadFile
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import pipeline
import os
from PIL import Image
import io
import pdfplumber
import docx
import openpyxl
import pytesseract
from io import BytesIO
import fitz  # PyMuPDF
import easyocr
from fastapi.templating import Jinja2Templates
from starlette.requests import Request

# Initialize the app
app = FastAPI()

# Mount the static directory to serve HTML, CSS, JS files
app.mount("/static", StaticFiles(directory="static"), name="static")

# Initialize transformers pipelines
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")

# Initialize EasyOCR for image-based text extraction
reader = easyocr.Reader(['en'])

# Define a template for rendering HTML
templates = Jinja2Templates(directory="templates")

# Ensure temp_files directory exists
temp_dir = "temp_files"
os.makedirs(temp_dir, exist_ok=True)

# Function to process PDFs
def extract_pdf_text(file_path: str):
    with pdfplumber.open(file_path) as pdf:
        text = ""
        for page in pdf.pages:
            text += page.extract_text()
    return text

# Function to process DOCX files
def extract_docx_text(file_path: str):
    doc = docx.Document(file_path)
    text = "\n".join([para.text for para in doc.paragraphs])
    return text

# Function to process PPTX files
def extract_pptx_text(file_path: str):
    from pptx import Presentation
    prs = Presentation(file_path)
    text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
    return text

# Function to extract text from images using OCR
def extract_text_from_image(image: Image):
    return pytesseract.image_to_string(image)

# Home route
@app.get("/")
def home():
    return RedirectResponse(url="/docs")

# Function to answer questions based on document content
@app.post("/question-answering-doc")
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
    file_path = os.path.join(temp_dir, file.filename)
    with open(file_path, "wb") as f:
        f.write(await file.read())

    if file.filename.endswith(".pdf"):
        text = extract_pdf_text(file_path)
    elif file.filename.endswith(".docx"):
        text = extract_docx_text(file_path)
    elif file.filename.endswith(".pptx"):
        text = extract_pptx_text(file_path)
    else:
        return {"error": "Unsupported file format"}

    qa_result = qa_pipeline(question=question, context=text)
    return {"answer": qa_result['answer']}

# Function to answer questions based on images
@app.post("/question-answering-image")
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
    image = Image.open(BytesIO(await image_file.read()))
    image_text = extract_text_from_image(image)

    image_qa_result = image_qa_pipeline({"image": image, "question": question})
    
    return {"answer": image_qa_result[0]['answer'], "image_text": image_text}

# Serve the application in Hugging Face space
@app.get("/docs")
async def get_docs(request: Request):
    return templates.TemplateResponse("index.html", {"request": request})
"""
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
import gradio as gr

from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import torch
import fitz  # PyMuPDF for PDF

app = FastAPI()

# ========== Document QA Setup ==========
doc_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
doc_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

def read_pdf(file):
    doc = fitz.open(stream=file.read(), filetype="pdf")
    text = ""
    for page in doc:
        text += page.get_text()
    return text

def answer_question_from_doc(file, question):
    if file is None or not question.strip():
        return "Please upload a document and ask a question."
    text = read_pdf(file)
    prompt = f"Context: {text}\nQuestion: {question}\nAnswer:"
    inputs = doc_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
    with torch.no_grad():
        outputs = doc_model.generate(**inputs, max_new_tokens=100)
    answer = doc_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer.split("Answer:")[-1].strip()

# ========== Image QA Setup ==========
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

def answer_question_from_image(image, question):
    if image is None or not question.strip():
        return "Please upload an image and ask a question."
    inputs = vqa_processor(image, question, return_tensors="pt")
    with torch.no_grad():
        outputs = vqa_model(**inputs)
    predicted_id = outputs.logits.argmax(-1).item()
    return vqa_model.config.id2label[predicted_id]

# ========== Gradio Interfaces ==========
doc_interface = gr.Interface(
    fn=answer_question_from_doc,
    inputs=[gr.File(label="Upload Document (PDF)"), gr.Textbox(label="Ask a Question")],
    outputs="text",
    title="Document Question Answering"
)

img_interface = gr.Interface(
    fn=answer_question_from_image,
    inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")],
    outputs="text",
    title="Image Question Answering"
)

# ========== Combine and Mount ==========
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def root():
    return RedirectResponse(url="/")