Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,8 +2,9 @@ import torch
|
|
| 2 |
import tempfile
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
-
import surya # Import to register custom
|
| 6 |
-
from
|
|
|
|
| 7 |
from peft import PeftModel # For LoRA adapter
|
| 8 |
from sunpy.net import Fido, attrs as a
|
| 9 |
import astropy.units as u
|
|
@@ -28,19 +29,19 @@ def cleanup_temp():
|
|
| 28 |
shutil.rmtree(dir_path, ignore_errors=True)
|
| 29 |
cleanup_temp() # Run once at start
|
| 30 |
|
| 31 |
-
# Surya model setup (
|
| 32 |
BASE_MODEL_ID = "nasa-ibm-ai4science/Surya-1.0"
|
| 33 |
ADAPTER_MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
|
| 34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
|
| 36 |
-
# Load base model
|
| 37 |
-
base_model =
|
| 38 |
|
| 39 |
# Load LoRA adapter for AR segmentation
|
| 40 |
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID).to(device)
|
| 41 |
|
| 42 |
-
# Load
|
| 43 |
-
processor =
|
| 44 |
|
| 45 |
# Historical observed values for May 2024 Gannon Storm
|
| 46 |
HISTORICAL_DENSITY_INCREASE = 6.0 # Up to 6x at 400 km
|
|
|
|
| 2 |
import tempfile
|
| 3 |
import os
|
| 4 |
import shutil
|
| 5 |
+
import surya # Import to register custom modules
|
| 6 |
+
from surya.models.surya_model import SuryaModel # Custom model class
|
| 7 |
+
from surya.processors import SuryaProcessor # Custom processor
|
| 8 |
from peft import PeftModel # For LoRA adapter
|
| 9 |
from sunpy.net import Fido, attrs as a
|
| 10 |
import astropy.units as u
|
|
|
|
| 29 |
shutil.rmtree(dir_path, ignore_errors=True)
|
| 30 |
cleanup_temp() # Run once at start
|
| 31 |
|
| 32 |
+
# Surya model setup (custom class + LoRA adapter)
|
| 33 |
BASE_MODEL_ID = "nasa-ibm-ai4science/Surya-1.0"
|
| 34 |
ADAPTER_MODEL_ID = "nasa-ibm-ai4science/ar_segmentation_surya"
|
| 35 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
|
| 37 |
+
# Load base model using custom SuryaModel class
|
| 38 |
+
base_model = SuryaModel.from_pretrained(BASE_MODEL_ID).to(device)
|
| 39 |
|
| 40 |
# Load LoRA adapter for AR segmentation
|
| 41 |
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID).to(device)
|
| 42 |
|
| 43 |
+
# Load custom processor
|
| 44 |
+
processor = SuryaProcessor.from_pretrained(BASE_MODEL_ID)
|
| 45 |
|
| 46 |
# Historical observed values for May 2024 Gannon Storm
|
| 47 |
HISTORICAL_DENSITY_INCREASE = 6.0 # Up to 6x at 400 km
|