ajsbsd commited on
Commit
3d488b6
·
verified ·
1 Parent(s): c41cdf5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -4,14 +4,11 @@ from neuralop.models import FNO
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
7
- # import requests # <--- NO LONGER NEEDED for Zenodo download
8
- # from tqdm import tqdm # <--- NO LONGER NEEDED for Zenodo download
9
  import spaces
10
- from huggingface_hub import hf_hub_download # <--- ADD THIS IMPORT
11
 
12
  # --- Configuration ---
13
  MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
14
- # Updated: Hugging Face Dataset/Model ID and filename
15
  HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
16
  HF_DATASET_FILENAME = "navier_stokes_2d.pt"
17
 
@@ -19,7 +16,7 @@ HF_DATASET_FILENAME = "navier_stokes_2d.pt"
19
  MODEL = None
20
  FULL_DATASET_X = None
21
 
22
- # --- Function to Download Dataset (MODIFIED to use hf_hub_download) ---
23
  def download_file_from_hf_hub(repo_id, filename):
24
  """Downloads a file from Hugging Face Hub."""
25
  print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
@@ -33,7 +30,7 @@ def download_file_from_hf_hub(repo_id, filename):
33
  raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
34
 
35
 
36
- # --- 1. Model Loading Function (No change from last successful CUDA fix) ---
37
  def load_model():
38
  """Loads the pre-trained FNO model to CPU."""
39
  global MODEL
@@ -41,19 +38,18 @@ def load_model():
41
  print("Loading FNO model to CPU...")
42
  try:
43
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
44
- MODEL.eval()
45
  print("Model loaded successfully to CPU.")
46
  except Exception as e:
47
  print(f"Error loading model: {e}")
48
  raise gr.Error(f"Failed to load model: {e}")
49
  return MODEL
50
 
51
- # --- 2. Dataset Loading Function (MODIFIED) ---
52
  def load_dataset():
53
  """Downloads and loads the initial conditions dataset from HF Hub."""
54
  global FULL_DATASET_X
55
  if FULL_DATASET_X is None:
56
- # Call the new HF Hub download function
57
  local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
58
  print("Loading dataset from local file...")
59
  try:
@@ -70,38 +66,43 @@ def load_dataset():
70
  raise gr.Error(f"Failed to load dataset from local file: {e}")
71
  return FULL_DATASET_X
72
 
73
- # --- 3. Inference Function for Gradio (No changes needed here) ---
74
  @spaces.GPU()
75
  def run_inference(sample_index: int):
76
  """
77
  Performs inference for a selected sample index from the dataset.
 
78
  Returns two Matplotlib figures: one for input, one for output.
79
  """
80
- model = load_model()
81
- dataset = load_dataset()
 
 
82
 
83
- if torch.cuda.is_available() and next(model.parameters()).device == torch.device('cpu'):
84
- model.cuda()
85
- print("Model moved to GPU within run_inference.")
 
 
 
 
86
 
87
  if not (0 <= sample_index < dataset.shape[0]):
88
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
89
 
90
- single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1)
91
-
92
- if torch.cuda.is_available():
93
- single_initial_condition = single_initial_condition.cuda()
94
- print("Input moved to GPU.")
95
- else:
96
- print("CUDA not available. Input remains on CPU.")
97
 
98
  print(f"Running inference for sample index {sample_index}...")
99
- with torch.no_grad():
100
- predicted_solution = model(single_initial_condition)
101
 
 
102
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
103
  output_numpy = predicted_solution.squeeze().cpu().numpy()
104
 
 
105
  fig_input, ax_input = plt.subplots()
106
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
107
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
@@ -147,8 +148,10 @@ with gr.Blocks() as demo:
147
  )
148
 
149
  def load_initial_data_and_predict():
 
150
  load_model()
151
  load_dataset()
 
152
  return run_inference(0)
153
 
154
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  import os
 
 
7
  import spaces
8
+ from huggingface_hub import hf_hub_download
9
 
10
  # --- Configuration ---
11
  MODEL_PATH = "fno_ckpt_single_res" # This model file still needs to be in your Space's repo
 
12
  HF_DATASET_REPO_ID = "ajsbsd/navier-stokes-2d-dataset" # Your new repo ID
13
  HF_DATASET_FILENAME = "navier_stokes_2d.pt"
14
 
 
16
  MODEL = None
17
  FULL_DATASET_X = None
18
 
19
+ # --- Function to Download Dataset from HF Hub ---
20
  def download_file_from_hf_hub(repo_id, filename):
21
  """Downloads a file from Hugging Face Hub."""
22
  print(f"Downloading {filename} from {repo_id} on Hugging Face Hub...")
 
30
  raise gr.Error(f"Failed to download dataset from Hugging Face Hub: {e}")
31
 
32
 
33
+ # --- 1. Model Loading Function (Loads to CPU, device transfer handled in run_inference) ---
34
  def load_model():
35
  """Loads the pre-trained FNO model to CPU."""
36
  global MODEL
 
38
  print("Loading FNO model to CPU...")
39
  try:
40
  MODEL = torch.load(MODEL_PATH, weights_only=False, map_location='cpu')
41
+ MODEL.eval() # Set to evaluation mode
42
  print("Model loaded successfully to CPU.")
43
  except Exception as e:
44
  print(f"Error loading model: {e}")
45
  raise gr.Error(f"Failed to load model: {e}")
46
  return MODEL
47
 
48
+ # --- 2. Dataset Loading Function ---
49
  def load_dataset():
50
  """Downloads and loads the initial conditions dataset from HF Hub."""
51
  global FULL_DATASET_X
52
  if FULL_DATASET_X is None:
 
53
  local_dataset_path = download_file_from_hf_hub(HF_DATASET_REPO_ID, HF_DATASET_FILENAME)
54
  print("Loading dataset from local file...")
55
  try:
 
66
  raise gr.Error(f"Failed to load dataset from local file: {e}")
67
  return FULL_DATASET_X
68
 
69
+ # --- 3. Inference Function for Gradio (MODIFIED: Explicit device handling) ---
70
  @spaces.GPU()
71
  def run_inference(sample_index: int):
72
  """
73
  Performs inference for a selected sample index from the dataset.
74
+ Ensures model and input are on the correct device (GPU).
75
  Returns two Matplotlib figures: one for input, one for output.
76
  """
77
+ # Determine the target device (GPU if available, else CPU)
78
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
+
80
+ model = load_model() # Model is initially loaded to CPU
81
 
82
+ # Move model to the correct device ONLY when inside the @spaces.GPU() decorated function
83
+ # and only if it's not already on the target device.
84
+ if next(model.parameters()).device != device:
85
+ model.to(device)
86
+ print(f"Model moved to {device} within run_inference.")
87
+
88
+ dataset = load_dataset()
89
 
90
  if not (0 <= sample_index < dataset.shape[0]):
91
  raise gr.Error(f"Sample index out of range. Please choose between 0 and {dataset.shape[0]-1}.")
92
 
93
+ # Move input tensor to the correct device directly
94
+ single_initial_condition = dataset[sample_index:sample_index+1, :, :].unsqueeze(1).to(device)
95
+ print(f"Input moved to {device}.")
 
 
 
 
96
 
97
  print(f"Running inference for sample index {sample_index}...")
98
+ with torch.no_grad(): # Disable gradient calculations for inference
99
+ predicted_solution = model(single_initial_condition) # This is where the error occurred before
100
 
101
+ # Move results back to CPU for plotting with Matplotlib
102
  input_numpy = single_initial_condition.squeeze().cpu().numpy()
103
  output_numpy = predicted_solution.squeeze().cpu().numpy()
104
 
105
+ # Create Matplotlib figures
106
  fig_input, ax_input = plt.subplots()
107
  im_input = ax_input.imshow(input_numpy, cmap='viridis')
108
  ax_input.set_title(f"Initial Condition (Sample {sample_index})")
 
148
  )
149
 
150
  def load_initial_data_and_predict():
151
+ # These functions are called during main process startup (CPU)
152
  load_model()
153
  load_dataset()
154
+ # The actual inference call here will ensure GPU utilization via @spaces.GPU()
155
  return run_inference(0)
156
 
157
  demo.load(load_initial_data_and_predict, inputs=None, outputs=[input_image_plot, output_image_plot])