maria355's picture
Update app.py
dada616 verified
import streamlit as st
import torch
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
import time
# Set page config
st.set_page_config(
page_title="๐Ÿš€ BLIP-2 Caption Generator",
page_icon="๐Ÿš€",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
<style>
.main-header {
text-align: center;
padding: 2rem 0;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
margin-bottom: 2rem;
}
.upload-section {
border: 2px dashed #ccc;
border-radius: 10px;
padding: 2rem;
text-align: center;
margin: 1rem 0;
}
.caption-box {
background-color: #f0f2f6;
border-left: 4px solid #667eea;
padding: 1rem;
border-radius: 5px;
margin: 1rem 0;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
"""Load and cache the BLIP-2 model and processor"""
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the smaller BLIP-2 model for better performance on Hugging Face Spaces
model_name = "Salesforce/blip2-opt-2.7b"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
)
if device == "cpu":
model = model.to(device)
return processor, model, device
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return None, None, None
def generate_caption(image, processor, model, device, prompt=""):
"""Generate caption for the uploaded image"""
try:
# Prepare inputs
if prompt:
inputs = processor(image, text=prompt, return_tensors="pt").to(device)
else:
inputs = processor(image, return_tensors="pt").to(device)
# Generate caption
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_length=50,
num_beams=5,
temperature=0.7,
do_sample=True,
early_stopping=True
)
# Decode the generated caption
caption = processor.decode(generated_ids[0], skip_special_tokens=True)
return caption
except Exception as e:
st.error(f"Error generating caption: {str(e)}")
return None
def main():
# Header
st.markdown("""
<div class="main-header">
<h1>๐Ÿš€ BLIP-2 Caption Generator</h1>
<p>Upload an image and get AI-generated captions instantly!</p>
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.header("๐Ÿ”ง Settings")
st.markdown("### Model Information")
st.info("Using **BLIP-2** (Salesforce/blip2-opt-2.7b)")
# Custom prompt option
custom_prompt = st.text_input(
"Custom Prompt (Optional):",
placeholder="e.g., 'Question: What is in this image? Answer:'"
)
st.markdown("### About")
st.markdown("""
This app uses the **BLIP-2** model to generate natural language descriptions of images.
**Features:**
- ๐Ÿ–ผ๏ธ Upload any image format
- ๐Ÿค– AI-powered captioning
- โšก Fast inference
- ๐ŸŽฏ Optional custom prompts
""")
# Main content
col1, col2 = st.columns([1, 1])
with col1:
st.markdown("### ๐Ÿ“ค Upload Image")
# File uploader
uploaded_file = st.file_uploader(
"Choose an image file",
type=["jpg", "jpeg", "png", "bmp", "tiff"],
help="Upload an image to generate a caption"
)
if uploaded_file is not None:
# Display uploaded image
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_container_width=True)
# Image info
st.markdown(f"""
**Image Info:**
- Size: {image.size[0]} x {image.size[1]} pixels
- Format: {image.format}
- Mode: {image.mode}
""")
with col2:
st.markdown("### ๐Ÿ”ฎ Generated Caption")
if uploaded_file is not None:
# Load model
with st.spinner("Loading BLIP-2 model..."):
processor, model, device = load_model()
if processor is not None and model is not None:
# Generate caption button
if st.button("๐ŸŽฏ Generate Caption", type="primary"):
with st.spinner("Generating caption..."):
start_time = time.time()
# Generate caption
caption = generate_caption(
image, processor, model, device, custom_prompt
)
end_time = time.time()
if caption:
# Display caption
st.markdown(f"""
<div class="caption-box">
<h4>๐Ÿ“ Caption:</h4>
<p style="font-size: 16px; font-weight: 500;">{caption}</p>
</div>
""", unsafe_allow_html=True)
# Performance info
st.success(f"Caption generated in {end_time - start_time:.2f} seconds")
# Copy to clipboard button
st.code(caption, language=None)
else:
st.error("Failed to load the model. Please try refreshing the page.")
else:
st.markdown("""
<div class="upload-section">
<h3>๐Ÿ‘† Upload an image to get started</h3>
<p>Supported formats: JPG, PNG, BMP, TIFF</p>
</div>
""", unsafe_allow_html=True)
# Footer
st.markdown("---")
st.markdown("""
<div style="text-align: center; color: #666;">
<p>Built with <strong>Streamlit</strong> and <strong>Hugging Face Transformers</strong></p>
<p>Powered by <strong>BLIP-2</strong> - Bootstrapping Language-Image Pre-training</p>
</div>
""", unsafe_allow_html=True)
if __name__ == "__main__":
main()