"""MobileCLIP2 Zero-Shot Classification Demo""" import torch import open_clip import gradio as gr from mobileclip.modules.common.mobileone import reparameterize_model import spaces ################################################################################ # Model Configuration ################################################################################ AVAILABLE_MODELS = { "MobileCLIP2-B": ("MobileCLIP2-B", "dfndr2b"), "MobileCLIP2-S0": ("MobileCLIP2-S0", "dfndr2b"), "MobileCLIP2-S2": ("MobileCLIP2-S2", "dfndr2b"), "MobileCLIP2-S3": ("MobileCLIP2-S3", "dfndr2b"), "MobileCLIP2-S4": ("MobileCLIP2-S4", "dfndr2b"), "MobileCLIP2-L-14": ("MobileCLIP2-L-14", "dfndr2b"), } # Cache for loaded models model_cache = {} ################################################################################ # Model Loading ################################################################################ def load_model(model_name): """Load and cache MobileCLIP2 model""" if model_name in model_cache: return model_cache[model_name] model_id, pretrained = AVAILABLE_MODELS[model_name] # Create model and preprocessing transforms model, _, preprocess = open_clip.create_model_and_transforms( model_id, pretrained=pretrained ) tokenizer = open_clip.get_tokenizer(model_id) # Reparameterize model for inference model = reparameterize_model(model.eval()) # Cache the model components model_cache[model_name] = { "model": model, "preprocess": preprocess, "tokenizer": tokenizer } return model_cache[model_name] ################################################################################ # Inference ################################################################################ @spaces.GPU(duration=120) def classify_image(image, candidate_labels, model_name): """ Classify image using selected MobileCLIP2 model Args: image: PIL Image candidate_labels: comma-separated string of labels model_name: selected model from dropdown Returns: Dictionary of label probabilities """ if image is None: return {} # Parse labels labels = [label.strip() for label in candidate_labels.split(",") if label.strip()] if not labels: return {} # Load model components model_components = load_model(model_name) model = model_components["model"] preprocess = model_components["preprocess"] tokenizer = model_components["tokenizer"] # Preprocess image image_tensor = preprocess(image.convert('RGB')).unsqueeze(0) # Tokenize text text_tokens = tokenizer(labels) # Run inference with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image_tensor) text_features = model.encode_text(text_tokens) # Normalize features image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) # Compute similarity and probabilities text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) # Format output as dictionary output = {labels[i]: float(text_probs[0][i]) for i in range(len(labels))} return output ################################################################################ # Gradio Interface ################################################################################ with gr.Blocks() as demo: gr.Markdown("# MobileCLIP2 Zero-Shot Image Classification") gr.Markdown( "Classify images using MobileCLIP2 models. Select a model, upload an image, " "and provide comma-separated class labels to get predictions." ) gr.Markdown("See [MobileCLIP2 model collection](https://huggingface.co/collections/apple/mobileclip2-68ac947dcb035c54bcd20c47).") with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown( choices=list(AVAILABLE_MODELS.keys()), value="MobileCLIP2-S2", label="Select MobileCLIP2 Model", info="Choose which model to use for classification" ) image_input = gr.Image(type="pil", label="Upload Image") text_input = gr.Textbox( label="Class Labels (comma separated)", placeholder="e.g., a cat, a dog, a bird" ) run_button = gr.Button("Classify", variant="primary") with gr.Column(): output_label = gr.Label( label="Classification Results", num_top_classes=5 ) # Examples examples = [ ["MobileCLIP2-S2", "./baklava.jpg", "dessert on a plate, a serving of baklava, a plate and spoon"], ["MobileCLIP2-S2", "./cat.jpg", "a cat, two cats, three cats"], ["MobileCLIP2-S2", "./cat.jpg", "two sleeping cats, two cats playing, three cats laying down"], ] gr.Examples( examples=examples, inputs=[model_dropdown, image_input, text_input], outputs=[output_label], fn=classify_image, cache_examples=False ) # Connect button run_button.click( fn=classify_image, inputs=[image_input, text_input, model_dropdown], outputs=[output_label] ) if __name__ == "__main__": demo.launch()