Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import random | |
| import numpy as np | |
| import yaml | |
| from pathlib import Path | |
| from io import BytesIO | |
| import random | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import tempfile | |
| import traceback | |
| import functools as ft | |
| import os | |
| import random | |
| import re | |
| from collections import defaultdict | |
| from datetime import datetime, timedelta | |
| from pathlib import Path | |
| import h5py | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| import logging | |
| from Prithvi import * | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="MERRA2 Data Processor", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| dataset_type = st.sidebar.selectbox( | |
| "Select Dataset Type", | |
| options=["MERRA2", "GEOS5"], | |
| index=0 | |
| ) | |
| st.title("MERRA2 Data Processor with PrithviWxC Model") | |
| # Sidebar for file uploads | |
| st.sidebar.header("Upload MERRA2 Data Files") | |
| # File uploader for surface data | |
| uploaded_surface_files = st.sidebar.file_uploader( | |
| "Upload Surface Data Files", | |
| type=["nc", "netcdf"], | |
| accept_multiple_files=True, | |
| key="surface_uploader", | |
| ) | |
| # File uploader for vertical data | |
| uploaded_vertical_files = st.sidebar.file_uploader( | |
| "Upload Vertical Data Files", | |
| type=["nc", "netcdf"], | |
| accept_multiple_files=True, | |
| key="vertical_uploader", | |
| ) | |
| # Optional: Upload config.yaml | |
| uploaded_config = st.sidebar.file_uploader( | |
| "Upload config.yaml", | |
| type=["yaml", "yml"], | |
| key="config_uploader", | |
| ) | |
| # Optional: Upload model weights | |
| uploaded_weights = st.sidebar.file_uploader( | |
| "Upload Model Weights (.pt)", | |
| type=["pt"], | |
| key="weights_uploader", | |
| ) | |
| # Other configurations | |
| st.sidebar.header("Task Configuration") | |
| lead_times = st.sidebar.multiselect( | |
| "Select Lead Times", | |
| options=[12, 24, 36, 48], | |
| default=[12], | |
| ) | |
| input_times = st.sidebar.multiselect( | |
| "Select Input Times", | |
| options=[-6, -12, -18, -24], | |
| default=[-6], | |
| ) | |
| time_range_start = st.sidebar.text_input( | |
| "Start Time (e.g., 2020-01-01T00:00:00)", | |
| value="2020-01-01T00:00:00", | |
| ) | |
| time_range_end = st.sidebar.text_input( | |
| "End Time (e.g., 2020-01-01T23:59:59)", | |
| value="2020-01-01T23:59:59", | |
| ) | |
| time_range = (time_range_start, time_range_end) | |
| # Function to save uploaded files | |
| def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024): | |
| if not uploaded_files: | |
| st.warning(f"No {folder_name} files uploaded.") | |
| return None | |
| # Validate file sizes | |
| for file in uploaded_files: | |
| if file.size > max_size_mb * 1024 * 1024: | |
| st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.") | |
| return None | |
| temp_dir = tempfile.mkdtemp() | |
| with st.spinner(f"Saving {folder_name} files..."): | |
| for uploaded_file in uploaded_files: | |
| file_path = Path(temp_dir) / uploaded_file.name | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| st.success(f"Saved {len(uploaded_files)} {folder_name} files.") | |
| return Path(temp_dir) | |
| # Save uploaded files | |
| surf_dir = save_uploaded_files(uploaded_surface_files, "surface") | |
| vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical") | |
| # Display uploaded files | |
| if surf_dir: | |
| st.sidebar.subheader("Surface Files Uploaded:") | |
| for file in surf_dir.iterdir(): | |
| st.sidebar.write(file.name) | |
| if vert_dir: | |
| st.sidebar.subheader("Vertical Files Uploaded:") | |
| for file in vert_dir.iterdir(): | |
| st.sidebar.write(file.name) | |
| # Handle Climatology Files | |
| st.sidebar.header("Upload Climatology Files (If Missing)") | |
| # Climatology files paths | |
| default_clim_dir = Path("Prithvi-WxC/examples/climatology") | |
| surf_in_scal_path = default_clim_dir / "musigma_surface.nc" | |
| vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" | |
| surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" | |
| vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" | |
| # Check if climatology files exist | |
| clim_files_exist = all( | |
| [ | |
| surf_in_scal_path.exists(), | |
| vert_in_scal_path.exists(), | |
| surf_out_scal_path.exists(), | |
| vert_out_scal_path.exists(), | |
| ] | |
| ) | |
| if not clim_files_exist: | |
| st.sidebar.warning("Climatology files are missing.") | |
| uploaded_clim_surface = st.sidebar.file_uploader( | |
| "Upload Climatology Surface File", | |
| type=["nc", "netcdf"], | |
| key="clim_surface_uploader", | |
| ) | |
| uploaded_clim_vertical = st.sidebar.file_uploader( | |
| "Upload Climatology Vertical File", | |
| type=["nc", "netcdf"], | |
| key="clim_vertical_uploader", | |
| ) | |
| if uploaded_clim_surface and uploaded_clim_vertical: | |
| clim_temp_dir = tempfile.mkdtemp() | |
| clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name | |
| with open(clim_surf_path, "wb") as f: | |
| f.write(uploaded_clim_surface.getbuffer()) | |
| clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name | |
| with open(clim_vert_path, "wb") as f: | |
| f.write(uploaded_clim_vertical.getbuffer()) | |
| st.success("Climatology files uploaded and saved.") | |
| else: | |
| if not (uploaded_clim_surface and uploaded_clim_vertical): | |
| st.warning("Please upload both climatology surface and vertical files.") | |
| else: | |
| clim_surf_path = surf_in_scal_path | |
| clim_vert_path = vert_in_scal_path | |
| # Save uploaded config.yaml | |
| if uploaded_config: | |
| temp_config = tempfile.mktemp(suffix=".yaml") | |
| with open(temp_config, "wb") as f: | |
| f.write(uploaded_config.getbuffer()) | |
| config_path = Path(temp_config) | |
| st.sidebar.success("Config.yaml uploaded and saved.") | |
| else: | |
| # Use default config.yaml path | |
| config_path = Path("Prithvi-WxC/examples/config.yaml") | |
| if not config_path.exists(): | |
| st.sidebar.error("Default config.yaml not found. Please upload a config file.") | |
| st.stop() | |
| # Save uploaded model weights | |
| if uploaded_weights: | |
| temp_weights = tempfile.mktemp(suffix=".pt") | |
| with open(temp_weights, "wb") as f: | |
| f.write(uploaded_weights.getbuffer()) | |
| weights_path = Path(temp_weights) | |
| st.sidebar.success("Model weights uploaded and saved.") | |
| else: | |
| # Use default weights path | |
| weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") | |
| if not weights_path.exists(): | |
| st.sidebar.error("Default model weights not found. Please upload model weights.") | |
| st.stop() | |
| # Button to run inference | |
| if st.sidebar.button("Run Inference"): | |
| # Initialize device | |
| torch.jit.enable_onednn_fusion(True) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| st.write(f"Using device: {torch.cuda.get_device_name()}") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = True | |
| else: | |
| device = torch.device("cpu") | |
| st.write("Using device: CPU") | |
| # Set random seeds | |
| random.seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(42) | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # Define variables and parameters | |
| surface_vars = [ | |
| "EFLUX", | |
| "GWETROOT", | |
| "HFLUX", | |
| "LAI", | |
| "LWGAB", | |
| "LWGEM", | |
| "LWTUP", | |
| "PS", | |
| "QV2M", | |
| "SLP", | |
| "SWGNT", | |
| "SWTNT", | |
| "T2M", | |
| "TQI", | |
| "TQL", | |
| "TQV", | |
| "TS", | |
| "U10M", | |
| "V10M", | |
| "Z0M", | |
| ] | |
| static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
| vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
| levels = [ | |
| 34.0, | |
| 39.0, | |
| 41.0, | |
| 43.0, | |
| 44.0, | |
| 45.0, | |
| 48.0, | |
| 51.0, | |
| 53.0, | |
| 56.0, | |
| 63.0, | |
| 68.0, | |
| 71.0, | |
| 72.0, | |
| ] | |
| padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
| residual = "climate" | |
| masking_mode = "local" | |
| decoder_shifting = True | |
| masking_ratio = 0.99 | |
| positional_encoding = "fourier" | |
| # Initialize Dataset | |
| try: | |
| with st.spinner("Initializing dataset..."): | |
| # Validate climatology files | |
| if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): | |
| st.error("Climatology files are missing. Please upload both surface and vertical climatology files.") | |
| st.stop() | |
| dataset = Merra2Dataset( | |
| time_range=time_range, | |
| lead_times=lead_times, | |
| input_times=input_times, | |
| data_path_surface=Path("Prithvi-WxC/examples/merra-2"), | |
| data_path_vertical=Path("Prithvi-WxC/examples/merra-2"), | |
| climatology_path_surface=Path("Prithvi-WxC/examples/climatology"), | |
| climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"), | |
| surface_vars=surface_vars, | |
| static_surface_vars=static_surface_vars, | |
| vertical_vars=vertical_vars, | |
| levels=levels, | |
| positional_encoding=positional_encoding, | |
| ) | |
| assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
| st.success("Dataset initialized successfully.") | |
| except Exception as e: | |
| st.error("Error initializing dataset:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Load scalers | |
| try: | |
| with st.spinner("Loading scalers..."): | |
| # Assuming the scaler paths are the same as climatology paths | |
| surf_in_scal_path = clim_surf_path | |
| vert_in_scal_path = clim_vert_path | |
| surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" | |
| vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" | |
| # Check if output scaler files exist | |
| if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): | |
| st.error("Anomaly variance scaler files are missing.") | |
| st.stop() | |
| in_mu, in_sig = input_scalers( | |
| surface_vars, | |
| vertical_vars, | |
| levels, | |
| surf_in_scal_path, | |
| vert_in_scal_path, | |
| ) | |
| output_sig = output_scalers( | |
| surface_vars, | |
| vertical_vars, | |
| levels, | |
| surf_out_scal_path, | |
| vert_out_scal_path, | |
| ) | |
| static_mu, static_sig = static_input_scalers( | |
| surf_in_scal_path, | |
| static_surface_vars, | |
| ) | |
| st.success("Scalers loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading scalers:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Load configuration | |
| try: | |
| with st.spinner("Loading configuration..."): | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| # Validate config | |
| required_params = [ | |
| "in_channels", "input_size_time", "in_channels_static", | |
| "input_scalers_epsilon", "static_input_scalers_epsilon", | |
| "n_lats_px", "n_lons_px", "patch_size_px", | |
| "mask_unit_size_px", "embed_dim", "n_blocks_encoder", | |
| "n_blocks_decoder", "mlp_multiplier", "n_heads", | |
| "dropout", "drop_path", "parameter_dropout" | |
| ] | |
| missing_params = [param for param in required_params if param not in config.get("params", {})] | |
| if missing_params: | |
| st.error(f"Missing configuration parameters: {missing_params}") | |
| st.stop() | |
| st.success("Configuration loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading configuration:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Initialize the model | |
| try: | |
| with st.spinner("Initializing model..."): | |
| model = PrithviWxC( | |
| in_channels=config["params"]["in_channels"], | |
| input_size_time=config["params"]["input_size_time"], | |
| in_channels_static=config["params"]["in_channels_static"], | |
| input_scalers_mu=in_mu, | |
| input_scalers_sigma=in_sig, | |
| input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
| static_input_scalers_mu=static_mu, | |
| static_input_scalers_sigma=static_sig, | |
| static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
| output_scalers=output_sig**0.5, | |
| n_lats_px=config["params"]["n_lats_px"], | |
| n_lons_px=config["params"]["n_lons_px"], | |
| patch_size_px=config["params"]["patch_size_px"], | |
| mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
| mask_ratio_inputs=masking_ratio, | |
| embed_dim=config["params"]["embed_dim"], | |
| n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
| n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
| mlp_multiplier=config["params"]["mlp_multiplier"], | |
| n_heads=config["params"]["n_heads"], | |
| dropout=config["params"]["dropout"], | |
| drop_path=config["params"]["drop_path"], | |
| parameter_dropout=config["params"]["parameter_dropout"], | |
| residual=residual, | |
| masking_mode=masking_mode, | |
| decoder_shifting=decoder_shifting, | |
| positional_encoding=positional_encoding, | |
| checkpoint_encoder=[], | |
| checkpoint_decoder=[], | |
| ) | |
| st.success("Model initialized successfully.") | |
| except Exception as e: | |
| st.error("Error initializing model:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Load model weights | |
| try: | |
| with st.spinner("Loading model weights..."): | |
| state_dict = torch.load(weights_path, map_location=device) | |
| if "model_state" in state_dict: | |
| state_dict = state_dict["model_state"] | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(device) | |
| st.success("Model weights loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading model weights:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Prepare data batch | |
| try: | |
| with st.spinner("Preparing data batch..."): | |
| data = next(iter(dataset)) | |
| batch = preproc([data], padding) | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(device) | |
| st.success("Data batch prepared successfully.") | |
| except Exception as e: | |
| st.error("Error preparing data batch:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Run inference | |
| try: | |
| with st.spinner("Running model inference..."): | |
| rng_state_1 = torch.get_rng_state() | |
| with torch.no_grad(): | |
| model.eval() | |
| out = model(batch) | |
| st.success("Model inference completed successfully.") | |
| except Exception as e: | |
| st.error("Error during model inference:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Display output | |
| st.header("Inference Results") | |
| st.write(out) # Adjust based on the structure of 'out' | |
| # Optionally, provide download links or visualizations | |
| # For example, if 'out' contains tensors or dataframes: | |
| # st.write("Output Tensor:", out["some_key"].cpu().numpy()) | |
| else: | |
| st.info("Please upload the necessary files and click 'Run Inference' to start.") | |