Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| from transformers import Blip2Processor, Blip2ForConditionalGeneration, BlipProcessor, BlipForQuestionAnswering | |
| import io | |
| import time | |
| import requests | |
| from typing import List, Dict | |
| import json | |
| # Set page config | |
| st.set_page_config( | |
| page_title="🚀 Advanced BLIP-2 Caption Generator", | |
| page_icon="🚀", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| text-align: center; | |
| padding: 2rem 0; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .upload-section { | |
| border: 2px dashed #ccc; | |
| border-radius: 10px; | |
| padding: 2rem; | |
| text-align: center; | |
| margin: 1rem 0; | |
| } | |
| .caption-box { | |
| background-color: #f0f2f6; | |
| border-left: 4px solid #667eea; | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin: 1rem 0; | |
| } | |
| .analysis-box { | |
| background-color: #f8f9fa; | |
| border: 1px solid #dee2e6; | |
| border-radius: 8px; | |
| padding: 1rem; | |
| margin: 0.5rem 0; | |
| } | |
| .location-box { | |
| background-color: #e8f5e8; | |
| border-left: 4px solid #28a745; | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin: 1rem 0; | |
| } | |
| .objects-box { | |
| background-color: #fff3cd; | |
| border-left: 4px solid #ffc107; | |
| padding: 1rem; | |
| border-radius: 5px; | |
| margin: 1rem 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_models(): | |
| """Load and cache the BLIP-2 model and BLIP VQA model""" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load BLIP-2 for general captioning | |
| blip2_model_name = "Salesforce/blip2-opt-2.7b" | |
| blip2_processor = Blip2Processor.from_pretrained(blip2_model_name) | |
| blip2_model = Blip2ForConditionalGeneration.from_pretrained( | |
| blip2_model_name, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None | |
| ) | |
| # Load BLIP for Visual Question Answering | |
| blip_model_name = "Salesforce/blip-vqa-base" | |
| blip_processor = BlipProcessor.from_pretrained(blip_model_name) | |
| blip_model = BlipForQuestionAnswering.from_pretrained( | |
| blip_model_name, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
| ) | |
| if device == "cpu": | |
| blip2_model = blip2_model.to(device) | |
| blip_model = blip_model.to(device) | |
| return blip2_processor, blip2_model, blip_processor, blip_model, device | |
| except Exception as e: | |
| st.error(f"Error loading models: {str(e)}") | |
| return None, None, None, None, None | |
| def generate_basic_caption(image, processor, model, device, prompt=""): | |
| """Generate basic caption for the uploaded image""" | |
| try: | |
| if prompt: | |
| inputs = processor(image, text=prompt, return_tensors="pt").to(device) | |
| else: | |
| inputs = processor(image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_length=100, | |
| num_beams=5, | |
| temperature=0.7, | |
| do_sample=True, | |
| early_stopping=True | |
| ) | |
| caption = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| return caption | |
| except Exception as e: | |
| st.error(f"Error generating caption: {str(e)}") | |
| return None | |
| def ask_visual_question(image, question, processor, model, device): | |
| """Ask specific questions about the image using BLIP VQA""" | |
| try: | |
| inputs = processor(image, question, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| out = model.generate(**inputs, max_length=50, num_beams=3) | |
| answer = processor.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| return "Unable to determine" | |
| def analyze_location_and_objects(image, blip_processor, blip_model, device): | |
| """Analyze image for locations, landmarks, and objects""" | |
| location_questions = [ | |
| "What country is this?", | |
| "What city is this?", | |
| "What landmark is this?", | |
| "Where is this place?", | |
| "What famous building is this?", | |
| "What monument is this?", | |
| "What geographical location is shown?", | |
| "What tourist attraction is this?", | |
| "What state or province is this?", | |
| "What region is this?", | |
| "What continent is this in?", | |
| "What neighborhood is this?", | |
| "What district is this?", | |
| "What area is this?" | |
| ] | |
| object_questions = [ | |
| "What objects can you see in this image?", | |
| "What are the main things in this picture?", | |
| "What vehicles are in this image?", | |
| "What buildings are visible?", | |
| "What natural features are shown?", | |
| "What people are doing in this image?", | |
| "What animals are in this picture?", | |
| "What food items can you see?", | |
| "What clothing can you see?", | |
| "What activities are happening?", | |
| "What weather is shown?", | |
| "What time of day is it?", | |
| "What season does this appear to be?", | |
| "What colors dominate this image?" | |
| ] | |
| architectural_questions = [ | |
| "What type of architecture is this?", | |
| "What style of building is this?", | |
| "What historical period does this represent?", | |
| "What cultural elements are visible?", | |
| "What materials is this building made of?", | |
| "What architectural features are prominent?", | |
| "What type of structure is this?", | |
| "What design style is shown?" | |
| ] | |
| location_info = {} | |
| object_info = {} | |
| architectural_info = {} | |
| # Analyze locations | |
| for question in location_questions: | |
| answer = ask_visual_question(image, question, blip_processor, blip_model, device) | |
| if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]: | |
| location_info[question] = answer | |
| # Analyze objects | |
| for question in object_questions: | |
| answer = ask_visual_question(image, question, blip_processor, blip_model, device) | |
| if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]: | |
| object_info[question] = answer | |
| # Analyze architecture | |
| for question in architectural_questions: | |
| answer = ask_visual_question(image, question, blip_processor, blip_model, device) | |
| if answer and answer.lower() not in ["no", "none", "unable to determine", "unknown", "unanswerable"]: | |
| architectural_info[question] = answer | |
| return location_info, object_info, architectural_info | |
| def generate_enhanced_caption(basic_caption, location_info, object_info, architectural_info): | |
| """Generate enhanced caption combining all analysis""" | |
| enhanced_parts = [basic_caption] | |
| if location_info: | |
| location_details = [] | |
| for question, answer in location_info.items(): | |
| if "country" in question.lower(): | |
| location_details.append(f"Located in {answer}") | |
| elif "city" in question.lower(): | |
| location_details.append(f"in {answer}") | |
| elif "landmark" in question.lower() or "monument" in question.lower(): | |
| location_details.append(f"showing {answer}") | |
| elif "building" in question.lower(): | |
| location_details.append(f"featuring {answer}") | |
| elif "state" in question.lower() or "province" in question.lower(): | |
| location_details.append(f"in {answer}") | |
| elif "region" in question.lower(): | |
| location_details.append(f"in the {answer} region") | |
| if location_details: | |
| enhanced_parts.append(" ".join(location_details[:3])) # Limit to avoid too long captions | |
| if architectural_info: | |
| arch_details = [] | |
| for question, answer in architectural_info.items(): | |
| if "architecture" in question.lower() or "style" in question.lower(): | |
| arch_details.append(f"The architecture appears to be {answer}") | |
| elif "period" in question.lower(): | |
| arch_details.append(f"from the {answer} period") | |
| if arch_details: | |
| enhanced_parts.append(" ".join(arch_details[:2])) | |
| if object_info: | |
| obj_details = [] | |
| for question, answer in object_info.items(): | |
| if "time of day" in question.lower(): | |
| obj_details.append(f"taken during {answer}") | |
| elif "weather" in question.lower(): | |
| obj_details.append(f"in {answer} weather") | |
| elif "season" in question.lower(): | |
| obj_details.append(f"during {answer}") | |
| if obj_details: | |
| enhanced_parts.append(" ".join(obj_details[:2])) | |
| return ". ".join(enhanced_parts) + "." | |
| def main(): | |
| # Header | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>🚀 Advanced BLIP-2 Caption Generator</h1> | |
| <p>Upload an image and get comprehensive AI analysis including locations, landmarks, and objects!</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("🔧 Settings") | |
| st.markdown("### Model Information") | |
| st.info("Using **BLIP-2** + **BLIP-VQA** for comprehensive analysis") | |
| # Analysis options | |
| st.markdown("### Analysis Options") | |
| include_location = st.checkbox("🌍 Location Analysis", value=True) | |
| include_objects = st.checkbox("🎯 Object Detection", value=True) | |
| include_architecture = st.checkbox("🏛️ Architecture Analysis", value=True) | |
| # Custom questions | |
| st.markdown("### Custom Questions") | |
| custom_question = st.text_input( | |
| "Ask about the image:", | |
| placeholder="e.g., What time of day is this?" | |
| ) | |
| st.markdown("### About") | |
| st.markdown(""" | |
| This enhanced app uses multiple AI models: | |
| **Features:** | |
| - 🖼️ Basic image captioning | |
| - 🌍 Country & city recognition | |
| - 🏛️ Landmark identification | |
| - 🎯 Object detection | |
| - 🏗️ Architecture analysis | |
| - ❓ Custom Q&A | |
| - 📍 State/Province detection | |
| - 🌆 Neighborhood analysis | |
| """) | |
| # Main content | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.markdown("### 📤 Upload Image") | |
| # File uploader | |
| uploaded_file = st.file_uploader( | |
| "Choose an image file", | |
| type=["jpg", "jpeg", "png", "bmp", "tiff"], | |
| help="Upload an image for comprehensive analysis" | |
| ) | |
| if uploaded_file is not None: | |
| # Display uploaded image | |
| image = Image.open(uploaded_file) | |
| st.image(image, caption="Uploaded Image", use_container_width=True) | |
| # Image info | |
| st.markdown(f""" | |
| **Image Info:** | |
| - Size: {image.size[0]} x {image.size[1]} pixels | |
| - Format: {image.format} | |
| - Mode: {image.mode} | |
| """) | |
| with col2: | |
| st.markdown("### 🔮 AI Analysis Results") | |
| if uploaded_file is not None: | |
| # Load models | |
| with st.spinner("Loading AI models..."): | |
| blip2_processor, blip2_model, blip_processor, blip_model, device = load_models() | |
| if all([blip2_processor, blip2_model, blip_processor, blip_model]): | |
| # Analyze button | |
| if st.button("🚀 Analyze Image", type="primary"): | |
| with st.spinner("Performing comprehensive analysis..."): | |
| start_time = time.time() | |
| # Generate basic caption | |
| basic_caption = generate_basic_caption( | |
| image, blip2_processor, blip2_model, device | |
| ) | |
| # Analyze for locations and objects | |
| location_info, object_info, architectural_info = analyze_location_and_objects( | |
| image, blip_processor, blip_model, device | |
| ) | |
| # Custom question | |
| custom_answer = None | |
| if custom_question: | |
| custom_answer = ask_visual_question( | |
| image, custom_question, blip_processor, blip_model, device | |
| ) | |
| end_time = time.time() | |
| if basic_caption: | |
| # Basic Caption | |
| st.markdown(f""" | |
| <div class="caption-box"> | |
| <h4>📝 Basic Caption:</h4> | |
| <p style="font-size: 16px; font-weight: 500;">{basic_caption}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Location Analysis | |
| if include_location and location_info: | |
| st.markdown(""" | |
| <div class="location-box"> | |
| <h4>🌍 Location Analysis:</h4> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| for question, answer in location_info.items(): | |
| st.write(f"**{question}** {answer}") | |
| # Object Analysis | |
| if include_objects and object_info: | |
| st.markdown(""" | |
| <div class="objects-box"> | |
| <h4>🎯 Object Analysis:</h4> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| for question, answer in object_info.items(): | |
| st.write(f"**{question}** {answer}") | |
| # Architecture Analysis | |
| if include_architecture and architectural_info: | |
| st.markdown(""" | |
| <div class="analysis-box"> | |
| <h4>🏛️ Architecture Analysis:</h4> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| for question, answer in architectural_info.items(): | |
| st.write(f"**{question}** {answer}") | |
| # Custom Question Answer | |
| if custom_answer: | |
| st.markdown(f""" | |
| <div class="analysis-box"> | |
| <h4>❓ Custom Question:</h4> | |
| <p><strong>Q:</strong> {custom_question}</p> | |
| <p><strong>A:</strong> {custom_answer}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Enhanced Caption | |
| enhanced_caption = generate_enhanced_caption( | |
| basic_caption, location_info, object_info, architectural_info | |
| ) | |
| st.markdown(f""" | |
| <div class="caption-box" style="border-left-color: #28a745;"> | |
| <h4>✨ Enhanced Caption:</h4> | |
| <p style="font-size: 16px; font-weight: 500;">{enhanced_caption}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Performance info | |
| st.success(f"Analysis completed in {end_time - start_time:.2f} seconds") | |
| # Copy caption to clipboard | |
| st.code(enhanced_caption, language=None) | |
| # Export options | |
| analysis_data = { | |
| "basic_caption": basic_caption, | |
| "enhanced_caption": enhanced_caption, | |
| "location_info": location_info if include_location else {}, | |
| "object_info": object_info if include_objects else {}, | |
| "architectural_info": architectural_info if include_architecture else {}, | |
| "custom_qa": {"question": custom_question, "answer": custom_answer} if custom_answer else None | |
| } | |
| st.download_button( | |
| label="📄 Download Analysis (JSON)", | |
| data=json.dumps(analysis_data, indent=2), | |
| file_name=f"image_analysis_{int(time.time())}.json", | |
| mime="application/json" | |
| ) | |
| else: | |
| st.error("Failed to load the models. Please try refreshing the page.") | |
| else: | |
| st.markdown(""" | |
| <div class="upload-section"> | |
| <h3>👆 Upload an image to get started</h3> | |
| <p>Get comprehensive AI analysis including locations, landmarks, and objects!</p> | |
| <p>Supported formats: JPG, PNG, BMP, TIFF</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div style="text-align: center; color: #666;"> | |
| <p>Built with ❤️ using <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p> | |
| <p>Powered by <strong>BLIP-2</strong> and <strong>BLIP-VQA</strong> for comprehensive image understanding</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() |