""" EASI Severity Prediction REST API ================================== FastAPI-based REST API for predicting EASI scores from dermatological images. Optimized for Hugging Face Spaces deployment. Endpoints: - POST /predict - Upload image and get EASI predictions - GET /health - Health check endpoint - GET /conditions - Get list of available conditions - GET /docs - Interactive API documentation Installation: pip install fastapi uvicorn python-multipart pillow tensorflow numpy pandas huggingface-hub Run locally: uvicorn api:app --host 0.0.0.0 --port 8000 --reload Deploy to HF Spaces: 1. Create Space with Docker SDK 2. Upload this file + Dockerfile + requirements.txt + trained_model/ 3. Accept terms for google/derm-foundation 4. Space auto-builds! """ import os import warnings import logging from typing import List, Dict, Any, Optional from io import BytesIO from pathlib import Path # Suppress warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' os.environ['MLIR_CRASH_REPRODUCER_DIRECTORY'] = '' warnings.filterwarnings('ignore') logging.getLogger('absl').setLevel(logging.ERROR) import tensorflow as tf tf.get_logger().setLevel('ERROR') tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from fastapi import FastAPI, File, UploadFile, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel, Field import numpy as np from PIL import Image import pickle import pandas as pd from huggingface_hub import hf_hub_download # Initialize FastAPI app app = FastAPI( title="EASI Severity Prediction API", description="REST API for predicting EASI scores from skin images. Deployed on Hugging Face Spaces.", version="2.0.0", docs_url="/docs", redoc_url="/redoc" ) # CORS middleware for Flutter web/mobile app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify your Flutter app domain allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configuration HF_REPO_ID = "google/derm-foundation" DERM_FOUNDATION_PATH = "./derm_foundation/" EASI_MODEL_PATH = './trained_model/easi_severity_model_derm_foundation_individual.pkl' # HF Spaces automatically injects HF_TOKEN for authenticated users HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") # Response Models class ConditionPrediction(BaseModel): condition: str probability: float = Field(..., ge=0, le=1) confidence: float = Field(..., ge=0) weight: float = Field(..., ge=0) easi_category: Optional[str] = None easi_contribution: int = Field(..., ge=0, le=3) class EASIComponent(BaseModel): name: str score: int = Field(..., ge=0, le=3) contributing_conditions: List[Dict[str, Any]] class PredictionResponse(BaseModel): success: bool total_easi_score: int = Field(..., ge=0, le=12) severity_interpretation: str easi_components: Dict[str, EASIComponent] predicted_conditions: List[ConditionPrediction] summary_statistics: Dict[str, float] image_info: Dict[str, Any] class HealthResponse(BaseModel): status: str models_loaded: Dict[str, bool] available_conditions: int hf_token_configured: bool deployment_platform: str space_info: Optional[Dict[str, str]] = None class ErrorResponse(BaseModel): success: bool = False error: str detail: Optional[str] = None # Model wrapper class class DermFoundationNeuralNetwork: def __init__(self): self.model = None self.mlb = None self.embedding_scaler = None self.confidence_scaler = None self.weighted_scaler = None def load_model(self, filepath): try: with open(filepath, 'rb') as f: model_data = pickle.load(f) self.mlb = model_data['mlb'] self.embedding_scaler = model_data['embedding_scaler'] self.confidence_scaler = model_data['confidence_scaler'] self.weighted_scaler = model_data['weighted_scaler'] # Get the original keras model path from pickle keras_model_path = model_data['keras_model_path'] # If the path doesn't exist, try looking in the same directory as the pickle file if not os.path.exists(keras_model_path): print(f"Original keras path not found: {keras_model_path}") # Get the directory where the pickle file is located pickle_dir = os.path.dirname(os.path.abspath(filepath)) # Extract just the filename, handling both Windows and Unix paths # Replace backslashes with forward slashes first normalized_path = keras_model_path.replace('\\', '/') keras_filename = normalized_path.split('/')[-1] print(f"Extracted filename: {keras_filename}") # Try looking for it in the same directory as the pickle alternative_path = os.path.join(pickle_dir, keras_filename) print(f"Trying alternative path: {alternative_path}") if os.path.exists(alternative_path): keras_model_path = alternative_path print(f"✓ Found keras model at: {keras_model_path}") else: print(f"✗ Keras model not found at alternative path either") print(f"Files in {pickle_dir}:") try: print(os.listdir(pickle_dir)) except: pass return False else: print(f"✓ Found keras model at original path: {keras_model_path}") # Load the keras model self.model = tf.keras.models.load_model(keras_model_path) print(f"✓ Keras model loaded successfully") return True except Exception as e: print(f"Error loading model: {e}") import traceback traceback.print_exc() return False def predict(self, embedding): if self.model is None: return None if len(embedding.shape) == 1: embedding = embedding.reshape(1, -1) embedding_scaled = self.embedding_scaler.transform(embedding) predictions = self.model.predict(embedding_scaled, verbose=0) condition_probs = predictions['conditions'][0] individual_confidences = predictions['individual_confidences'][0] individual_weights = predictions['individual_weights'][0] condition_threshold = 0.3 predicted_condition_indices = np.where(condition_probs > condition_threshold)[0] predicted_conditions = [] predicted_confidences = [] predicted_weights_dict = {} for idx in predicted_condition_indices: condition_name = self.mlb.classes_[idx] condition_prob = float(condition_probs[idx]) if individual_confidences[idx] > 0: confidence_orig = self.confidence_scaler.inverse_transform([[individual_confidences[idx]]])[0, 0] else: confidence_orig = 0.0 if individual_weights[idx] > 0: weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[idx]]])[0, 0] else: weight_orig = 0.0 predicted_conditions.append(condition_name) predicted_confidences.append(max(0, confidence_orig)) predicted_weights_dict[condition_name] = max(0, weight_orig) all_condition_probs = {} all_confidences = {} all_weights = {} for i, class_name in enumerate(self.mlb.classes_): all_condition_probs[class_name] = float(condition_probs[i]) if individual_confidences[i] > 0: conf_orig = self.confidence_scaler.inverse_transform([[individual_confidences[i]]])[0, 0] all_confidences[class_name] = max(0, conf_orig) else: all_confidences[class_name] = 0.0 if individual_weights[i] > 0: weight_orig = self.weighted_scaler.inverse_transform([[individual_weights[i]]])[0, 0] all_weights[class_name] = max(0, weight_orig) else: all_weights[class_name] = 0.0 return { 'dermatologist_skin_condition_on_label_name': predicted_conditions, 'dermatologist_skin_condition_confidence': predicted_confidences, 'weighted_skin_condition_label': predicted_weights_dict, 'all_condition_probabilities': all_condition_probs, 'all_individual_confidences': all_confidences, 'all_individual_weights': all_weights, 'condition_threshold': condition_threshold } # Helper function to download from Hugging Face def download_derm_foundation_from_hf(output_dir): """Download Derm Foundation model from Hugging Face Hub""" try: # Get token - on HF Spaces it's auto-injected hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") print("=" * 80) print("DOWNLOADING DERM FOUNDATION MODEL FROM HUGGING FACE") print("=" * 80) if hf_token: print(f"✓ HF Token found (length: {len(hf_token)})") else: print("⚠ No HF Token found - attempting anonymous download") print(" Note: If this fails, you need to:") print(" 1. Accept terms at https://huggingface.co/google/derm-foundation") print(" 2. Add HF_TOKEN to Space secrets") os.makedirs(output_dir, exist_ok=True) # Files to download files_to_download = [ "saved_model.pb", "variables/variables.data-00000-of-00001", "variables/variables.index" ] for file_path in files_to_download: print(f"\n📥 Downloading: {file_path}") try: downloaded_path = hf_hub_download( repo_id=HF_REPO_ID, filename=file_path, token=hf_token, local_dir=output_dir, local_dir_use_symlinks=False, resume_download=True ) # Verify file exists and get size if os.path.exists(downloaded_path): file_size_mb = os.path.getsize(downloaded_path) / (1024 * 1024) print(f"✓ Downloaded successfully ({file_size_mb:.2f} MB)") else: print(f"✗ File not found after download: {downloaded_path}") return False except Exception as download_error: print(f"✗ Failed to download {file_path}") print(f" Error: {str(download_error)}") raise print("\n" + "=" * 80) print("✓ DERM FOUNDATION MODEL DOWNLOADED SUCCESSFULLY") print("=" * 80) return True except Exception as e: print("\n" + "=" * 80) print("✗ ERROR DOWNLOADING MODEL") print("=" * 80) print(f"Error: {str(e)}") print("\nTroubleshooting steps:") print("1. Ensure you've accepted the model terms at:") print(" https://huggingface.co/google/derm-foundation") print("2. Add HF_TOKEN to your Space secrets (Settings → Repository secrets)") print("3. Make sure your token has 'Read access to gated repos' permission") import traceback traceback.print_exc() return False # EASI calculation functions def calculate_easi_scores(predictions): easi_categories = { 'erythema': { 'name': 'Erythema (Redness)', 'conditions': [ 'Post-Inflammatory hyperpigmentation', 'Erythema ab igne', 'Erythema annulare centrifugum', 'Erythema elevatum diutinum', 'Erythema gyratum repens', 'Erythema multiforme', 'Erythema nodosum', 'Flagellate erythema', 'Annular erythema', 'Drug Rash', 'Allergic Contact Dermatitis', 'Irritant Contact Dermatitis', 'Contact dermatitis', 'Acute dermatitis', 'Chronic dermatitis', 'Acute and chronic dermatitis', 'Sunburn', 'Photodermatitis', 'Phytophotodermatitis', 'Rosacea', 'Seborrheic Dermatitis', 'Stasis Dermatitis', 'Perioral Dermatitis', 'Burn erythema of abdominal wall', 'Burn erythema of back of hand', 'Burn erythema of lower leg', 'Cellulitis', 'Infection of skin', 'Viral Exanthem', 'Infected eczema', 'Crusted eczematous dermatitis', 'Inflammatory dermatosis', 'Vasculitis of the skin', 'Leukocytoclastic Vasculitis', 'Cutaneous lupus', 'CD - Contact dermatitis', 'Acute dermatitis, NOS', 'Herpes Simplex', 'Hypersensitivity', 'Impetigo', 'Pigmented purpuric eruption', 'Pityriasis rosea', 'Tinea', 'Tinea Versicolor' ] }, 'induration': { 'name': 'Induration/Papulation (Swelling/Bumps)', 'conditions': [ 'Prurigo nodularis', 'Urticaria', 'Granuloma annulare', 'Morphea', 'Scleroderma', 'Lichen Simplex Chronicus', 'Lichen planus', 'lichenoid eruption', 'Lichen nitidus', 'Lichen spinulosus', 'Lichen striatus', 'Keratosis pilaris', 'Molluscum Contagiosum', 'Verruca vulgaris', 'Folliculitis', 'Acne', 'Hidradenitis', 'Nodular vasculitis', 'Sweet syndrome', 'Necrobiosis lipoidica', 'Basal Cell Carcinoma', 'SCC', 'SCCIS', 'SK', 'ISK', 'Cutaneous T Cell Lymphoma', 'Skin cancer', 'Adnexal neoplasm', 'Insect Bite', 'Milia', 'Miliaria', 'Xanthoma', 'Psoriasis', 'Lichen planus/lichenoid eruption' ] }, 'excoriation': { 'name': 'Excoriation (Scratching Damage)', 'conditions': [ 'Inflicted skin lesions', 'Scabies', 'Abrasion', 'Abrasion of wrist', 'Superficial wound of body region', 'Scrape', 'Animal bite - wound', 'Pruritic dermatitis', 'Prurigo', 'Atopic dermatitis', 'Scab' ] }, 'lichenification': { 'name': 'Lichenification (Skin Thickening)', 'conditions': [ 'Lichenified eczematous dermatitis', 'Acanthosis nigricans', 'Hyperkeratosis of skin', 'HK - Hyperkeratosis', 'Keratoderma', 'Ichthyosis', 'Ichthyosiform dermatosis', 'Chronic eczema', 'Psoriasis', 'Xerosis' ] } } def probability_to_score(prob): if prob < 0.171: return 0 elif prob < 0.238: return 1 elif prob < 0.421: return 2 elif prob < 0.614: return 3 else: return 3 easi_results = {} all_condition_probs = predictions['all_condition_probabilities'] for component, category_info in easi_categories.items(): category_conditions = [] for condition_name, probability in all_condition_probs.items(): if condition_name.lower() == 'eczema': continue if condition_name in category_info['conditions']: category_conditions.append({ 'condition': condition_name, 'probability': probability, 'individual_score': probability_to_score(probability) }) category_conditions = [c for c in category_conditions if c['individual_score'] > 0] category_conditions.sort(key=lambda x: x['probability'], reverse=True) component_score = sum(c['individual_score'] for c in category_conditions) component_score = min(component_score, 3) easi_results[component] = { 'name': category_info['name'], 'score': component_score, 'contributing_conditions': category_conditions } total_easi = sum(result['score'] for result in easi_results.values()) return easi_results, total_easi def get_severity_interpretation(total_easi): if total_easi == 0: return "No significant EASI features detected" elif total_easi <= 3: return "Mild EASI severity" elif total_easi <= 6: return "Moderate EASI severity" elif total_easi <= 9: return "Severe EASI severity" else: return "Very Severe EASI severity" # Image processing functions def smart_crop_to_square(image): width, height = image.size if width == height: return image size = min(width, height) left = (width - size) // 2 top = (height - size) // 2 right = left + size bottom = top + size return image.crop((left, top, right, bottom)) def generate_derm_foundation_embedding(model, image): try: if image.mode != 'RGB': image = image.convert('RGB') buf = BytesIO() image.save(buf, format='JPEG') image_bytes = buf.getvalue() input_tensor = tf.train.Example(features=tf.train.Features( feature={'image/encoded': tf.train.Feature( bytes_list=tf.train.BytesList(value=[image_bytes])) })).SerializeToString() infer = model.signatures["serving_default"] output = infer(inputs=tf.constant([input_tensor])) if 'embedding' in output: embedding_vector = output['embedding'].numpy().flatten() else: key = list(output.keys())[0] embedding_vector = output[key].numpy().flatten() return embedding_vector except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating embedding: {str(e)}") # Global model instances derm_model = None easi_model = None deployment_platform = "huggingface_spaces" @app.on_event("startup") async def load_models(): """Load models on startup""" global derm_model, easi_model, deployment_platform # Force garbage collection before starting import gc gc.collect() print("\n" + "=" * 80) print("🚀 STARTING EASI API ON HUGGING FACE SPACES") print("=" * 80) # Detect if running on HF Spaces space_id = os.environ.get("SPACE_ID") space_author = os.environ.get("SPACE_AUTHOR_NAME") space_host = os.environ.get("SPACE_HOST") if space_id: deployment_platform = f"huggingface_spaces ({space_id})" print(f"📍 Space: {space_id}") print(f"👤 Author: {space_author}") print(f"🌐 Host: {space_host}") else: deployment_platform = "local" print("📍 Running locally") print("=" * 80) # Check HF Token hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") if hf_token: print(f"✓ HF Token configured (length: {len(hf_token)})") else: print("⚠ No HF Token found") print(" If model download fails, add HF_TOKEN to Space secrets") print("=" * 80) # Check if Derm Foundation model exists locally model_files_exist = ( os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "saved_model.pb")) and os.path.exists(os.path.join(DERM_FOUNDATION_PATH, "variables")) ) if not model_files_exist: print("\n📥 Derm Foundation model not found - downloading from Hugging Face...") success = download_derm_foundation_from_hf(DERM_FOUNDATION_PATH) if not success: print("\n❌ CRITICAL: Failed to download Derm Foundation model!") print(" API will not function correctly.") return else: print("\n✓ Derm Foundation model found locally (using cache)") # Load Derm Foundation model print("\n" + "=" * 80) print("📦 LOADING DERM FOUNDATION MODEL") print("=" * 80) try: print(f"Loading from: {DERM_FOUNDATION_PATH}") gc.collect() derm_model = tf.saved_model.load(DERM_FOUNDATION_PATH) print("✓ Derm Foundation model loaded successfully!") gc.collect() except Exception as e: print(f"✗ Failed to load Derm Foundation model: {str(e)}") import traceback traceback.print_exc() # Load EASI model print("\n" + "=" * 80) print("📦 LOADING EASI PREDICTION MODEL") print("=" * 80) if os.path.exists(EASI_MODEL_PATH): easi_model = DermFoundationNeuralNetwork() success = easi_model.load_model(EASI_MODEL_PATH) if success: print(f"✓ EASI model loaded from: {EASI_MODEL_PATH}") print(f" Available conditions: {len(easi_model.mlb.classes_)}") else: print(f"✗ Failed to load EASI model") easi_model = None else: print(f"✗ EASI model not found at: {EASI_MODEL_PATH}") print(" Make sure trained_model/ folder is included in your Space") # Final status print("\n" + "=" * 80) print("🏁 STARTUP COMPLETE") print("=" * 80) print(f"Derm Foundation Model: {'✓ Loaded' if derm_model else '✗ Failed'}") print(f"EASI Prediction Model: {'✓ Loaded' if easi_model else '✗ Failed'}") print(f"Platform: {deployment_platform}") print("=" * 80) if derm_model and easi_model: print("✅ All systems ready! API is operational.") else: print("⚠️ WARNING: Some models failed to load. API may not work correctly.") print("=" * 80 + "\n") # API Endpoints @app.get("/") async def root(): """Root endpoint with API information""" space_info = { "space_id": os.environ.get("SPACE_ID", "local"), "space_author": os.environ.get("SPACE_AUTHOR_NAME", "unknown"), "space_host": os.environ.get("SPACE_HOST", "localhost") } return { "message": "EASI Severity Prediction API", "version": "2.0.0", "platform": deployment_platform, "space_info": space_info, "status": "operational" if (derm_model and easi_model) else "degraded", "endpoints": { "health": "/health", "predict": "/predict", "conditions": "/conditions", "docs": "/docs", "redoc": "/redoc" }, "documentation": "Visit /docs for interactive API documentation" } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" space_info = None if os.environ.get("SPACE_ID"): space_info = { "space_id": os.environ.get("SPACE_ID"), "space_author": os.environ.get("SPACE_AUTHOR_NAME"), "space_host": os.environ.get("SPACE_HOST") } return { "status": "healthy" if (derm_model is not None and easi_model is not None) else "degraded", "models_loaded": { "derm_foundation": derm_model is not None, "easi_model": easi_model is not None }, "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0, "hf_token_configured": (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")) is not None, "deployment_platform": deployment_platform, "space_info": space_info } @app.get("/conditions", response_model=Dict[str, List[str]]) async def get_conditions(): """Get list of available conditions""" if easi_model is None: raise HTTPException( status_code=503, detail="EASI model not loaded. Check server logs or /health endpoint." ) return { "conditions": easi_model.mlb.classes_.tolist(), "total_count": len(easi_model.mlb.classes_) } @app.post("/predict", response_model=PredictionResponse) async def predict_easi( file: UploadFile = File(..., description="Skin image file (JPG, JPEG, PNG)") ): """ Predict EASI scores from uploaded skin image. - **file**: Image file (JPG, JPEG, PNG) - Returns: EASI scores, component breakdown, and condition predictions """ # Validate models loaded if derm_model is None or easi_model is None: error_detail = [] if derm_model is None: error_detail.append("Derm Foundation model not loaded") if easi_model is None: error_detail.append("EASI model not loaded") raise HTTPException( status_code=503, detail=f"Models not available: {', '.join(error_detail)}. Check /health endpoint for details." ) # Validate file type if not file.content_type or not file.content_type.startswith('image/'): raise HTTPException( status_code=400, detail="File must be an image (JPG, JPEG, PNG). Received: " + str(file.content_type) ) try: # Read and process image image_bytes = await file.read() original_image = Image.open(BytesIO(image_bytes)).convert('RGB') original_size = original_image.size # Process to 448x448 cropped_img = smart_crop_to_square(original_image) processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS) # Generate embedding embedding = generate_derm_foundation_embedding(derm_model, processed_img) # Make prediction predictions = easi_model.predict(embedding) if predictions is None: raise HTTPException(status_code=500, detail="Prediction failed - model returned None") # Calculate EASI scores easi_results, total_easi = calculate_easi_scores(predictions) severity = get_severity_interpretation(total_easi) # Format predicted conditions predicted_conditions = [] for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']): prob = predictions['all_condition_probabilities'][condition] conf = predictions['dermatologist_skin_condition_confidence'][i] weight = predictions['weighted_skin_condition_label'][condition] # Find EASI category easi_category = None easi_contribution = 0 for cat_key, cat_info in easi_results.items(): for contrib in cat_info['contributing_conditions']: if contrib['condition'] == condition: easi_category = cat_info['name'] easi_contribution = contrib['individual_score'] break predicted_conditions.append(ConditionPrediction( condition=condition, probability=float(prob), confidence=float(conf), weight=float(weight), easi_category=easi_category, easi_contribution=easi_contribution )) # Summary statistics summary_stats = { "total_conditions": len(predicted_conditions), "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0, "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0, "total_weight": float(sum(predictions['weighted_skin_condition_label'].values())) } # Format EASI components easi_components_formatted = { component: EASIComponent( name=result['name'], score=result['score'], contributing_conditions=result['contributing_conditions'] ) for component, result in easi_results.items() } return PredictionResponse( success=True, total_easi_score=total_easi, severity_interpretation=severity, easi_components=easi_components_formatted, predicted_conditions=predicted_conditions, summary_statistics=summary_stats, image_info={ "original_size": f"{original_size[0]}x{original_size[1]}", "processed_size": "448x448", "filename": file.filename } ) except HTTPException: raise except Exception as e: import traceback error_traceback = traceback.format_exc() print(f"Error processing image: {str(e)}") print(error_traceback) raise HTTPException( status_code=500, detail=f"Error processing image: {str(e)}" ) @app.exception_handler(HTTPException) async def http_exception_handler(request, exc): """Custom HTTP exception handler""" return JSONResponse( status_code=exc.status_code, content=ErrorResponse( error=exc.detail, detail=str(exc) ).dict() ) @app.exception_handler(Exception) async def general_exception_handler(request, exc): """General exception handler for unexpected errors""" import traceback error_traceback = traceback.format_exc() print(f"Unexpected error: {str(exc)}") print(error_traceback) return JSONResponse( status_code=500, content=ErrorResponse( error="Internal server error", detail=str(exc) ).dict() ) if __name__ == "__main__": import uvicorn print("=" * 80) print("🚀 Starting EASI API Server") print("=" * 80) print("Access the API at: http://localhost:8000") print("Interactive docs: http://localhost:8000/docs") print("=" * 80) uvicorn.run( app, host="0.0.0.0", port=8000, log_level="info" )