import streamlit as st import torch from PIL import Image from transformers import Blip2Processor, Blip2ForConditionalGeneration import io import time # Set page config st.set_page_config( page_title="🚀 BLIP-2 Caption Generator", page_icon="🚀", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load and cache the BLIP-2 model and processor""" try: device = "cuda" if torch.cuda.is_available() else "cpu" # Use the smaller BLIP-2 model for better performance on Hugging Face Spaces model_name = "Salesforce/blip2-opt-2.7b" processor = Blip2Processor.from_pretrained(model_name) model = Blip2ForConditionalGeneration.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None ) if device == "cpu": model = model.to(device) return processor, model, device except Exception as e: st.error(f"Error loading model: {str(e)}") return None, None, None def generate_caption(image, processor, model, device, prompt=""): """Generate caption for the uploaded image""" try: # Prepare inputs if prompt: inputs = processor(image, text=prompt, return_tensors="pt").to(device) else: inputs = processor(image, return_tensors="pt").to(device) # Generate caption with torch.no_grad(): generated_ids = model.generate( **inputs, max_length=50, num_beams=5, temperature=0.7, do_sample=True, early_stopping=True ) # Decode the generated caption caption = processor.decode(generated_ids[0], skip_special_tokens=True) return caption except Exception as e: st.error(f"Error generating caption: {str(e)}") return None def main(): # Header st.markdown("""

🚀 BLIP-2 Caption Generator

Upload an image and get AI-generated captions instantly!

""", unsafe_allow_html=True) # Sidebar with st.sidebar: st.header("🔧 Settings") st.markdown("### Model Information") st.info("Using **BLIP-2** (Salesforce/blip2-opt-2.7b)") # Custom prompt option custom_prompt = st.text_input( "Custom Prompt (Optional):", placeholder="e.g., 'Question: What is in this image? Answer:'" ) st.markdown("### About") st.markdown(""" This app uses the **BLIP-2** model to generate natural language descriptions of images. **Features:** - 🖼️ Upload any image format - 🤖 AI-powered captioning - ⚡ Fast inference - 🎯 Optional custom prompts """) # Main content col1, col2 = st.columns([1, 1]) with col1: st.markdown("### 📤 Upload Image") # File uploader uploaded_file = st.file_uploader( "Choose an image file", type=["jpg", "jpeg", "png", "bmp", "tiff"], help="Upload an image to generate a caption" ) if uploaded_file is not None: # Display uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_container_width=True) # Image info st.markdown(f""" **Image Info:** - Size: {image.size[0]} x {image.size[1]} pixels - Format: {image.format} - Mode: {image.mode} """) with col2: st.markdown("### 🔮 Generated Caption") if uploaded_file is not None: # Load model with st.spinner("Loading BLIP-2 model..."): processor, model, device = load_model() if processor is not None and model is not None: # Generate caption button if st.button("🎯 Generate Caption", type="primary"): with st.spinner("Generating caption..."): start_time = time.time() # Generate caption caption = generate_caption( image, processor, model, device, custom_prompt ) end_time = time.time() if caption: # Display caption st.markdown(f"""

📝 Caption:

{caption}

""", unsafe_allow_html=True) # Performance info st.success(f"Caption generated in {end_time - start_time:.2f} seconds") # Copy to clipboard button st.code(caption, language=None) else: st.error("Failed to load the model. Please try refreshing the page.") else: st.markdown("""

👆 Upload an image to get started

Supported formats: JPG, PNG, BMP, TIFF

""", unsafe_allow_html=True) # Footer st.markdown("---") st.markdown("""

Built with Streamlit and Hugging Face Transformers

Powered by BLIP-2 - Bootstrapping Language-Image Pre-training

""", unsafe_allow_html=True) if __name__ == "__main__": main()