maria355's picture
Update app.py
b889520 verified
raw
history blame
19 kB
import streamlit as st
import torch
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForQuestionAnswering
import io
import time
import requests
from typing import List, Dict
import json
# Set page config
st.set_page_config(
page_title="🚀 Advanced BLIP-2 Caption Generator",
page_icon="🚀",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
text-align: center;
padding: 2rem 0;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 2rem;
}
.upload-section {
border: 2px dashed #ccc;
border-radius: 10px;
padding: 2rem;
text-align: center;
margin: 1rem 0;
}
.caption-box {
background-color: #f0f2f6;
border-left: 4px solid #667eea;
padding: 1rem;
border-radius: 5px;
margin: 1rem 0;
}
.analysis-box {
background-color: #f8f9fa;
border: 1px solid #dee2e6;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
}
.location-box {
background-color: #e8f5e8;
border-left: 4px solid #28a745;
padding: 1rem;
border-radius: 5px;
margin: 1rem 0;
}
.objects-box {
background-color: #fff3cd;
border-left: 4px solid #ffc107;
padding: 1rem;
border-radius: 5px;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_models():
"""Load and cache the BLIP-2 model and BLIP VQA model"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load BLIP-2 for general captioning
blip2_model_name = "Salesforce/blip2-opt-2.7b"
blip2_processor = Blip2Processor.from_pretrained(blip2_model_name)
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
blip2_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
)
# Load BLIP for Visual Question Answering
blip_model_name = "Salesforce/blip-vqa-base"
blip_processor = BlipProcessor.from_pretrained(blip_model_name)
blip_model = BlipForQuestionAnswering.from_pretrained(
blip_model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
if device == "cpu":
blip2_model = blip2_model.to(device)
blip_model = blip_model.to(device)
return blip2_processor, blip2_model, blip_processor, blip_model, device
except Exception as e:
st.error(f"Error loading models: {str(e)}")
return None, None, None, None, None
def generate_basic_caption(image, processor, model, device, prompt=""):
"""Generate basic caption for the uploaded image"""
try:
if prompt:
inputs = processor(image, text=prompt, return_tensors="pt").to(device)
else:
inputs = processor(image, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_length=100,
num_beams=5,
temperature=0.7,
do_sample=True,
early_stopping=True
)
caption = processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
except Exception as e:
st.error(f"Error generating caption: {str(e)}")
return None
def ask_visual_question(image, question, processor, model, device):
"""Ask specific questions about the image using BLIP VQA"""
try:
inputs = processor(image, question, return_tensors="pt").to(device)
with torch.no_grad():
out = model.generate(**inputs, max_length=50, num_beams=3)
answer = processor.decode(out[0], skip_special_tokens=True)
return answer
except Exception as e:
return "Unable to determine"
def analyze_location_and_objects(image, blip_processor, blip_model, device):
"""Analyze image for locations, landmarks, and objects"""
location_questions = [
"What country is this?",
"What city is this?",
"What landmark is this?",
"Where is this place?",
"What famous building is this?",
"What monument is this?",
"What geographical location is shown?",
"What tourist attraction is this?",
"What state or province is this?",
"What region is this?",
"What continent is this in?",
"What neighborhood is this?",
"What district is this?",
"What area is this?"
]
object_questions = [
"What objects can you see in this image?",
"What are the main things in this picture?",
"What vehicles are in this image?",
"What buildings are visible?",
"What natural features are shown?",
"What people are doing in this image?",
"What animals are in this picture?",
"What food items can you see?",
"What clothing can you see?",
"What activities are happening?",
"What weather is shown?",
"What time of day is it?",
"What season does this appear to be?",
"What colors dominate this image?"
]
architectural_questions = [
"What type of architecture is this?",
"What style of building is this?",
"What historical period does this represent?",
"What cultural elements are visible?",
"What materials is this building made of?",
"What architectural features are prominent?",
"What type of structure is this?",
"What design style is shown?"
]
location_info = {}
object_info = {}
architectural_info = {}
# Analyze locations
for question in location_questions:
answer = ask_visual_question(image, question, blip_processor, blip_model, device)
if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
location_info[question] = answer
# Analyze objects
for question in object_questions:
answer = ask_visual_question(image, question, blip_processor, blip_model, device)
if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
object_info[question] = answer
# Analyze architecture
for question in architectural_questions:
answer = ask_visual_question(image, question, blip_processor, blip_model, device)
if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]:
architectural_info[question] = answer
return location_info, object_info, architectural_info
def generate_enhanced_caption(basic_caption, location_info, object_info, architectural_info):
"""Generate enhanced caption combining all analysis"""
enhanced_parts = [basic_caption]
if location_info:
location_details = []
for question, answer in location_info.items():
if "country" in question.lower():
location_details.append(f"Located in {answer}")
elif "city" in question.lower():
location_details.append(f"in {answer}")
elif "landmark" in question.lower() or "monument" in question.lower():
location_details.append(f"showing {answer}")
elif "building" in question.lower():
location_details.append(f"featuring {answer}")
elif "state" in question.lower() or "province" in question.lower():
location_details.append(f"in {answer}")
elif "region" in question.lower():
location_details.append(f"in the {answer} region")
if location_details:
enhanced_parts.append(" ".join(location_details[:3])) # Limit to avoid too long captions
if architectural_info:
arch_details = []
for question, answer in architectural_info.items():
if "architecture" in question.lower() or "style" in question.lower():
arch_details.append(f"The architecture appears to be {answer}")
elif "period" in question.lower():
arch_details.append(f"from the {answer} period")
if arch_details:
enhanced_parts.append(" ".join(arch_details[:2]))
if object_info:
obj_details = []
for question, answer in object_info.items():
if "time of day" in question.lower():
obj_details.append(f"taken during {answer}")
elif "weather" in question.lower():
obj_details.append(f"in {answer} weather")
elif "season" in question.lower():
obj_details.append(f"during {answer}")
if obj_details:
enhanced_parts.append(" ".join(obj_details[:2]))
return ". ".join(enhanced_parts) + "."
def main():
# Header
st.markdown("""
<div class="main-header">
<h1>🚀 Advanced BLIP-2 Caption Generator</h1>
<p>Upload an image and get comprehensive AI analysis including locations, landmarks, and objects!</p>
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.header("🔧 Settings")
st.markdown("### Model Information")
st.info("Using **BLIP-2** + **BLIP-VQA** for comprehensive analysis")
# Analysis options
st.markdown("### Analysis Options")
include_location = st.checkbox("🌍 Location Analysis", value=True)
include_objects = st.checkbox("🎯 Object Detection", value=True)
include_architecture = st.checkbox("🏛️ Architecture Analysis", value=True)
# Custom questions
st.markdown("### Custom Questions")
custom_question = st.text_input(
"Ask about the image:",
placeholder="e.g., What time of day is this?"
)
st.markdown("### About")
st.markdown("""
This enhanced app uses multiple AI models:
**Features:**
- 🖼️ Basic image captioning
- 🌍 Country & city recognition
- 🏛️ Landmark identification
- 🎯 Object detection
- 🏗️ Architecture analysis
- ❓ Custom Q&A
- 📍 State/Province detection
- 🌆 Neighborhood analysis
""")
# Main content
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### 📤 Upload Image")
# File uploader
uploaded_file = st.file_uploader(
"Choose an image file",
type=["jpg", "jpeg", "png", "bmp", "tiff"],
help="Upload an image for comprehensive analysis"
)
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Image info
st.markdown(f"""
**Image Info:**
- Size: {image.size[0]} x {image.size[1]} pixels
- Format: {image.format}
- Mode: {image.mode}
""")
with col2:
st.markdown("### 🔮 AI Analysis Results")
if uploaded_file is not None:
# Load models
with st.spinner("Loading AI models..."):
blip2_processor, blip2_model, blip_processor, blip_model, device = load_models()
if all([blip2_processor, blip2_model, blip_processor, blip_model]):
# Analyze button
if st.button("🚀 Analyze Image", type="primary"):
with st.spinner("Performing comprehensive analysis..."):
start_time = time.time()
# Generate basic caption
basic_caption = generate_basic_caption(
image, blip2_processor, blip2_model, device
)
# Analyze for locations and objects
location_info, object_info, architectural_info = analyze_location_and_objects(
image, blip_processor, blip_model, device
)
# Custom question
custom_answer = None
if custom_question:
custom_answer = ask_visual_question(
image, custom_question, blip_processor, blip_model, device
)
end_time = time.time()
if basic_caption:
# Basic Caption
st.markdown(f"""
<div class="caption-box">
<h4>📝 Basic Caption:</h4>
<p style="font-size: 16px; font-weight: 500;">{basic_caption}</p>
</div>
""", unsafe_allow_html=True)
# Location Analysis
if include_location and location_info:
st.markdown("""
<div class="location-box">
<h4>🌍 Location Analysis:</h4>
</div>
""", unsafe_allow_html=True)
for question, answer in location_info.items():
st.write(f"**{question}** {answer}")
# Object Analysis
if include_objects and object_info:
st.markdown("""
<div class="objects-box">
<h4>🎯 Object Analysis:</h4>
</div>
""", unsafe_allow_html=True)
for question, answer in object_info.items():
st.write(f"**{question}** {answer}")
# Architecture Analysis
if include_architecture and architectural_info:
st.markdown("""
<div class="analysis-box">
<h4>🏛️ Architecture Analysis:</h4>
</div>
""", unsafe_allow_html=True)
for question, answer in architectural_info.items():
st.write(f"**{question}** {answer}")
# Custom Question Answer
if custom_answer:
st.markdown(f"""
<div class="analysis-box">
<h4>❓ Custom Question:</h4>
<p><strong>Q:</strong> {custom_question}</p>
<p><strong>A:</strong> {custom_answer}</p>
</div>
""", unsafe_allow_html=True)
# Enhanced Caption
enhanced_caption = generate_enhanced_caption(
basic_caption, location_info, object_info, architectural_info
)
st.markdown(f"""
<div class="caption-box" style="border-left-color: #28a745;">
<h4>✨ Enhanced Caption:</h4>
<p style="font-size: 16px; font-weight: 500;">{enhanced_caption}</p>
</div>
""", unsafe_allow_html=True)
# Performance info
st.success(f"Analysis completed in {end_time - start_time:.2f} seconds")
# Copy caption to clipboard
st.code(enhanced_caption, language=None)
# Export options
analysis_data = {
"basic_caption": basic_caption,
"enhanced_caption": enhanced_caption,
"location_info": location_info if include_location else {},
"object_info": object_info if include_objects else {},
"architectural_info": architectural_info if include_architecture else {},
"custom_qa": {"question": custom_question, "answer": custom_answer} if custom_answer else None
}
st.download_button(
label="📄 Download Analysis (JSON)",
data=json.dumps(analysis_data, indent=2),
file_name=f"image_analysis_{int(time.time())}.json",
mime="application/json"
)
else:
st.error("Failed to load the models. Please try refreshing the page.")
else:
st.markdown("""
<div class="upload-section">
<h3>👆 Upload an image to get started</h3>
<p>Get comprehensive AI analysis including locations, landmarks, and objects!</p>
<p>Supported formats: JPG, PNG, BMP, TIFF</p>
</div>
""", unsafe_allow_html=True)
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666;">
<p>Built with ❤️ using <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
<p>Powered by <strong>BLIP-2</strong> and <strong>BLIP-VQA</strong> for comprehensive image understanding</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()