Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import random | |
| import yaml | |
| import os | |
| import argparse | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| import nibabel as nib | |
| from monai.visualize.gradient_based import SmoothGrad, GuidedBackpropSmoothGrad | |
| from dataset2 import MedicalImageDatasetBalancedIntensity3D | |
| from load_brainiac import load_brainiac | |
| # Fix random seed | |
| seed = 42 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # collate funcntion (unneccerary for single timpoint input) | |
| def custom_collate(batch): | |
| """Handles variable size of the scans and pads the sequence dimension.""" | |
| images = [item['image'] for item in batch] | |
| labels = [item['label'] for item in batch] | |
| patids = [item['pat_id'] for item in batch] | |
| max_len = 1 # singlescan input | |
| padded_images = [] | |
| for img in images: | |
| pad_size = max_len - img.shape[0] | |
| if pad_size > 0: | |
| padding = torch.zeros((pad_size,) + img.shape[1:]) | |
| img_padded = torch.cat([img, padding], dim=0) | |
| padded_images.append(img_padded) | |
| else: | |
| padded_images.append(img) | |
| return {"image": torch.stack(padded_images, dim=0), "label": labels, "pat_id": patids} | |
| def generate_saliency_maps(model, data_loader, output_dir, device): | |
| """Generate saliency maps using guided backprop method""" | |
| model.eval() | |
| visualizer = GuidedBackpropSmoothGrad(model=model.backbone, stdev_spread=0.15, n_samples=10, magnitude=True) | |
| for sample in tqdm(data_loader, desc="Generating saliency maps"): | |
| inputs = sample['image'].requires_grad_(True) | |
| patids = sample["pat_id"] | |
| imagename = patids[0] | |
| input_tensor = inputs.to(device) | |
| with torch.enable_grad(): | |
| saliency_map = visualizer(input_tensor) | |
| # Save input image and saliency map | |
| inputs_np = input_tensor.squeeze().cpu().detach().numpy() | |
| saliency_np = saliency_map.squeeze().cpu().detach().numpy() | |
| input_nifti = nib.Nifti1Image(inputs_np, np.eye(4)) | |
| saliency_nifti = nib.Nifti1Image(saliency_np, np.eye(4)) | |
| # Save files | |
| nib.save(input_nifti, os.path.join(output_dir, f"{imagename}_image.nii.gz")) | |
| nib.save(saliency_nifti, os.path.join(output_dir, f"{imagename}_saliencymap.nii.gz")) | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Generate saliency maps for medical images') | |
| parser.add_argument('--checkpoint', type=str, required=True, | |
| help='Path to the model checkpoint') | |
| parser.add_argument('--input_csv', type=str, required=True, | |
| help='Path to the input CSV file containing image paths') | |
| parser.add_argument('--output_dir', type=str, required=True, | |
| help='Directory to save saliency maps') | |
| parser.add_argument('--root_dir', type=str, required=True, | |
| help='Root directory containing the image data') | |
| args = parser.parse_args() | |
| device = torch.device("cpu") | |
| # Create output directory if it doesn't exist | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| # Initialize dataset and dataloader | |
| dataset = MedicalImageDatasetBalancedIntensity3D( | |
| csv_path=args.input_csv, | |
| root_dir=args.root_dir | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| collate_fn=custom_collate, | |
| num_workers=1 | |
| ) | |
| # Load brainiac and ensure it's on CPU | |
| model = load_brainiac(args.checkpoint, device) | |
| model = model.to(device) | |
| # Make sure model weights are on CPU | |
| model.backbone = model.backbone.to(device) | |
| # Generate saliency maps | |
| generate_saliency_maps(model, dataloader, args.output_dir, device) | |
| print(f"Saliency maps generated and saved to {args.output_dir}") | |
| if __name__ == "__main__": | |
| main() |