maria355 commited on
Commit
b889520
ยท
verified ยท
1 Parent(s): bd9009b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +313 -50
app.py CHANGED
@@ -1,13 +1,16 @@
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
- from transformers import Blip2Processor, Blip2ForConditionalGeneration
5
  import io
6
  import time
 
 
 
7
 
8
  # Set page config
9
  st.set_page_config(
10
- page_title="๐Ÿš€ BLIP-2 Caption Generator",
11
  page_icon="๐Ÿš€",
12
  layout="wide",
13
  initial_sidebar_state="expanded"
@@ -38,67 +41,225 @@ st.markdown("""
38
  border-radius: 5px;
39
  margin: 1rem 0;
40
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  </style>
42
  """, unsafe_allow_html=True)
43
 
44
  @st.cache_resource
45
- def load_model():
46
- """Load and cache the BLIP-2 model and processor"""
47
  try:
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
 
50
- # Use the smaller BLIP-2 model for better performance on Hugging Face Spaces
51
- model_name = "Salesforce/blip2-opt-2.7b"
52
-
53
- processor = Blip2Processor.from_pretrained(model_name)
54
- model = Blip2ForConditionalGeneration.from_pretrained(
55
- model_name,
56
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
57
  device_map="auto" if device == "cuda" else None
58
  )
59
 
 
 
 
 
 
 
 
 
60
  if device == "cpu":
61
- model = model.to(device)
 
62
 
63
- return processor, model, device
64
  except Exception as e:
65
- st.error(f"Error loading model: {str(e)}")
66
- return None, None, None
67
 
68
- def generate_caption(image, processor, model, device, prompt=""):
69
- """Generate caption for the uploaded image"""
70
  try:
71
- # Prepare inputs
72
  if prompt:
73
  inputs = processor(image, text=prompt, return_tensors="pt").to(device)
74
  else:
75
  inputs = processor(image, return_tensors="pt").to(device)
76
 
77
- # Generate caption
78
  with torch.no_grad():
79
  generated_ids = model.generate(
80
  **inputs,
81
- max_length=50,
82
  num_beams=5,
83
  temperature=0.7,
84
  do_sample=True,
85
  early_stopping=True
86
  )
87
 
88
- # Decode the generated caption
89
  caption = processor.decode(generated_ids[0], skip_special_tokens=True)
90
  return caption
91
-
92
  except Exception as e:
93
  st.error(f"Error generating caption: {str(e)}")
94
  return None
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def main():
97
  # Header
98
  st.markdown("""
99
  <div class="main-header">
100
- <h1>๐Ÿš€ BLIP-2 Caption Generator</h1>
101
- <p>Upload an image and get AI-generated captions instantly!</p>
102
  </div>
103
  """, unsafe_allow_html=True)
104
 
@@ -106,17 +267,34 @@ def main():
106
  with st.sidebar:
107
  st.header("๐Ÿ”ง Settings")
108
  st.markdown("### Model Information")
109
- st.info("Using **BLIP-2** (Salesforce/blip2-opt-2.7b)")
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  st.markdown("### About")
112
  st.markdown("""
113
- This app uses the **BLIP-2** model to generate natural language descriptions of images.
114
 
115
  **Features:**
116
- - ๐Ÿ–ผ๏ธ Upload any image format
117
- - ๐Ÿค– AI-powered captioning
118
- - โšก Fast inference
119
- - ๐ŸŽฏ Optional custom prompts
 
 
 
 
120
  """)
121
 
122
  # Main content
@@ -129,7 +307,7 @@ def main():
129
  uploaded_file = st.file_uploader(
130
  "Choose an image file",
131
  type=["jpg", "jpeg", "png", "bmp", "tiff"],
132
- help="Upload an image to generate a caption"
133
  )
134
 
135
  if uploaded_file is not None:
@@ -146,46 +324,131 @@ def main():
146
  """)
147
 
148
  with col2:
149
- st.markdown("### ๐Ÿ”ฎ Generated Caption")
150
 
151
  if uploaded_file is not None:
152
- # Load model
153
- with st.spinner("Loading BLIP-2 model..."):
154
- processor, model, device = load_model()
155
 
156
- if processor is not None and model is not None:
157
- # Generate caption button
158
- if st.button("๐ŸŽฏ Generate Caption", type="primary"):
159
- with st.spinner("Generating caption..."):
160
  start_time = time.time()
161
 
162
- # Generate caption
163
- caption = generate_caption(
164
- image, processor, model, device
 
 
 
 
 
165
  )
166
 
 
 
 
 
 
 
 
167
  end_time = time.time()
168
 
169
- if caption:
170
- # Display caption
171
  st.markdown(f"""
172
  <div class="caption-box">
173
- <h4>๐Ÿ“ Caption:</h4>
174
- <p style="font-size: 16px; font-weight: 500;">{caption}</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  </div>
176
  """, unsafe_allow_html=True)
177
 
178
  # Performance info
179
- st.success(f"Caption generated in {end_time - start_time:.2f} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Copy to clipboard button
182
- st.code(caption, language=None)
 
 
 
 
183
  else:
184
- st.error("Failed to load the model. Please try refreshing the page.")
185
  else:
186
  st.markdown("""
187
  <div class="upload-section">
188
  <h3>๐Ÿ‘† Upload an image to get started</h3>
 
189
  <p>Supported formats: JPG, PNG, BMP, TIFF</p>
190
  </div>
191
  """, unsafe_allow_html=True)
@@ -195,7 +458,7 @@ def main():
195
  st.markdown("""
196
  <div style="text-align: center; color: #666;">
197
  <p>Built with โค๏ธ using <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
198
- <p>Powered by <strong>BLIP-2</strong> - Bootstrapping Language-Image Pre-training</p>
199
  </div>
200
  """, unsafe_allow_html=True)
201
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
4
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForQuestionAnswering
5
  import io
6
  import time
7
+ import requests
8
+ from typing import List, Dict
9
+ import json
10
 
11
  # Set page config
12
  st.set_page_config(
13
+ page_title="๐Ÿš€ Advanced BLIP-2 Caption Generator",
14
  page_icon="๐Ÿš€",
15
  layout="wide",
16
  initial_sidebar_state="expanded"
 
41
  border-radius: 5px;
42
  margin: 1rem 0;
43
  }
44
+ .analysis-box {
45
+ background-color: #f8f9fa;
46
+ border: 1px solid #dee2e6;
47
+ border-radius: 8px;
48
+ padding: 1rem;
49
+ margin: 0.5rem 0;
50
+ }
51
+ .location-box {
52
+ background-color: #e8f5e8;
53
+ border-left: 4px solid #28a745;
54
+ padding: 1rem;
55
+ border-radius: 5px;
56
+ margin: 1rem 0;
57
+ }
58
+ .objects-box {
59
+ background-color: #fff3cd;
60
+ border-left: 4px solid #ffc107;
61
+ padding: 1rem;
62
+ border-radius: 5px;
63
+ margin: 1rem 0;
64
+ }
65
  </style>
66
  """, unsafe_allow_html=True)
67
 
68
  @st.cache_resource
69
+ def load_models():
70
+ """Load and cache the BLIP-2 model and BLIP VQA model"""
71
  try:
72
  device = "cuda" if torch.cuda.is_available() else "cpu"
73
 
74
+ # Load BLIP-2 for general captioning
75
+ blip2_model_name = "Salesforce/blip2-opt-2.7b"
76
+ blip2_processor = Blip2Processor.from_pretrained(blip2_model_name)
77
+ blip2_model = Blip2ForConditionalGeneration.from_pretrained(
78
+ blip2_model_name,
 
79
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
80
  device_map="auto" if device == "cuda" else None
81
  )
82
 
83
+ # Load BLIP for Visual Question Answering
84
+ blip_model_name = "Salesforce/blip-vqa-base"
85
+ blip_processor = BlipProcessor.from_pretrained(blip_model_name)
86
+ blip_model = BlipForQuestionAnswering.from_pretrained(
87
+ blip_model_name,
88
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
89
+ )
90
+
91
  if device == "cpu":
92
+ blip2_model = blip2_model.to(device)
93
+ blip_model = blip_model.to(device)
94
 
95
+ return blip2_processor, blip2_model, blip_processor, blip_model, device
96
  except Exception as e:
97
+ st.error(f"Error loading models: {str(e)}")
98
+ return None, None, None, None, None
99
 
100
+ def generate_basic_caption(image, processor, model, device, prompt=""):
101
+ """Generate basic caption for the uploaded image"""
102
  try:
 
103
  if prompt:
104
  inputs = processor(image, text=prompt, return_tensors="pt").to(device)
105
  else:
106
  inputs = processor(image, return_tensors="pt").to(device)
107
 
 
108
  with torch.no_grad():
109
  generated_ids = model.generate(
110
  **inputs,
111
+ max_length=100,
112
  num_beams=5,
113
  temperature=0.7,
114
  do_sample=True,
115
  early_stopping=True
116
  )
117
 
 
118
  caption = processor.decode(generated_ids[0], skip_special_tokens=True)
119
  return caption
 
120
  except Exception as e:
121
  st.error(f"Error generating caption: {str(e)}")
122
  return None
123
 
124
+ def ask_visual_question(image, question, processor, model, device):
125
+ """Ask specific questions about the image using BLIP VQA"""
126
+ try:
127
+ inputs = processor(image, question, return_tensors="pt").to(device)
128
+
129
+ with torch.no_grad():
130
+ out = model.generate(**inputs, max_length=50, num_beams=3)
131
+
132
+ answer = processor.decode(out[0], skip_special_tokens=True)
133
+ return answer
134
+ except Exception as e:
135
+ return "Unable to determine"
136
+
137
+ def analyze_location_and_objects(image, blip_processor, blip_model, device):
138
+ """Analyze image for locations, landmarks, and objects"""
139
+ location_questions = [
140
+ "What country is this?",
141
+ "What city is this?",
142
+ "What landmark is this?",
143
+ "Where is this place?",
144
+ "What famous building is this?",
145
+ "What monument is this?",
146
+ "What geographical location is shown?",
147
+ "What tourist attraction is this?",
148
+ "What state or province is this?",
149
+ "What region is this?",
150
+ "What continent is this in?",
151
+ "What neighborhood is this?",
152
+ "What district is this?",
153
+ "What area is this?"
154
+ ]
155
+
156
+ object_questions = [
157
+ "What objects can you see in this image?",
158
+ "What are the main things in this picture?",
159
+ "What vehicles are in this image?",
160
+ "What buildings are visible?",
161
+ "What natural features are shown?",
162
+ "What people are doing in this image?",
163
+ "What animals are in this picture?",
164
+ "What food items can you see?",
165
+ "What clothing can you see?",
166
+ "What activities are happening?",
167
+ "What weather is shown?",
168
+ "What time of day is it?",
169
+ "What season does this appear to be?",
170
+ "What colors dominate this image?"
171
+ ]
172
+
173
+ architectural_questions = [
174
+ "What type of architecture is this?",
175
+ "What style of building is this?",
176
+ "What historical period does this represent?",
177
+ "What cultural elements are visible?",
178
+ "What materials is this building made of?",
179
+ "What architectural features are prominent?",
180
+ "What type of structure is this?",
181
+ "What design style is shown?"
182
+ ]
183
+
184
+ location_info = {}
185
+ object_info = {}
186
+ architectural_info = {}
187
+
188
+ # Analyze locations
189
+ for question in location_questions:
190
+ answer = ask_visual_question(image, question, blip_processor, blip_model, device)
191
+ if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
192
+ location_info[question] = answer
193
+
194
+ # Analyze objects
195
+ for question in object_questions:
196
+ answer = ask_visual_question(image, question, blip_processor, blip_model, device)
197
+ if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
198
+ object_info[question] = answer
199
+
200
+ # Analyze architecture
201
+ for question in architectural_questions:
202
+ answer = ask_visual_question(image, question, blip_processor, blip_model, device)
203
+ if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
204
+ architectural_info[question] = answer
205
+
206
+ return location_info, object_info, architectural_info
207
+
208
+ def generate_enhanced_caption(basic_caption, location_info, object_info, architectural_info):
209
+ """Generate enhanced caption combining all analysis"""
210
+ enhanced_parts = [basic_caption]
211
+
212
+ if location_info:
213
+ location_details = []
214
+ for question, answer in location_info.items():
215
+ if "country" in question.lower():
216
+ location_details.append(f"Located in {answer}")
217
+ elif "city" in question.lower():
218
+ location_details.append(f"in {answer}")
219
+ elif "landmark" in question.lower() or "monument" in question.lower():
220
+ location_details.append(f"showing {answer}")
221
+ elif "building" in question.lower():
222
+ location_details.append(f"featuring {answer}")
223
+ elif "state" in question.lower() or "province" in question.lower():
224
+ location_details.append(f"in {answer}")
225
+ elif "region" in question.lower():
226
+ location_details.append(f"in the {answer} region")
227
+
228
+ if location_details:
229
+ enhanced_parts.append(" ".join(location_details[:3])) # Limit to avoid too long captions
230
+
231
+ if architectural_info:
232
+ arch_details = []
233
+ for question, answer in architectural_info.items():
234
+ if "architecture" in question.lower() or "style" in question.lower():
235
+ arch_details.append(f"The architecture appears to be {answer}")
236
+ elif "period" in question.lower():
237
+ arch_details.append(f"from the {answer} period")
238
+
239
+ if arch_details:
240
+ enhanced_parts.append(" ".join(arch_details[:2]))
241
+
242
+ if object_info:
243
+ obj_details = []
244
+ for question, answer in object_info.items():
245
+ if "time of day" in question.lower():
246
+ obj_details.append(f"taken during {answer}")
247
+ elif "weather" in question.lower():
248
+ obj_details.append(f"in {answer} weather")
249
+ elif "season" in question.lower():
250
+ obj_details.append(f"during {answer}")
251
+
252
+ if obj_details:
253
+ enhanced_parts.append(" ".join(obj_details[:2]))
254
+
255
+ return ". ".join(enhanced_parts) + "."
256
+
257
  def main():
258
  # Header
259
  st.markdown("""
260
  <div class="main-header">
261
+ <h1>๐Ÿš€ Advanced BLIP-2 Caption Generator</h1>
262
+ <p>Upload an image and get comprehensive AI analysis including locations, landmarks, and objects!</p>
263
  </div>
264
  """, unsafe_allow_html=True)
265
 
 
267
  with st.sidebar:
268
  st.header("๐Ÿ”ง Settings")
269
  st.markdown("### Model Information")
270
+ st.info("Using **BLIP-2** + **BLIP-VQA** for comprehensive analysis")
271
+
272
+ # Analysis options
273
+ st.markdown("### Analysis Options")
274
+ include_location = st.checkbox("๐ŸŒ Location Analysis", value=True)
275
+ include_objects = st.checkbox("๐ŸŽฏ Object Detection", value=True)
276
+ include_architecture = st.checkbox("๐Ÿ›๏ธ Architecture Analysis", value=True)
277
+
278
+ # Custom questions
279
+ st.markdown("### Custom Questions")
280
+ custom_question = st.text_input(
281
+ "Ask about the image:",
282
+ placeholder="e.g., What time of day is this?"
283
+ )
284
 
285
  st.markdown("### About")
286
  st.markdown("""
287
+ This enhanced app uses multiple AI models:
288
 
289
  **Features:**
290
+ - ๐Ÿ–ผ๏ธ Basic image captioning
291
+ - ๐ŸŒ Country & city recognition
292
+ - ๐Ÿ›๏ธ Landmark identification
293
+ - ๐ŸŽฏ Object detection
294
+ - ๐Ÿ—๏ธ Architecture analysis
295
+ - โ“ Custom Q&A
296
+ - ๐Ÿ“ State/Province detection
297
+ - ๐ŸŒ† Neighborhood analysis
298
  """)
299
 
300
  # Main content
 
307
  uploaded_file = st.file_uploader(
308
  "Choose an image file",
309
  type=["jpg", "jpeg", "png", "bmp", "tiff"],
310
+ help="Upload an image for comprehensive analysis"
311
  )
312
 
313
  if uploaded_file is not None:
 
324
  """)
325
 
326
  with col2:
327
+ st.markdown("### ๐Ÿ”ฎ AI Analysis Results")
328
 
329
  if uploaded_file is not None:
330
+ # Load models
331
+ with st.spinner("Loading AI models..."):
332
+ blip2_processor, blip2_model, blip_processor, blip_model, device = load_models()
333
 
334
+ if all([blip2_processor, blip2_model, blip_processor, blip_model]):
335
+ # Analyze button
336
+ if st.button("๐Ÿš€ Analyze Image", type="primary"):
337
+ with st.spinner("Performing comprehensive analysis..."):
338
  start_time = time.time()
339
 
340
+ # Generate basic caption
341
+ basic_caption = generate_basic_caption(
342
+ image, blip2_processor, blip2_model, device
343
+ )
344
+
345
+ # Analyze for locations and objects
346
+ location_info, object_info, architectural_info = analyze_location_and_objects(
347
+ image, blip_processor, blip_model, device
348
  )
349
 
350
+ # Custom question
351
+ custom_answer = None
352
+ if custom_question:
353
+ custom_answer = ask_visual_question(
354
+ image, custom_question, blip_processor, blip_model, device
355
+ )
356
+
357
  end_time = time.time()
358
 
359
+ if basic_caption:
360
+ # Basic Caption
361
  st.markdown(f"""
362
  <div class="caption-box">
363
+ <h4>๐Ÿ“ Basic Caption:</h4>
364
+ <p style="font-size: 16px; font-weight: 500;">{basic_caption}</p>
365
+ </div>
366
+ """, unsafe_allow_html=True)
367
+
368
+ # Location Analysis
369
+ if include_location and location_info:
370
+ st.markdown("""
371
+ <div class="location-box">
372
+ <h4>๐ŸŒ Location Analysis:</h4>
373
+ </div>
374
+ """, unsafe_allow_html=True)
375
+
376
+ for question, answer in location_info.items():
377
+ st.write(f"**{question}** {answer}")
378
+
379
+ # Object Analysis
380
+ if include_objects and object_info:
381
+ st.markdown("""
382
+ <div class="objects-box">
383
+ <h4>๐ŸŽฏ Object Analysis:</h4>
384
+ </div>
385
+ """, unsafe_allow_html=True)
386
+
387
+ for question, answer in object_info.items():
388
+ st.write(f"**{question}** {answer}")
389
+
390
+ # Architecture Analysis
391
+ if include_architecture and architectural_info:
392
+ st.markdown("""
393
+ <div class="analysis-box">
394
+ <h4>๐Ÿ›๏ธ Architecture Analysis:</h4>
395
+ </div>
396
+ """, unsafe_allow_html=True)
397
+
398
+ for question, answer in architectural_info.items():
399
+ st.write(f"**{question}** {answer}")
400
+
401
+ # Custom Question Answer
402
+ if custom_answer:
403
+ st.markdown(f"""
404
+ <div class="analysis-box">
405
+ <h4>โ“ Custom Question:</h4>
406
+ <p><strong>Q:</strong> {custom_question}</p>
407
+ <p><strong>A:</strong> {custom_answer}</p>
408
+ </div>
409
+ """, unsafe_allow_html=True)
410
+
411
+ # Enhanced Caption
412
+ enhanced_caption = generate_enhanced_caption(
413
+ basic_caption, location_info, object_info, architectural_info
414
+ )
415
+
416
+ st.markdown(f"""
417
+ <div class="caption-box" style="border-left-color: #28a745;">
418
+ <h4>โœจ Enhanced Caption:</h4>
419
+ <p style="font-size: 16px; font-weight: 500;">{enhanced_caption}</p>
420
  </div>
421
  """, unsafe_allow_html=True)
422
 
423
  # Performance info
424
+ st.success(f"Analysis completed in {end_time - start_time:.2f} seconds")
425
+
426
+ # Copy caption to clipboard
427
+ st.code(enhanced_caption, language=None)
428
+
429
+ # Export options
430
+ analysis_data = {
431
+ "basic_caption": basic_caption,
432
+ "enhanced_caption": enhanced_caption,
433
+ "location_info": location_info if include_location else {},
434
+ "object_info": object_info if include_objects else {},
435
+ "architectural_info": architectural_info if include_architecture else {},
436
+ "custom_qa": {"question": custom_question, "answer": custom_answer} if custom_answer else None
437
+ }
438
 
439
+ st.download_button(
440
+ label="๐Ÿ“„ Download Analysis (JSON)",
441
+ data=json.dumps(analysis_data, indent=2),
442
+ file_name=f"image_analysis_{int(time.time())}.json",
443
+ mime="application/json"
444
+ )
445
  else:
446
+ st.error("Failed to load the models. Please try refreshing the page.")
447
  else:
448
  st.markdown("""
449
  <div class="upload-section">
450
  <h3>๐Ÿ‘† Upload an image to get started</h3>
451
+ <p>Get comprehensive AI analysis including locations, landmarks, and objects!</p>
452
  <p>Supported formats: JPG, PNG, BMP, TIFF</p>
453
  </div>
454
  """, unsafe_allow_html=True)
 
458
  st.markdown("""
459
  <div style="text-align: center; color: #666;">
460
  <p>Built with โค๏ธ using <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
461
+ <p>Powered by <strong>BLIP-2</strong> and <strong>BLIP-VQA</strong> for comprehensive image understanding</p>
462
  </div>
463
  """, unsafe_allow_html=True)
464