ikraamkb commited on
Commit
4e1a845
Β·
verified Β·
1 Parent(s): 29f5581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -4
app.py CHANGED
@@ -16,12 +16,15 @@ app = FastAPI()
16
 
17
  # Use GPU if available
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
19
 
20
  # Function to load models lazily
21
  def get_qa_pipeline():
 
22
  return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
23
 
24
  def get_image_captioning_pipeline():
 
25
  return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
26
 
27
  ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
@@ -31,44 +34,66 @@ MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing
31
  def validate_file_type(file):
32
  if hasattr(file, "name"):
33
  ext = file.name.split(".")[-1].lower()
 
34
  if ext not in ALLOWED_EXTENSIONS:
 
35
  return f"❌ Unsupported file format: {ext}"
36
  return None
 
37
  return "❌ Invalid file format!"
38
 
39
  # βœ… Extract Text from PDF
40
  async def extract_text_from_pdf(file):
 
41
  loop = asyncio.get_event_loop()
42
- return await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
 
 
43
 
44
  # βœ… Extract Text from DOCX
45
  async def extract_text_from_docx(file):
 
46
  loop = asyncio.get_event_loop()
47
- return await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
 
 
48
 
49
  # βœ… Extract Text from PPTX
50
  async def extract_text_from_pptx(file):
 
51
  loop = asyncio.get_event_loop()
52
- return await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
 
 
53
 
54
  # βœ… Extract Text from Excel
55
  async def extract_text_from_excel(file):
 
56
  loop = asyncio.get_event_loop()
57
- return await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
 
 
58
 
59
  # βœ… Truncate Long Text
60
  def truncate_text(text):
 
61
  return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
62
 
63
  # βœ… Answer Questions from Image or Document
64
  async def answer_question(file, question: str):
 
 
65
  if isinstance(file, np.ndarray): # Image Processing
 
66
  image = Image.fromarray(file)
67
  image_captioning = get_image_captioning_pipeline()
68
  caption = image_captioning(image)[0]['generated_text']
 
69
 
70
  qa = get_qa_pipeline()
 
71
  response = qa(f"Question: {question}\nContext: {caption}")
 
72
  return response[0]["generated_text"]
73
 
74
  validation_error = validate_file_type(file)
@@ -87,18 +112,22 @@ async def answer_question(file, question: str):
87
  elif file_ext == "xlsx":
88
  text = await extract_text_from_excel(file)
89
  else:
 
90
  return "❌ Unsupported file format!"
91
 
92
  if not text:
 
93
  return "⚠️ No text extracted from the document."
94
 
95
  truncated_text = truncate_text(text)
96
 
97
  # Run QA model asynchronously
 
98
  loop = asyncio.get_event_loop()
99
  qa = get_qa_pipeline()
100
  response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
101
 
 
102
  return response[0]["generated_text"]
103
 
104
  # βœ… Gradio Interface (Separate File & Image Inputs)
@@ -115,6 +144,7 @@ with gr.Blocks() as demo:
115
 
116
  submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
117
 
 
118
  # βœ… Mount Gradio with FastAPI
119
  app = gr.mount_gradio_app(app, demo, path="/")
120
 
 
16
 
17
  # Use GPU if available
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"βœ… Using device: {device}")
20
 
21
  # Function to load models lazily
22
  def get_qa_pipeline():
23
+ print("πŸ”„ Loading QA pipeline model...")
24
  return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
25
 
26
  def get_image_captioning_pipeline():
27
+ print("πŸ”„ Loading Image Captioning model...")
28
  return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
29
 
30
  ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
 
34
  def validate_file_type(file):
35
  if hasattr(file, "name"):
36
  ext = file.name.split(".")[-1].lower()
37
+ print(f"πŸ“ File extension detected: {ext}")
38
  if ext not in ALLOWED_EXTENSIONS:
39
+ print(f"❌ Unsupported file format: {ext}")
40
  return f"❌ Unsupported file format: {ext}"
41
  return None
42
+ print("❌ Invalid file format!")
43
  return "❌ Invalid file format!"
44
 
45
  # βœ… Extract Text from PDF
46
  async def extract_text_from_pdf(file):
47
+ print(f"πŸ“„ Extracting text from PDF: {file.name}")
48
  loop = asyncio.get_event_loop()
49
+ text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
50
+ print(f"βœ… Extracted {len(text)} characters from PDF")
51
+ return text
52
 
53
  # βœ… Extract Text from DOCX
54
  async def extract_text_from_docx(file):
55
+ print(f"πŸ“„ Extracting text from DOCX: {file.name}")
56
  loop = asyncio.get_event_loop()
57
+ text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
58
+ print(f"βœ… Extracted {len(text)} characters from DOCX")
59
+ return text
60
 
61
  # βœ… Extract Text from PPTX
62
  async def extract_text_from_pptx(file):
63
+ print(f"πŸ“„ Extracting text from PPTX: {file.name}")
64
  loop = asyncio.get_event_loop()
65
+ text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
66
+ print(f"βœ… Extracted {len(text)} characters from PPTX")
67
+ return text
68
 
69
  # βœ… Extract Text from Excel
70
  async def extract_text_from_excel(file):
71
+ print(f"πŸ“„ Extracting text from Excel: {file.name}")
72
  loop = asyncio.get_event_loop()
73
+ text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
74
+ print(f"βœ… Extracted {len(text)} characters from Excel")
75
+ return text
76
 
77
  # βœ… Truncate Long Text
78
  def truncate_text(text):
79
+ print(f"βœ‚οΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
80
  return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
81
 
82
  # βœ… Answer Questions from Image or Document
83
  async def answer_question(file, question: str):
84
+ print(f"❓ Question received: {question}")
85
+
86
  if isinstance(file, np.ndarray): # Image Processing
87
+ print("πŸ–ΌοΈ Processing image for captioning...")
88
  image = Image.fromarray(file)
89
  image_captioning = get_image_captioning_pipeline()
90
  caption = image_captioning(image)[0]['generated_text']
91
+ print(f"πŸ“ Generated caption: {caption}")
92
 
93
  qa = get_qa_pipeline()
94
+ print("πŸ€– Running QA model...")
95
  response = qa(f"Question: {question}\nContext: {caption}")
96
+ print(f"βœ… Model response: {response[0]['generated_text']}")
97
  return response[0]["generated_text"]
98
 
99
  validation_error = validate_file_type(file)
 
112
  elif file_ext == "xlsx":
113
  text = await extract_text_from_excel(file)
114
  else:
115
+ print("❌ Unsupported file format!")
116
  return "❌ Unsupported file format!"
117
 
118
  if not text:
119
+ print("⚠️ No text extracted from the document.")
120
  return "⚠️ No text extracted from the document."
121
 
122
  truncated_text = truncate_text(text)
123
 
124
  # Run QA model asynchronously
125
+ print("πŸ€– Running QA model...")
126
  loop = asyncio.get_event_loop()
127
  qa = get_qa_pipeline()
128
  response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
129
 
130
+ print(f"βœ… Model response: {response[0]['generated_text']}")
131
  return response[0]["generated_text"]
132
 
133
  # βœ… Gradio Interface (Separate File & Image Inputs)
 
144
 
145
  submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
146
 
147
+
148
  # βœ… Mount Gradio with FastAPI
149
  app = gr.mount_gradio_app(app, demo, path="/")
150