maria355 commited on
Commit
b3d0543
ยท
verified ยท
1 Parent(s): a90a526

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -313
app.py CHANGED
@@ -1,16 +1,13 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForQuestionAnswering
5
  import io
6
  import time
7
- import requests
8
- from typing import List, Dict
9
- import json
10
 
11
  # Set page config
12
  st.set_page_config(
13
- page_title="๐Ÿš€ Advanced BLIP-2 Caption Generator",
14
  page_icon="๐Ÿš€",
15
  layout="wide",
16
  initial_sidebar_state="expanded"
@@ -41,225 +38,67 @@ st.markdown("""
41
  border-radius: 5px;
42
  margin: 1rem 0;
43
  }
44
- .analysis-box {
45
- background-color: #f8f9fa;
46
- border: 1px solid #dee2e6;
47
- border-radius: 8px;
48
- padding: 1rem;
49
- margin: 0.5rem 0;
50
- }
51
- .location-box {
52
- background-color: #e8f5e8;
53
- border-left: 4px solid #28a745;
54
- padding: 1rem;
55
- border-radius: 5px;
56
- margin: 1rem 0;
57
- }
58
- .objects-box {
59
- background-color: #fff3cd;
60
- border-left: 4px solid #ffc107;
61
- padding: 1rem;
62
- border-radius: 5px;
63
- margin: 1rem 0;
64
- }
65
  </style>
66
  """, unsafe_allow_html=True)
67
 
68
  @st.cache_resource
69
- def load_models():
70
- """Load and cache the BLIP-2 model and BLIP VQA model"""
71
  try:
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
 
74
- # Load BLIP-2 for general captioning
75
- blip2_model_name = "Salesforce/blip2-opt-2.7b"
76
- blip2_processor = Blip2Processor.from_pretrained(blip2_model_name)
77
- blip2_model = Blip2ForConditionalGeneration.from_pretrained(
78
- blip2_model_name,
 
79
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
80
  device_map="auto" if device == "cuda" else None
81
  )
82
 
83
- # Load BLIP for Visual Question Answering
84
- blip_model_name = "Salesforce/blip-vqa-base"
85
- blip_processor = BlipProcessor.from_pretrained(blip_model_name)
86
- blip_model = BlipForQuestionAnswering.from_pretrained(
87
- blip_model_name,
88
- torch_dtype=torch.float16 if device == "cuda" else torch.float32
89
- )
90
-
91
  if device == "cpu":
92
- blip2_model = blip2_model.to(device)
93
- blip_model = blip_model.to(device)
94
 
95
- return blip2_processor, blip2_model, blip_processor, blip_model, device
96
  except Exception as e:
97
- st.error(f"Error loading models: {str(e)}")
98
- return None, None, None, None, None
99
 
100
- def generate_basic_caption(image, processor, model, device, prompt=""):
101
- """Generate basic caption for the uploaded image"""
102
  try:
 
103
  if prompt:
104
  inputs = processor(image, text=prompt, return_tensors="pt").to(device)
105
  else:
106
  inputs = processor(image, return_tensors="pt").to(device)
107
 
 
108
  with torch.no_grad():
109
  generated_ids = model.generate(
110
  **inputs,
111
- max_length=100,
112
  num_beams=5,
113
  temperature=0.7,
114
  do_sample=True,
115
  early_stopping=True
116
  )
117
 
 
118
  caption = processor.decode(generated_ids[0], skip_special_tokens=True)
119
  return caption
 
120
  except Exception as e:
121
  st.error(f"Error generating caption: {str(e)}")
122
  return None
123
 
124
- def ask_visual_question(image, question, processor, model, device):
125
- """Ask specific questions about the image using BLIP VQA"""
126
- try:
127
- inputs = processor(image, question, return_tensors="pt").to(device)
128
-
129
- with torch.no_grad():
130
- out = model.generate(**inputs, max_length=50, num_beams=3)
131
-
132
- answer = processor.decode(out[0], skip_special_tokens=True)
133
- return answer
134
- except Exception as e:
135
- return "Unable to determine"
136
-
137
- def analyze_location_and_objects(image, blip_processor, blip_model, device):
138
- """Analyze image for locations, landmarks, and objects"""
139
- location_questions = [
140
- "What country is this?",
141
- "What city is this?",
142
- "What landmark is this?",
143
- "Where is this place?",
144
- "What famous building is this?",
145
- "What monument is this?",
146
- "What geographical location is shown?",
147
- "What tourist attraction is this?",
148
- "What state or province is this?",
149
- "What region is this?",
150
- "What continent is this in?",
151
- "What neighborhood is this?",
152
- "What district is this?",
153
- "What area is this?"
154
- ]
155
-
156
- object_questions = [
157
- "What objects can you see in this image?",
158
- "What are the main things in this picture?",
159
- "What vehicles are in this image?",
160
- "What buildings are visible?",
161
- "What natural features are shown?",
162
- "What people are doing in this image?",
163
- "What animals are in this picture?",
164
- "What food items can you see?",
165
- "What clothing can you see?",
166
- "What activities are happening?",
167
- "What weather is shown?",
168
- "What time of day is it?",
169
- "What season does this appear to be?",
170
- "What colors dominate this image?"
171
- ]
172
-
173
- architectural_questions = [
174
- "What type of architecture is this?",
175
- "What style of building is this?",
176
- "What historical period does this represent?",
177
- "What cultural elements are visible?",
178
- "What materials is this building made of?",
179
- "What architectural features are prominent?",
180
- "What type of structure is this?",
181
- "What design style is shown?"
182
- ]
183
-
184
- location_info = {}
185
- object_info = {}
186
- architectural_info = {}
187
-
188
- # Analyze locations
189
- for question in location_questions:
190
- answer = ask_visual_question(image, question, blip_processor, blip_model, device)
191
- if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
192
- location_info[question] = answer
193
-
194
- # Analyze objects
195
- for question in object_questions:
196
- answer = ask_visual_question(image, question, blip_processor, blip_model, device)
197
- if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
198
- object_info[question] = answer
199
-
200
- # Analyze architecture
201
- for question in architectural_questions:
202
- answer = ask_visual_question(image, question, blip_processor, blip_model, device)
203
- if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
204
- architectural_info[question] = answer
205
-
206
- return location_info, object_info, architectural_info
207
-
208
- def generate_enhanced_caption(basic_caption, location_info, object_info, architectural_info):
209
- """Generate enhanced caption combining all analysis"""
210
- enhanced_parts = [basic_caption]
211
-
212
- if location_info:
213
- location_details = []
214
- for question, answer in location_info.items():
215
- if "country" in question.lower():
216
- location_details.append(f"Located in {answer}")
217
- elif "city" in question.lower():
218
- location_details.append(f"in {answer}")
219
- elif "landmark" in question.lower() or "monument" in question.lower():
220
- location_details.append(f"showing {answer}")
221
- elif "building" in question.lower():
222
- location_details.append(f"featuring {answer}")
223
- elif "state" in question.lower() or "province" in question.lower():
224
- location_details.append(f"in {answer}")
225
- elif "region" in question.lower():
226
- location_details.append(f"in the {answer} region")
227
-
228
- if location_details:
229
- enhanced_parts.append(" ".join(location_details[:3])) # Limit to avoid too long captions
230
-
231
- if architectural_info:
232
- arch_details = []
233
- for question, answer in architectural_info.items():
234
- if "architecture" in question.lower() or "style" in question.lower():
235
- arch_details.append(f"The architecture appears to be {answer}")
236
- elif "period" in question.lower():
237
- arch_details.append(f"from the {answer} period")
238
-
239
- if arch_details:
240
- enhanced_parts.append(" ".join(arch_details[:2]))
241
-
242
- if object_info:
243
- obj_details = []
244
- for question, answer in object_info.items():
245
- if "time of day" in question.lower():
246
- obj_details.append(f"taken during {answer}")
247
- elif "weather" in question.lower():
248
- obj_details.append(f"in {answer} weather")
249
- elif "season" in question.lower():
250
- obj_details.append(f"during {answer}")
251
-
252
- if obj_details:
253
- enhanced_parts.append(" ".join(obj_details[:2]))
254
-
255
- return ". ".join(enhanced_parts) + "."
256
-
257
  def main():
258
  # Header
259
  st.markdown("""
260
  <div class="main-header">
261
- <h1>๐Ÿš€ Advanced BLIP-2 Caption Generator</h1>
262
- <p>Upload an image and get comprehensive AI analysis including locations, landmarks, and objects!</p>
263
  </div>
264
  """, unsafe_allow_html=True)
265
 
@@ -267,34 +106,23 @@ def main():
267
  with st.sidebar:
268
  st.header("๐Ÿ”ง Settings")
269
  st.markdown("### Model Information")
270
- st.info("Using **BLIP-2** + **BLIP-VQA** for comprehensive analysis")
271
 
272
- # Analysis options
273
- st.markdown("### Analysis Options")
274
- include_location = st.checkbox("๐ŸŒ Location Analysis", value=True)
275
- include_objects = st.checkbox("๐ŸŽฏ Object Detection", value=True)
276
- include_architecture = st.checkbox("๐Ÿ›๏ธ Architecture Analysis", value=True)
277
-
278
- # Custom questions
279
- st.markdown("### Custom Questions")
280
- custom_question = st.text_input(
281
- "Ask about the image:",
282
- placeholder="e.g., What time of day is this?"
283
  )
284
 
285
  st.markdown("### About")
286
  st.markdown("""
287
- This enhanced app uses multiple AI models:
288
 
289
  **Features:**
290
- - ๐Ÿ–ผ๏ธ Basic image captioning
291
- - ๐ŸŒ Country & city recognition
292
- - ๐Ÿ›๏ธ Landmark identification
293
- - ๐ŸŽฏ Object detection
294
- - ๐Ÿ—๏ธ Architecture analysis
295
- - โ“ Custom Q&A
296
- - ๐Ÿ“ State/Province detection
297
- - ๐ŸŒ† Neighborhood analysis
298
  """)
299
 
300
  # Main content
@@ -307,13 +135,13 @@ def main():
307
  uploaded_file = st.file_uploader(
308
  "Choose an image file",
309
  type=["jpg", "jpeg", "png", "bmp", "tiff"],
310
- help="Upload an image for comprehensive analysis"
311
  )
312
 
313
  if uploaded_file is not None:
314
  # Display uploaded image
315
  image = Image.open(uploaded_file)
316
- st.image(image, caption="Uploaded Image", use_container_width=True)
317
 
318
  # Image info
319
  st.markdown(f"""
@@ -324,131 +152,46 @@ def main():
324
  """)
325
 
326
  with col2:
327
- st.markdown("### ๐Ÿ”ฎ AI Analysis Results")
328
 
329
  if uploaded_file is not None:
330
- # Load models
331
- with st.spinner("Loading AI models..."):
332
- blip2_processor, blip2_model, blip_processor, blip_model, device = load_models()
333
 
334
- if all([blip2_processor, blip2_model, blip_processor, blip_model]):
335
- # Analyze button
336
- if st.button("๐Ÿš€ Analyze Image", type="primary"):
337
- with st.spinner("Performing comprehensive analysis..."):
338
  start_time = time.time()
339
 
340
- # Generate basic caption
341
- basic_caption = generate_basic_caption(
342
- image, blip2_processor, blip2_model, device
343
- )
344
-
345
- # Analyze for locations and objects
346
- location_info, object_info, architectural_info = analyze_location_and_objects(
347
- image, blip_processor, blip_model, device
348
  )
349
 
350
- # Custom question
351
- custom_answer = None
352
- if custom_question:
353
- custom_answer = ask_visual_question(
354
- image, custom_question, blip_processor, blip_model, device
355
- )
356
-
357
  end_time = time.time()
358
 
359
- if basic_caption:
360
- # Basic Caption
361
  st.markdown(f"""
362
  <div class="caption-box">
363
- <h4>๐Ÿ“ Basic Caption:</h4>
364
- <p style="font-size: 16px; font-weight: 500;">{basic_caption}</p>
365
- </div>
366
- """, unsafe_allow_html=True)
367
-
368
- # Location Analysis
369
- if include_location and location_info:
370
- st.markdown("""
371
- <div class="location-box">
372
- <h4>๐ŸŒ Location Analysis:</h4>
373
- </div>
374
- """, unsafe_allow_html=True)
375
-
376
- for question, answer in location_info.items():
377
- st.write(f"**{question}** {answer}")
378
-
379
- # Object Analysis
380
- if include_objects and object_info:
381
- st.markdown("""
382
- <div class="objects-box">
383
- <h4>๐ŸŽฏ Object Analysis:</h4>
384
- </div>
385
- """, unsafe_allow_html=True)
386
-
387
- for question, answer in object_info.items():
388
- st.write(f"**{question}** {answer}")
389
-
390
- # Architecture Analysis
391
- if include_architecture and architectural_info:
392
- st.markdown("""
393
- <div class="analysis-box">
394
- <h4>๐Ÿ›๏ธ Architecture Analysis:</h4>
395
- </div>
396
- """, unsafe_allow_html=True)
397
-
398
- for question, answer in architectural_info.items():
399
- st.write(f"**{question}** {answer}")
400
-
401
- # Custom Question Answer
402
- if custom_answer:
403
- st.markdown(f"""
404
- <div class="analysis-box">
405
- <h4>โ“ Custom Question:</h4>
406
- <p><strong>Q:</strong> {custom_question}</p>
407
- <p><strong>A:</strong> {custom_answer}</p>
408
- </div>
409
- """, unsafe_allow_html=True)
410
-
411
- # Enhanced Caption
412
- enhanced_caption = generate_enhanced_caption(
413
- basic_caption, location_info, object_info, architectural_info
414
- )
415
-
416
- st.markdown(f"""
417
- <div class="caption-box" style="border-left-color: #28a745;">
418
- <h4>โœจ Enhanced Caption:</h4>
419
- <p style="font-size: 16px; font-weight: 500;">{enhanced_caption}</p>
420
  </div>
421
  """, unsafe_allow_html=True)
422
 
423
  # Performance info
424
- st.success(f"Analysis completed in {end_time - start_time:.2f} seconds")
425
-
426
- # Copy caption to clipboard
427
- st.code(enhanced_caption, language=None)
428
-
429
- # Export options
430
- analysis_data = {
431
- "basic_caption": basic_caption,
432
- "enhanced_caption": enhanced_caption,
433
- "location_info": location_info if include_location else {},
434
- "object_info": object_info if include_objects else {},
435
- "architectural_info": architectural_info if include_architecture else {},
436
- "custom_qa": {"question": custom_question, "answer": custom_answer} if custom_answer else None
437
- }
438
 
439
- st.download_button(
440
- label="๐Ÿ“„ Download Analysis (JSON)",
441
- data=json.dumps(analysis_data, indent=2),
442
- file_name=f"image_analysis_{int(time.time())}.json",
443
- mime="application/json"
444
- )
445
  else:
446
- st.error("Failed to load the models. Please try refreshing the page.")
447
  else:
448
  st.markdown("""
449
  <div class="upload-section">
450
  <h3>๐Ÿ‘† Upload an image to get started</h3>
451
- <p>Get comprehensive AI analysis including locations, landmarks, and objects!</p>
452
  <p>Supported formats: JPG, PNG, BMP, TIFF</p>
453
  </div>
454
  """, unsafe_allow_html=True)
@@ -457,8 +200,8 @@ def main():
457
  st.markdown("---")
458
  st.markdown("""
459
  <div style="text-align: center; color: #666;">
460
- <p>Built with โค๏ธ using <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
461
- <p>Powered by <strong>BLIP-2</strong> and <strong>BLIP-VQA</strong> for comprehensive image understanding</p>
462
  </div>
463
  """, unsafe_allow_html=True)
464
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
5
  import io
6
  import time
 
 
 
7
 
8
  # Set page config
9
  st.set_page_config(
10
+ page_title="๐Ÿš€ BLIP-2 Caption Generator",
11
  page_icon="๐Ÿš€",
12
  layout="wide",
13
  initial_sidebar_state="expanded"
 
38
  border-radius: 5px;
39
  margin: 1rem 0;
40
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  </style>
42
  """, unsafe_allow_html=True)
43
 
44
  @st.cache_resource
45
+ def load_model():
46
+ """Load and cache the BLIP-2 model and processor"""
47
  try:
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
 
50
+ # Use the smaller BLIP-2 model for better performance on Hugging Face Spaces
51
+ model_name = "Salesforce/blip2-opt-2.7b"
52
+
53
+ processor = Blip2Processor.from_pretrained(model_name)
54
+ model = Blip2ForConditionalGeneration.from_pretrained(
55
+ model_name,
56
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
57
  device_map="auto" if device == "cuda" else None
58
  )
59
 
 
 
 
 
 
 
 
 
60
  if device == "cpu":
61
+ model = model.to(device)
 
62
 
63
+ return processor, model, device
64
  except Exception as e:
65
+ st.error(f"Error loading model: {str(e)}")
66
+ return None, None, None
67
 
68
+ def generate_caption(image, processor, model, device, prompt=""):
69
+ """Generate caption for the uploaded image"""
70
  try:
71
+ # Prepare inputs
72
  if prompt:
73
  inputs = processor(image, text=prompt, return_tensors="pt").to(device)
74
  else:
75
  inputs = processor(image, return_tensors="pt").to(device)
76
 
77
+ # Generate caption
78
  with torch.no_grad():
79
  generated_ids = model.generate(
80
  **inputs,
81
+ max_length=50,
82
  num_beams=5,
83
  temperature=0.7,
84
  do_sample=True,
85
  early_stopping=True
86
  )
87
 
88
+ # Decode the generated caption
89
  caption = processor.decode(generated_ids[0], skip_special_tokens=True)
90
  return caption
91
+
92
  except Exception as e:
93
  st.error(f"Error generating caption: {str(e)}")
94
  return None
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def main():
97
  # Header
98
  st.markdown("""
99
  <div class="main-header">
100
+ <h1>๐Ÿš€ BLIP-2 Caption Generator</h1>
101
+ <p>Upload an image and get AI-generated captions instantly!</p>
102
  </div>
103
  """, unsafe_allow_html=True)
104
 
 
106
  with st.sidebar:
107
  st.header("๐Ÿ”ง Settings")
108
  st.markdown("### Model Information")
109
+ st.info("Using **BLIP-2** (Salesforce/blip2-opt-2.7b)")
110
 
111
+ # Custom prompt option
112
+ custom_prompt = st.text_input(
113
+ "Custom Prompt (Optional):",
114
+ placeholder="e.g., 'Question: What is in this image? Answer:'"
 
 
 
 
 
 
 
115
  )
116
 
117
  st.markdown("### About")
118
  st.markdown("""
119
+ This app uses the **BLIP-2** model to generate natural language descriptions of images.
120
 
121
  **Features:**
122
+ - ๐Ÿ–ผ๏ธ Upload any image format
123
+ - ๐Ÿค– AI-powered captioning
124
+ - โšก Fast inference
125
+ - ๐ŸŽฏ Optional custom prompts
 
 
 
 
126
  """)
127
 
128
  # Main content
 
135
  uploaded_file = st.file_uploader(
136
  "Choose an image file",
137
  type=["jpg", "jpeg", "png", "bmp", "tiff"],
138
+ help="Upload an image to generate a caption"
139
  )
140
 
141
  if uploaded_file is not None:
142
  # Display uploaded image
143
  image = Image.open(uploaded_file)
144
+ st.image(image, caption="Uploaded Image", use_column_width=True)
145
 
146
  # Image info
147
  st.markdown(f"""
 
152
  """)
153
 
154
  with col2:
155
+ st.markdown("### ๐Ÿ”ฎ Generated Caption")
156
 
157
  if uploaded_file is not None:
158
+ # Load model
159
+ with st.spinner("Loading BLIP-2 model..."):
160
+ processor, model, device = load_model()
161
 
162
+ if processor is not None and model is not None:
163
+ # Generate caption button
164
+ if st.button("๐ŸŽฏ Generate Caption", type="primary"):
165
+ with st.spinner("Generating caption..."):
166
  start_time = time.time()
167
 
168
+ # Generate caption
169
+ caption = generate_caption(
170
+ image, processor, model, device, custom_prompt
 
 
 
 
 
171
  )
172
 
 
 
 
 
 
 
 
173
  end_time = time.time()
174
 
175
+ if caption:
176
+ # Display caption
177
  st.markdown(f"""
178
  <div class="caption-box">
179
+ <h4>๐Ÿ“ Caption:</h4>
180
+ <p style="font-size: 16px; font-weight: 500;">{caption}</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  </div>
182
  """, unsafe_allow_html=True)
183
 
184
  # Performance info
185
+ st.success(f"Caption generated in {end_time - start_time:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ # Copy to clipboard button
188
+ st.code(caption, language=None)
 
 
 
 
189
  else:
190
+ st.error("Failed to load the model. Please try refreshing the page.")
191
  else:
192
  st.markdown("""
193
  <div class="upload-section">
194
  <h3>๐Ÿ‘† Upload an image to get started</h3>
 
195
  <p>Supported formats: JPG, PNG, BMP, TIFF</p>
196
  </div>
197
  """, unsafe_allow_html=True)
 
200
  st.markdown("---")
201
  st.markdown("""
202
  <div style="text-align: center; color: #666;">
203
+ <p>Built with <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
204
+ <p>Powered by <strong>BLIP-2</strong> - Bootstrapping Language-Image Pre-training</p>
205
  </div>
206
  """, unsafe_allow_html=True)
207