RezinWiz commited on
Commit
291b2db
Β·
verified Β·
1 Parent(s): 7b883e9

multiple images intake

Browse files
Files changed (1) hide show
  1. api.py +334 -134
api.py CHANGED
@@ -1,12 +1,17 @@
1
  """
2
- EASI Severity Prediction REST API
3
- ==================================
4
-
5
  FastAPI-based REST API for predicting EASI scores from dermatological images.
6
- Optimized for Hugging Face Spaces deployment.
 
 
 
 
 
7
 
8
  Endpoints:
9
- - POST /predict - Upload image and get EASI predictions
 
10
  - GET /health - Health check endpoint
11
  - GET /conditions - Get list of available conditions
12
  - GET /docs - Interactive API documentation
@@ -16,12 +21,6 @@ pip install fastapi uvicorn python-multipart pillow tensorflow numpy pandas hugg
16
 
17
  Run locally:
18
  uvicorn api:app --host 0.0.0.0 --port 8000 --reload
19
-
20
- Deploy to HF Spaces:
21
- 1. Create Space with Docker SDK
22
- 2. Upload this file + Dockerfile + requirements.txt + trained_model/
23
- 3. Accept terms for google/derm-foundation
24
- 4. Space auto-builds!
25
  """
26
 
27
  import os
@@ -30,6 +29,8 @@ import logging
30
  from typing import List, Dict, Any, Optional
31
  from io import BytesIO
32
  from pathlib import Path
 
 
33
 
34
  # Suppress warnings
35
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
@@ -52,31 +53,34 @@ import pickle
52
  import pandas as pd
53
  from huggingface_hub import hf_hub_download
54
 
 
 
 
 
 
 
 
 
55
  # Initialize FastAPI app
56
  app = FastAPI(
57
  title="EASI Severity Prediction API",
58
- description="REST API for predicting EASI scores from skin images. Deployed on Hugging Face Spaces.",
59
- version="2.0.0",
60
  docs_url="/docs",
61
  redoc_url="/redoc"
62
  )
63
 
64
- # CORS middleware for Flutter web/mobile
65
  app.add_middleware(
66
  CORSMiddleware,
67
- allow_origins=["*"], # In production, specify your Flutter app domain
68
  allow_credentials=True,
69
  allow_methods=["*"],
70
  allow_headers=["*"],
71
  )
72
 
73
- # Configuration
74
- HF_REPO_ID = "google/derm-foundation"
75
- DERM_FOUNDATION_PATH = "./derm_foundation/"
76
- EASI_MODEL_PATH = './trained_model/easi_severity_model_derm_foundation_individual.pkl'
77
-
78
- # HF Spaces automatically injects HF_TOKEN for authenticated users
79
- HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
80
 
81
  # Response Models
82
  class ConditionPrediction(BaseModel):
@@ -101,12 +105,23 @@ class PredictionResponse(BaseModel):
101
  summary_statistics: Dict[str, float]
102
  image_info: Dict[str, Any]
103
 
 
 
 
 
 
 
 
 
 
104
  class HealthResponse(BaseModel):
105
  status: str
106
  models_loaded: Dict[str, bool]
107
  available_conditions: int
108
  hf_token_configured: bool
109
  deployment_platform: str
 
 
110
  space_info: Optional[Dict[str, str]] = None
111
 
112
  class ErrorResponse(BaseModel):
@@ -134,24 +149,14 @@ class DermFoundationNeuralNetwork:
134
  self.confidence_scaler = model_data['confidence_scaler']
135
  self.weighted_scaler = model_data['weighted_scaler']
136
 
137
- # Get the original keras model path from pickle
138
  keras_model_path = model_data['keras_model_path']
139
 
140
- # If the path doesn't exist, try looking in the same directory as the pickle file
141
  if not os.path.exists(keras_model_path):
142
  print(f"Original keras path not found: {keras_model_path}")
143
-
144
- # Get the directory where the pickle file is located
145
  pickle_dir = os.path.dirname(os.path.abspath(filepath))
146
-
147
- # Extract just the filename, handling both Windows and Unix paths
148
- # Replace backslashes with forward slashes first
149
  normalized_path = keras_model_path.replace('\\', '/')
150
  keras_filename = normalized_path.split('/')[-1]
151
-
152
  print(f"Extracted filename: {keras_filename}")
153
-
154
- # Try looking for it in the same directory as the pickle
155
  alternative_path = os.path.join(pickle_dir, keras_filename)
156
  print(f"Trying alternative path: {alternative_path}")
157
 
@@ -160,16 +165,10 @@ class DermFoundationNeuralNetwork:
160
  print(f"βœ“ Found keras model at: {keras_model_path}")
161
  else:
162
  print(f"βœ— Keras model not found at alternative path either")
163
- print(f"Files in {pickle_dir}:")
164
- try:
165
- print(os.listdir(pickle_dir))
166
- except:
167
- pass
168
  return False
169
  else:
170
  print(f"βœ“ Found keras model at original path: {keras_model_path}")
171
 
172
- # Load the keras model
173
  self.model = tf.keras.models.load_model(keras_model_path)
174
  print(f"βœ“ Keras model loaded successfully")
175
  return True
@@ -253,7 +252,6 @@ class DermFoundationNeuralNetwork:
253
  def download_derm_foundation_from_hf(output_dir):
254
  """Download Derm Foundation model from Hugging Face Hub"""
255
  try:
256
- # Get token - on HF Spaces it's auto-injected
257
  hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
258
 
259
  print("=" * 80)
@@ -264,13 +262,9 @@ def download_derm_foundation_from_hf(output_dir):
264
  print(f"βœ“ HF Token found (length: {len(hf_token)})")
265
  else:
266
  print("⚠ No HF Token found - attempting anonymous download")
267
- print(" Note: If this fails, you need to:")
268
- print(" 1. Accept terms at https://huggingface.co/google/derm-foundation")
269
- print(" 2. Add HF_TOKEN to Space secrets")
270
 
271
  os.makedirs(output_dir, exist_ok=True)
272
 
273
- # Files to download
274
  files_to_download = [
275
  "saved_model.pb",
276
  "variables/variables.data-00000-of-00001",
@@ -290,7 +284,6 @@ def download_derm_foundation_from_hf(output_dir):
290
  resume_download=True
291
  )
292
 
293
- # Verify file exists and get size
294
  if os.path.exists(downloaded_path):
295
  file_size_mb = os.path.getsize(downloaded_path) / (1024 * 1024)
296
  print(f"βœ“ Downloaded successfully ({file_size_mb:.2f} MB)")
@@ -313,11 +306,6 @@ def download_derm_foundation_from_hf(output_dir):
313
  print("βœ— ERROR DOWNLOADING MODEL")
314
  print("=" * 80)
315
  print(f"Error: {str(e)}")
316
- print("\nTroubleshooting steps:")
317
- print("1. Ensure you've accepted the model terms at:")
318
- print(" https://huggingface.co/google/derm-foundation")
319
- print("2. Add HF_TOKEN to your Space secrets (Settings β†’ Repository secrets)")
320
- print("3. Make sure your token has 'Read access to gated repos' permission")
321
  import traceback
322
  traceback.print_exc()
323
  return False
@@ -477,7 +465,115 @@ def generate_derm_foundation_embedding(model, image):
477
 
478
  return embedding_vector
479
  except Exception as e:
480
- raise HTTPException(status_code=500, detail=f"Error generating embedding: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
 
482
 
483
  # Global model instances
@@ -491,12 +587,11 @@ async def load_models():
491
  """Load models on startup"""
492
  global derm_model, easi_model, deployment_platform
493
 
494
- # Force garbage collection before starting
495
  import gc
496
  gc.collect()
497
 
498
  print("\n" + "=" * 80)
499
- print("πŸš€ STARTING EASI API ON HUGGING FACE SPACES")
500
  print("=" * 80)
501
 
502
  # Detect if running on HF Spaces
@@ -513,6 +608,8 @@ async def load_models():
513
  deployment_platform = "local"
514
  print("πŸ“ Running locally")
515
 
 
 
516
  print("=" * 80)
517
 
518
  # Check HF Token
@@ -521,7 +618,6 @@ async def load_models():
521
  print(f"βœ“ HF Token configured (length: {len(hf_token)})")
522
  else:
523
  print("⚠ No HF Token found")
524
- print(" If model download fails, add HF_TOKEN to Space secrets")
525
 
526
  print("=" * 80)
527
 
@@ -537,7 +633,6 @@ async def load_models():
537
 
538
  if not success:
539
  print("\n❌ CRITICAL: Failed to download Derm Foundation model!")
540
- print(" API will not function correctly.")
541
  return
542
  else:
543
  print("\nβœ“ Derm Foundation model found locally (using cache)")
@@ -577,7 +672,6 @@ async def load_models():
577
  easi_model = None
578
  else:
579
  print(f"βœ— EASI model not found at: {EASI_MODEL_PATH}")
580
- print(" Make sure trained_model/ folder is included in your Space")
581
 
582
  # Final status
583
  print("\n" + "=" * 80)
@@ -585,13 +679,14 @@ async def load_models():
585
  print("=" * 80)
586
  print(f"Derm Foundation Model: {'βœ“ Loaded' if derm_model else 'βœ— Failed'}")
587
  print(f"EASI Prediction Model: {'βœ“ Loaded' if easi_model else 'βœ— Failed'}")
 
588
  print(f"Platform: {deployment_platform}")
589
  print("=" * 80)
590
 
591
  if derm_model and easi_model:
592
  print("βœ… All systems ready! API is operational.")
593
  else:
594
- print("⚠️ WARNING: Some models failed to load. API may not work correctly.")
595
 
596
  print("=" * 80 + "\n")
597
 
@@ -608,14 +703,20 @@ async def root():
608
  }
609
 
610
  return {
611
- "message": "EASI Severity Prediction API",
612
- "version": "2.0.0",
613
  "platform": deployment_platform,
614
  "space_info": space_info,
615
  "status": "operational" if (derm_model and easi_model) else "degraded",
 
 
 
 
 
616
  "endpoints": {
617
  "health": "/health",
618
- "predict": "/predict",
 
619
  "conditions": "/conditions",
620
  "docs": "/docs",
621
  "redoc": "/redoc"
@@ -644,6 +745,8 @@ async def health_check():
644
  "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0,
645
  "hf_token_configured": (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")) is not None,
646
  "deployment_platform": deployment_platform,
 
 
647
  "space_info": space_info
648
  }
649
 
@@ -695,85 +798,19 @@ async def predict_easi(
695
  )
696
 
697
  try:
698
- # Read and process image
699
  image_bytes = await file.read()
700
- original_image = Image.open(BytesIO(image_bytes)).convert('RGB')
701
- original_size = original_image.size
702
 
703
- # Process to 448x448
704
- cropped_img = smart_crop_to_square(original_image)
705
- processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS)
706
 
707
- # Generate embedding
708
- embedding = generate_derm_foundation_embedding(derm_model, processed_img)
709
-
710
- # Make prediction
711
- predictions = easi_model.predict(embedding)
712
-
713
- if predictions is None:
714
- raise HTTPException(status_code=500, detail="Prediction failed - model returned None")
715
-
716
- # Calculate EASI scores
717
- easi_results, total_easi = calculate_easi_scores(predictions)
718
- severity = get_severity_interpretation(total_easi)
719
-
720
- # Format predicted conditions
721
- predicted_conditions = []
722
- for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']):
723
- prob = predictions['all_condition_probabilities'][condition]
724
- conf = predictions['dermatologist_skin_condition_confidence'][i]
725
- weight = predictions['weighted_skin_condition_label'][condition]
726
-
727
- # Find EASI category
728
- easi_category = None
729
- easi_contribution = 0
730
- for cat_key, cat_info in easi_results.items():
731
- for contrib in cat_info['contributing_conditions']:
732
- if contrib['condition'] == condition:
733
- easi_category = cat_info['name']
734
- easi_contribution = contrib['individual_score']
735
- break
736
-
737
- predicted_conditions.append(ConditionPrediction(
738
- condition=condition,
739
- probability=float(prob),
740
- confidence=float(conf),
741
- weight=float(weight),
742
- easi_category=easi_category,
743
- easi_contribution=easi_contribution
744
- ))
745
-
746
- # Summary statistics
747
- summary_stats = {
748
- "total_conditions": len(predicted_conditions),
749
- "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0,
750
- "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0,
751
- "total_weight": float(sum(predictions['weighted_skin_condition_label'].values()))
752
- }
753
-
754
- # Format EASI components
755
- easi_components_formatted = {
756
- component: EASIComponent(
757
- name=result['name'],
758
- score=result['score'],
759
- contributing_conditions=result['contributing_conditions']
760
  )
761
- for component, result in easi_results.items()
762
- }
763
 
764
- return PredictionResponse(
765
- success=True,
766
- total_easi_score=total_easi,
767
- severity_interpretation=severity,
768
- easi_components=easi_components_formatted,
769
- predicted_conditions=predicted_conditions,
770
- summary_statistics=summary_stats,
771
- image_info={
772
- "original_size": f"{original_size[0]}x{original_size[1]}",
773
- "processed_size": "448x448",
774
- "filename": file.filename
775
- }
776
- )
777
 
778
  except HTTPException:
779
  raise
@@ -789,6 +826,166 @@ async def predict_easi(
789
  )
790
 
791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  @app.exception_handler(HTTPException)
793
  async def http_exception_handler(request, exc):
794
  """Custom HTTP exception handler"""
@@ -822,7 +1019,10 @@ if __name__ == "__main__":
822
  import uvicorn
823
 
824
  print("=" * 80)
825
- print("πŸš€ Starting EASI API Server")
 
 
 
826
  print("=" * 80)
827
  print("Access the API at: http://localhost:8000")
828
  print("Interactive docs: http://localhost:8000/docs")
 
1
  """
2
+ EASI Severity Prediction REST API with Batch Processing
3
+ ========================================================
 
4
  FastAPI-based REST API for predicting EASI scores from dermatological images.
5
+ Now supports both single and batch image processing!
6
+
7
+ New Features:
8
+ - POST /predict/batch - Process multiple images in one request
9
+ - Configurable max batch size and timeout
10
+ - Parallel processing for faster batch predictions
11
 
12
  Endpoints:
13
+ - POST /predict - Upload single image and get EASI predictions
14
+ - POST /predict/batch - Upload multiple images (up to 10 at once)
15
  - GET /health - Health check endpoint
16
  - GET /conditions - Get list of available conditions
17
  - GET /docs - Interactive API documentation
 
21
 
22
  Run locally:
23
  uvicorn api:app --host 0.0.0.0 --port 8000 --reload
 
 
 
 
 
 
24
  """
25
 
26
  import os
 
29
  from typing import List, Dict, Any, Optional
30
  from io import BytesIO
31
  from pathlib import Path
32
+ import asyncio
33
+ from concurrent.futures import ThreadPoolExecutor
34
 
35
  # Suppress warnings
36
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
53
  import pandas as pd
54
  from huggingface_hub import hf_hub_download
55
 
56
+ # Configuration
57
+ MAX_BATCH_SIZE = 10 # Maximum images per batch request
58
+ BATCH_TIMEOUT = 300 # Timeout in seconds for batch processing
59
+ HF_REPO_ID = "google/derm-foundation"
60
+ DERM_FOUNDATION_PATH = "./derm_foundation/"
61
+ EASI_MODEL_PATH = './trained_model/easi_severity_model_derm_foundation_individual.pkl'
62
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
63
+
64
  # Initialize FastAPI app
65
  app = FastAPI(
66
  title="EASI Severity Prediction API",
67
+ description="REST API for predicting EASI scores from skin images. Supports single and batch processing.",
68
+ version="2.1.0",
69
  docs_url="/docs",
70
  redoc_url="/redoc"
71
  )
72
 
73
+ # CORS middleware
74
  app.add_middleware(
75
  CORSMiddleware,
76
+ allow_origins=["*"],
77
  allow_credentials=True,
78
  allow_methods=["*"],
79
  allow_headers=["*"],
80
  )
81
 
82
+ # Thread pool for parallel processing
83
+ executor = ThreadPoolExecutor(max_workers=4)
 
 
 
 
 
84
 
85
  # Response Models
86
  class ConditionPrediction(BaseModel):
 
105
  summary_statistics: Dict[str, float]
106
  image_info: Dict[str, Any]
107
 
108
+ class BatchPredictionResponse(BaseModel):
109
+ success: bool
110
+ total_images_processed: int
111
+ successful_predictions: int
112
+ failed_predictions: int
113
+ results: List[Optional[PredictionResponse]]
114
+ errors: List[Optional[str]]
115
+ processing_time_seconds: float
116
+
117
  class HealthResponse(BaseModel):
118
  status: str
119
  models_loaded: Dict[str, bool]
120
  available_conditions: int
121
  hf_token_configured: bool
122
  deployment_platform: str
123
+ batch_processing_enabled: bool
124
+ max_batch_size: int
125
  space_info: Optional[Dict[str, str]] = None
126
 
127
  class ErrorResponse(BaseModel):
 
149
  self.confidence_scaler = model_data['confidence_scaler']
150
  self.weighted_scaler = model_data['weighted_scaler']
151
 
 
152
  keras_model_path = model_data['keras_model_path']
153
 
 
154
  if not os.path.exists(keras_model_path):
155
  print(f"Original keras path not found: {keras_model_path}")
 
 
156
  pickle_dir = os.path.dirname(os.path.abspath(filepath))
 
 
 
157
  normalized_path = keras_model_path.replace('\\', '/')
158
  keras_filename = normalized_path.split('/')[-1]
 
159
  print(f"Extracted filename: {keras_filename}")
 
 
160
  alternative_path = os.path.join(pickle_dir, keras_filename)
161
  print(f"Trying alternative path: {alternative_path}")
162
 
 
165
  print(f"βœ“ Found keras model at: {keras_model_path}")
166
  else:
167
  print(f"βœ— Keras model not found at alternative path either")
 
 
 
 
 
168
  return False
169
  else:
170
  print(f"βœ“ Found keras model at original path: {keras_model_path}")
171
 
 
172
  self.model = tf.keras.models.load_model(keras_model_path)
173
  print(f"βœ“ Keras model loaded successfully")
174
  return True
 
252
  def download_derm_foundation_from_hf(output_dir):
253
  """Download Derm Foundation model from Hugging Face Hub"""
254
  try:
 
255
  hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
256
 
257
  print("=" * 80)
 
262
  print(f"βœ“ HF Token found (length: {len(hf_token)})")
263
  else:
264
  print("⚠ No HF Token found - attempting anonymous download")
 
 
 
265
 
266
  os.makedirs(output_dir, exist_ok=True)
267
 
 
268
  files_to_download = [
269
  "saved_model.pb",
270
  "variables/variables.data-00000-of-00001",
 
284
  resume_download=True
285
  )
286
 
 
287
  if os.path.exists(downloaded_path):
288
  file_size_mb = os.path.getsize(downloaded_path) / (1024 * 1024)
289
  print(f"βœ“ Downloaded successfully ({file_size_mb:.2f} MB)")
 
306
  print("βœ— ERROR DOWNLOADING MODEL")
307
  print("=" * 80)
308
  print(f"Error: {str(e)}")
 
 
 
 
 
309
  import traceback
310
  traceback.print_exc()
311
  return False
 
465
 
466
  return embedding_vector
467
  except Exception as e:
468
+ raise Exception(f"Error generating embedding: {str(e)}")
469
+
470
+
471
+ def process_single_image_sync(image_bytes: bytes, filename: str) -> Dict[str, Any]:
472
+ """
473
+ Synchronous function to process a single image.
474
+ Returns dict with 'success', 'result', and 'error' keys.
475
+ """
476
+ try:
477
+ # Read and process image
478
+ original_image = Image.open(BytesIO(image_bytes)).convert('RGB')
479
+ original_size = original_image.size
480
+
481
+ # Process to 448x448
482
+ cropped_img = smart_crop_to_square(original_image)
483
+ processed_img = cropped_img.resize((448, 448), Image.Resampling.LANCZOS)
484
+
485
+ # Generate embedding
486
+ embedding = generate_derm_foundation_embedding(derm_model, processed_img)
487
+
488
+ # Make prediction
489
+ predictions = easi_model.predict(embedding)
490
+
491
+ if predictions is None:
492
+ return {
493
+ 'success': False,
494
+ 'result': None,
495
+ 'error': "Prediction failed - model returned None"
496
+ }
497
+
498
+ # Calculate EASI scores
499
+ easi_results, total_easi = calculate_easi_scores(predictions)
500
+ severity = get_severity_interpretation(total_easi)
501
+
502
+ # Format predicted conditions
503
+ predicted_conditions = []
504
+ for i, condition in enumerate(predictions['dermatologist_skin_condition_on_label_name']):
505
+ prob = predictions['all_condition_probabilities'][condition]
506
+ conf = predictions['dermatologist_skin_condition_confidence'][i]
507
+ weight = predictions['weighted_skin_condition_label'][condition]
508
+
509
+ # Find EASI category
510
+ easi_category = None
511
+ easi_contribution = 0
512
+ for cat_key, cat_info in easi_results.items():
513
+ for contrib in cat_info['contributing_conditions']:
514
+ if contrib['condition'] == condition:
515
+ easi_category = cat_info['name']
516
+ easi_contribution = contrib['individual_score']
517
+ break
518
+
519
+ predicted_conditions.append(ConditionPrediction(
520
+ condition=condition,
521
+ probability=float(prob),
522
+ confidence=float(conf),
523
+ weight=float(weight),
524
+ easi_category=easi_category,
525
+ easi_contribution=easi_contribution
526
+ ))
527
+
528
+ # Summary statistics
529
+ summary_stats = {
530
+ "total_conditions": len(predicted_conditions),
531
+ "average_confidence": float(np.mean(predictions['dermatologist_skin_condition_confidence'])) if predicted_conditions else 0.0,
532
+ "average_weight": float(np.mean(list(predictions['weighted_skin_condition_label'].values()))) if predicted_conditions else 0.0,
533
+ "total_weight": float(sum(predictions['weighted_skin_condition_label'].values()))
534
+ }
535
+
536
+ # Format EASI components
537
+ easi_components_formatted = {
538
+ component: EASIComponent(
539
+ name=result['name'],
540
+ score=result['score'],
541
+ contributing_conditions=result['contributing_conditions']
542
+ )
543
+ for component, result in easi_results.items()
544
+ }
545
+
546
+ result = PredictionResponse(
547
+ success=True,
548
+ total_easi_score=total_easi,
549
+ severity_interpretation=severity,
550
+ easi_components=easi_components_formatted,
551
+ predicted_conditions=predicted_conditions,
552
+ summary_statistics=summary_stats,
553
+ image_info={
554
+ "original_size": f"{original_size[0]}x{original_size[1]}",
555
+ "processed_size": "448x448",
556
+ "filename": filename
557
+ }
558
+ )
559
+
560
+ return {
561
+ 'success': True,
562
+ 'result': result,
563
+ 'error': None
564
+ }
565
+
566
+ except Exception as e:
567
+ import traceback
568
+ error_traceback = traceback.format_exc()
569
+ print(f"Error processing image {filename}: {str(e)}")
570
+ print(error_traceback)
571
+
572
+ return {
573
+ 'success': False,
574
+ 'result': None,
575
+ 'error': str(e)
576
+ }
577
 
578
 
579
  # Global model instances
 
587
  """Load models on startup"""
588
  global derm_model, easi_model, deployment_platform
589
 
 
590
  import gc
591
  gc.collect()
592
 
593
  print("\n" + "=" * 80)
594
+ print("πŸš€ STARTING EASI API WITH BATCH PROCESSING")
595
  print("=" * 80)
596
 
597
  # Detect if running on HF Spaces
 
608
  deployment_platform = "local"
609
  print("πŸ“ Running locally")
610
 
611
+ print(f"πŸ”’ Max batch size: {MAX_BATCH_SIZE}")
612
+ print(f"⏱️ Batch timeout: {BATCH_TIMEOUT}s")
613
  print("=" * 80)
614
 
615
  # Check HF Token
 
618
  print(f"βœ“ HF Token configured (length: {len(hf_token)})")
619
  else:
620
  print("⚠ No HF Token found")
 
621
 
622
  print("=" * 80)
623
 
 
633
 
634
  if not success:
635
  print("\n❌ CRITICAL: Failed to download Derm Foundation model!")
 
636
  return
637
  else:
638
  print("\nβœ“ Derm Foundation model found locally (using cache)")
 
672
  easi_model = None
673
  else:
674
  print(f"βœ— EASI model not found at: {EASI_MODEL_PATH}")
 
675
 
676
  # Final status
677
  print("\n" + "=" * 80)
 
679
  print("=" * 80)
680
  print(f"Derm Foundation Model: {'βœ“ Loaded' if derm_model else 'βœ— Failed'}")
681
  print(f"EASI Prediction Model: {'βœ“ Loaded' if easi_model else 'βœ— Failed'}")
682
+ print(f"Batch Processing: βœ“ Enabled (max {MAX_BATCH_SIZE} images)")
683
  print(f"Platform: {deployment_platform}")
684
  print("=" * 80)
685
 
686
  if derm_model and easi_model:
687
  print("βœ… All systems ready! API is operational.")
688
  else:
689
+ print("⚠️ WARNING: Some models failed to load.")
690
 
691
  print("=" * 80 + "\n")
692
 
 
703
  }
704
 
705
  return {
706
+ "message": "EASI Severity Prediction API with Batch Processing",
707
+ "version": "2.1.0",
708
  "platform": deployment_platform,
709
  "space_info": space_info,
710
  "status": "operational" if (derm_model and easi_model) else "degraded",
711
+ "batch_processing": {
712
+ "enabled": True,
713
+ "max_batch_size": MAX_BATCH_SIZE,
714
+ "timeout_seconds": BATCH_TIMEOUT
715
+ },
716
  "endpoints": {
717
  "health": "/health",
718
+ "predict": "/predict (single image)",
719
+ "predict_batch": "/predict/batch (multiple images)",
720
  "conditions": "/conditions",
721
  "docs": "/docs",
722
  "redoc": "/redoc"
 
745
  "available_conditions": len(easi_model.mlb.classes_) if easi_model else 0,
746
  "hf_token_configured": (os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")) is not None,
747
  "deployment_platform": deployment_platform,
748
+ "batch_processing_enabled": True,
749
+ "max_batch_size": MAX_BATCH_SIZE,
750
  "space_info": space_info
751
  }
752
 
 
798
  )
799
 
800
  try:
801
+ # Read image bytes
802
  image_bytes = await file.read()
 
 
803
 
804
+ # Process image synchronously
805
+ result = process_single_image_sync(image_bytes, file.filename)
 
806
 
807
+ if not result['success']:
808
+ raise HTTPException(
809
+ status_code=500,
810
+ detail=f"Error processing image: {result['error']}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  )
 
 
812
 
813
+ return result['result']
 
 
 
 
 
 
 
 
 
 
 
 
814
 
815
  except HTTPException:
816
  raise
 
826
  )
827
 
828
 
829
+ @app.post("/predict/batch", response_model=BatchPredictionResponse)
830
+ async def predict_easi_batch(
831
+ files: List[UploadFile] = File(..., description=f"Multiple skin image files (max {MAX_BATCH_SIZE})")
832
+ ):
833
+ """
834
+ Predict EASI scores from multiple uploaded skin images in parallel.
835
+
836
+ - **files**: List of image files (JPG, JPEG, PNG) - max 10 images per request
837
+ - Returns: Batch results with individual predictions and errors
838
+
839
+ **Example Usage (Python):**
840
+ ```python
841
+ import requests
842
+
843
+ files = [
844
+ ('files', open('image1.jpg', 'rb')),
845
+ ('files', open('image2.jpg', 'rb')),
846
+ ('files', open('image3.jpg', 'rb'))
847
+ ]
848
+
849
+ response = requests.post('http://localhost:8000/predict/batch', files=files)
850
+ results = response.json()
851
+ ```
852
+
853
+ **Example Usage (cURL):**
854
+ ```bash
855
+ curl -X POST "http://localhost:8000/predict/batch" \
856
857
858
859
+ ```
860
+ """
861
+
862
+ import time
863
+ start_time = time.time()
864
+
865
+ # Validate models loaded
866
+ if derm_model is None or easi_model is None:
867
+ error_detail = []
868
+ if derm_model is None:
869
+ error_detail.append("Derm Foundation model not loaded")
870
+ if easi_model is None:
871
+ error_detail.append("EASI model not loaded")
872
+
873
+ raise HTTPException(
874
+ status_code=503,
875
+ detail=f"Models not available: {', '.join(error_detail)}. Check /health endpoint."
876
+ )
877
+
878
+ # Validate batch size
879
+ num_files = len(files)
880
+ if num_files == 0:
881
+ raise HTTPException(
882
+ status_code=400,
883
+ detail="No files provided. Please upload at least one image."
884
+ )
885
+
886
+ if num_files > MAX_BATCH_SIZE:
887
+ raise HTTPException(
888
+ status_code=400,
889
+ detail=f"Too many files. Maximum batch size is {MAX_BATCH_SIZE}, received {num_files}."
890
+ )
891
+
892
+ print(f"\nπŸ”„ Processing batch of {num_files} images...")
893
+
894
+ # Validate file types and read all files
895
+ image_data = []
896
+ for idx, file in enumerate(files):
897
+ if not file.content_type or not file.content_type.startswith('image/'):
898
+ raise HTTPException(
899
+ status_code=400,
900
+ detail=f"File {idx+1} ('{file.filename}') is not an image. Received: {file.content_type}"
901
+ )
902
+
903
+ try:
904
+ image_bytes = await file.read()
905
+ image_data.append({
906
+ 'bytes': image_bytes,
907
+ 'filename': file.filename,
908
+ 'index': idx
909
+ })
910
+ except Exception as e:
911
+ raise HTTPException(
912
+ status_code=400,
913
+ detail=f"Error reading file {idx+1} ('{file.filename}'): {str(e)}"
914
+ )
915
+
916
+ # Process images in parallel using thread pool
917
+ try:
918
+ loop = asyncio.get_event_loop()
919
+
920
+ # Create tasks for parallel processing
921
+ tasks = [
922
+ loop.run_in_executor(
923
+ executor,
924
+ process_single_image_sync,
925
+ img['bytes'],
926
+ img['filename']
927
+ )
928
+ for img in image_data
929
+ ]
930
+
931
+ # Wait for all tasks with timeout
932
+ results = await asyncio.wait_for(
933
+ asyncio.gather(*tasks, return_exceptions=True),
934
+ timeout=BATCH_TIMEOUT
935
+ )
936
+
937
+ except asyncio.TimeoutError:
938
+ raise HTTPException(
939
+ status_code=504,
940
+ detail=f"Batch processing timeout after {BATCH_TIMEOUT} seconds. Try reducing batch size."
941
+ )
942
+ except Exception as e:
943
+ import traceback
944
+ traceback.print_exc()
945
+ raise HTTPException(
946
+ status_code=500,
947
+ detail=f"Error during batch processing: {str(e)}"
948
+ )
949
+
950
+ # Collect results and errors
951
+ prediction_results = []
952
+ error_messages = []
953
+ successful_count = 0
954
+ failed_count = 0
955
+
956
+ for idx, result in enumerate(results):
957
+ if isinstance(result, Exception):
958
+ # Handle exception during processing
959
+ prediction_results.append(None)
960
+ error_messages.append(f"Exception: {str(result)}")
961
+ failed_count += 1
962
+ print(f" βœ— Image {idx+1} failed: {str(result)}")
963
+ elif result['success']:
964
+ prediction_results.append(result['result'])
965
+ error_messages.append(None)
966
+ successful_count += 1
967
+ print(f" βœ“ Image {idx+1} processed successfully")
968
+ else:
969
+ prediction_results.append(None)
970
+ error_messages.append(result['error'])
971
+ failed_count += 1
972
+ print(f" βœ— Image {idx+1} failed: {result['error']}")
973
+
974
+ processing_time = time.time() - start_time
975
+
976
+ print(f"βœ… Batch complete: {successful_count} successful, {failed_count} failed in {processing_time:.2f}s\n")
977
+
978
+ return BatchPredictionResponse(
979
+ success=True,
980
+ total_images_processed=num_files,
981
+ successful_predictions=successful_count,
982
+ failed_predictions=failed_count,
983
+ results=prediction_results,
984
+ errors=error_messages,
985
+ processing_time_seconds=round(processing_time, 2)
986
+ )
987
+
988
+
989
  @app.exception_handler(HTTPException)
990
  async def http_exception_handler(request, exc):
991
  """Custom HTTP exception handler"""
 
1019
  import uvicorn
1020
 
1021
  print("=" * 80)
1022
+ print("πŸš€ Starting EASI API Server with Batch Processing")
1023
+ print("=" * 80)
1024
+ print(f"Max batch size: {MAX_BATCH_SIZE} images")
1025
+ print(f"Batch timeout: {BATCH_TIMEOUT} seconds")
1026
  print("=" * 80)
1027
  print("Access the API at: http://localhost:8000")
1028
  print("Interactive docs: http://localhost:8000/docs")