Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """MAE ViT-Base waste classifier for inference.""" | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| from PIL import Image | |
| import timm | |
| import os | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| class MAEWasteClassifier: | |
| """Waste classifier using finetuned MAE ViT-Base model.""" | |
| def __init__(self, model_path=None, hf_model_id="ysfad/mae-waste-classifier", device=None): | |
| self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.hf_model_id = hf_model_id | |
| # Try to load model from different sources | |
| if model_path and os.path.exists(model_path): | |
| self.model_path = model_path | |
| print(f"π Using local model: {model_path}") | |
| else: | |
| # Try to download from HF Hub | |
| try: | |
| print(f"π Downloading model from HF Hub: {hf_model_id}") | |
| self.model_path = hf_hub_download( | |
| repo_id=hf_model_id, | |
| filename="best_model.pth", | |
| cache_dir="./hf_cache" | |
| ) | |
| print(f"β Downloaded model to: {self.model_path}") | |
| except Exception as e: | |
| print(f"β οΈ Could not download from HF Hub: {e}") | |
| # Fallback to local path | |
| self.model_path = "output_simple_mae/best_model.pth" | |
| if not os.path.exists(self.model_path): | |
| raise FileNotFoundError(f"Model not found locally at {self.model_path} and could not download from HF Hub") | |
| # Class names from training | |
| self.class_names = [ | |
| 'Cardboard', 'Food Organics', 'Glass', 'Metal', | |
| 'Miscellaneous Trash', 'Paper', 'Plastic', | |
| 'Textile Trash', 'Vegetation' | |
| ] | |
| # Load disposal instructions | |
| self.disposal_instructions = { | |
| "Cardboard": "Flatten and place in recycling bin. Remove any tape or staples.", | |
| "Food Organics": "Compost in organic waste bin or home composter.", | |
| "Glass": "Rinse and place in glass recycling. Remove lids and caps.", | |
| "Metal": "Rinse aluminum/steel cans and place in recycling bin.", | |
| "Miscellaneous Trash": "Dispose in general waste bin. Cannot be recycled.", | |
| "Paper": "Place clean paper in recycling. Remove plastic windows from envelopes.", | |
| "Plastic": "Check recycling number. Rinse containers before recycling.", | |
| "Textile Trash": "Donate if reusable, otherwise dispose in textile recycling.", | |
| "Vegetation": "Compost in organic waste or use for mulch in garden." | |
| } | |
| # Load model | |
| self.model = self._load_model() | |
| # Image preprocessing | |
| self.transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| print(f"β MAE Waste Classifier loaded on {self.device}") | |
| print(f"π Model: ViT-Base MAE, Classes: {len(self.class_names)}") | |
| def _load_model(self): | |
| """Load the finetuned MAE model.""" | |
| try: | |
| # Create ViT model using timm | |
| model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=len(self.class_names)) | |
| # Load checkpoint | |
| checkpoint = torch.load(self.model_path, map_location=self.device) | |
| # Load state dict | |
| if 'model_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.to(self.device) | |
| model.eval() | |
| print(f"β Loaded finetuned MAE model from {self.model_path}") | |
| return model | |
| except Exception as e: | |
| print(f"β Error loading model: {e}") | |
| raise | |
| def classify_image(self, image, top_k=5): | |
| """ | |
| Classify a waste image. | |
| Args: | |
| image: PIL Image or path to image | |
| top_k: Number of top predictions to return | |
| Returns: | |
| dict: Classification results | |
| """ | |
| try: | |
| # Load and preprocess image | |
| if isinstance(image, str): | |
| image = Image.open(image).convert('RGB') | |
| elif not isinstance(image, Image.Image): | |
| raise ValueError("Image must be PIL Image or path string") | |
| # Preprocess | |
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) | |
| # Inference | |
| with torch.no_grad(): | |
| outputs = self.model(input_tensor) | |
| probabilities = F.softmax(outputs, dim=1) | |
| # Get top predictions | |
| top_probs, top_indices = torch.topk(probabilities, k=min(top_k, len(self.class_names))) | |
| top_predictions = [] | |
| for prob, idx in zip(top_probs[0], top_indices[0]): | |
| top_predictions.append({ | |
| 'class': self.class_names[idx.item()], | |
| 'confidence': prob.item() | |
| }) | |
| # Best prediction | |
| best_pred = top_predictions[0] | |
| return { | |
| 'success': True, | |
| 'predicted_class': best_pred['class'], | |
| 'confidence': best_pred['confidence'], | |
| 'top_predictions': top_predictions | |
| } | |
| except Exception as e: | |
| return { | |
| 'success': False, | |
| 'error': str(e) | |
| } | |
| def get_disposal_instructions(self, class_name): | |
| """Get disposal instructions for a waste class.""" | |
| return self.disposal_instructions.get(class_name, "No specific instructions available.") | |
| def get_model_info(self): | |
| """Get information about the loaded model.""" | |
| return { | |
| 'model_name': 'ViT-Base MAE', | |
| 'architecture': 'Vision Transformer (ViT-Base)', | |
| 'pretrained': 'MAE (Masked Autoencoder)', | |
| 'num_classes': len(self.class_names), | |
| 'device': self.device, | |
| 'model_path': self.model_path | |
| } | |
| # Test the classifier | |
| if __name__ == "__main__": | |
| print("π§ͺ Testing MAE Waste Classifier...") | |
| try: | |
| # Initialize classifier | |
| classifier = MAEWasteClassifier() | |
| # Test with a sample image if available | |
| test_images = [ | |
| "fail_images/image.webp", | |
| "fail_images/IMG_9501.webp" | |
| ] | |
| for img_path in test_images: | |
| if os.path.exists(img_path): | |
| print(f"\nπ Testing with {img_path}") | |
| result = classifier.classify_image(img_path) | |
| if result['success']: | |
| print(f"β Predicted: {result['predicted_class']} ({result['confidence']:.3f})") | |
| print(f"π Instructions: {classifier.get_disposal_instructions(result['predicted_class'])}") | |
| print("\nπ Top predictions:") | |
| for i, pred in enumerate(result['top_predictions'][:3], 1): | |
| print(f" {i}. {pred['class']}: {pred['confidence']:.3f}") | |
| else: | |
| print(f"β Error: {result['error']}") | |
| break | |
| else: | |
| print("βΉοΈ No test images found, but classifier loaded successfully!") | |
| # Print model info | |
| info = classifier.get_model_info() | |
| print(f"\nπ€ Model Info:") | |
| for key, value in info.items(): | |
| print(f" {key}: {value}") | |
| print("\nSuccess!") | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| import traceback | |
| traceback.print_exc() |