Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import random | |
| import numpy as np | |
| import yaml | |
| from pathlib import Path | |
| import tempfile | |
| import traceback | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| from Prithvi import * # Ensure this import includes your model and dataset classes | |
| import xarray as xr | |
| from aurora import Batch, Metadata | |
| from aurora import Aurora, rollout | |
| import logging | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import cartopy.crs as ccrs | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Function to save uploaded files to temporary files and store paths in session_state | |
| def save_uploaded_files(uploaded_files): | |
| if 'temp_file_paths' not in st.session_state: | |
| st.session_state.temp_file_paths = [] | |
| for uploaded_file in uploaded_files: | |
| suffix = os.path.splitext(uploaded_file.name)[1] | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| temp_file.write(uploaded_file.read()) | |
| temp_file.close() | |
| st.session_state.temp_file_paths.append(temp_file.name) | |
| # Cached function to load dataset | |
| def load_dataset(file_paths): | |
| try: | |
| ds = xr.open_mfdataset(file_paths, combine='by_coords').load() | |
| return ds | |
| except Exception as e: | |
| st.error("Error loading dataset:") | |
| st.error(traceback.format_exc()) | |
| return None | |
| # Set page configuration | |
| st.set_page_config( | |
| page_title="Weather Data Processor", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| # Create a header with two columns: one for the title and one for the model selector | |
| header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed | |
| with header_col1: | |
| st.title("🌦️ Weather & Climate Data Processor and Forecaster") | |
| with header_col2: | |
| st.markdown("### Select a Model") | |
| selected_model = st.selectbox( | |
| "", | |
| options=["Aurora", "Climax", "Prithvi", "LSTM"], | |
| index=0, | |
| key="model_selector", | |
| help="Select the model you want to use for processing the data." | |
| ) | |
| st.write("---") # Horizontal separator | |
| # --- Layout: Two Columns --- | |
| left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed | |
| with left_col: | |
| st.header("🔧 Configuration") | |
| # --- Dynamic Configuration Based on Selected Model --- | |
| def get_model_configuration(model_name): | |
| if model_name == "Prithvi": | |
| st.subheader("Prithvi Model Configuration") | |
| # Prithvi-specific configuration inputs | |
| param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) | |
| param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") | |
| # Add other Prithvi-specific parameters here | |
| config = { | |
| "param1": param1, | |
| "param2": param2, | |
| # Include other parameters as needed | |
| } | |
| # --- Prithvi-Specific File Uploads --- | |
| st.markdown("### Upload Data Files for Prithvi Model") | |
| # File uploader for surface data | |
| uploaded_surface_files = st.file_uploader( | |
| "Upload Surface Data Files", | |
| type=["nc", "netcdf"], | |
| accept_multiple_files=True, | |
| key="surface_uploader", | |
| ) | |
| # File uploader for vertical data | |
| uploaded_vertical_files = st.file_uploader( | |
| "Upload Vertical Data Files", | |
| type=["nc", "netcdf"], | |
| accept_multiple_files=True, | |
| key="vertical_uploader", | |
| ) | |
| # Handle Climatology Files | |
| st.markdown("### Upload Climatology Files (If Missing)") | |
| # Climatology files paths | |
| default_clim_dir = Path("Prithvi-WxC/examples/climatology") | |
| surf_in_scal_path = default_clim_dir / "musigma_surface.nc" | |
| vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" | |
| surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" | |
| vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" | |
| # Check if climatology files exist | |
| clim_files_exist = all( | |
| [ | |
| surf_in_scal_path.exists(), | |
| vert_in_scal_path.exists(), | |
| surf_out_scal_path.exists(), | |
| vert_out_scal_path.exists(), | |
| ] | |
| ) | |
| if not clim_files_exist: | |
| st.warning("Climatology files are missing.") | |
| uploaded_clim_surface = st.file_uploader( | |
| "Upload Climatology Surface File", | |
| type=["nc", "netcdf"], | |
| key="clim_surface_uploader", | |
| ) | |
| uploaded_clim_vertical = st.file_uploader( | |
| "Upload Climatology Vertical File", | |
| type=["nc", "netcdf"], | |
| key="clim_vertical_uploader", | |
| ) | |
| # Process uploaded climatology files | |
| if uploaded_clim_surface and uploaded_clim_vertical: | |
| clim_temp_dir = tempfile.mkdtemp() | |
| clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name | |
| with open(clim_surf_path, "wb") as f: | |
| f.write(uploaded_clim_surface.getbuffer()) | |
| clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name | |
| with open(clim_vert_path, "wb") as f: | |
| f.write(uploaded_clim_vertical.getbuffer()) | |
| st.success("Climatology files uploaded and saved.") | |
| else: | |
| st.warning("Please upload both climatology surface and vertical files.") | |
| else: | |
| clim_surf_path = surf_in_scal_path | |
| clim_vert_path = vert_in_scal_path | |
| # Optional: Upload config.yaml | |
| uploaded_config = st.file_uploader( | |
| "Upload config.yaml", | |
| type=["yaml", "yml"], | |
| key="config_uploader", | |
| ) | |
| if uploaded_config: | |
| temp_config = tempfile.mktemp(suffix=".yaml") | |
| with open(temp_config, "wb") as f: | |
| f.write(uploaded_config.getbuffer()) | |
| config_path = Path(temp_config) | |
| st.success("Config.yaml uploaded and saved.") | |
| else: | |
| # Use default config.yaml path | |
| config_path = Path("Prithvi-WxC/examples/config.yaml") | |
| if not config_path.exists(): | |
| st.error("Default config.yaml not found. Please upload a config file.") | |
| st.stop() | |
| # Optional: Upload model weights | |
| uploaded_weights = st.file_uploader( | |
| "Upload Model Weights (.pt)", | |
| type=["pt"], | |
| key="weights_uploader", | |
| ) | |
| if uploaded_weights: | |
| temp_weights = tempfile.mktemp(suffix=".pt") | |
| with open(temp_weights, "wb") as f: | |
| f.write(uploaded_weights.getbuffer()) | |
| weights_path = Path(temp_weights) | |
| st.success("Model weights uploaded and saved.") | |
| else: | |
| # Use default weights path | |
| weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") | |
| if not weights_path.exists(): | |
| st.error("Default model weights not found. Please upload model weights.") | |
| st.stop() | |
| return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path | |
| else: | |
| # For other models, provide a simple file uploader | |
| st.subheader(f"{model_name} Model Data Upload") | |
| st.markdown("### Drag and Drop Your Data Files Here") | |
| uploaded_files = st.file_uploader( | |
| f"Upload Data Files for {model_name}", | |
| accept_multiple_files=True, | |
| key=f"{model_name.lower()}_uploader", | |
| type=["nc", "netcdf", "nc4"], | |
| ) | |
| return uploaded_files | |
| # Retrieve model-specific configuration and files | |
| if selected_model == "Prithvi": | |
| config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model) | |
| else: | |
| uploaded_files = get_model_configuration(selected_model) | |
| st.write("---") # Horizontal separator | |
| # --- Run Inference Button --- | |
| if st.button("🚀 Run Inference"): | |
| with right_col: | |
| st.header("📈 Inference Progress & Visualization") | |
| # Initialize device | |
| try: | |
| torch.jit.enable_onednn_fusion(True) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| st.write(f"Using device: **{torch.cuda.get_device_name()}**") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = True | |
| else: | |
| device = torch.device("cpu") | |
| st.write("Using device: **CPU**") | |
| except Exception as e: | |
| st.error("Error initializing device:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # Set random seeds | |
| try: | |
| random.seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(42) | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| except Exception as e: | |
| st.error("Error setting random seeds:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # # Define variables and parameters based on dataset type | |
| # if dataset_type == "MERRA2": | |
| # surface_vars = [ | |
| # "EFLUX", | |
| # "GWETROOT", | |
| # "HFLUX", | |
| # "LAI", | |
| # "LWGAB", | |
| # "LWGEM", | |
| # "LWTUP", | |
| # "PS", | |
| # "QV2M", | |
| # "SLP", | |
| # "SWGNT", | |
| # "SWTNT", | |
| # "T2M", | |
| # "TQI", | |
| # "TQL", | |
| # "TQV", | |
| # "TS", | |
| # "U10M", | |
| # "V10M", | |
| # "Z0M", | |
| # ] | |
| # static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
| # vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
| # levels = [ | |
| # 34.0, | |
| # 39.0, | |
| # 41.0, | |
| # 43.0, | |
| # 44.0, | |
| # 45.0, | |
| # 48.0, | |
| # 51.0, | |
| # 53.0, | |
| # 56.0, | |
| # 63.0, | |
| # 68.0, | |
| # 71.0, | |
| # 72.0, | |
| # ] | |
| # elif dataset_type == "GEOS5": | |
| # # Define GEOS5 specific variables | |
| # surface_vars = [ | |
| # "GEOS5_EFLUX", | |
| # "GEOS5_GWETROOT", | |
| # "GEOS5_HFLUX", | |
| # "GEOS5_LAI", | |
| # "GEOS5_LWGAB", | |
| # "GEOS5_LWGEM", | |
| # "GEOS5_LWTUP", | |
| # "GEOS5_PS", | |
| # "GEOS5_QV2M", | |
| # "GEOS5_SLP", | |
| # "GEOS5_SWGNT", | |
| # "GEOS5_SWTNT", | |
| # "GEOS5_T2M", | |
| # "GEOS5_TQI", | |
| # "GEOS5_TQL", | |
| # "GEOS5_TQV", | |
| # "GEOS5_TS", | |
| # "GEOS5_U10M", | |
| # "GEOS5_V10M", | |
| # "GEOS5_Z0M", | |
| # ] | |
| # static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"] | |
| # vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"] | |
| # levels = [ | |
| # # Define levels specific to GEOS5 if different | |
| # 10.0, | |
| # 20.0, | |
| # 30.0, | |
| # 40.0, | |
| # 50.0, | |
| # 60.0, | |
| # 70.0, | |
| # 80.0, | |
| # ] | |
| # else: | |
| # st.error("Unsupported dataset type selected.") | |
| # st.stop() | |
| padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
| residual = "climate" | |
| masking_mode = "local" | |
| decoder_shifting = True | |
| masking_ratio = 0.99 | |
| positional_encoding = "fourier" | |
| # --- Initialize Dataset --- | |
| try: | |
| with st.spinner("Initializing dataset..."): | |
| if selected_model == "Prithvi": | |
| pass | |
| # # Validate climatology files | |
| # if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): | |
| # st.error("Climatology files are missing. Please upload both climatology surface and vertical files.") | |
| # st.stop() | |
| # dataset = Merra2Dataset( | |
| # time_range=time_range, | |
| # lead_times=lead_times, | |
| # input_times=input_times, | |
| # data_path_surface=surf_dir, | |
| # data_path_vertical=vert_dir, | |
| # climatology_path_surface=clim_surf_path, | |
| # climatology_path_vertical=clim_vert_path, | |
| # surface_vars=surface_vars, | |
| # static_surface_vars=static_surface_vars, | |
| # vertical_vars=vertical_vars, | |
| # levels=levels, | |
| # positional_encoding=positional_encoding, | |
| # ) | |
| # assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
| elif selected_model == "Aurora": | |
| # TODO just temporary, replace this | |
| if uploaded_files: | |
| temp_file_paths = [] # List to store paths of temporary files | |
| try: | |
| # Save each uploaded file to a temporary file | |
| save_uploaded_files(uploaded_files) | |
| ds = load_dataset(st.session_state.temp_file_paths) | |
| # Now, use xarray to open the multiple files | |
| if ds: | |
| st.success("Files successfully loaded!") | |
| st.session_state.ds_subset = ds | |
| # print(ds) | |
| ds = ds.fillna(ds.mean()) | |
| desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] | |
| # Ensure that the 'lev' dimension exists | |
| if 'lev' not in ds.dims: | |
| raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.") | |
| # Define the _prepare function | |
| def _prepare(x: np.ndarray, i: int) -> torch.Tensor: | |
| # Select previous and current time steps | |
| selected = x[[i - 6, i]] | |
| # Add a batch dimension | |
| selected = selected[None] | |
| # Ensure data is contiguous | |
| selected = selected.copy() | |
| # Convert to PyTorch tensor | |
| return torch.from_numpy(selected) | |
| # Adjust latitudes and longitudes | |
| lat = ds.lat.values * -1 | |
| lon = ds.lon.values + 180 | |
| # Subset the dataset to only include the desired pressure levels | |
| ds_subset = ds.sel(lev=desired_levels, method="nearest") | |
| # Verify that all desired levels are present | |
| present_levels = ds_subset.lev.values | |
| missing_levels = set(desired_levels) - set(present_levels) | |
| if missing_levels: | |
| raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}") | |
| # Extract pressure levels after subsetting | |
| lev = ds_subset.lev.values # Pressure levels in hPa | |
| # Prepare surface variables at 1000 hPa | |
| try: | |
| lev_index_1000 = np.where(lev == 1000)[0][0] | |
| except IndexError: | |
| raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.") | |
| T_surface = ds_subset.T.isel(lev=lev_index_1000).compute() | |
| U_surface = ds_subset.U.isel(lev=lev_index_1000).compute() | |
| V_surface = ds_subset.V.isel(lev=lev_index_1000).compute() | |
| SLP = ds_subset.SLP.compute() | |
| # Reorder static variables (selecting the first time index to remove the time dimension) | |
| PHIS = ds_subset.PHIS.isel(time=0).compute() | |
| # Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa | |
| atmos_levels = [int(level) for level in lev if level != 1000] | |
| T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute() | |
| U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute() | |
| V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute() | |
| # Select time index | |
| num_times = ds_subset.time.size | |
| i = 6 # Adjust as needed (1 <= i < num_times) | |
| if i >= num_times or i < 1: | |
| raise IndexError("Time index i is out of bounds.") | |
| time_values = ds_subset.time.values | |
| current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime) | |
| # Prepare surface variables | |
| surf_vars = { | |
| "2t": _prepare(T_surface.values, i), # Two-meter temperature | |
| "10u": _prepare(U_surface.values, i), # Ten-meter eastward wind | |
| "10v": _prepare(V_surface.values, i), # Ten-meter northward wind | |
| "msl": _prepare(SLP.values, i), # Mean sea-level pressure | |
| } | |
| # Prepare static variables (now 2D tensors) | |
| static_vars = { | |
| "z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w) | |
| # Add 'lsm' and 'slt' if available and needed | |
| } | |
| # Prepare atmospheric variables | |
| atmos_vars = { | |
| "t": _prepare(T_atm.values, i), # Temperature at desired levels | |
| "u": _prepare(U_atm.values, i), # Eastward wind at desired levels | |
| "v": _prepare(V_atm.values, i), # Southward wind at desired levels | |
| } | |
| # Define metadata | |
| metadata = Metadata( | |
| lat=torch.from_numpy(lat.copy()), | |
| lon=torch.from_numpy(lon.copy()), | |
| time=(current_time,), | |
| atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels | |
| ) | |
| # Create the Batch object | |
| batch = Batch( | |
| surf_vars=surf_vars, | |
| static_vars=static_vars, | |
| atmos_vars=atmos_vars, | |
| metadata=metadata | |
| ) # Display the dataset or perform further processing | |
| st.session_state['batch'] = batch | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| # finally: | |
| # # Clean up: Remove temporary files | |
| # for path in temp_file_paths: | |
| # try: | |
| # os.remove(path) | |
| # except Exception as e: | |
| # st.warning(f"Could not delete temp file {path}: {e}") | |
| else: | |
| # For other models, implement their specific dataset initialization | |
| # Placeholder: Replace with actual dataset initialization for other models | |
| dataset = None # Replace with actual dataset | |
| st.warning("Dataset initialization for this model is not implemented yet.") | |
| st.stop() | |
| st.success("Dataset initialized successfully.") | |
| except Exception as e: | |
| st.error("Error initializing dataset:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Load Scalers --- | |
| try: | |
| with st.spinner("Loading scalers..."): | |
| if selected_model == "Prithvi": | |
| pass | |
| # # Assuming the scaler paths are the same as climatology paths | |
| # surf_in_scal_path = clim_surf_path | |
| # vert_in_scal_path = clim_vert_path | |
| # surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" | |
| # vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" | |
| # # Check if output scaler files exist | |
| # if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): | |
| # st.error("Anomaly variance scaler files are missing.") | |
| # st.stop() | |
| # in_mu, in_sig = input_scalers( | |
| # surface_vars, | |
| # vertical_vars, | |
| # levels, | |
| # surf_in_scal_path, | |
| # vert_in_scal_path, | |
| # ) | |
| # output_sig = output_scalers( | |
| # surface_vars, | |
| # vertical_vars, | |
| # levels, | |
| # surf_out_scal_path, | |
| # vert_out_scal_path, | |
| # ) | |
| # static_mu, static_sig = static_input_scalers( | |
| # surf_in_scal_path, | |
| # static_surface_vars, | |
| # ) | |
| else: | |
| # Load scalers for other models if applicable | |
| # Placeholder: Replace with actual scaler loading for other models | |
| in_mu, in_sig = None, None | |
| output_sig = None | |
| static_mu, static_sig = None, None | |
| st.success("Scalers loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading scalers:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Load Configuration --- | |
| try: | |
| with st.spinner("Loading configuration..."): | |
| if selected_model == "Prithvi": | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| # Validate config | |
| required_params = [ | |
| "in_channels", "input_size_time", "in_channels_static", | |
| "input_scalers_epsilon", "static_input_scalers_epsilon", | |
| "n_lats_px", "n_lons_px", "patch_size_px", | |
| "mask_unit_size_px", "embed_dim", "n_blocks_encoder", | |
| "n_blocks_decoder", "mlp_multiplier", "n_heads", | |
| "dropout", "drop_path", "parameter_dropout" | |
| ] | |
| missing_params = [param for param in required_params if param not in config.get("params", {})] | |
| if missing_params: | |
| st.error(f"Missing configuration parameters: {missing_params}") | |
| st.stop() | |
| else: | |
| # Load configuration for other models if applicable | |
| # Placeholder: Replace with actual configuration loading for other models | |
| config = {} | |
| st.success("Configuration loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading configuration:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Initialize the Model --- | |
| try: | |
| with st.spinner("Initializing model..."): | |
| if selected_model == "Prithvi": | |
| model = PrithviWxC( | |
| in_channels=config["params"]["in_channels"], | |
| input_size_time=config["params"]["input_size_time"], | |
| in_channels_static=config["params"]["in_channels_static"], | |
| input_scalers_mu=in_mu, | |
| input_scalers_sigma=in_sig, | |
| input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
| static_input_scalers_mu=static_mu, | |
| static_input_scalers_sigma=static_sig, | |
| static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
| output_scalers=output_sig**0.5, | |
| n_lats_px=config["params"]["n_lats_px"], | |
| n_lons_px=config["params"]["n_lons_px"], | |
| patch_size_px=config["params"]["patch_size_px"], | |
| mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
| mask_ratio_inputs=masking_ratio, | |
| embed_dim=config["params"]["embed_dim"], | |
| n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
| n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
| mlp_multiplier=config["params"]["mlp_multiplier"], | |
| n_heads=config["params"]["n_heads"], | |
| dropout=config["params"]["dropout"], | |
| drop_path=config["params"]["drop_path"], | |
| parameter_dropout=config["params"]["parameter_dropout"], | |
| residual=residual, | |
| masking_mode=masking_mode, | |
| decoder_shifting=decoder_shifting, | |
| positional_encoding=positional_encoding, | |
| checkpoint_encoder=[], | |
| checkpoint_decoder=[], | |
| ) | |
| elif selected_model == "Aurora": | |
| pass | |
| else: | |
| # Initialize other models here | |
| # Placeholder: Replace with actual model initialization for other models | |
| model = None | |
| st.warning("Model initialization for this model is not implemented yet.") | |
| st.stop() | |
| # model.to(device) | |
| st.success("Model initialized successfully.") | |
| except Exception as e: | |
| st.error("Error initializing model:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Load Model Weights --- | |
| try: | |
| with st.spinner("Loading model weights..."): | |
| if selected_model == "Prithvi": | |
| state_dict = torch.load(weights_path, map_location=device) | |
| if "model_state" in state_dict: | |
| state_dict = state_dict["model_state"] | |
| model.load_state_dict(state_dict, strict=True) | |
| model.to(device) | |
| else: | |
| # Load weights for other models if applicable | |
| # Placeholder: Replace with actual weight loading for other models | |
| pass | |
| st.success("Model weights loaded successfully.") | |
| except Exception as e: | |
| st.error("Error loading model weights:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Prepare Data Batch --- | |
| try: | |
| with st.spinner("Preparing data batch..."): | |
| if selected_model == "Prithvi": | |
| data = next(iter(dataset)) | |
| batch = preproc([data], padding) | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(device) | |
| elif selected_model == "Aurora": | |
| batch = batch.regrid(res=0.25) | |
| else: | |
| # Prepare data batch for other models | |
| # Placeholder: Replace with actual data preparation for other models | |
| batch = None | |
| st.success("Data batch prepared successfully.") | |
| except Exception as e: | |
| st.error("Error preparing data batch:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Run Inference --- | |
| try: | |
| with st.spinner("Running model inference..."): | |
| if selected_model == "Prithvi": | |
| model.eval() | |
| with torch.no_grad(): | |
| out = model(batch) | |
| elif selected_model == "Aurora": | |
| model = Aurora(use_lora=False) | |
| # model = Aurora() | |
| model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") | |
| # model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt") | |
| model.eval() | |
| # model = model.to("cuda") # Uncomment if using a GPU | |
| with torch.inference_mode(): | |
| out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)] | |
| model = model.to("cpu") | |
| st.session_state.model = model | |
| else: | |
| # Run inference for other models | |
| # Placeholder: Replace with actual inference code for other models | |
| out = torch.randn(1, 10, 180, 360) # Dummy tensor | |
| st.success("Model inference completed successfully.") | |
| st.session_state['out'] = out | |
| except Exception as e: | |
| st.error("Error during model inference:") | |
| st.error(traceback.format_exc()) | |
| st.stop() | |
| # --- Visualization Settings --- | |
| st.markdown("## 📊 Visualization Settings") | |
| if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi": | |
| # Display the shape of the output tensor | |
| out_tensor = st.session_state['out'] | |
| st.write(f"**Output tensor shape:** {out_tensor.shape}") | |
| # Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon) | |
| if out_tensor.ndim < 4: | |
| st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).") | |
| st.stop() | |
| # Get the number of variables | |
| num_variables = out_tensor.shape[1] | |
| # Define variable names (update with your actual variable names) | |
| variable_names = [f"Variable_{i}" for i in range(num_variables)] | |
| # Visualization settings | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| # Select variable to plot | |
| selected_variable_name = st.selectbox( | |
| "Select Variable to Plot", | |
| options=variable_names, | |
| index=0, | |
| help="Choose the variable you want to visualize." | |
| ) | |
| # Select plot type | |
| plot_type = st.selectbox( | |
| "Select Plot Type", | |
| options=["Contour", "Heatmap"], | |
| index=0, | |
| help="Choose the type of plot to display." | |
| ) | |
| with col2: | |
| # Select color map | |
| cmap = st.selectbox( | |
| "Select Color Map", | |
| options=plt.colormaps(), | |
| index=plt.colormaps().index("viridis"), | |
| help="Choose the color map for the plot." | |
| ) | |
| # Set number of levels (for contour plot) | |
| if plot_type == "Contour": | |
| num_levels = st.slider( | |
| "Number of Contour Levels", | |
| min_value=5, | |
| max_value=100, | |
| value=20, | |
| step=5, | |
| help="Set the number of contour levels." | |
| ) | |
| else: | |
| num_levels = None | |
| # Find the index based on the selected name | |
| variable_index = variable_names.index(selected_variable_name) | |
| # Extract the selected variable | |
| selected_variable = out_tensor[0, variable_index].cpu().numpy() | |
| # Generate latitude and longitude arrays | |
| lat = np.linspace(-90, 90, selected_variable.shape[0]) | |
| lon = np.linspace(-180, 180, selected_variable.shape[1]) | |
| X, Y = np.meshgrid(lon, lat) | |
| # Plot the selected variable | |
| st.markdown(f"### Plot of {selected_variable_name}") | |
| # Matplotlib figure | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| if plot_type == "Contour": | |
| # Generate the contour plot | |
| contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap) | |
| elif plot_type == "Heatmap": | |
| # Generate the heatmap | |
| contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto') | |
| # Add a color bar | |
| cbar = plt.colorbar(contour, ax=ax) | |
| cbar.set_label(f'{selected_variable_name}', fontsize=12) | |
| # Set aspect ratio and labels | |
| ax.set_xlabel("Longitude", fontsize=12) | |
| ax.set_ylabel("Latitude", fontsize=12) | |
| ax.set_title(f"{selected_variable_name}", fontsize=14) | |
| # Display the plot in Streamlit | |
| st.pyplot(fig) | |
| # Optional: Provide interactive Plotly plot | |
| st.markdown("#### Interactive Plot") | |
| if plot_type == "Contour": | |
| fig_plotly = go.Figure(data=go.Contour( | |
| z=selected_variable, | |
| x=lon, | |
| y=lat, | |
| colorscale=cmap, | |
| contours=dict( | |
| coloring='fill', | |
| showlabels=True, | |
| labelfont=dict(size=12, color='white'), | |
| ncontours=num_levels | |
| ) | |
| )) | |
| elif plot_type == "Heatmap": | |
| fig_plotly = go.Figure(data=go.Heatmap( | |
| z=selected_variable, | |
| x=lon, | |
| y=lat, | |
| colorscale=cmap | |
| )) | |
| fig_plotly.update_layout( | |
| xaxis_title="Longitude", | |
| yaxis_title="Latitude", | |
| autosize=False, | |
| width=800, | |
| height=600, | |
| ) | |
| st.plotly_chart(fig_plotly) | |
| elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None: | |
| preds = st.session_state['out'] | |
| ds_subset = st.session_state.get('ds_subset', None) | |
| batch = st.session_state.get('batch', None) | |
| # **Determine Available Levels** | |
| # For example, let's assume levels range from 0 to max_level_index | |
| # You need to replace 'max_level_index' with the actual maximum level index in your data | |
| try: | |
| # Assuming 'lev' dimension exists and is 1D | |
| levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure | |
| level_indices = list(range(levels)) | |
| except Exception as e: | |
| st.error("Error determining available levels:") | |
| st.error(traceback.format_exc()) | |
| levels = None # Set to None if levels cannot be determined | |
| if levels is not None: | |
| # **Add a Slider for Level Selection** | |
| selected_level = st.slider( | |
| 'Select Level', | |
| min_value=0, | |
| max_value=levels - 1, | |
| value=11, # Default level index | |
| step=1, | |
| help="Select the vertical level for plotting." | |
| ) | |
| # Loop through predictions and ground truths | |
| for idx in range(len(preds)): | |
| pred = preds[idx] | |
| pred_time = pred.metadata.time[0] | |
| # Display prediction time | |
| st.write(f"### Prediction Time: {pred_time}") | |
| # **Extract Data at Selected Level** | |
| try: | |
| # Update indices with the selected_level | |
| pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15 | |
| truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15 | |
| except Exception as e: | |
| st.error("Error extracting data for plotting:") | |
| st.error(traceback.format_exc()) | |
| continue | |
| # Extract latitude and longitude | |
| try: | |
| lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D | |
| lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D | |
| except Exception as e: | |
| st.error("Error extracting latitude and longitude:") | |
| st.error(traceback.format_exc()) | |
| continue | |
| # Create a meshgrid for plotting | |
| lon_grid, lat_grid = np.meshgrid(lon, lat) | |
| # Create a Matplotlib figure with Cartopy projection | |
| fig, axes = plt.subplots( | |
| 1, 3, figsize=(18, 6), | |
| subplot_kw={'projection': ccrs.PlateCarree()} | |
| ) | |
| # **Ground Truth Plot** | |
| im1 = axes[0].imshow( | |
| truth_data, | |
| extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
| origin='lower', | |
| cmap='coolwarm', | |
| transform=ccrs.PlateCarree() | |
| ) | |
| axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}") | |
| axes[0].set_xlabel('Longitude') | |
| axes[0].set_ylabel('Latitude') | |
| plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05) | |
| # **Prediction Plot** | |
| im2 = axes[1].imshow( | |
| pred_data, | |
| extent=[lon.min(), lon.max(), lat.min(), lat.max()], | |
| origin='lower', | |
| cmap='coolwarm', | |
| transform=ccrs.PlateCarree() | |
| ) | |
| axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}") | |
| axes[1].set_xlabel('Longitude') | |
| axes[1].set_ylabel('Latitude') | |
| plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05) | |
| plt.tight_layout() | |
| # Display the plot in Streamlit | |
| st.pyplot(fig) | |
| else: | |
| st.error("Could not determine the available levels in the data.") | |
| else: | |
| st.warning("No output available to display or visualization is not implemented for this model.") | |
| # --- End of Inference Button --- | |
| else: | |
| with right_col: | |
| st.header("🖥️ Visualization & Progress") | |
| st.info("Awaiting inference to display results.") | |