import os import torch import nibabel as nib from flask import Flask, request, render_template, redirect, url_for, flash, jsonify import tempfile import yaml import traceback # For detailed error printing import zipfile import dicom2nifti import shutil import subprocess # To run unzip command import SimpleITK as sitk import itk import numpy as np from scipy.signal import medfilt import skimage.filters import cv2 # For Gaussian Blur import io # For saving plots to memory import base64 # For encoding plots import uuid # For unique IDs # Configure Matplotlib for non-GUI backend *before* importing pyplot import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # --- Preprocessing Imports --- try: # Adjust import path based on Docker structure # Assumes HD_BET is now at /app/BrainIAC/HD_BET from HD_BET.run import run_hd_bet # Import MONAI saliency visualizer from monai.visualize.gradient_based import GuidedBackpropSmoothGrad except ImportError as e: print(f"Could not import HD_BET or MONAI visualize: {e}. Advanced features might fail.") run_hd_bet = None GuidedBackpropSmoothGrad = None # Import necessary components from your existing modules from model import Backbone, SingleScanModel, Classifier # Removed: from dataset2 import NormalSynchronizedTransform3D # Import specific MONAI transforms needed from monai.transforms import Resized, ScaleIntensityd # Removed ToTensord, will handle manually app = Flask(__name__) app.secret_key = 'supersecretkey' # Needed for flashing messages # --- Constants for Preprocessing --- APP_DIR = os.path.dirname(__file__) TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates") PARAMS_RIGID_PATH = os.path.join(APP_DIR, "golden_image", "mni_templates", "Parameters_Rigid.txt") DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii") # Using adult template as default HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py") HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") # Path to copied models # --- Configuration Loading --- def load_config(): # Assuming config.yml is in the same directory as app.py config_path = os.path.join(APP_DIR, 'config.yml') try: with open(config_path, 'r') as file: config = yaml.safe_load(file) # Add default image_size if not present in config if 'data' not in config: config['data'] = {} if 'image_size' not in config['data']: config['data']['image_size'] = [128, 128, 128] except FileNotFoundError: print(f"Error: Configuration file not found at {config_path}") # Provide default config or handle error appropriately config = { 'gpu': {'device': 'cpu'}, 'infer': {'checkpoints': 'checkpoints/brainage_model_latest.pt'}, 'data': {'image_size': [128, 128, 128]} # Default image size } return config config = load_config() # Ensure image_size is available, e.g., from config or a default DEFAULT_IMAGE_SIZE = (128, 128, 128) image_size_cfg = config.get('data', {}).get('image_size', DEFAULT_IMAGE_SIZE) # Validate image_size format if not isinstance(image_size_cfg, (list, tuple)) or len(image_size_cfg) != 3: print(f"Warning: Invalid image_size in config ({image_size_cfg}). Using default {DEFAULT_IMAGE_SIZE}.") image_size = DEFAULT_IMAGE_SIZE else: image_size = tuple(image_size_cfg) # Ensure it's a tuple for transforms # --- Model Loading --- def load_model(device, checkpoint_path): backbone = Backbone() classifier = Classifier(d_model=2048) # Make sure d_model matches your trained model model = SingleScanModel(backbone, classifier) try: # Construct absolute path if checkpoint_path is relative relative_path = config.get('infer', {}).get('checkpoints', 'checkpoints/brainage_model_latest.pt') # Use path relative to app.py location checkpoint_path_abs = os.path.join(APP_DIR, relative_path) checkpoint = torch.load(checkpoint_path_abs, map_location=device) # Adjust key if necessary based on how model was saved if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint) model.to(device) model.eval() print(f"Model loaded successfully from {checkpoint_path_abs} onto {device}.") return model except FileNotFoundError: print(f"Error: Checkpoint file not found at {checkpoint_path_abs}") return None except Exception as e: print(f"Error loading model checkpoint: {e}") traceback.print_exc() return None device = torch.device(config.get('gpu', {}).get('device', 'cpu')) # Default to CPU model = load_model(device, config) # Pass full config for path finding # --- Preprocessing Functions from preprocess_utils.py --- def bias_field_correction(img_array): """Performs N4 bias field correction using SimpleITK.""" image = sitk.GetImageFromArray(img_array) # Ensure image is float32 for N4 if image.GetPixelID() != sitk.sitkFloat32: image = sitk.Cast(image, sitk.sitkFloat32) maskImage = sitk.OtsuThreshold(image, 0, 1, 200) corrector = sitk.N4BiasFieldCorrectionImageFilter() numberFittingLevels = 4 # Define iterations per level more robustly max_iters = [min(50 * (2**i), 200) for i in range(numberFittingLevels)] corrector.SetMaximumNumberOfIterations(max_iters) # Set convergence threshold (optional, can speed up) # corrector.SetConvergenceThreshold(1e-6) print(" Running N4 Bias Field Correction...") corrected_image = corrector.Execute(image, maskImage) print(" N4 Correction finished.") return sitk.GetArrayFromImage(corrected_image) def denoise(volume, kernel_size=3): """Applies median filter for denoising.""" print(f" Applying median filter denoising (kernel={kernel_size})...") return medfilt(volume, kernel_size) def rescale_intensity(volume, percentils=[0.5, 99.5], bins_num=256): """Rescales intensity after removing background via Otsu.""" print(" Rescaling intensity...") # Ensure input is float for Otsu and calculations volume_float = volume.astype(np.float32) try: t = skimage.filters.threshold_otsu(volume_float, nbins=256) print(f" Otsu threshold found: {t}") volume_masked = np.copy(volume_float) volume_masked[volume_masked < t] = 0 # Apply mask based on original values obj_volume = volume_masked[np.where(volume_masked > 0)] except ValueError: # Handle cases with near-uniform intensity print(" Otsu failed (likely uniform image), skipping background mask.") obj_volume = volume_float.flatten() if obj_volume.size == 0: print(" Warning: No foreground voxels found after Otsu. Scaling full volume.") obj_volume = volume_float.flatten() # Fallback to full volume min_value = np.min(obj_volume) max_value = np.max(obj_volume) else: min_value = np.percentile(obj_volume, percentils[0]) max_value = np.percentile(obj_volume, percentils[1]) print(f" Intensity range used for scaling: [{min_value:.2f}, {max_value:.2f}]") # Avoid division by zero if max == min denominator = max_value - min_value if denominator < 1e-6: denominator = 1e-6 # Create a copy to modify for output output_volume = np.copy(volume_float) # Apply scaling only to the object volume identified (or full volume as fallback) if bins_num == 0: # Scale to 0-1 (float) output_volume = (volume_float - min_value) / denominator output_volume = np.clip(output_volume, 0.0, 1.0) # Clip results to [0, 1] else: # Scale and bin output_volume = np.round((volume_float - min_value) / denominator * (bins_num - 1)) output_volume = np.clip(output_volume, 0, bins_num - 1) # Ensure within bin range # Ensure output is float32 for consistency return output_volume.astype(np.float32) def equalize_hist(volume, bins_num=256): """Performs histogram equalization on non-zero voxels.""" print(" Performing histogram equalization...") # Create a mask of non-zero voxels mask = volume > 1e-6 # Use a small epsilon for float comparison obj_volume = volume[mask] if obj_volume.size == 0: print(" Warning: No non-zero voxels found for histogram equalization. Skipping.") return volume # Return original volume if no foreground # Compute histogram and CDF on the non-zero voxels hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max())) cdf = hist.cumsum() # Normalize CDF cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1]) # Interpolate new values for the object volume equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized) # Create a copy of the original volume to put the results back equalized_volume = np.copy(volume) equalized_volume[mask] = equalized_obj_volume # Ensure output is float32 return equalized_volume.astype(np.float32) def enhance(img_array, run_bias_correction=True, kernel_size=3, percentils=[0.5, 99.5], bins_num=256, run_equalize_hist=True): """Full enhancement pipeline from preprocess_utils.""" print("Starting enhancement pipeline...") volume = img_array.astype(np.float32) # Ensure float input try: if run_bias_correction: volume = bias_field_correction(volume) volume = denoise(volume, kernel_size) volume = rescale_intensity(volume, percentils, bins_num) if run_equalize_hist: volume = equalize_hist(volume, bins_num) print("Enhancement pipeline finished.") return volume except Exception as e: print(f"Error during enhancement: {e}") traceback.print_exc() raise RuntimeError(f"Failed enhancing image: {e}") # Re-raise to stop processing # --- Registration Function (modified enhance call) --- def register_image(input_nifti_path, output_nifti_path): """Registers input NIfTI to the default template using Elastix.""" print(f"Registering {input_nifti_path} to {DEFAULT_TEMPLATE_PATH}") if not os.path.exists(PARAMS_RIGID_PATH): raise FileNotFoundError(f"Elastix parameter file not found at {PARAMS_RIGID_PATH}") if not os.path.exists(DEFAULT_TEMPLATE_PATH): raise FileNotFoundError(f"Default template file not found at {DEFAULT_TEMPLATE_PATH}") fixed_image = itk.imread(DEFAULT_TEMPLATE_PATH, itk.F) moving_image = itk.imread(input_nifti_path, itk.F) parameter_object = itk.ParameterObject.New() parameter_object.AddParameterFile(PARAMS_RIGID_PATH) result_image, _ = itk.elastix_registration_method( fixed_image, moving_image, parameter_object=parameter_object, log_to_console=False # Keep console clean ) itk.imwrite(result_image, output_nifti_path) print(f"Registration output saved to {output_nifti_path}") # --- Enhanced Image Function (calls actual enhance) --- def run_enhance_on_file(input_nifti_path, output_nifti_path): """Reads NIfTI, runs enhance pipeline, saves NIfTI.""" print(f"Running full enhancement on {input_nifti_path}") img_sitk = sitk.ReadImage(input_nifti_path) img_array = sitk.GetArrayFromImage(img_sitk) # Run the actual enhancement pipeline enhanced_array = enhance(img_array, run_bias_correction=True) # Assuming N4 is desired enhanced_img_sitk = sitk.GetImageFromArray(enhanced_array) enhanced_img_sitk.CopyInformation(img_sitk) # Preserve metadata sitk.WriteImage(enhanced_img_sitk, output_nifti_path) print(f"Enhanced image saved to {output_nifti_path}") # --- Skull Stripping Function (Set Environment Variable) --- def run_skull_stripping(input_nifti_path, output_dir): """Runs HD-BET skull stripping.""" print(f"Running HD-BET skull stripping on {input_nifti_path}") if run_hd_bet is None: raise RuntimeError("HD-BET module could not be imported. Cannot perform skull stripping.") # Removed environment variable setting as path is fixed in HD_BET/paths.py # # Set environment variable *before* calling run_hd_bet # # Ensure the target directory exists # if not os.path.isdir(HD_BET_MODEL_DIR): # raise FileNotFoundError(f"HD-BET model directory not found at specified path: {HD_BET_MODEL_DIR}") # print(f"Setting HD_BET_MODELS environment variable to: {HD_BET_MODEL_DIR}") # os.environ['HD_BET_MODELS'] = HD_BET_MODEL_DIR # Check config path if not os.path.exists(HD_BET_CONFIG_PATH): alt_config_path = os.path.join(APP_DIR, "HD_BET", "HD_BET", "config.py") if os.path.exists(alt_config_path): print(f"Warning: Using alternative HD-BET config path: {alt_config_path}") config_to_use = alt_config_path else: raise FileNotFoundError(f"HD-BET config file not found at {HD_BET_CONFIG_PATH} or {alt_config_path}") else: config_to_use = HD_BET_CONFIG_PATH # Define output paths base_name = os.path.basename(input_nifti_path).replace(".nii.gz", "").replace(".nii", "") output_file_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz") output_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz") # Make sure output directory exists os.makedirs(output_dir, exist_ok=True) # Run HD-BET run_hd_bet(input_nifti_path, output_file_path, mode="fast", device='cpu', config_file=config_to_use, postprocess=False, do_tta=False, keep_mask=True, overwrite=True) # Unset environment variable after use (optional, good practice) # del os.environ['HD_BET_MODELS'] if not os.path.exists(output_file_path): raise RuntimeError(f"HD-BET did not produce the expected output file: {output_file_path}") print(f"Skull stripping output saved to {output_file_path}") return output_file_path, output_mask_path # --- Image Preprocessing --- # Define necessary MONAI transforms directly # Keys must match the dictionary keys we create later ('image') resize_transform = Resized(keys=["image"], spatial_size=image_size) scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0) def preprocess_nifti(nifti_path): """Loads and preprocesses a NIfTI file, returning a 5D tensor.""" print(f"Preprocessing NIfTI: {nifti_path}") scan_data = nib.load(nifti_path).get_fdata() print(f" Loaded scan data shape: {scan_data.shape}") scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim print(f" Shape after tensor+channel: {scan_tensor.shape}") sample = {"image": scan_tensor} sample_resized = resize_transform(sample) print(f" Shape after resize: {sample_resized['image'].shape}") sample_scaled = scale_transform(sample_resized) print(f" Shape after scaling: {sample_scaled['image'].shape}") input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim print(f" Final shape for model: {input_tensor.shape}") if input_tensor.dim() != 5: raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.") return input_tensor # --- Final NIfTI Preprocessing for Model --- def preprocess_nifti_for_model(nifti_path): """Loads final NIfTI and prepares 5D tensor for the model.""" # ... (Same as previous preprocess_nifti function) ... print(f"Preprocessing NIfTI for model: {nifti_path}") scan_data = nib.load(nifti_path).get_fdata() print(f" Loaded scan data shape: {scan_data.shape}") scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim print(f" Shape after tensor+channel: {scan_tensor.shape}") sample = {"image": scan_tensor} sample_resized = resize_transform(sample) print(f" Shape after resize: {sample_resized['image'].shape}") sample_scaled = scale_transform(sample_resized) print(f" Shape after scaling: {sample_scaled['image'].shape}") input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim print(f" Final shape for model: {input_tensor.shape}") if input_tensor.dim() != 5: raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.") return input_tensor # --- Saliency Map Generation --- def generate_saliency(model, input_tensor_5d): """Generates saliency map using GuidedBackpropSmoothGrad.""" if GuidedBackpropSmoothGrad is None: raise ImportError("MONAI visualize components not imported. Cannot generate saliency map.") if model is None: raise ValueError("Model not loaded. Cannot generate saliency map.") print("Generating saliency map...") input_tensor_5d.requires_grad_(True) # Use the backbone for saliency as in the original script # Ensure model and backbone are on the correct device (CPU in this case) visualizer = GuidedBackpropSmoothGrad(model=model.backbone.to(device), stdev_spread=0.15, n_samples=10, magnitude=True) try: with torch.enable_grad(): saliency_map_5d = visualizer(input_tensor_5d.to(device)) print("Saliency map generated.") # Detach, move to CPU, remove Batch and Channel dims for processing/plotting -> (D, H, W) input_3d = input_tensor_5d.squeeze().cpu().detach().numpy() saliency_3d = saliency_map_5d.squeeze().cpu().detach().numpy() return input_3d, saliency_3d except Exception as e: print(f"Error during saliency map generation: {e}") traceback.print_exc() # Return None or empty arrays if generation fails return None, None finally: # Ensure requires_grad is turned off if it was modified input_tensor_5d.requires_grad_(False) # --- Plotting Function for Single Slice --- def create_plot_images_for_slice(mri_data_3d, saliency_data_3d, slice_index): """Creates base64 encoded PNGs for a specific axial slice index.""" print(f" Generating plots for slice index: {slice_index}") if mri_data_3d is None or saliency_data_3d is None: print(" Input or Saliency data is None, cannot generate plot.") return None if slice_index < 0 or slice_index >= mri_data_3d.shape[2]: print(f" Error: Slice index {slice_index} out of bounds (0-{mri_data_3d.shape[2]-1}).") return None # Function to save plot to base64 string (copied from previous version) def save_plot_to_base64(fig): buf = io.BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75) plt.close(fig) # Close the figure immediately buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') buf.close() return img_str try: mri_slice = mri_data_3d[:, :, slice_index] saliency_slice_orig = saliency_data_3d[:, :, slice_index] # --- Normalize MRI Slice (using volume stats if available, otherwise slice stats) --- # For consistency, ideally pass volume stats, but recalculating per slice is fallback p1_vol, p99_vol = np.percentile(mri_data_3d, (1, 99)) mri_norm_denom = p99_vol - p1_vol if mri_norm_denom < 1e-6: mri_norm_denom = 1e-6 mri_slice_norm = np.clip(mri_slice, p1_vol, p99_vol) mri_slice_norm = (mri_slice_norm - p1_vol) / mri_norm_denom # --- Process Saliency Slice --- saliency_slice = np.copy(saliency_slice_orig) saliency_slice[saliency_slice < 0] = 0 # Ensure non-negative saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0) # Use volume max for normalization if possible, fallback to slice max s_max_vol = np.max(saliency_data_3d[saliency_data_3d >= 0]) # Max of non-negative values in volume if s_max_vol < 1e-6: s_max_vol = 1e-6 # --- Add logging for the calculated global max --- print(f" Calculated Global Max Saliency (s_max_vol) for normalization: {s_max_vol:.4f}") # -------------------------------------------------- saliency_slice_norm = saliency_slice_blurred / s_max_vol threshold_value = 0.0 saliency_slice_thresholded = np.where(saliency_slice_norm > threshold_value, saliency_slice_norm, 0) # --- Generate Plots --- slice_plots = {} # Plot 1: Input Slice fig1, ax1 = plt.subplots(figsize=(3, 3)) ax1.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower') ax1.axis('off') slice_plots['input_slice'] = save_plot_to_base64(fig1) # Plot 2: Saliency Heatmap fig2, ax2 = plt.subplots(figsize=(3, 3)) ax2.imshow(saliency_slice_thresholded, cmap='magma', interpolation='none', origin='lower') ax2.axis('off') slice_plots['heatmap_slice'] = save_plot_to_base64(fig2) # Plot 3: Overlay fig3, ax3 = plt.subplots(figsize=(3, 3)) ax3.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower') if np.max(saliency_slice_thresholded) > 0: # Remove fixed levels to let contour auto-determine levels based on slice data ax3.contour(saliency_slice_thresholded, cmap='magma', origin='lower', linewidths=1.0) ax3.axis('off') slice_plots['overlay_slice'] = save_plot_to_base64(fig3) print(f" Generated plots successfully for slice {slice_index}.") return slice_plots except Exception as e: print(f"Error generating plots for slice {slice_index}: {e}") traceback.print_exc() return None # --- Flask Routes --- @app.route('/', methods=['GET']) def index(): return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): if model is None: flash('Model not loaded. Cannot perform prediction.', 'error') return redirect(url_for('index')) # Get form data file_type = request.form.get('file_type') run_preprocess_flag = request.form.get('preprocess') == 'yes' generate_saliency_flag = request.form.get('generate_saliency') == 'yes' # Get saliency flag file = request.files.get('scan_file') # --- Basic Input Validation --- if not file_type: flash('Please select a file type (NIfTI or DICOM).', 'error') return redirect(url_for('index')) if not file or file.filename == '': flash('No scan file selected', 'error') return redirect(url_for('index')) print(f"Received upload: type='{file_type}', filename='{file.filename}', preprocess={run_preprocess_flag}, saliency={generate_saliency_flag}") # --- Setup Temporary Directory --- # temp_dir_obj = tempfile.TemporaryDirectory() # <--- PROBLEM: Cleans up automatically # Use mkdtemp to create a persistent temporary directory # NOTE: Requires a manual cleanup strategy later! try: temp_dir = tempfile.mkdtemp() except Exception as e: print(f"Error creating temporary directory: {e}") flash("Server error: Could not create temporary directory.", "error") return redirect(url_for('index')) # Generate a unique ID based on the temp directory name unique_id = os.path.basename(temp_dir) print(f"Created persistent temp directory: {temp_dir} (ID: {unique_id})") nifti_for_preprocessing_path = None # Path to the NIfTI before optional preprocessing try: # --- Handle Upload and DICOM Conversion --- # --- Handle NIfTI Upload --- if file_type == 'nifti': if not file.filename.endswith('.nii.gz'): flash('Invalid file type for NIfTI selection. Please upload .nii.gz', 'error') # temp_dir_obj.cleanup() # No object to cleanup, need manual rmtree shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) uploaded_file_path = os.path.join(temp_dir, "uploaded_scan.nii.gz") file.save(uploaded_file_path) print(f"Saved uploaded NIfTI file to: {uploaded_file_path}") nifti_for_preprocessing_path = uploaded_file_path # --- Handle DICOM Upload --- elif file_type == 'dicom': if not file.filename.endswith('.zip'): flash('Invalid file type for DICOM selection. Please upload a .zip file.', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) uploaded_zip_path = os.path.join(temp_dir, "dicom_files.zip") file.save(uploaded_zip_path) print(f"Saved uploaded DICOM zip to: {uploaded_zip_path}") dicom_input_dir = os.path.join(temp_dir, "dicom_input") nifti_output_dir = os.path.join(temp_dir, "nifti_output") os.makedirs(dicom_input_dir, exist_ok=True) os.makedirs(nifti_output_dir, exist_ok=True) try: # Use shutil.unpack_archive for better cross-platform compatibility potentially shutil.unpack_archive(uploaded_zip_path, dicom_input_dir) print(f"Unzip successful.") except Exception as e: print(f"Unzip failed: {e}") flash(f'Error unzipping DICOM file: {e}', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) try: dicom2nifti.convert_directory(dicom_input_dir, nifti_output_dir, compression=True, reorient=True) nifti_files = [f for f in os.listdir(nifti_output_dir) if f.endswith('.nii.gz')] if not nifti_files: raise RuntimeError("dicom2nifti did not produce a .nii.gz file.") nifti_for_preprocessing_path = os.path.join(nifti_output_dir, nifti_files[0]) print(f"DICOM conversion successful. NIfTI file: {nifti_for_preprocessing_path}") except Exception as e: print(f"DICOM to NIfTI conversion failed: {e}") flash(f'Error converting DICOM to NIfTI: {e}', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) else: flash('Invalid file type selected.', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) if not nifti_for_preprocessing_path or not os.path.exists(nifti_for_preprocessing_path): flash('Error: Could not find the NIfTI file for processing.', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) # --- Optional Preprocessing Steps --- nifti_to_predict_path = nifti_for_preprocessing_path if run_preprocess_flag: print("--- Running Optional Preprocessing Pipeline ---") try: registered_path = os.path.join(temp_dir, "registered.nii.gz") register_image(nifti_for_preprocessing_path, registered_path) enhanced_path = os.path.join(temp_dir, "enhanced.nii.gz") run_enhance_on_file(registered_path, enhanced_path) skullstrip_output_dir = os.path.join(temp_dir, "skullstripped") skullstripped_path, _ = run_skull_stripping(enhanced_path, skullstrip_output_dir) nifti_to_predict_path = skullstripped_path print("--- Optional Preprocessing Pipeline Complete ---") except Exception as e: print(f"Error during optional preprocessing pipeline: {e}") traceback.print_exc() flash(f'Error during preprocessing: {e}', 'error') # temp_dir_obj.cleanup() shutil.rmtree(temp_dir, ignore_errors=True) return redirect(url_for('index')) else: print("--- Skipping Optional Preprocessing Pipeline ---") # --- Final Preprocessing for Model & Prediction --- input_tensor_5d = preprocess_nifti_for_model(nifti_to_predict_path) print("Performing prediction...") with torch.no_grad(): output = model(input_tensor_5d) predicted_age = output.item() predicted_age_years = predicted_age / 12 # Adjust if needed print(f"Prediction successful: {predicted_age_years:.2f} years") # --- Saliency Data Handling (Generate, Save, Get Initial Plot) --- saliency_output_for_template = None # Initialize if generate_saliency_flag: print("--- Generating & Saving Saliency Data ---") try: input_3d_for_plot, saliency_3d = generate_saliency(model, input_tensor_5d) if input_3d_for_plot is not None and saliency_3d is not None: num_slices = input_3d_for_plot.shape[2] center_slice_index = num_slices // 2 # Save the numpy arrays for the dynamic route input_array_path = os.path.join(temp_dir, f"{unique_id}_input.npy") saliency_array_path = os.path.join(temp_dir, f"{unique_id}_saliency.npy") np.save(input_array_path, input_3d_for_plot) np.save(saliency_array_path, saliency_3d) print(f"Saved input array to {input_array_path}") print(f"Saved saliency array to {saliency_array_path}") # Generate ONLY the center slice plots for the initial view center_slice_plots = create_plot_images_for_slice(input_3d_for_plot, saliency_3d, center_slice_index) if center_slice_plots: # Prepare data structure for the template saliency_output_for_template = { 'center_slice_plots': center_slice_plots, 'num_slices': num_slices, 'center_slice_index': center_slice_index, 'unique_id': unique_id, # Pass the ID for filenames 'temp_dir_path': temp_dir # Pass the full path for lookup } print("--- Saliency Data Saved & Initial Plot Generated ---") else: print("--- Center Slice Plotting Failed ---") flash('Failed to generate initial saliency plot.', 'warning') else: print("--- Saliency Generation Failed --- ") flash('Saliency map generation failed.', 'warning') except Exception as e: print(f"Error during saliency processing/saving: {e}") traceback.print_exc() flash('Could not generate or save saliency maps due to an error.', 'error') # Render result, passing prediction and potentially the NEW saliency structure return render_template('index.html', prediction=f"{predicted_age_years:.2f} years", saliency_info=saliency_output_for_template) # Pass the new dict except Exception as e: flash(f'Error processing file: {e}', 'error') print(f"Caught Exception during prediction process: {e}") traceback.print_exc() # Ensure cleanup happens even if exception occurs mid-process # temp_dir_obj.cleanup() if temp_dir and os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) # Manual cleanup on general error return redirect(url_for('index')) # NOTE: Temporary directory created with mkdtemp is NOT automatically cleaned. # Need a separate mechanism (e.g., cron job, background task) to remove old directories # from the system's temporary location (e.g., /tmp) based on age. # Leaving the directory here so /get_slice can access the files. # --- New Route for Dynamic Slice Loading --- @app.route('/get_slice//') def get_slice(unique_id, slice_index): # Get the actual temporary directory path from query parameter temp_dir_path = request.args.get('path') if not temp_dir_path: print("Error: 'path' query parameter missing in /get_slice request") return jsonify({"error": "Required path information missing."}), 400 # Construct paths using the provided directory path and unique ID input_array_path = os.path.join(temp_dir_path, f"{unique_id}_input.npy") saliency_array_path = os.path.join(temp_dir_path, f"{unique_id}_saliency.npy") print(f"Attempting to load slice {slice_index} for ID {unique_id} from actual path: {temp_dir_path}") try: # Check using the exact paths constructed above if not os.path.exists(input_array_path) or not os.path.exists(saliency_array_path): print(f"Error: .npy files not found for ID {unique_id} at {temp_dir_path}") return jsonify({"error": "Saliency data not found. It might have expired or failed to save."}), 404 input_3d = np.load(input_array_path) saliency_3d = np.load(saliency_array_path) print(f"Loaded arrays for ID {unique_id}. Input shape: {input_3d.shape}, Saliency shape: {saliency_3d.shape}") # Generate plots for the requested slice using the helper function slice_plots = create_plot_images_for_slice(input_3d, saliency_3d, slice_index) if slice_plots: return jsonify(slice_plots) # Return plot data as JSON else: return jsonify({"error": f"Failed to generate plots for slice {slice_index}."}), 500 except Exception as e: print(f"Error in /get_slice for ID {unique_id}, slice {slice_index}: {e}") traceback.print_exc() return jsonify({"error": "An internal error occurred while fetching the slice data."}), 500 if __name__ == '__main__': # Use '0.0.0.0' to make it accessible outside the container app.run(host='0.0.0.0', port=5000, debug=False) # Turn off debug for production/docker