FassikaF commited on
Commit
16fc1a6
Β·
verified Β·
1 Parent(s): df9016a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -2,8 +2,9 @@ import torch
2
  import tempfile
3
  import os
4
  import shutil
5
- import surya # Import to register custom 'surya' model_type with transformers
6
- from transformers import AutoModelForImageSegmentation, AutoProcessor
 
7
  from peft import PeftModel # For LoRA adapter
8
  from sunpy.net import Fido, attrs as a
9
  import astropy.units as u
@@ -28,19 +29,19 @@ def cleanup_temp():
28
  shutil.rmtree(dir_path, ignore_errors=True)
29
  cleanup_temp() # Run once at start
30
 
31
- # Surya model setup (base + LoRA adapter; registration via 'import surya')
32
  BASE_MODEL_ID = "nasa-ibm-ai4science/Surya-1.0"
33
  ADAPTER_MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
- # Load base model (now recognized after surya import)
37
- base_model = AutoModelForImageSegmentation.from_pretrained(BASE_MODEL_ID).to(device)
38
 
39
  # Load LoRA adapter for AR segmentation
40
  model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID).to(device)
41
 
42
- # Load processor (from base; assumes standard image processor)
43
- processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
44
 
45
  # Historical observed values for May 2024 Gannon Storm
46
  HISTORICAL_DENSITY_INCREASE = 6.0 # Up to 6x at 400 km
 
2
  import tempfile
3
  import os
4
  import shutil
5
+ import surya # Import to register custom modules
6
+ from surya.models.surya_model import SuryaModel # Custom model class
7
+ from surya.processors import SuryaProcessor # Custom processor
8
  from peft import PeftModel # For LoRA adapter
9
  from sunpy.net import Fido, attrs as a
10
  import astropy.units as u
 
29
  shutil.rmtree(dir_path, ignore_errors=True)
30
  cleanup_temp() # Run once at start
31
 
32
+ # Surya model setup (custom class + LoRA adapter)
33
  BASE_MODEL_ID = "nasa-ibm-ai4science/Surya-1.0"
34
  ADAPTER_MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
 
37
+ # Load base model using custom SuryaModel class
38
+ base_model = SuryaModel.from_pretrained(BASE_MODEL_ID).to(device)
39
 
40
  # Load LoRA adapter for AR segmentation
41
  model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID).to(device)
42
 
43
+ # Load custom processor
44
+ processor = SuryaProcessor.from_pretrained(BASE_MODEL_ID)
45
 
46
  # Historical observed values for May 2024 Gannon Storm
47
  HISTORICAL_DENSITY_INCREASE = 6.0 # Up to 6x at 400 km