Spaces:
Running
Running
| """ | |
| EASI Severity Prediction REST API with Batch Processing | |
| ======================================================== | |
| FastAPI-based REST API for predicting EASI scores from dermatological images. | |
| Now supports both single and batch image processing! | |
| New Features: | |
| - POST /predict/batch - Process multiple images in one request | |
| - Configurable max batch size and timeout | |
| - Parallel processing for faster batch predictions | |
| Endpoints: | |
| - POST /predict - Upload single image and get EASI predictions | |
| - POST /predict/batch - Upload multiple images (up to 10 at once) | |
| - GET /health - Health check endpoint | |
| - GET /conditions - Get list of available conditions | |
| - GET /docs - Interactive API documentation | |
| Installation: | |
| pip install fastapi uvicorn python-multipart pillow tensorflow numpy pandas huggingface-hub | |
| Run locally: | |
| uvicorn api:app --host 0.0.0.0 --port 8000 --reload | |
| """ | |
| import os | |
| import warnings | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from io import BytesIO | |
| from pathlib import Path | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| # Suppress warnings | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' | |
| os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = '' | |
| warnings.filterwarnings('ignore') | |
| logging.getLogger('absl').setLevel(logging.ERROR) | |
| import tensorflow as tf | |
| tf.get_logger().setLevel('ERROR') | |
| tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| import numpy as np | |
| from PIL import Image | |
| import pickle | |
| import pandas as pd | |
| from huggingface_hub import hf_hub_download | |
| # Configuration | |
| MAX_BATCH_SIZE = 10 # Maximum images per batch request | |
| BATCH_TIMEOUT = 300 # Timeout in seconds for batch processing | |
| HF_REPO_ID = "google/derm-foundation" | |
| DERM_FOUNDATION_PATH = "./derm_foundation/" | |
| EASI_MODEL_PATH = './trained_model/easi_severity_model_derm_foundation_individual.pkl' | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="EASI Severity Prediction API", | |
| description="REST API for predicting EASI scores from skin images. Supports single and batch processing.", | |
| version="2.1.0", | |
| docs_url="/docs", | |
| redoc_url="/redoc" | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Thread pool for parallel processing | |
| executor = ThreadPoolExecutor(max_workers=4) | |
| # Response Models | |
| class ConditionPrediction(BaseModel): | |
| condition: str | |
| probability: float = Field(..., ge=0, le=1) | |
| confidence: float = Field(..., ge=0) | |
| weight: float = Field(..., ge=0) | |
| easi_category: Optional[str] = None | |
| easi_contribution: int = Field(..., ge=0, le=3) | |
| class EASIComponent(BaseModel): | |
| name: str | |
| score: int = Field(..., ge=0, le=3) | |
| contributing_conditions: List[Dict[str, Any]] | |
| class PredictionResponse(BaseModel): | |
| success: bool | |
| total_easi_score: int = Field(..., ge=0, le=12) | |
| severity_interpretation: str | |
| easi_components: Dict[str, EASIComponent] | |
| predicted_conditions: List[ConditionPrediction] | |
| summary_statistics: Dict[str, float] | |
| image_info: Dict[str, Any] | |
| class BatchPredictionResponse(BaseModel): | |
| success: bool | |
| total_images_processed: int | |
| successful_predictions: int | |
| failed_predictions: int | |
| results: List[Optional[PredictionResponse]] | |
| errors: List[Optional[str]] | |
| processing_time_seconds: float | |
| class HealthResponse(BaseModel): | |
| status: str | |
| models_loaded: Dict[str, bool] | |
| available_conditions: int | |
| hf_token_configured: bool | |
| deployment_platform: str | |
| batch_processing_enabled: bool | |
| max_batch_size: int | |
| space_info: Optional[Dict[str, str]] = None | |
| class ErrorResponse(BaseModel): | |
| success: bool = False | |
| error: str | |
| detail: Optional[str] = None | |
| # Model wrapper class | |
| class DermFoundationNeuralNetwork: | |
| def __init__(self): | |
| self.model = None | |
| self.mlb = None | |
| self.embedding_scaler = None | |
| self.confidence_scaler = None | |
| self.weighted_scaler = None | |
| def load_model(self, filepath): | |
| try: | |
| with open(filepath, 'rb') as f: | |
| model_data = pickle.load(f) | |
| self.mlb = model_data['mlb'] | |
| self.embedding_scaler = model_data['embedding_scaler'] | |
| self.confidence_scaler = model_data['confidence_scaler'] | |
| self.weighted_scaler = model_data['weighted_scaler'] | |
| keras_model_path = model_data['keras_model_path'] | |
| if not os.path.exists(keras_model_path): | |
| print(f"Original keras path not found: {keras_model_path}") | |
| pickle_dir = os.path.dirname(os.path.abspath(filepath)) | |
| normalized_path = keras_model_path.replace('\\', '/') | |
| keras_filename = normalized_path.split('/')[-1] | |
| print(f"Extracted filename: {keras_filename}") | |
| alternative_path = os.path.join(pickle_dir, keras_filename) | |
| print(f"Trying alternative path: {alternative_path}") | |
| if os.path.exists(alternative_path): | |
| keras_model_path = alternative_path | |
| print(f"β Found keras model at: {keras_model_path}") | |
| else: | |
| print(f"β Keras model not found at alternative path either") | |
| return False | |
| else: | |
| print(f"β Found keras model at original path: {keras_model_path}") | |
| self.model = tf.keras.models.load_model(keras_model_path) | |
| print(f"β Keras model loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def predict(self, embedding): | |
| if self.model is None: | |
| return None | |
| if len(embedding.shape) == 1: | |
| embedding = embedding.reshape(1, -1) | |
| embedding_scaled = self.embedding_scaler.transform(embedding) | |
| predictions = self.model.predict(embedding_scaled, verbose=0) | |
| condition_probs = predictions['conditions'][0] | |
| individual_confidences = predictions['individual_confidences'][0] | |
| individual_weights = predictions['individual_weights'][0] | |
| condition_threshold = 0.3 | |
| predicted_condition_indices = np.where(condition_probs > condition_threshold)[0] | |
| predicted_conditions = [] | |
| predicted_confidences = [] | |
| predicted_weights_dict = {} | |
| for idx in predicted_condition_indices: | |
| condition_name = self.mlb.classes_[idx] | |
| condition_prob = float(condition_probs[idx]) | |
| if individual_confidences[idx] > 0: | |
| confidence_orig = self.confidence_scaler.inverse_transform([[individual_confidences[idx]]])[0, 0] | |
| else: | |
| confidence_orig = 0.0 | |
| if individual_weights[idx] > 0: | |
| weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[idx]]])[0, 0] | |
| else: | |
| weight_orig = 0.0 | |
| predicted_conditions.append(condition_name) | |
| predicted_confidences.append(max(0, confidence_orig)) | |
| predicted_weights_dict[condition_name] = max(0, weight_orig) | |
| all_condition_probs = {} | |
| all_confidences = {} | |
| all_weights = {} | |
| for i, class_name in enumerate(self.mlb.classes_): | |
| all_condition_probs[class_name] = float(condition_probs[i]) | |
| if individual_confidences[i] > 0: | |
| conf_orig = self.confidence_scaler.inverse_transform([[individual_confidences[i]]])[0, 0] | |
| all_confidences[class_name] = max(0, conf_orig) | |
| else: | |
| all_confidences[class_name] = 0.0 | |
| if individual_weights[i] > 0: | |
| weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[i]]])[0, 0] | |
| all_weights[class_name] = max(0, weight_orig) | |
| else: | |
| all_weights[class_name] = 0.0 | |
| return { | |
| 'dermatologist_skin_condition_on_label_name': predicted_conditions, | |
| 'dermatologist_skin_condition_confidence': predicted_confidences, | |
| 'weighted_skin_condition_label': predicted_weights_dict, | |
| 'all_condition_probabilities': all_condition_probs, | |
| 'all_individual_confidences': all_confidences, | |
| 'all_individual_weights': all_weights, | |
| 'condition_threshold': condition_threshold | |
| } | |
| # Helper function to download from Hugging Face | |
| def download_derm_foundation_from_hf(output_dir): | |
| """Download Derm Foundation model from Hugging Face Hub""" | |
| try: | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| print("=" * 80) | |
| print("DOWNLOADING DERM FOUNDATION MODEL FROM HUGGING FACE") | |
| print("=" * 80) | |
| if hf_token: | |
| print(f"β HF Token found (length: {len(hf_token)})") | |
| else: | |
| print("β No HF Token found - attempting anonymous download") | |
| os.makedirs(output_dir, exist_ok=True) | |
| files_to_download = [ | |
| "saved_model.pb", | |
| "variables/variables.data-00000-of-00001", | |
| "variables/variables.index" | |
| ] | |
| for file_path in files_to_download: | |
| print(f"\nπ₯ Downloading: {file_path}") | |
| try: | |
| downloaded_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=file_path, | |
| token=hf_token, | |
| local_dir=output_dir, | |
| local_dir_use_symlinks=False, | |
| resume_download=True | |
| ) | |
| if os.path.exists(downloaded_path): | |
| file_size_mb = os.path.getsize(downloaded_path) / (1024 * 1024) | |
| print(f"β Downloaded successfully ({file_size_mb:.2f} MB)") | |
| else: | |
| print(f"β File not found after download: {downloaded_path}") | |
| return False | |
| except Exception as download_error: | |
| print(f"β Failed to download {file_path}") | |
| print(f" Error: {str(download_error)}") | |
| raise | |
| print("\n" + "=" * 80) | |
| print("β DERM FOUNDATION MODEL DOWNLOADED SUCCESSFULLY") | |
| print("=" * 80) | |
| return True | |
| except Exception as e: | |
| print("\n" + "=" * 80) | |
| print("β ERROR DOWNLOADING MODEL") | |
| print("=" * 80) | |
| print(f"Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # EASI calculation functions | |
| def calculate_easi_scores(predictions): | |
| easi_categories = { | |
| 'erythema': { | |
| 'name': 'Erythema (Redness)', | |
| 'conditions': [ | |
| 'Post-Inflammatory hyperpigmentation', 'Erythema ab igne', 'Erythema annulare centrifugum', | |
| 'Erythema elevatum diutinum', 'Erythema gyratum repens', 'Erythema multiforme', | |
| 'Erythema nodosum', 'Flagellate erythema', 'Annular erythema', 'Drug Rash', | |
| 'Allergic Contact Dermatitis', 'Irritant Contact Dermatitis', 'Contact dermatitis', | |
| 'Acute dermatitis', 'Chronic dermatitis', 'Acute and chronic dermatitis', | |
| 'Sunburn', 'Photodermatitis', 'Phytophotodermatitis', 'Rosacea', | |
| 'Seborrheic Dermatitis', 'Stasis Dermatitis', 'Perioral Dermatitis', | |
| 'Burn erythema of abdominal wall', 'Burn erythema of back of hand', | |
| 'Burn erythema of lower leg', 'Cellulitis', 'Infection of skin', | |
| 'Viral Exanthem', 'Infected eczema', 'Crusted eczematous dermatitis', | |
| 'Inflammatory dermatosis', 'Vasculitis of the skin', 'Leukocytoclastic Vasculitis', | |
| 'Cutaneous lupus', 'CD - Contact dermatitis', 'Acute dermatitis, NOS', | |
| 'Herpes Simplex', 'Hypersensitivity', 'Impetigo', 'Pigmented purpuric eruption', | |
| 'Pityriasis rosea', 'Tinea', 'Tinea Versicolor' | |
| ] | |
| }, | |
| 'induration': { | |
| 'name': 'Induration/Papulation (Swelling/Bumps)', | |
| 'conditions': [ | |
| 'Prurigo nodularis', 'Urticaria', 'Granuloma annulare', 'Morphea', | |
| 'Scleroderma', 'Lichen Simplex Chronicus', 'Lichen planus', 'lichenoid eruption', | |
| 'Lichen nitidus', 'Lichen spinulosus', 'Lichen striatus', 'Keratosis pilaris', | |
| 'Molluscum Contagiosum', 'Verruca vulgaris', 'Folliculitis', 'Acne', | |
| 'Hidradenitis', 'Nodular vasculitis', 'Sweet syndrome', 'Necrobiosis lipoidica', | |
| 'Basal Cell Carcinoma', 'SCC', 'SCCIS', 'SK', 'ISK', | |
| 'Cutaneous T Cell Lymphoma', 'Skin cancer', 'Adnexal neoplasm', | |
| 'Insect Bite', 'Milia', 'Miliaria', 'Xanthoma', 'Psoriasis', | |
| 'Lichen planus/lichenoid eruption' | |
| ] | |
| }, | |
| 'excoriation': { | |
| 'name': 'Excoriation (Scratching Damage)', | |
| 'conditions': [ | |
| 'Inflicted skin lesions', 'Scabies', 'Abrasion', 'Abrasion of wrist', | |
| 'Superficial wound of body region', 'Scrape', 'Animal bite - wound', | |
| 'Pruritic dermatitis', 'Prurigo', 'Atopic dermatitis', 'Scab' | |
| ] | |
| }, | |
| 'lichenification': { | |
| 'name': 'Lichenification (Skin Thickening)', | |
| 'conditions': [ | |
| 'Lichenified eczematous dermatitis', 'Acanthosis nigricans', | |
| 'Hyperkeratosis of skin', 'HK - Hyperkeratosis', 'Keratoderma', | |
| 'Ichthyosis', 'Ichthyosiform dermatosis', 'Chronic eczema', | |
| 'Psoriasis', 'Xerosis' | |
| ] | |
| } | |
| } | |
| def probability_to_score(prob): | |
| if prob < 0.171: | |
| return 0 | |
| elif prob < 0.238: | |
| return 1 | |
| elif prob < 0.421: | |
| return 2 | |
| elif prob < 0.614: | |
| return 3 | |
| else: | |
| return 3 | |
| easi_results = {} | |
| all_condition_probs = predictions['all_condition_probabilities'] | |
| for component, category_info in easi_categories.items(): | |
| category_conditions = [] | |
| for condition_name, probability in all_condition_probs.items(): | |
| if condition_name.lower() == 'eczema': | |
| continue | |
| if condition_name in category_info['conditions']: | |
| category_conditions.append({ | |
| 'condition': condition_name, | |
| 'probability': probability, | |
| 'individual_score': probability_to_score(probability) | |
| }) | |
| category_conditions = [c for c in category_conditions if c['individual_score'] > 0] | |
| category_conditions.sort(key=lambda x: x['probability'], reverse=True) | |
| component_score = sum(c['individual_score'] for c in category_conditions) | |
| component_score = min(component_score, 3) | |
| easi_results[component] = { | |
| 'name': category_info['name'], | |
| 'score': component_score, | |
| 'contributing_conditions': category_conditions | |
| } | |
| total_easi = sum(result['score'] for result in easi_results.values()) | |
| return easi_results, total_easi | |
| def get_severity_interpretation(total_easi): | |
| if total_easi == 0: | |
| return "No significant EASI features detected" | |
| elif total_easi <= 3: | |
| return "Mild EASI severity" | |
| elif total_easi <= 6: | |
| return "Moderate EASI severity" | |
| elif total_easi <= 9: | |
| return "Severe EASI severity" | |
| else: | |
| return "Very Severe EASI severity" | |
| # Image processing functions | |
| def smart_crop_to_square(image): | |
| width, height = image.size | |
| if width == height: | |
| return image | |
| size = min(width, height) | |
| left = (width - size) // 2 | |
| top = (height - size) // 2 | |
| right = left + size | |
| bottom = top + size | |
| return image.crop((left, top, right, bottom)) | |
| def generate_derm_foundation_embedding(model, image): | |
| try: | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| buf = BytesIO() | |
| image.save(buf, format='JPEG') | |
| image_bytes = buf.getvalue() | |
| input_tensor = tf.train.Example(features=tf.train.Features( | |
| feature={'image/encoded': tf.train.Feature( | |
| bytes_list=tf.train.BytesList(value=[image_bytes])) | |
| })).SerializeToString() | |
| infer = model.signatures["serving_default"] | |
| output = infer(inputs=tf.constant([input_tensor])) | |
| if 'embedding' in output: | |
| embedding_vector = output['embedding'].numpy().flatten() | |
| else: | |
| key = list(output.keys())[0] | |
| embedding_vector = output[key].numpy().flatten() | |
| return embedding_vector | |
| except Exception as e: | |
| raise Exception(f"Error generating embedding: {str(e)}") | |
| def process_single_image_sync(image_bytes: bytes, filename: str) -> Dict[str, Any]: | |
| """ | |
| Synchronous function to process a single image. | |
| Returns dict with 'success', 'result', and 'error' keys. | |
| """ | |
| try: | |
| # Read and process image | |
| original_image = Image.open(BytesIO(image_bytes)).convert('RGB') | |
| original_size = original_image.size | |
| # Process to 448x448 | |
| cropped_img = smart_crop_to_square(original_image) | |
| processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS) | |
| # Generate embedding | |
| embedding = generate_derm_foundation_embedding(derm_model, processed_img) | |
| # Make prediction | |
| predictions = easi_model.predict(embedding) | |
| if predictions is None: | |
| return { | |
| 'success': False, | |
| 'result': None, | |
| 'error': "Prediction failed - model returned None" | |
| } | |
| # Calculate EASI scores | |
| easi_results, total_easi = calculate_easi_scores(predictions) | |
| severity = get_severity_interpretation(total_easi) | |
| # Format predicted conditions | |
| predicted_conditions = [] | |
| for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']): | |
| prob = predictions['all_condition_probabilities'][condition] | |
| conf = predictions['dermatologist_skin_condition_confidence'][i] | |
| weight = predictions['weighted_skin_condition_label'][condition] | |
| # Find EASI category | |
| easi_category = None | |
| easi_contribution = 0 | |
| for cat_key, cat_info in easi_results.items(): | |
| for contrib in cat_info['contributing_conditions']: | |
| if contrib['condition'] == condition: | |
| easi_category = cat_info['name'] | |
| easi_contribution = contrib['individual_score'] | |
| break | |
| predicted_conditions.append(ConditionPrediction( | |
| condition=condition, | |
| probability=float(prob), | |
| confidence=float(conf), | |
| weight=float(weight), | |
| easi_category=easi_category, | |
| easi_contribution=easi_contribution | |
| )) | |
| # Summary statistics | |
| summary_stats = { | |
| "total_conditions": len(predicted_conditions), | |
| "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0, | |
| "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0, | |
| "total_weight": float(sum(predictions['weighted_skin_condition_label'].values())) | |
| } | |
| # Format EASI components | |
| easi_components_formatted = { | |
| component: EASIComponent( | |
| name=result['name'], | |
| score=result['score'], | |
| contributing_conditions=result['contributing_conditions'] | |
| ) | |
| for component, result in easi_results.items() | |
| } | |
| result = PredictionResponse( | |
| success=True, | |
| total_easi_score=total_easi, | |
| severity_interpretation=severity, | |
| easi_components=easi_components_formatted, | |
| predicted_conditions=predicted_conditions, | |
| summary_statistics=summary_stats, | |
| image_info={ | |
| "original_size": f"{original_size[0]}x{original_size[1]}", | |
| "processed_size": "448x448", | |
| "filename": filename | |
| } | |
| ) | |
| return { | |
| 'success': True, | |
| 'result': result, | |
| 'error': None | |
| } | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"Error processing image {filename}: {str(e)}") | |
| print(error_traceback) | |
| return { | |
| 'success': False, | |
| 'result': None, | |
| 'error': str(e) | |
| } | |
| # Global model instances | |
| derm_model = None | |
| easi_model = None | |
| deployment_platform = "huggingface_spaces" | |
| async def load_models(): | |
| """Load models on startup""" | |
| global derm_model, easi_model, deployment_platform | |
| import gc | |
| gc.collect() | |
| print("\n" + "=" * 80) | |
| print("π STARTING EASI API WITH BATCH PROCESSING") | |
| print("=" * 80) | |
| # Detect if running on HF Spaces | |
| space_id = os.environ.get("SPACE_ID") | |
| space_author = os.environ.get("SPACE_AUTHOR_NAME") | |
| space_host = os.environ.get("SPACE_HOST") | |
| if space_id: | |
| deployment_platform = f"huggingface_spaces ({space_id})" | |
| print(f"π Space: {space_id}") | |
| print(f"π€ Author: {space_author}") | |
| print(f"π Host: {space_host}") | |
| else: | |
| deployment_platform = "local" | |
| print("π Running locally") | |
| print(f"π’ Max batch size: {MAX_BATCH_SIZE}") | |
| print(f"β±οΈ Batch timeout: {BATCH_TIMEOUT}s") | |
| print("=" * 80) | |
| # Check HF Token | |
| hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") | |
| if hf_token: | |
| print(f"β HF Token configured (length: {len(hf_token)})") | |
| else: | |
| print("β No HF Token found") | |
| print("=" * 80) | |
| # Check if Derm Foundation model exists locally | |
| model_files_exist = ( | |
| os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")) and | |
| os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "variables")) | |
| ) | |
| if not model_files_exist: | |
| print("\nπ₯ Derm Foundation model not found - downloading from Hugging Face...") | |
| success = download_derm_foundation_from_hf(DERM_FOUNDATION_PATH) | |
| if not success: | |
| print("\nβ CRITICAL: Failed to download Derm Foundation model!") | |
| return | |
| else: | |
| print("\nβ Derm Foundation model found locally (using cache)") | |
| # Load Derm Foundation model | |
| print("\n" + "=" * 80) | |
| print("π¦ LOADING DERM FOUNDATION MODEL") | |
| print("=" * 80) | |
| try: | |
| print(f"Loading from: {DERM_FOUNDATION_PATH}") | |
| gc.collect() | |
| derm_model = tf.saved_model.load(DERM_FOUNDATION_PATH) | |
| print("β Derm Foundation model loaded successfully!") | |
| gc.collect() | |
| except Exception as e: | |
| print(f"β Failed to load Derm Foundation model: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| # Load EASI model | |
| print("\n" + "=" * 80) | |
| print("π¦ LOADING EASI PREDICTION MODEL") | |
| print("=" * 80) | |
| if os.path.exists(EASI_MODEL_PATH): | |
| easi_model = DermFoundationNeuralNetwork() | |
| success = easi_model.load_model(EASI_MODEL_PATH) | |
| if success: | |
| print(f"β EASI model loaded from: {EASI_MODEL_PATH}") | |
| print(f" Available conditions: {len(easi_model.mlb.classes_)}") | |
| else: | |
| print(f"β Failed to load EASI model") | |
| easi_model = None | |
| else: | |
| print(f"β EASI model not found at: {EASI_MODEL_PATH}") | |
| # Final status | |
| print("\n" + "=" * 80) | |
| print("π STARTUP COMPLETE") | |
| print("=" * 80) | |
| print(f"Derm Foundation Model: {'β Loaded' if derm_model else 'β Failed'}") | |
| print(f"EASI Prediction Model: {'β Loaded' if easi_model else 'β Failed'}") | |
| print(f"Batch Processing: β Enabled (max {MAX_BATCH_SIZE} images)") | |
| print(f"Platform: {deployment_platform}") | |
| print("=" * 80) | |
| if derm_model and easi_model: | |
| print("β All systems ready! API is operational.") | |
| else: | |
| print("β οΈ WARNING: Some models failed to load.") | |
| print("=" * 80 + "\n") | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| space_info = { | |
| "space_id": os.environ.get("SPACE_ID", "local"), | |
| "space_author": os.environ.get("SPACE_AUTHOR_NAME", "unknown"), | |
| "space_host": os.environ.get("SPACE_HOST", "localhost") | |
| } | |
| return { | |
| "message": "EASI Severity Prediction API with Batch Processing", | |
| "version": "2.1.0", | |
| "platform": deployment_platform, | |
| "space_info": space_info, | |
| "status": "operational" if (derm_model and easi_model) else "degraded", | |
| "batch_processing": { | |
| "enabled": True, | |
| "max_batch_size": MAX_BATCH_SIZE, | |
| "timeout_seconds": BATCH_TIMEOUT | |
| }, | |
| "endpoints": { | |
| "health": "/health", | |
| "predict": "/predict (single image)", | |
| "predict_batch": "/predict/batch (multiple images)", | |
| "conditions": "/conditions", | |
| "docs": "/docs", | |
| "redoc": "/redoc" | |
| }, | |
| "documentation": "Visit /docs for interactive API documentation" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| space_info = None | |
| if os.environ.get("SPACE_ID"): | |
| space_info = { | |
| "space_id": os.environ.get("SPACE_ID"), | |
| "space_author": os.environ.get("SPACE_AUTHOR_NAME"), | |
| "space_host": os.environ.get("SPACE_HOST") | |
| } | |
| return { | |
| "status": "healthy" if (derm_model is not None and easi_model is not None) else "degraded", | |
| "models_loaded": { | |
| "derm_foundation": derm_model is not None, | |
| "easi_model": easi_model is not None | |
| }, | |
| "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0, | |
| "hf_token_configured": (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")) is not None, | |
| "deployment_platform": deployment_platform, | |
| "batch_processing_enabled": True, | |
| "max_batch_size": MAX_BATCH_SIZE, | |
| "space_info": space_info | |
| } | |
| async def get_conditions(): | |
| """Get list of available conditions""" | |
| if easi_model is None: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="EASI model not loaded. Check server logs or /health endpoint." | |
| ) | |
| return { | |
| "conditions": easi_model.mlb.classes_.tolist(), | |
| "total_count": len(easi_model.mlb.classes_) | |
| } | |
| async def predict_easi( | |
| file: UploadFile = File(..., description="Skin image file (JPG, JPEG, PNG)") | |
| ): | |
| """ | |
| Predict EASI scores from uploaded skin image. | |
| - **file**: Image file (JPG, JPEG, PNG) | |
| - Returns: EASI scores, component breakdown, and condition predictions | |
| """ | |
| # Validate models loaded | |
| if derm_model is None or easi_model is None: | |
| error_detail = [] | |
| if derm_model is None: | |
| error_detail.append("Derm Foundation model not loaded") | |
| if easi_model is None: | |
| error_detail.append("EASI model not loaded") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Models not available: {', '.join(error_detail)}. Check /health endpoint for details." | |
| ) | |
| # Validate file type | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="File must be an image (JPG, JPEG, PNG). Received: " + str(file.content_type) | |
| ) | |
| try: | |
| # Read image bytes | |
| image_bytes = await file.read() | |
| # Process image synchronously | |
| result = process_single_image_sync(image_bytes, file.filename) | |
| if not result['success']: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image: {result['error']}" | |
| ) | |
| return result['result'] | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"Error processing image: {str(e)}") | |
| print(error_traceback) | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error processing image: {str(e)}" | |
| ) | |
| async def predict_easi_batch( | |
| files: List[UploadFile] = File(..., description=f"Multiple skin image files (max {MAX_BATCH_SIZE})") | |
| ): | |
| """ | |
| Predict EASI scores from multiple uploaded skin images in parallel. | |
| - **files**: List of image files (JPG, JPEG, PNG) - max 10 images per request | |
| - Returns: Batch results with individual predictions and errors | |
| **Example Usage (Python):** | |
| ```python | |
| import requests | |
| files = [ | |
| ('files', open('image1.jpg', 'rb')), | |
| ('files', open('image2.jpg', 'rb')), | |
| ('files', open('image3.jpg', 'rb')) | |
| ] | |
| response = requests.post('http://localhost:8000/predict/batch', files=files) | |
| results = response.json() | |
| ``` | |
| **Example Usage (cURL):** | |
| ```bash | |
| curl -X POST "http://localhost:8000/predict/batch" \ | |
| -F "[email protected]" \ | |
| -F "[email protected]" \ | |
| -F "[email protected]" | |
| ``` | |
| """ | |
| import time | |
| start_time = time.time() | |
| # Validate models loaded | |
| if derm_model is None or easi_model is None: | |
| error_detail = [] | |
| if derm_model is None: | |
| error_detail.append("Derm Foundation model not loaded") | |
| if easi_model is None: | |
| error_detail.append("EASI model not loaded") | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Models not available: {', '.join(error_detail)}. Check /health endpoint." | |
| ) | |
| # Validate batch size | |
| num_files = len(files) | |
| if num_files == 0: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="No files provided. Please upload at least one image." | |
| ) | |
| if num_files > MAX_BATCH_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Too many files. Maximum batch size is {MAX_BATCH_SIZE}, received {num_files}." | |
| ) | |
| print(f"\nπ Processing batch of {num_files} images...") | |
| # Validate file types and read all files | |
| image_data = [] | |
| for idx, file in enumerate(files): | |
| if not file.content_type or not file.content_type.startswith('image/'): | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File {idx+1} ('{file.filename}') is not an image. Received: {file.content_type}" | |
| ) | |
| try: | |
| image_bytes = await file.read() | |
| image_data.append({ | |
| 'bytes': image_bytes, | |
| 'filename': file.filename, | |
| 'index': idx | |
| }) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Error reading file {idx+1} ('{file.filename}'): {str(e)}" | |
| ) | |
| # Process images in parallel using thread pool | |
| try: | |
| loop = asyncio.get_event_loop() | |
| # Create tasks for parallel processing | |
| tasks = [ | |
| loop.run_in_executor( | |
| executor, | |
| process_single_image_sync, | |
| img['bytes'], | |
| img['filename'] | |
| ) | |
| for img in image_data | |
| ] | |
| # Wait for all tasks with timeout | |
| results = await asyncio.wait_for( | |
| asyncio.gather(*tasks, return_exceptions=True), | |
| timeout=BATCH_TIMEOUT | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException( | |
| status_code=504, | |
| detail=f"Batch processing timeout after {BATCH_TIMEOUT} seconds. Try reducing batch size." | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error during batch processing: {str(e)}" | |
| ) | |
| # Collect results and errors | |
| prediction_results = [] | |
| error_messages = [] | |
| successful_count = 0 | |
| failed_count = 0 | |
| for idx, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| # Handle exception during processing | |
| prediction_results.append(None) | |
| error_messages.append(f"Exception: {str(result)}") | |
| failed_count += 1 | |
| print(f" β Image {idx+1} failed: {str(result)}") | |
| elif result['success']: | |
| prediction_results.append(result['result']) | |
| error_messages.append(None) | |
| successful_count += 1 | |
| print(f" β Image {idx+1} processed successfully") | |
| else: | |
| prediction_results.append(None) | |
| error_messages.append(result['error']) | |
| failed_count += 1 | |
| print(f" β Image {idx+1} failed: {result['error']}") | |
| processing_time = time.time() - start_time | |
| print(f"β Batch complete: {successful_count} successful, {failed_count} failed in {processing_time:.2f}s\n") | |
| return BatchPredictionResponse( | |
| success=True, | |
| total_images_processed=num_files, | |
| successful_predictions=successful_count, | |
| failed_predictions=failed_count, | |
| results=prediction_results, | |
| errors=error_messages, | |
| processing_time_seconds=round(processing_time, 2) | |
| ) | |
| async def http_exception_handler(request, exc): | |
| """Custom HTTP exception handler""" | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content=ErrorResponse( | |
| error=exc.detail, | |
| detail=str(exc) | |
| ).dict() | |
| ) | |
| async def general_exception_handler(request, exc): | |
| """General exception handler for unexpected errors""" | |
| import traceback | |
| error_traceback = traceback.format_exc() | |
| print(f"Unexpected error: {str(exc)}") | |
| print(error_traceback) | |
| return JSONResponse( | |
| status_code=500, | |
| content=ErrorResponse( | |
| error="Internal server error", | |
| detail=str(exc) | |
| ).dict() | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("=" * 80) | |
| print("π Starting EASI API Server with Batch Processing") | |
| print("=" * 80) | |
| print(f"Max batch size: {MAX_BATCH_SIZE} images") | |
| print(f"Batch timeout: {BATCH_TIMEOUT} seconds") | |
| print("=" * 80) | |
| print("Access the API at: http://localhost:8000") | |
| print("Interactive docs: http://localhost:8000/docs") | |
| print("=" * 80) | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=8000, | |
| log_level="info" | |
| ) |