Divyanshu Tak commited on
Commit
65bee5d
·
1 Parent(s): 20cb642

Add BrainIAC IDH Classification app with Vision Transformer model

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.10-slim
3
+
4
+ # Set the working directory in the container
5
+ WORKDIR /app
6
+
7
+ # Install necessary system dependencies (kept minimal)
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends \
10
+ git \
11
+ libgl1 \
12
+ libglib2.0-0 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy the requirements file first to leverage Docker cache
16
+ COPY requirements.txt ./
17
+
18
+ # Install Python packages
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy the entire project
22
+ COPY . /app/
23
+
24
+ # Create a non-root user (HF Spaces requirement)
25
+ RUN useradd -m -u 1000 user
26
+ USER user
27
+
28
+ # Make sure the user owns the app directory
29
+ COPY --chown=user:user . /app/
30
+
31
+ # Expose Gradio default port
32
+ EXPOSE 7860
33
+
34
+ ENV PYTHONUNBUFFERED=1
35
+
36
+ # Run the app from the src/IDH directory
37
+ CMD ["python", "src/IDH/app_gradio.py"]
README.md CHANGED
@@ -1,11 +1,29 @@
1
  ---
2
- title: IDH Classification BrainIAC
3
- emoji: 🐠
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: docker
7
  pinned: false
8
- license: cc-by-4.0
9
  ---
10
 
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BrainIAC IDH Classification
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
  ---
10
 
11
+ # BrainIAC: IDH Classification
12
+
13
+ A Vision Transformer model for predicting IDH mutation status from dual MRI sequences (FLAIR + T1c).
14
+
15
+ ## Features
16
+ - Upload FLAIR and T1c MRI NIfTI files
17
+ - Optional preprocessing (registration + enhancement + skull stripping)
18
+ - Vision Transformer attention map visualization
19
+ - Interactive slice-by-slice attention viewing
20
+ - Real-time IDH mutation prediction with confidence scores
21
+
22
+ ## Usage
23
+ 1. Upload FLAIR and T1c MRI scans (.nii or .nii.gz)
24
+ 2. Optionally enable preprocessing
25
+ 3. Enable saliency maps for attention visualization
26
+ 4. Adjust prediction threshold
27
+ 5. View results and attention maps
28
+
29
+ *Research use only*
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ monai==1.3.2
2
+ nibabel==5.2.1
3
+ numpy==1.23.5
4
+ pydicom
5
+ PyYAML
6
+ pytorch-lightning==2.3.3
7
+ scipy==1.10.1
8
+ SimpleITK==2.4.0
9
+ torch==2.6.0
10
+ tqdm
11
+ gradio
12
+ pandas
13
+ scikit-image==0.21.0
14
+ opencv-python
15
+ itk-elastix
16
+ dicom2nifti
17
+ einops
18
+ matplotlib
src/IDH/HD_BET/HD_BET/config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from HD_BET.utils import SetNetworkToVal, softmax_helper
4
+ from abc import abstractmethod
5
+ from HD_BET.network_architecture import Network
6
+
7
+
8
+ class BaseConfig(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_split(self, fold, random_state=12345):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_network(self, mode="train"):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_basic_generators(self, fold):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_data_generators(self, fold):
26
+ pass
27
+
28
+ def preprocess(self, data):
29
+ return data
30
+
31
+ def __repr__(self):
32
+ res = ""
33
+ for v in vars(self):
34
+ if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
35
+ res += (v + ": " + str(self.__getattribute__(v)) + "\n")
36
+ return res
37
+
38
+
39
+ class HD_BET_Config(BaseConfig):
40
+ def __init__(self):
41
+ super(HD_BET_Config, self).__init__()
42
+
43
+ self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
44
+
45
+ # network parameters
46
+ self.net_base_num_layers = 21
47
+ self.BATCH_SIZE = 2
48
+ self.net_do_DS = True
49
+ self.net_dropout_p = 0.0
50
+ self.net_use_inst_norm = True
51
+ self.net_conv_use_bias = True
52
+ self.net_norm_use_affine = True
53
+ self.net_leaky_relu_slope = 1e-1
54
+
55
+ # hyperparameters
56
+ self.INPUT_PATCH_SIZE = (128, 128, 128)
57
+ self.num_classes = 2
58
+ self.selected_data_channels = range(1)
59
+
60
+ # data augmentation
61
+ self.da_mirror_axes = (2, 3, 4)
62
+
63
+ # validation
64
+ self.val_use_DO = False
65
+ self.val_use_train_mode = False # for dropout sampling
66
+ self.val_num_repeats = 1 # only useful if dropout sampling
67
+ self.val_batch_size = 1 # only useful if dropout sampling
68
+ self.val_save_npz = True
69
+ self.val_do_mirroring = True # test time data augmentation via mirroring
70
+ self.val_write_images = True
71
+ self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
72
+ self.val_min_size = self.INPUT_PATCH_SIZE
73
+ self.val_fn = None
74
+
75
+ # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
76
+ # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
77
+ # to false in 0.4)
78
+ self.val_use_moving_averages = False
79
+
80
+ def get_network(self, train=True, pretrained_weights=None):
81
+ net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
82
+ self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
83
+ self.net_norm_use_affine, True, self.net_do_DS)
84
+
85
+ if pretrained_weights is not None:
86
+ net.load_state_dict(
87
+ torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
88
+
89
+ if train:
90
+ net.train(True)
91
+ else:
92
+ net.train(False)
93
+ net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
94
+ net.do_ds = False
95
+
96
+ optimizer = None
97
+ self.lr_scheduler = None
98
+ return net, optimizer
99
+
100
+ def get_data_generators(self, fold):
101
+ pass
102
+
103
+ def get_split(self, fold, random_state=12345):
104
+ pass
105
+
106
+ def get_basic_generators(self, fold):
107
+ pass
108
+
109
+ def on_epoch_end(self, epoch):
110
+ pass
111
+
112
+ def preprocess(self, data):
113
+ data = np.copy(data)
114
+ for c in range(data.shape[0]):
115
+ data[c] -= data[c].mean()
116
+ data[c] /= data[c].std()
117
+ return data
118
+
119
+
120
+ config = HD_BET_Config
121
+
src/IDH/HD_BET/HD_BET/data_loading.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import numpy as np
3
+ from skimage.transform import resize
4
+
5
+
6
+ def resize_image(image, old_spacing, new_spacing, order=3):
7
+ new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
8
+ int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
9
+ int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
10
+ return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
11
+
12
+
13
+ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
14
+ spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
15
+ image = sitk.GetArrayFromImage(itk_image).astype(float)
16
+
17
+ assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
18
+
19
+ if not is_seg:
20
+ if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
21
+ image = resize_image(image, spacing, spacing_target).astype(np.float32)
22
+
23
+ image -= image.mean()
24
+ image /= image.std()
25
+ else:
26
+ new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
27
+ int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
28
+ int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
29
+ image = resize_segmentation(image, new_shape, 1)
30
+ return image
31
+
32
+
33
+ def load_and_preprocess(mri_file):
34
+ images = {}
35
+ # t1
36
+ images["T1"] = sitk.ReadImage(mri_file)
37
+
38
+ properties_dict = {
39
+ "spacing": images["T1"].GetSpacing(),
40
+ "direction": images["T1"].GetDirection(),
41
+ "size": images["T1"].GetSize(),
42
+ "origin": images["T1"].GetOrigin()
43
+ }
44
+
45
+ for k in images.keys():
46
+ images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
47
+
48
+ properties_dict['size_before_cropping'] = images["T1"].shape
49
+
50
+ imgs = []
51
+ for seq in ['T1']:
52
+ imgs.append(images[seq][None])
53
+ all_data = np.vstack(imgs)
54
+ print("image shape after preprocessing: ", str(all_data[0].shape))
55
+ return all_data, properties_dict
56
+
57
+
58
+ def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
59
+ '''
60
+ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
61
+ of the original image
62
+
63
+ dct:
64
+ size_before_cropping
65
+ brain_bbox
66
+ size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
67
+ spacing
68
+ origin
69
+ direction
70
+
71
+ :param segmentation:
72
+ :param dct:
73
+ :param out_fname:
74
+ :return:
75
+ '''
76
+ old_size = dct.get('size_before_cropping')
77
+ bbox = dct.get('brain_bbox')
78
+ if bbox is not None:
79
+ seg_old_size = np.zeros(old_size)
80
+ for c in range(3):
81
+ bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
82
+ seg_old_size[bbox[0][0]:bbox[0][1],
83
+ bbox[1][0]:bbox[1][1],
84
+ bbox[2][0]:bbox[2][1]] = segmentation
85
+ else:
86
+ seg_old_size = segmentation
87
+ if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
88
+ seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
89
+ else:
90
+ seg_old_spacing = seg_old_size
91
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
92
+ seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
93
+ seg_resized_itk.SetOrigin(dct['origin'])
94
+ seg_resized_itk.SetDirection(dct['direction'])
95
+ sitk.WriteImage(seg_resized_itk, out_fname)
96
+
97
+
98
+ def resize_segmentation(segmentation, new_shape, order=3, cval=0):
99
+ '''
100
+ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
101
+
102
+ Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
103
+ hot encoding which is resized and transformed back to a segmentation map.
104
+ This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
105
+ :param segmentation:
106
+ :param new_shape:
107
+ :param order:
108
+ :return:
109
+ '''
110
+ tpe = segmentation.dtype
111
+ unique_labels = np.unique(segmentation)
112
+ assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
113
+ if order == 0:
114
+ return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
115
+ else:
116
+ reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
117
+
118
+ for i, c in enumerate(unique_labels):
119
+ reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
120
+ reshaped[reshaped_multihot >= 0.5] = c
121
+ return reshaped
src/IDH/HD_BET/HD_BET/hd_bet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
6
+ from HD_BET.run import run_hd_bet
7
+ from HD_BET.utils import maybe_mkdir_p, subfiles
8
+ import HD_BET
9
+
10
+ def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
11
+
12
+ if output_file_or_dir is None:
13
+ output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
14
+ os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
15
+
16
+
17
+ params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
18
+ config_file = os.path.join(HD_BET.__path__[0], "config.py")
19
+
20
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
21
+
22
+ if device == 'cpu':
23
+ pass
24
+ else:
25
+ device = int(device)
26
+
27
+ if os.path.isdir(input_file_or_dir):
28
+ maybe_mkdir_p(output_file_or_dir)
29
+ input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
30
+
31
+ if len(input_files) == 0:
32
+ raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
33
+
34
+ output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
35
+ input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
36
+ else:
37
+ if not output_file_or_dir.endswith('.nii.gz'):
38
+ output_file_or_dir += '.nii.gz'
39
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
40
+
41
+ output_files = [output_file_or_dir]
42
+ input_files = [input_file_or_dir]
43
+
44
+ if tta == 0:
45
+ tta = False
46
+ elif tta == 1:
47
+ tta = True
48
+ else:
49
+ raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
50
+
51
+ if overwrite_existing == 0:
52
+ overwrite_existing = False
53
+ elif overwrite_existing == 1:
54
+ overwrite_existing = True
55
+ else:
56
+ raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
57
+
58
+ if pp == 0:
59
+ pp = False
60
+ elif pp == 1:
61
+ pp = True
62
+ else:
63
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
64
+
65
+ if save_mask == 0:
66
+ save_mask = False
67
+ elif save_mask == 1:
68
+ save_mask = True
69
+ else:
70
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
71
+
72
+ run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ print("\n########################")
77
+ print("If you are using hd-bet, please cite the following paper:")
78
+ print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
79
+ "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
80
+ "neural networks. arXiv preprint arXiv:1901.11341, 2019.")
81
+ print("########################\n")
82
+
83
+ import argparse
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
86
+ 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
87
+ 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
88
+ 'within that folder will be brain extracted.', required=True, type=str)
89
+ parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
90
+ ' will be created', required=False, type=str)
91
+ parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
92
+ 'use only one set of parameters whereas accurate will '
93
+ 'use the five sets of parameters that resulted from '
94
+ 'our cross-validation as an ensemble. Default: '
95
+ 'accurate',
96
+ required=False)
97
+ parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
98
+ 'Must be either int or str. Use int for GPU id or '
99
+ '\'cpu\' to run on CPU. When using CPU you should '
100
+ 'consider disabling tta. Default for -device is: 0',
101
+ required=False)
102
+ parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
103
+ '(mirroring). 1= True, 0=False. Disable this '
104
+ 'if you are using CPU to speed things up! '
105
+ 'Default: 1')
106
+ parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
107
+ ' but the largest connected component in '
108
+ 'the prediction. Default: 1')
109
+ parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
110
+ 'mask will not be '
111
+ 'saved')
112
+ parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
113
+ "want to overwrite existing "
114
+ "predictions")
115
+
116
+ args = parser.parse_args()
117
+
118
+ hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
119
+
src/IDH/HD_BET/HD_BET/network_architecture.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from HD_BET.utils import softmax_helper
5
+
6
+
7
+ class EncodingModule(nn.Module):
8
+ def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
9
+ inst_norm_affine=True, lrelu_inplace=True):
10
+ nn.Module.__init__(self)
11
+ self.dropout_p = dropout_p
12
+ self.lrelu_inplace = lrelu_inplace
13
+ self.inst_norm_affine = inst_norm_affine
14
+ self.conv_bias = conv_bias
15
+ self.leakiness = leakiness
16
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
17
+ self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
18
+ self.dropout = nn.Dropout3d(dropout_p)
19
+ self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
20
+ self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
21
+
22
+ def forward(self, x):
23
+ skip = x
24
+ x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
25
+ x = self.conv1(x)
26
+ if self.dropout_p is not None and self.dropout_p > 0:
27
+ x = self.dropout(x)
28
+ x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
29
+ x = self.conv2(x)
30
+ x = x + skip
31
+ return x
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
36
+ super(Upsample, self).__init__()
37
+ self.align_corners = align_corners
38
+ self.mode = mode
39
+ self.scale_factor = scale_factor
40
+ self.size = size
41
+
42
+ def forward(self, x):
43
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
44
+ align_corners=self.align_corners)
45
+
46
+
47
+ class LocalizationModule(nn.Module):
48
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
49
+ lrelu_inplace=True):
50
+ nn.Module.__init__(self)
51
+ self.lrelu_inplace = lrelu_inplace
52
+ self.inst_norm_affine = inst_norm_affine
53
+ self.conv_bias = conv_bias
54
+ self.leakiness = leakiness
55
+ self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
56
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
57
+ self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
58
+ self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
59
+
60
+ def forward(self, x):
61
+ x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
62
+ x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
63
+ return x
64
+
65
+
66
+ class UpsamplingModule(nn.Module):
67
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
68
+ lrelu_inplace=True):
69
+ nn.Module.__init__(self)
70
+ self.lrelu_inplace = lrelu_inplace
71
+ self.inst_norm_affine = inst_norm_affine
72
+ self.conv_bias = conv_bias
73
+ self.leakiness = leakiness
74
+ self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
75
+ self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
76
+ self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
77
+
78
+ def forward(self, x):
79
+ x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
80
+ inplace=self.lrelu_inplace)
81
+ return x
82
+
83
+
84
+ class DownsamplingModule(nn.Module):
85
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
86
+ lrelu_inplace=True):
87
+ nn.Module.__init__(self)
88
+ self.lrelu_inplace = lrelu_inplace
89
+ self.inst_norm_affine = inst_norm_affine
90
+ self.conv_bias = conv_bias
91
+ self.leakiness = leakiness
92
+ self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
93
+ self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
94
+
95
+ def forward(self, x):
96
+ x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
97
+ b = self.downsample(x)
98
+ return x, b
99
+
100
+
101
+ class Network(nn.Module):
102
+ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
103
+ final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
104
+ lrelu_inplace=True, do_ds=True):
105
+ super(Network, self).__init__()
106
+
107
+ self.do_ds = do_ds
108
+ self.lrelu_inplace = lrelu_inplace
109
+ self.inst_norm_affine = inst_norm_affine
110
+ self.conv_bias = conv_bias
111
+ self.leakiness = leakiness
112
+ self.final_nonlin = final_nonlin
113
+ self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
114
+
115
+ self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
116
+ inst_norm_affine=True, lrelu_inplace=True)
117
+ self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
118
+ inst_norm_affine=True, lrelu_inplace=True)
119
+
120
+ self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
121
+ inst_norm_affine=True, lrelu_inplace=True)
122
+ self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
123
+ inst_norm_affine=True, lrelu_inplace=True)
124
+
125
+ self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
126
+ inst_norm_affine=True, lrelu_inplace=True)
127
+ self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
128
+ inst_norm_affine=True, lrelu_inplace=True)
129
+
130
+ self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
131
+ inst_norm_affine=True, lrelu_inplace=True)
132
+ self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
133
+ inst_norm_affine=True, lrelu_inplace=True)
134
+
135
+ self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
136
+ conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
137
+
138
+ self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
139
+ self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
140
+ inst_norm_affine=True, lrelu_inplace=True)
141
+
142
+ self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
143
+ inst_norm_affine=True, lrelu_inplace=True)
144
+ self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
145
+ inst_norm_affine=True, lrelu_inplace=True)
146
+
147
+ self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
148
+ inst_norm_affine=True, lrelu_inplace=True)
149
+ self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
150
+ self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
151
+ inst_norm_affine=True, lrelu_inplace=True)
152
+
153
+ self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
154
+ inst_norm_affine=True, lrelu_inplace=True)
155
+ self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
156
+ self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
157
+ inst_norm_affine=True, lrelu_inplace=True)
158
+
159
+ self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
160
+ self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
161
+ self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
162
+ self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
163
+ self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
164
+
165
+ def forward(self, x):
166
+ seg_outputs = []
167
+
168
+ x = self.init_conv(x)
169
+ x = self.context1(x)
170
+
171
+ skip1, x = self.down1(x)
172
+ x = self.context2(x)
173
+
174
+ skip2, x = self.down2(x)
175
+ x = self.context3(x)
176
+
177
+ skip3, x = self.down3(x)
178
+ x = self.context4(x)
179
+
180
+ skip4, x = self.down4(x)
181
+ x = self.context5(x)
182
+
183
+ x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
184
+ x = self.up1(x)
185
+
186
+ x = torch.cat((skip4, x), dim=1)
187
+ x = self.loc1(x)
188
+ x = self.up2(x)
189
+
190
+ x = torch.cat((skip3, x), dim=1)
191
+ x = self.loc2(x)
192
+ loc2_seg = self.final_nonlin(self.loc2_seg(x))
193
+ seg_outputs.append(loc2_seg)
194
+ x = self.up3(x)
195
+
196
+ x = torch.cat((skip2, x), dim=1)
197
+ x = self.loc3(x)
198
+ loc3_seg = self.final_nonlin(self.loc3_seg(x))
199
+ seg_outputs.append(loc3_seg)
200
+ x = self.up4(x)
201
+
202
+ x = torch.cat((skip1, x), dim=1)
203
+ x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
204
+ inplace=self.lrelu_inplace)
205
+ x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
206
+ inplace=self.lrelu_inplace)
207
+ x = self.final_nonlin(self.seg_layer(x))
208
+ seg_outputs.append(x)
209
+
210
+ if self.do_ds:
211
+ return seg_outputs[::-1]
212
+ else:
213
+ return seg_outputs[-1]
src/IDH/HD_BET/HD_BET/paths.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # please refer to the readme on where to get the parameters. Save them in this folder:
4
+ # Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
5
+ # Updated path for Docker container:
6
+ folder_with_parameter_files = "/app/IDH/hdbet_model"
src/IDH/HD_BET/HD_BET/predict_case.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
6
+ if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
7
+ shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
8
+ shp = patient.shape
9
+ new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
10
+ shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
11
+ shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
12
+ for i in range(len(shp)):
13
+ if shp[i] % shape_must_be_divisible_by[i] == 0:
14
+ new_shp[i] -= shape_must_be_divisible_by[i]
15
+ if min_size is not None:
16
+ new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
17
+ return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
18
+
19
+
20
+ def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
21
+ shape = tuple(list(image.shape))
22
+ new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
23
+ if pad_value is None:
24
+ if len(shape) == 2:
25
+ pad_value = image[0,0]
26
+ elif len(shape) == 3:
27
+ pad_value = image[0, 0, 0]
28
+ else:
29
+ raise ValueError("Image must be either 2 or 3 dimensional")
30
+ res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
31
+ if len(shape) == 2:
32
+ res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
33
+ elif len(shape) == 3:
34
+ res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
35
+ return res
36
+
37
+
38
+ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
39
+ new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
40
+ with torch.no_grad():
41
+ pad_res = []
42
+ for i in range(patient_data.shape[0]):
43
+ t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
44
+ pad_res.append(t[None])
45
+
46
+ patient_data = np.vstack(pad_res)
47
+
48
+ new_shp = patient_data.shape
49
+
50
+ data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
51
+
52
+ data[0] = patient_data
53
+
54
+ if BATCH_SIZE is not None:
55
+ data = np.vstack([data] * BATCH_SIZE)
56
+
57
+ a = torch.rand(data.shape).float()
58
+
59
+ if main_device == 'cpu':
60
+ pass
61
+ else:
62
+ a = a.cuda(main_device)
63
+
64
+ if do_mirroring:
65
+ x = 8
66
+ else:
67
+ x = 1
68
+ all_preds = []
69
+ for i in range(num_repeats):
70
+ for m in range(x):
71
+ data_for_net = np.array(data)
72
+ do_stuff = False
73
+ if m == 0:
74
+ do_stuff = True
75
+ pass
76
+ if m == 1 and (4 in mirror_axes):
77
+ do_stuff = True
78
+ data_for_net = data_for_net[:, :, :, :, ::-1]
79
+ if m == 2 and (3 in mirror_axes):
80
+ do_stuff = True
81
+ data_for_net = data_for_net[:, :, :, ::-1, :]
82
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
83
+ do_stuff = True
84
+ data_for_net = data_for_net[:, :, :, ::-1, ::-1]
85
+ if m == 4 and (2 in mirror_axes):
86
+ do_stuff = True
87
+ data_for_net = data_for_net[:, :, ::-1, :, :]
88
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
89
+ do_stuff = True
90
+ data_for_net = data_for_net[:, :, ::-1, :, ::-1]
91
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
92
+ do_stuff = True
93
+ data_for_net = data_for_net[:, :, ::-1, ::-1, :]
94
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
95
+ do_stuff = True
96
+ data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
97
+
98
+ if do_stuff:
99
+ _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
100
+ p = net(a) # np.copy is necessary because ::-1 creates just a view i think
101
+ p = p.data.cpu().numpy()
102
+
103
+ if m == 0:
104
+ pass
105
+ if m == 1 and (4 in mirror_axes):
106
+ p = p[:, :, :, :, ::-1]
107
+ if m == 2 and (3 in mirror_axes):
108
+ p = p[:, :, :, ::-1, :]
109
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
110
+ p = p[:, :, :, ::-1, ::-1]
111
+ if m == 4 and (2 in mirror_axes):
112
+ p = p[:, :, ::-1, :, :]
113
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
114
+ p = p[:, :, ::-1, :, ::-1]
115
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
116
+ p = p[:, :, ::-1, ::-1, :]
117
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
118
+ p = p[:, :, ::-1, ::-1, ::-1]
119
+ all_preds.append(p)
120
+
121
+ stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
122
+ predicted_segmentation = stacked.mean(0).argmax(0)
123
+ uncertainty = stacked.var(0)
124
+ bayesian_predictions = stacked
125
+ softmax_pred = stacked.mean(0)
126
+ return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
src/IDH/HD_BET/HD_BET/run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
5
+ from HD_BET.predict_case import predict_case_3D_net
6
+ import imp
7
+ from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
8
+ import os
9
+ import HD_BET
10
+
11
+
12
+ def apply_bet(img, bet, out_fname):
13
+ img_itk = sitk.ReadImage(img)
14
+ img_npy = sitk.GetArrayFromImage(img_itk)
15
+ img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
16
+ img_npy[img_bet == 0] = 0
17
+ out = sitk.GetImageFromArray(img_npy)
18
+ out.CopyInformation(img_itk)
19
+ sitk.WriteImage(out, out_fname)
20
+
21
+
22
+ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
23
+ postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
24
+ """
25
+
26
+ :param mri_fnames: str or list/tuple of str
27
+ :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
28
+ :param mode: fast or accurate
29
+ :param config_file: config.py
30
+ :param device: either int (for device id) or 'cpu'
31
+ :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
32
+ but the largest predicted connected component. Default False
33
+ :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
34
+ CPU you may want to turn that off to speed things up
35
+ :return:
36
+ """
37
+
38
+ list_of_param_files = []
39
+
40
+ if mode == 'fast':
41
+ params_file = get_params_fname(0)
42
+ maybe_download_parameters(0)
43
+
44
+ list_of_param_files.append(params_file)
45
+ elif mode == 'accurate':
46
+ for i in range(5):
47
+ params_file = get_params_fname(i)
48
+ maybe_download_parameters(i)
49
+
50
+ list_of_param_files.append(params_file)
51
+ else:
52
+ raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
53
+
54
+ assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
55
+
56
+ cf = imp.load_source('cf', config_file)
57
+ cf = cf.config()
58
+
59
+ net, _ = cf.get_network(cf.val_use_train_mode, None)
60
+ if device == "cpu":
61
+ net = net.cpu()
62
+ else:
63
+ net.cuda(device)
64
+
65
+ if not isinstance(mri_fnames, (list, tuple)):
66
+ mri_fnames = [mri_fnames]
67
+
68
+ if not isinstance(output_fnames, (list, tuple)):
69
+ output_fnames = [output_fnames]
70
+
71
+ assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
72
+
73
+ params = []
74
+ for p in list_of_param_files:
75
+ params.append(torch.load(p, map_location=lambda storage, loc: storage))
76
+
77
+ for in_fname, out_fname in zip(mri_fnames, output_fnames):
78
+ mask_fname = out_fname[:-7] + "_mask.nii.gz"
79
+ if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
80
+ print("File:", in_fname)
81
+ print("preprocessing...")
82
+ try:
83
+ data, data_dict = load_and_preprocess(in_fname)
84
+ except RuntimeError:
85
+ print("\nERROR\nCould not read file", in_fname, "\n")
86
+ continue
87
+ except AssertionError as e:
88
+ print(e)
89
+ continue
90
+
91
+ softmax_preds = []
92
+
93
+ print("prediction (CNN id)...")
94
+ for i, p in enumerate(params):
95
+ print(i)
96
+ net.load_state_dict(p)
97
+ net.eval()
98
+ net.apply(SetNetworkToVal(False, False))
99
+ _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
100
+ cf.val_batch_size, cf.net_input_must_be_divisible_by,
101
+ cf.val_min_size, device, cf.da_mirror_axes)
102
+ softmax_preds.append(softmax_pred[None])
103
+
104
+ seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
105
+
106
+ if postprocess:
107
+ seg = postprocess_prediction(seg)
108
+
109
+ print("exporting segmentation...")
110
+ save_segmentation_nifti(seg, data_dict, mask_fname)
111
+
112
+ apply_bet(in_fname, mask_fname, out_fname)
113
+
114
+ if not keep_mask:
115
+ os.remove(mask_fname)
116
+
117
+
src/IDH/HD_BET/HD_BET/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from skimage.morphology import label
6
+ import os
7
+ from HD_BET.paths import folder_with_parameter_files
8
+
9
+
10
+ def get_params_fname(fold):
11
+ return os.path.join(folder_with_parameter_files, "%d.model" % fold)
12
+
13
+
14
+ def maybe_download_parameters(fold=0, force_overwrite=False):
15
+ """
16
+ Downloads the parameters for some fold if it is not present yet.
17
+ :param fold:
18
+ :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
19
+ :return:
20
+ """
21
+
22
+ assert 0 <= fold <= 4, "fold must be between 0 and 4"
23
+
24
+ if not os.path.isdir(folder_with_parameter_files):
25
+ maybe_mkdir_p(folder_with_parameter_files)
26
+
27
+ out_filename = get_params_fname(fold)
28
+
29
+ if force_overwrite and os.path.isfile(out_filename):
30
+ os.remove(out_filename)
31
+
32
+ if not os.path.isfile(out_filename):
33
+ url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
34
+ print("Downloading", url, "...")
35
+ data = urlopen(url).read()
36
+ #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
37
+ with open(out_filename, 'wb') as f:
38
+ f.write(data)
39
+
40
+
41
+ def init_weights(module):
42
+ if isinstance(module, nn.Conv3d):
43
+ module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
44
+ if module.bias is not None:
45
+ module.bias = nn.init.constant(module.bias, 0)
46
+
47
+
48
+ def softmax_helper(x):
49
+ rpt = [1 for _ in range(len(x.size()))]
50
+ rpt[1] = x.size(1)
51
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
52
+ e_x = torch.exp(x - x_max)
53
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
54
+
55
+
56
+ class SetNetworkToVal(object):
57
+ def __init__(self, use_dropout_sampling=False, norm_use_average=True):
58
+ self.norm_use_average = norm_use_average
59
+ self.use_dropout_sampling = use_dropout_sampling
60
+
61
+ def __call__(self, module):
62
+ if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
63
+ module.train(self.use_dropout_sampling)
64
+ elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
65
+ isinstance(module, nn.InstanceNorm1d) \
66
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
67
+ isinstance(module, nn.BatchNorm1d):
68
+ module.train(not self.norm_use_average)
69
+
70
+
71
+ def postprocess_prediction(seg):
72
+ # basically look for connected components and choose the largest one, delete everything else
73
+ print("running postprocessing... ")
74
+ mask = seg != 0
75
+ lbls = label(mask, connectivity=mask.ndim)
76
+ lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
77
+ largest_region = np.argmax(lbls_sizes[1:]) + 1
78
+ seg[lbls != largest_region] = 0
79
+ return seg
80
+
81
+
82
+ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
83
+ if join:
84
+ l = os.path.join
85
+ else:
86
+ l = lambda x, y: y
87
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
88
+ and (prefix is None or i.startswith(prefix))
89
+ and (suffix is None or i.endswith(suffix))]
90
+ if sort:
91
+ res.sort()
92
+ return res
93
+
94
+
95
+ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
96
+ if join:
97
+ l = os.path.join
98
+ else:
99
+ l = lambda x, y: y
100
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
101
+ and (prefix is None or i.startswith(prefix))
102
+ and (suffix is None or i.endswith(suffix))]
103
+ if sort:
104
+ res.sort()
105
+ return res
106
+
107
+
108
+ subfolders = subdirs # I am tired of confusing those
109
+
110
+
111
+ def maybe_mkdir_p(directory):
112
+ splits = directory.split("/")[1:]
113
+ for i in range(0, len(splits)):
114
+ if not os.path.isdir(os.path.join("", *splits[:i+1])):
115
+ os.mkdir(os.path.join("", *splits[:i+1]))
src/IDH/HD_BET/config.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from HD_BET.utils import SetNetworkToVal, softmax_helper
4
+ from abc import abstractmethod
5
+ from HD_BET.network_architecture import Network
6
+
7
+
8
+ class BaseConfig(object):
9
+ def __init__(self):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def get_split(self, fold, random_state=12345):
14
+ pass
15
+
16
+ @abstractmethod
17
+ def get_network(self, mode="train"):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def get_basic_generators(self, fold):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def get_data_generators(self, fold):
26
+ pass
27
+
28
+ def preprocess(self, data):
29
+ return data
30
+
31
+ def __repr__(self):
32
+ res = ""
33
+ for v in vars(self):
34
+ if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
35
+ res += (v + ": " + str(self.__getattribute__(v)) + "\n")
36
+ return res
37
+
38
+
39
+ class HD_BET_Config(BaseConfig):
40
+ def __init__(self):
41
+ super(HD_BET_Config, self).__init__()
42
+
43
+ self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
44
+
45
+ # network parameters
46
+ self.net_base_num_layers = 21
47
+ self.BATCH_SIZE = 2
48
+ self.net_do_DS = True
49
+ self.net_dropout_p = 0.0
50
+ self.net_use_inst_norm = True
51
+ self.net_conv_use_bias = True
52
+ self.net_norm_use_affine = True
53
+ self.net_leaky_relu_slope = 1e-1
54
+
55
+ # hyperparameters
56
+ self.INPUT_PATCH_SIZE = (128, 128, 128)
57
+ self.num_classes = 2
58
+ self.selected_data_channels = range(1)
59
+
60
+ # data augmentation
61
+ self.da_mirror_axes = (2, 3, 4)
62
+
63
+ # validation
64
+ self.val_use_DO = False
65
+ self.val_use_train_mode = False # for dropout sampling
66
+ self.val_num_repeats = 1 # only useful if dropout sampling
67
+ self.val_batch_size = 1 # only useful if dropout sampling
68
+ self.val_save_npz = True
69
+ self.val_do_mirroring = True # test time data augmentation via mirroring
70
+ self.val_write_images = True
71
+ self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
72
+ self.val_min_size = self.INPUT_PATCH_SIZE
73
+ self.val_fn = None
74
+
75
+ # CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
76
+ # stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
77
+ # to false in 0.4)
78
+ self.val_use_moving_averages = False
79
+
80
+ def get_network(self, train=True, pretrained_weights=None):
81
+ net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
82
+ self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
83
+ self.net_norm_use_affine, True, self.net_do_DS)
84
+
85
+ if pretrained_weights is not None:
86
+ net.load_state_dict(
87
+ torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
88
+
89
+ if train:
90
+ net.train(True)
91
+ else:
92
+ net.train(False)
93
+ net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
94
+ net.do_ds = False
95
+
96
+ optimizer = None
97
+ self.lr_scheduler = None
98
+ return net, optimizer
99
+
100
+ def get_data_generators(self, fold):
101
+ pass
102
+
103
+ def get_split(self, fold, random_state=12345):
104
+ pass
105
+
106
+ def get_basic_generators(self, fold):
107
+ pass
108
+
109
+ def on_epoch_end(self, epoch):
110
+ pass
111
+
112
+ def preprocess(self, data):
113
+ data = np.copy(data)
114
+ for c in range(data.shape[0]):
115
+ data[c] -= data[c].mean()
116
+ data[c] /= data[c].std()
117
+ return data
118
+
119
+
120
+ config = HD_BET_Config
121
+
src/IDH/HD_BET/data_loading.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import SimpleITK as sitk
2
+ import numpy as np
3
+ from skimage.transform import resize
4
+
5
+
6
+ def resize_image(image, old_spacing, new_spacing, order=3):
7
+ new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
8
+ int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
9
+ int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
10
+ return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
11
+
12
+
13
+ def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
14
+ spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
15
+ image = sitk.GetArrayFromImage(itk_image).astype(float)
16
+
17
+ assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
18
+
19
+ if not is_seg:
20
+ if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
21
+ image = resize_image(image, spacing, spacing_target).astype(np.float32)
22
+
23
+ image -= image.mean()
24
+ image /= image.std()
25
+ else:
26
+ new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
27
+ int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
28
+ int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
29
+ image = resize_segmentation(image, new_shape, 1)
30
+ return image
31
+
32
+
33
+ def load_and_preprocess(mri_file):
34
+ images = {}
35
+ # t1
36
+ images["T1"] = sitk.ReadImage(mri_file)
37
+
38
+ properties_dict = {
39
+ "spacing": images["T1"].GetSpacing(),
40
+ "direction": images["T1"].GetDirection(),
41
+ "size": images["T1"].GetSize(),
42
+ "origin": images["T1"].GetOrigin()
43
+ }
44
+
45
+ for k in images.keys():
46
+ images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
47
+
48
+ properties_dict['size_before_cropping'] = images["T1"].shape
49
+
50
+ imgs = []
51
+ for seq in ['T1']:
52
+ imgs.append(images[seq][None])
53
+ all_data = np.vstack(imgs)
54
+ print("image shape after preprocessing: ", str(all_data[0].shape))
55
+ return all_data, properties_dict
56
+
57
+
58
+ def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
59
+ '''
60
+ segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
61
+ of the original image
62
+
63
+ dct:
64
+ size_before_cropping
65
+ brain_bbox
66
+ size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
67
+ spacing
68
+ origin
69
+ direction
70
+
71
+ :param segmentation:
72
+ :param dct:
73
+ :param out_fname:
74
+ :return:
75
+ '''
76
+ old_size = dct.get('size_before_cropping')
77
+ bbox = dct.get('brain_bbox')
78
+ if bbox is not None:
79
+ seg_old_size = np.zeros(old_size)
80
+ for c in range(3):
81
+ bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
82
+ seg_old_size[bbox[0][0]:bbox[0][1],
83
+ bbox[1][0]:bbox[1][1],
84
+ bbox[2][0]:bbox[2][1]] = segmentation
85
+ else:
86
+ seg_old_size = segmentation
87
+ if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
88
+ seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
89
+ else:
90
+ seg_old_spacing = seg_old_size
91
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
92
+ seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
93
+ seg_resized_itk.SetOrigin(dct['origin'])
94
+ seg_resized_itk.SetDirection(dct['direction'])
95
+ sitk.WriteImage(seg_resized_itk, out_fname)
96
+
97
+
98
+ def resize_segmentation(segmentation, new_shape, order=3, cval=0):
99
+ '''
100
+ Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
101
+
102
+ Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
103
+ hot encoding which is resized and transformed back to a segmentation map.
104
+ This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
105
+ :param segmentation:
106
+ :param new_shape:
107
+ :param order:
108
+ :return:
109
+ '''
110
+ tpe = segmentation.dtype
111
+ unique_labels = np.unique(segmentation)
112
+ assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
113
+ if order == 0:
114
+ return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
115
+ else:
116
+ reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
117
+
118
+ for i, c in enumerate(unique_labels):
119
+ reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
120
+ reshaped[reshaped_multihot >= 0.5] = c
121
+ return reshaped
src/IDH/HD_BET/hd_bet.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import sys
5
+ sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
6
+ from HD_BET.run import run_hd_bet
7
+ from HD_BET.utils import maybe_mkdir_p, subfiles
8
+ import HD_BET
9
+
10
+ def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
11
+
12
+ if output_file_or_dir is None:
13
+ output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
14
+ os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
15
+
16
+
17
+ params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
18
+ config_file = os.path.join(HD_BET.__path__[0], "config.py")
19
+
20
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
21
+
22
+ if device == 'cpu':
23
+ pass
24
+ else:
25
+ device = int(device)
26
+
27
+ if os.path.isdir(input_file_or_dir):
28
+ maybe_mkdir_p(output_file_or_dir)
29
+ input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
30
+
31
+ if len(input_files) == 0:
32
+ raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
33
+
34
+ output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
35
+ input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
36
+ else:
37
+ if not output_file_or_dir.endswith('.nii.gz'):
38
+ output_file_or_dir += '.nii.gz'
39
+ assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
40
+
41
+ output_files = [output_file_or_dir]
42
+ input_files = [input_file_or_dir]
43
+
44
+ if tta == 0:
45
+ tta = False
46
+ elif tta == 1:
47
+ tta = True
48
+ else:
49
+ raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
50
+
51
+ if overwrite_existing == 0:
52
+ overwrite_existing = False
53
+ elif overwrite_existing == 1:
54
+ overwrite_existing = True
55
+ else:
56
+ raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
57
+
58
+ if pp == 0:
59
+ pp = False
60
+ elif pp == 1:
61
+ pp = True
62
+ else:
63
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
64
+
65
+ if save_mask == 0:
66
+ save_mask = False
67
+ elif save_mask == 1:
68
+ save_mask = True
69
+ else:
70
+ raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
71
+
72
+ run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ print("\n########################")
77
+ print("If you are using hd-bet, please cite the following paper:")
78
+ print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
79
+ "Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
80
+ "neural networks. arXiv preprint arXiv:1901.11341, 2019.")
81
+ print("########################\n")
82
+
83
+ import argparse
84
+ parser = argparse.ArgumentParser()
85
+ parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
86
+ 'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
87
+ 'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
88
+ 'within that folder will be brain extracted.', required=True, type=str)
89
+ parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
90
+ ' will be created', required=False, type=str)
91
+ parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
92
+ 'use only one set of parameters whereas accurate will '
93
+ 'use the five sets of parameters that resulted from '
94
+ 'our cross-validation as an ensemble. Default: '
95
+ 'accurate',
96
+ required=False)
97
+ parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
98
+ 'Must be either int or str. Use int for GPU id or '
99
+ '\'cpu\' to run on CPU. When using CPU you should '
100
+ 'consider disabling tta. Default for -device is: 0',
101
+ required=False)
102
+ parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
103
+ '(mirroring). 1= True, 0=False. Disable this '
104
+ 'if you are using CPU to speed things up! '
105
+ 'Default: 1')
106
+ parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
107
+ ' but the largest connected component in '
108
+ 'the prediction. Default: 1')
109
+ parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
110
+ 'mask will not be '
111
+ 'saved')
112
+ parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
113
+ "want to overwrite existing "
114
+ "predictions")
115
+
116
+ args = parser.parse_args()
117
+
118
+ hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
119
+
src/IDH/HD_BET/network_architecture.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from HD_BET.utils import softmax_helper
5
+
6
+
7
+ class EncodingModule(nn.Module):
8
+ def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
9
+ inst_norm_affine=True, lrelu_inplace=True):
10
+ nn.Module.__init__(self)
11
+ self.dropout_p = dropout_p
12
+ self.lrelu_inplace = lrelu_inplace
13
+ self.inst_norm_affine = inst_norm_affine
14
+ self.conv_bias = conv_bias
15
+ self.leakiness = leakiness
16
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
17
+ self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
18
+ self.dropout = nn.Dropout3d(dropout_p)
19
+ self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
20
+ self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
21
+
22
+ def forward(self, x):
23
+ skip = x
24
+ x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
25
+ x = self.conv1(x)
26
+ if self.dropout_p is not None and self.dropout_p > 0:
27
+ x = self.dropout(x)
28
+ x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
29
+ x = self.conv2(x)
30
+ x = x + skip
31
+ return x
32
+
33
+
34
+ class Upsample(nn.Module):
35
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
36
+ super(Upsample, self).__init__()
37
+ self.align_corners = align_corners
38
+ self.mode = mode
39
+ self.scale_factor = scale_factor
40
+ self.size = size
41
+
42
+ def forward(self, x):
43
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
44
+ align_corners=self.align_corners)
45
+
46
+
47
+ class LocalizationModule(nn.Module):
48
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
49
+ lrelu_inplace=True):
50
+ nn.Module.__init__(self)
51
+ self.lrelu_inplace = lrelu_inplace
52
+ self.inst_norm_affine = inst_norm_affine
53
+ self.conv_bias = conv_bias
54
+ self.leakiness = leakiness
55
+ self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
56
+ self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
57
+ self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
58
+ self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
59
+
60
+ def forward(self, x):
61
+ x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
62
+ x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
63
+ return x
64
+
65
+
66
+ class UpsamplingModule(nn.Module):
67
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
68
+ lrelu_inplace=True):
69
+ nn.Module.__init__(self)
70
+ self.lrelu_inplace = lrelu_inplace
71
+ self.inst_norm_affine = inst_norm_affine
72
+ self.conv_bias = conv_bias
73
+ self.leakiness = leakiness
74
+ self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
75
+ self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
76
+ self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
77
+
78
+ def forward(self, x):
79
+ x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
80
+ inplace=self.lrelu_inplace)
81
+ return x
82
+
83
+
84
+ class DownsamplingModule(nn.Module):
85
+ def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
86
+ lrelu_inplace=True):
87
+ nn.Module.__init__(self)
88
+ self.lrelu_inplace = lrelu_inplace
89
+ self.inst_norm_affine = inst_norm_affine
90
+ self.conv_bias = conv_bias
91
+ self.leakiness = leakiness
92
+ self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
93
+ self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
94
+
95
+ def forward(self, x):
96
+ x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
97
+ b = self.downsample(x)
98
+ return x, b
99
+
100
+
101
+ class Network(nn.Module):
102
+ def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
103
+ final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
104
+ lrelu_inplace=True, do_ds=True):
105
+ super(Network, self).__init__()
106
+
107
+ self.do_ds = do_ds
108
+ self.lrelu_inplace = lrelu_inplace
109
+ self.inst_norm_affine = inst_norm_affine
110
+ self.conv_bias = conv_bias
111
+ self.leakiness = leakiness
112
+ self.final_nonlin = final_nonlin
113
+ self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
114
+
115
+ self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
116
+ inst_norm_affine=True, lrelu_inplace=True)
117
+ self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
118
+ inst_norm_affine=True, lrelu_inplace=True)
119
+
120
+ self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
121
+ inst_norm_affine=True, lrelu_inplace=True)
122
+ self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
123
+ inst_norm_affine=True, lrelu_inplace=True)
124
+
125
+ self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
126
+ inst_norm_affine=True, lrelu_inplace=True)
127
+ self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
128
+ inst_norm_affine=True, lrelu_inplace=True)
129
+
130
+ self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
131
+ inst_norm_affine=True, lrelu_inplace=True)
132
+ self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
133
+ inst_norm_affine=True, lrelu_inplace=True)
134
+
135
+ self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
136
+ conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
137
+
138
+ self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
139
+ self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
140
+ inst_norm_affine=True, lrelu_inplace=True)
141
+
142
+ self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
143
+ inst_norm_affine=True, lrelu_inplace=True)
144
+ self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
145
+ inst_norm_affine=True, lrelu_inplace=True)
146
+
147
+ self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
148
+ inst_norm_affine=True, lrelu_inplace=True)
149
+ self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
150
+ self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
151
+ inst_norm_affine=True, lrelu_inplace=True)
152
+
153
+ self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
154
+ inst_norm_affine=True, lrelu_inplace=True)
155
+ self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
156
+ self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
157
+ inst_norm_affine=True, lrelu_inplace=True)
158
+
159
+ self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
160
+ self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
161
+ self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
162
+ self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
163
+ self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
164
+
165
+ def forward(self, x):
166
+ seg_outputs = []
167
+
168
+ x = self.init_conv(x)
169
+ x = self.context1(x)
170
+
171
+ skip1, x = self.down1(x)
172
+ x = self.context2(x)
173
+
174
+ skip2, x = self.down2(x)
175
+ x = self.context3(x)
176
+
177
+ skip3, x = self.down3(x)
178
+ x = self.context4(x)
179
+
180
+ skip4, x = self.down4(x)
181
+ x = self.context5(x)
182
+
183
+ x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
184
+ x = self.up1(x)
185
+
186
+ x = torch.cat((skip4, x), dim=1)
187
+ x = self.loc1(x)
188
+ x = self.up2(x)
189
+
190
+ x = torch.cat((skip3, x), dim=1)
191
+ x = self.loc2(x)
192
+ loc2_seg = self.final_nonlin(self.loc2_seg(x))
193
+ seg_outputs.append(loc2_seg)
194
+ x = self.up3(x)
195
+
196
+ x = torch.cat((skip2, x), dim=1)
197
+ x = self.loc3(x)
198
+ loc3_seg = self.final_nonlin(self.loc3_seg(x))
199
+ seg_outputs.append(loc3_seg)
200
+ x = self.up4(x)
201
+
202
+ x = torch.cat((skip1, x), dim=1)
203
+ x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
204
+ inplace=self.lrelu_inplace)
205
+ x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
206
+ inplace=self.lrelu_inplace)
207
+ x = self.final_nonlin(self.seg_layer(x))
208
+ seg_outputs.append(x)
209
+
210
+ if self.do_ds:
211
+ return seg_outputs[::-1]
212
+ else:
213
+ return seg_outputs[-1]
src/IDH/HD_BET/paths.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ # please refer to the readme on where to get the parameters. Save them in this folder:
4
+ # Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
5
+ # Updated path for Docker container:
6
+ folder_with_parameter_files = "/app/IDH/hdbet_model"
src/IDH/HD_BET/predict_case.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
6
+ if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
7
+ shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
8
+ shp = patient.shape
9
+ new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
10
+ shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
11
+ shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
12
+ for i in range(len(shp)):
13
+ if shp[i] % shape_must_be_divisible_by[i] == 0:
14
+ new_shp[i] -= shape_must_be_divisible_by[i]
15
+ if min_size is not None:
16
+ new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
17
+ return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
18
+
19
+
20
+ def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
21
+ shape = tuple(list(image.shape))
22
+ new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
23
+ if pad_value is None:
24
+ if len(shape) == 2:
25
+ pad_value = image[0,0]
26
+ elif len(shape) == 3:
27
+ pad_value = image[0, 0, 0]
28
+ else:
29
+ raise ValueError("Image must be either 2 or 3 dimensional")
30
+ res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
31
+ if len(shape) == 2:
32
+ res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
33
+ elif len(shape) == 3:
34
+ res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
35
+ return res
36
+
37
+
38
+ def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
39
+ new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
40
+ with torch.no_grad():
41
+ pad_res = []
42
+ for i in range(patient_data.shape[0]):
43
+ t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
44
+ pad_res.append(t[None])
45
+
46
+ patient_data = np.vstack(pad_res)
47
+
48
+ new_shp = patient_data.shape
49
+
50
+ data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
51
+
52
+ data[0] = patient_data
53
+
54
+ if BATCH_SIZE is not None:
55
+ data = np.vstack([data] * BATCH_SIZE)
56
+
57
+ a = torch.rand(data.shape).float()
58
+
59
+ if main_device == 'cpu':
60
+ pass
61
+ else:
62
+ a = a.cuda(main_device)
63
+
64
+ if do_mirroring:
65
+ x = 8
66
+ else:
67
+ x = 1
68
+ all_preds = []
69
+ for i in range(num_repeats):
70
+ for m in range(x):
71
+ data_for_net = np.array(data)
72
+ do_stuff = False
73
+ if m == 0:
74
+ do_stuff = True
75
+ pass
76
+ if m == 1 and (4 in mirror_axes):
77
+ do_stuff = True
78
+ data_for_net = data_for_net[:, :, :, :, ::-1]
79
+ if m == 2 and (3 in mirror_axes):
80
+ do_stuff = True
81
+ data_for_net = data_for_net[:, :, :, ::-1, :]
82
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
83
+ do_stuff = True
84
+ data_for_net = data_for_net[:, :, :, ::-1, ::-1]
85
+ if m == 4 and (2 in mirror_axes):
86
+ do_stuff = True
87
+ data_for_net = data_for_net[:, :, ::-1, :, :]
88
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
89
+ do_stuff = True
90
+ data_for_net = data_for_net[:, :, ::-1, :, ::-1]
91
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
92
+ do_stuff = True
93
+ data_for_net = data_for_net[:, :, ::-1, ::-1, :]
94
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
95
+ do_stuff = True
96
+ data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
97
+
98
+ if do_stuff:
99
+ _ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
100
+ p = net(a) # np.copy is necessary because ::-1 creates just a view i think
101
+ p = p.data.cpu().numpy()
102
+
103
+ if m == 0:
104
+ pass
105
+ if m == 1 and (4 in mirror_axes):
106
+ p = p[:, :, :, :, ::-1]
107
+ if m == 2 and (3 in mirror_axes):
108
+ p = p[:, :, :, ::-1, :]
109
+ if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
110
+ p = p[:, :, :, ::-1, ::-1]
111
+ if m == 4 and (2 in mirror_axes):
112
+ p = p[:, :, ::-1, :, :]
113
+ if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
114
+ p = p[:, :, ::-1, :, ::-1]
115
+ if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
116
+ p = p[:, :, ::-1, ::-1, :]
117
+ if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
118
+ p = p[:, :, ::-1, ::-1, ::-1]
119
+ all_preds.append(p)
120
+
121
+ stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
122
+ predicted_segmentation = stacked.mean(0).argmax(0)
123
+ uncertainty = stacked.var(0)
124
+ bayesian_predictions = stacked
125
+ softmax_pred = stacked.mean(0)
126
+ return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
src/IDH/HD_BET/run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
5
+ from HD_BET.predict_case import predict_case_3D_net
6
+ import imp
7
+ from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
8
+ import os
9
+ import HD_BET
10
+
11
+
12
+ def apply_bet(img, bet, out_fname):
13
+ img_itk = sitk.ReadImage(img)
14
+ img_npy = sitk.GetArrayFromImage(img_itk)
15
+ img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
16
+ img_npy[img_bet == 0] = 0
17
+ out = sitk.GetImageFromArray(img_npy)
18
+ out.CopyInformation(img_itk)
19
+ sitk.WriteImage(out, out_fname)
20
+
21
+
22
+ def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
23
+ postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
24
+ """
25
+
26
+ :param mri_fnames: str or list/tuple of str
27
+ :param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
28
+ :param mode: fast or accurate
29
+ :param config_file: config.py
30
+ :param device: either int (for device id) or 'cpu'
31
+ :param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
32
+ but the largest predicted connected component. Default False
33
+ :param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
34
+ CPU you may want to turn that off to speed things up
35
+ :return:
36
+ """
37
+
38
+ list_of_param_files = []
39
+
40
+ if mode == 'fast':
41
+ params_file = get_params_fname(0)
42
+ maybe_download_parameters(0)
43
+
44
+ list_of_param_files.append(params_file)
45
+ elif mode == 'accurate':
46
+ for i in range(5):
47
+ params_file = get_params_fname(i)
48
+ maybe_download_parameters(i)
49
+
50
+ list_of_param_files.append(params_file)
51
+ else:
52
+ raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
53
+
54
+ assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
55
+
56
+ cf = imp.load_source('cf', config_file)
57
+ cf = cf.config()
58
+
59
+ net, _ = cf.get_network(cf.val_use_train_mode, None)
60
+ if device == "cpu":
61
+ net = net.cpu()
62
+ else:
63
+ net.cuda(device)
64
+
65
+ if not isinstance(mri_fnames, (list, tuple)):
66
+ mri_fnames = [mri_fnames]
67
+
68
+ if not isinstance(output_fnames, (list, tuple)):
69
+ output_fnames = [output_fnames]
70
+
71
+ assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
72
+
73
+ params = []
74
+ for p in list_of_param_files:
75
+ params.append(torch.load(p, map_location=lambda storage, loc: storage))
76
+
77
+ for in_fname, out_fname in zip(mri_fnames, output_fnames):
78
+ mask_fname = out_fname[:-7] + "_mask.nii.gz"
79
+ if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
80
+ print("File:", in_fname)
81
+ print("preprocessing...")
82
+ try:
83
+ data, data_dict = load_and_preprocess(in_fname)
84
+ except RuntimeError:
85
+ print("\nERROR\nCould not read file", in_fname, "\n")
86
+ continue
87
+ except AssertionError as e:
88
+ print(e)
89
+ continue
90
+
91
+ softmax_preds = []
92
+
93
+ print("prediction (CNN id)...")
94
+ for i, p in enumerate(params):
95
+ print(i)
96
+ net.load_state_dict(p)
97
+ net.eval()
98
+ net.apply(SetNetworkToVal(False, False))
99
+ _, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
100
+ cf.val_batch_size, cf.net_input_must_be_divisible_by,
101
+ cf.val_min_size, device, cf.da_mirror_axes)
102
+ softmax_preds.append(softmax_pred[None])
103
+
104
+ seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
105
+
106
+ if postprocess:
107
+ seg = postprocess_prediction(seg)
108
+
109
+ print("exporting segmentation...")
110
+ save_segmentation_nifti(seg, data_dict, mask_fname)
111
+
112
+ apply_bet(in_fname, mask_fname, out_fname)
113
+
114
+ if not keep_mask:
115
+ os.remove(mask_fname)
116
+
117
+
src/IDH/HD_BET/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.request import urlopen
2
+ import torch
3
+ from torch import nn
4
+ import numpy as np
5
+ from skimage.morphology import label
6
+ import os
7
+ from HD_BET.paths import folder_with_parameter_files
8
+
9
+
10
+ def get_params_fname(fold):
11
+ return os.path.join(folder_with_parameter_files, "%d.model" % fold)
12
+
13
+
14
+ def maybe_download_parameters(fold=0, force_overwrite=False):
15
+ """
16
+ Downloads the parameters for some fold if it is not present yet.
17
+ :param fold:
18
+ :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
19
+ :return:
20
+ """
21
+
22
+ assert 0 <= fold <= 4, "fold must be between 0 and 4"
23
+
24
+ if not os.path.isdir(folder_with_parameter_files):
25
+ maybe_mkdir_p(folder_with_parameter_files)
26
+
27
+ out_filename = get_params_fname(fold)
28
+
29
+ if force_overwrite and os.path.isfile(out_filename):
30
+ os.remove(out_filename)
31
+
32
+ if not os.path.isfile(out_filename):
33
+ url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
34
+ print("Downloading", url, "...")
35
+ data = urlopen(url).read()
36
+ #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
37
+ with open(out_filename, 'wb') as f:
38
+ f.write(data)
39
+
40
+
41
+ def init_weights(module):
42
+ if isinstance(module, nn.Conv3d):
43
+ module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
44
+ if module.bias is not None:
45
+ module.bias = nn.init.constant(module.bias, 0)
46
+
47
+
48
+ def softmax_helper(x):
49
+ rpt = [1 for _ in range(len(x.size()))]
50
+ rpt[1] = x.size(1)
51
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
52
+ e_x = torch.exp(x - x_max)
53
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
54
+
55
+
56
+ class SetNetworkToVal(object):
57
+ def __init__(self, use_dropout_sampling=False, norm_use_average=True):
58
+ self.norm_use_average = norm_use_average
59
+ self.use_dropout_sampling = use_dropout_sampling
60
+
61
+ def __call__(self, module):
62
+ if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
63
+ module.train(self.use_dropout_sampling)
64
+ elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
65
+ isinstance(module, nn.InstanceNorm1d) \
66
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
67
+ isinstance(module, nn.BatchNorm1d):
68
+ module.train(not self.norm_use_average)
69
+
70
+
71
+ def postprocess_prediction(seg):
72
+ # basically look for connected components and choose the largest one, delete everything else
73
+ print("running postprocessing... ")
74
+ mask = seg != 0
75
+ lbls = label(mask, connectivity=mask.ndim)
76
+ lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
77
+ largest_region = np.argmax(lbls_sizes[1:]) + 1
78
+ seg[lbls != largest_region] = 0
79
+ return seg
80
+
81
+
82
+ def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
83
+ if join:
84
+ l = os.path.join
85
+ else:
86
+ l = lambda x, y: y
87
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
88
+ and (prefix is None or i.startswith(prefix))
89
+ and (suffix is None or i.endswith(suffix))]
90
+ if sort:
91
+ res.sort()
92
+ return res
93
+
94
+
95
+ def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
96
+ if join:
97
+ l = os.path.join
98
+ else:
99
+ l = lambda x, y: y
100
+ res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
101
+ and (prefix is None or i.startswith(prefix))
102
+ and (suffix is None or i.endswith(suffix))]
103
+ if sort:
104
+ res.sort()
105
+ return res
106
+
107
+
108
+ subfolders = subdirs # I am tired of confusing those
109
+
110
+
111
+ def maybe_mkdir_p(directory):
112
+ splits = directory.split("/")[1:]
113
+ for i in range(0, len(splits)):
114
+ if not os.path.isdir(os.path.join("", *splits[:i+1])):
115
+ os.mkdir(os.path.join("", *splits[:i+1]))
src/IDH/app_gradio.py ADDED
@@ -0,0 +1,867 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import nibabel as nib
5
+ import numpy as np
6
+ import gradio as gr
7
+ from typing import Tuple
8
+ import tempfile
9
+ import shutil
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Use non-interactive backend
13
+ import cv2 # For Gaussian Blur
14
+ import io # For saving plots to memory
15
+ import base64 # For encoding plots
16
+ import uuid # For unique IDs
17
+ import traceback # For detailed error printing
18
+
19
+ import SimpleITK as sitk
20
+ import itk
21
+ from scipy.signal import medfilt
22
+ import skimage.filters
23
+
24
+ from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Resized, NormalizeIntensityd, ToTensord
25
+
26
+ from model import ViTBackboneNet, Classifier, SingleScanModelBP
27
+
28
+ # Optional HD-BET import (packaged locally like in MCI app)
29
+ try:
30
+ from HD_BET.run import run_hd_bet
31
+ from HD_BET.hd_bet import hd_bet
32
+ except Exception as e:
33
+ print(f"Warning: HD_BET not available: {e}")
34
+ run_hd_bet = None
35
+ hd_bet = None
36
+
37
+
38
+ APP_DIR = os.path.dirname(__file__)
39
+ TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates")
40
+ PARAMS_RIGID_PATH = os.path.join(TEMPLATE_DIR, "Parameters_Rigid.txt")
41
+ DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "temp_head.nii.gz")
42
+ FLAIR_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_04.5-18.5_t2w.nii")
43
+ T1C_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii")
44
+ HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py")
45
+ HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model")
46
+
47
+
48
+ def load_config() -> dict:
49
+ cfg_path = os.path.join(APP_DIR, "config.yml")
50
+ if os.path.exists(cfg_path):
51
+ with open(cfg_path, "r") as f:
52
+ return yaml.safe_load(f)
53
+ # Defaults
54
+ return {
55
+ "gpu": {"device": "cpu"},
56
+ "infer": {
57
+ "checkpoints": "./checkpoints/idh_model.ckpt",
58
+ "simclr_checkpoint": "./checkpoints/simclr_vitb.ckpt",
59
+ "threshold": 0.5,
60
+ "image_size": [96, 96, 96],
61
+ },
62
+ }
63
+
64
+
65
+ def build_model(cfg: dict):
66
+ device = torch.device(cfg.get("gpu", {}).get("device", "cpu"))
67
+ infer_cfg = cfg.get("infer", {})
68
+ simclr_path = os.path.join(APP_DIR, infer_cfg.get("simclr_checkpoint", ""))
69
+ ckpt_path = os.path.join(APP_DIR, infer_cfg.get("checkpoints", ""))
70
+
71
+ backbone = ViTBackboneNet(simclr_ckpt_path=simclr_path)
72
+ classifier = Classifier(d_model=768, num_classes=1)
73
+ model = SingleScanModelBP(backbone, classifier)
74
+
75
+ # Load finetuned checkpoint (Lightning or plain state_dict)
76
+ if os.path.exists(ckpt_path):
77
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
78
+ if "state_dict" in checkpoint:
79
+ state_dict = checkpoint["state_dict"]
80
+ new_state_dict = {}
81
+ for key, value in state_dict.items():
82
+ if key.startswith("model."):
83
+ new_state_dict[key[len("model."):]] = value
84
+ else:
85
+ new_state_dict[key] = value
86
+ else:
87
+ new_state_dict = checkpoint
88
+ model.load_state_dict(new_state_dict, strict=False)
89
+ else:
90
+ print(f"Warning: Finetuned checkpoint not found at {ckpt_path}. Model will use backbone-only weights.")
91
+
92
+ model.to(device)
93
+ model.eval()
94
+ return model, device
95
+
96
+
97
+ # ---------------- Preprocessing (Registration + Enhancement + Skull Stripping) ----------------
98
+
99
+ def bias_field_correction(img_array: np.ndarray) -> np.ndarray:
100
+ image = sitk.GetImageFromArray(img_array.astype(np.float32))
101
+ if image.GetPixelID() != sitk.sitkFloat32:
102
+ image = sitk.Cast(image, sitk.sitkFloat32)
103
+ maskImage = sitk.OtsuThreshold(image, 0, 1, 200)
104
+ corrector = sitk.N4BiasFieldCorrectionImageFilter()
105
+ numberFittingLevels = 4
106
+ max_iters = [min(50 * (2 ** i), 200) for i in range(numberFittingLevels)]
107
+ corrector.SetMaximumNumberOfIterations(max_iters)
108
+ corrected_image = corrector.Execute(image, maskImage)
109
+ return sitk.GetArrayFromImage(corrected_image)
110
+
111
+
112
+ def denoise(volume: np.ndarray, kernel_size: int = 3) -> np.ndarray:
113
+ return medfilt(volume, kernel_size)
114
+
115
+
116
+ def rescale_intensity(volume: np.ndarray, percentils=[0.5, 99.5], bins_num=256) -> np.ndarray:
117
+ volume_float = volume.astype(np.float32)
118
+ try:
119
+ t = skimage.filters.threshold_otsu(volume_float, nbins=256)
120
+ volume_masked = np.copy(volume_float)
121
+ volume_masked[volume_masked < t] = 0
122
+ obj_volume = volume_masked[np.where(volume_masked > 0)]
123
+ except ValueError:
124
+ obj_volume = volume_float.flatten()
125
+ if obj_volume.size == 0:
126
+ obj_volume = volume_float.flatten()
127
+ min_value = np.min(obj_volume)
128
+ max_value = np.max(obj_volume)
129
+ else:
130
+ min_value = np.percentile(obj_volume, percentils[0])
131
+ max_value = np.percentile(obj_volume, percentils[1])
132
+ denom = max_value - min_value
133
+ if denom < 1e-6:
134
+ denom = 1e-6
135
+ if bins_num == 0:
136
+ output_volume = (volume_float - min_value) / denom
137
+ output_volume = np.clip(output_volume, 0.0, 1.0)
138
+ else:
139
+ output_volume = np.round((volume_float - min_value) / denom * (bins_num - 1))
140
+ output_volume = np.clip(output_volume, 0, bins_num - 1)
141
+ return output_volume.astype(np.float32)
142
+
143
+
144
+ def equalize_hist(volume: np.ndarray, bins_num=256) -> np.ndarray:
145
+ mask = volume > 1e-6
146
+ obj_volume = volume[mask]
147
+ if obj_volume.size == 0:
148
+ return volume
149
+ hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max()))
150
+ cdf = hist.cumsum()
151
+ cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1])
152
+ equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized)
153
+ equalized_volume = np.copy(volume)
154
+ equalized_volume[mask] = equalized_obj_volume
155
+ return equalized_volume.astype(np.float32)
156
+
157
+
158
+ def run_enhance_on_file(input_nifti_path: str, output_nifti_path: str):
159
+ """
160
+ Simplified enhancement - just copy the file since N4 is now done in registration.
161
+ This maintains compatibility with the existing preprocessing pipeline.
162
+ """
163
+ print(f"Enhancement step (N4 already applied during registration): {input_nifti_path}")
164
+ # Since N4 bias correction is now handled in registration, just copy the file
165
+ import shutil
166
+ shutil.copy2(input_nifti_path, output_nifti_path)
167
+ print(f"Enhancement complete (passthrough): {output_nifti_path}")
168
+
169
+
170
+ def register_image_sitk(input_nifti_path: str, output_nifti_path: str, template_path: str, interp_type='linear'):
171
+ """
172
+ MRI registration with SimpleITK matching the provided script approach.
173
+
174
+ Args:
175
+ input_nifti_path: Path to input NIfTI file
176
+ output_nifti_path: Path to save registered output
177
+ template_path: Path to template image
178
+ interp_type: Interpolation type ('linear', 'bspline', 'nearest_neighbor')
179
+ """
180
+ print(f"Registering {input_nifti_path} to template {template_path}")
181
+
182
+ # Read template and moving images
183
+ fixed_img = sitk.ReadImage(template_path, sitk.sitkFloat32)
184
+ moving_img = sitk.ReadImage(input_nifti_path, sitk.sitkFloat32)
185
+
186
+ # Apply N4 bias correction to moving image
187
+ moving_img = sitk.N4BiasFieldCorrection(moving_img)
188
+
189
+ # Resample fixed image to 1mm isotropic
190
+ old_size = fixed_img.GetSize()
191
+ old_spacing = fixed_img.GetSpacing()
192
+ new_spacing = (1, 1, 1)
193
+ new_size = [
194
+ int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))),
195
+ int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))),
196
+ int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2])))
197
+ ]
198
+
199
+ # Set interpolation type
200
+ if interp_type == 'linear':
201
+ interp_type = sitk.sitkLinear
202
+ elif interp_type == 'bspline':
203
+ interp_type = sitk.sitkBSpline
204
+ elif interp_type == 'nearest_neighbor':
205
+ interp_type = sitk.sitkNearestNeighbor
206
+ else:
207
+ interp_type = sitk.sitkLinear
208
+
209
+ # Resample fixed image
210
+ resample = sitk.ResampleImageFilter()
211
+ resample.SetOutputSpacing(new_spacing)
212
+ resample.SetSize(new_size)
213
+ resample.SetOutputOrigin(fixed_img.GetOrigin())
214
+ resample.SetOutputDirection(fixed_img.GetDirection())
215
+ resample.SetInterpolator(interp_type)
216
+ resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue())
217
+ resample.SetOutputPixelType(sitk.sitkFloat32)
218
+ fixed_img = resample.Execute(fixed_img)
219
+
220
+ # Initialize transform
221
+ transform = sitk.CenteredTransformInitializer(
222
+ fixed_img,
223
+ moving_img,
224
+ sitk.Euler3DTransform(),
225
+ sitk.CenteredTransformInitializerFilter.GEOMETRY)
226
+
227
+ # Set up registration method
228
+ registration_method = sitk.ImageRegistrationMethod()
229
+ registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
230
+ registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
231
+ registration_method.SetMetricSamplingPercentage(0.01)
232
+ registration_method.SetInterpolator(sitk.sitkLinear)
233
+ registration_method.SetOptimizerAsGradientDescent(
234
+ learningRate=1.0,
235
+ numberOfIterations=100,
236
+ convergenceMinimumValue=1e-6,
237
+ convergenceWindowSize=10)
238
+ registration_method.SetOptimizerScalesFromPhysicalShift()
239
+ registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
240
+ registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
241
+ registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
242
+ registration_method.SetInitialTransform(transform)
243
+
244
+ # Execute registration
245
+ final_transform = registration_method.Execute(fixed_img, moving_img)
246
+
247
+ # Apply transform and save registered image
248
+ moving_img_resampled = sitk.Resample(
249
+ moving_img,
250
+ fixed_img,
251
+ final_transform,
252
+ sitk.sitkLinear,
253
+ 0.0,
254
+ moving_img.GetPixelID())
255
+
256
+ sitk.WriteImage(moving_img_resampled, output_nifti_path)
257
+ print(f"Registration complete. Saved to: {output_nifti_path}")
258
+
259
+
260
+ def register_image(input_nifti_path: str, output_nifti_path: str):
261
+ """Wrapper to maintain compatibility - now uses SimpleITK registration."""
262
+ if not os.path.exists(DEFAULT_TEMPLATE_PATH):
263
+ raise FileNotFoundError(f"Template file missing: {DEFAULT_TEMPLATE_PATH}")
264
+ register_image_sitk(input_nifti_path, output_nifti_path, DEFAULT_TEMPLATE_PATH)
265
+
266
+
267
+ def run_skull_stripping(input_nifti_path: str, output_dir: str):
268
+ """
269
+ Brain extraction using HD-BET direct integration matching the script approach.
270
+
271
+ Args:
272
+ input_nifti_path: Path to input NIfTI file
273
+ output_dir: Directory to save skull-stripped output
274
+
275
+ Returns:
276
+ tuple: (output_file_path, output_mask_path)
277
+ """
278
+ print(f"Running HD-BET skull stripping on {input_nifti_path}")
279
+
280
+ if hd_bet is None:
281
+ raise RuntimeError("HD-BET not available. Please include HD_BET and hdbet_model in src/IDH.")
282
+
283
+ if not os.path.exists(HD_BET_MODEL_DIR):
284
+ raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}")
285
+
286
+ os.makedirs(output_dir, exist_ok=True)
287
+
288
+ # Get base filename and prepare HD-BET compatible naming
289
+ base_name = os.path.basename(input_nifti_path).replace('.nii.gz', '').replace('.nii', '')
290
+
291
+ # HD-BET expects files with _0000 suffix - create temporary file if needed
292
+ temp_input_dir = os.path.join(output_dir, "temp_input")
293
+ os.makedirs(temp_input_dir, exist_ok=True)
294
+
295
+ # Copy input file with _0000 suffix for HD-BET
296
+ temp_input_path = os.path.join(temp_input_dir, f"{base_name}_0000.nii.gz")
297
+ shutil.copy2(input_nifti_path, temp_input_path)
298
+
299
+ # Set device
300
+ device = "0" if torch.cuda.is_available() else "cpu"
301
+
302
+ try:
303
+ # Also try setting the specific model file path
304
+ model_file = os.path.join(HD_BET_MODEL_DIR, '0.model')
305
+
306
+ if os.path.exists(model_file):
307
+ print(f"Local model file exists at: {model_file}")
308
+ else:
309
+ print(f"Warning: Model file not found at: {model_file}")
310
+ # List directory contents for debugging
311
+ if os.path.exists(HD_BET_MODEL_DIR):
312
+ print(f"Contents of {HD_BET_MODEL_DIR}: {os.listdir(HD_BET_MODEL_DIR)}")
313
+ else:
314
+ print(f"Directory {HD_BET_MODEL_DIR} does not exist")
315
+
316
+ # Run HD-BET directly on the temporary directory
317
+ print(f"Running hd_bet with input_dir: {temp_input_dir}, output_dir: {output_dir}")
318
+ hd_bet(temp_input_dir, output_dir, device=device, mode='fast', tta=0)
319
+
320
+ # HD-BET outputs files with original naming convention
321
+ output_file_path = os.path.join(output_dir, f"{base_name}_0000.nii.gz")
322
+ output_mask_path = os.path.join(output_dir, f"{base_name}_0000_mask.nii.gz")
323
+
324
+ # Rename to expected format for compatibility
325
+ final_output_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz")
326
+ final_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz")
327
+
328
+ if os.path.exists(output_file_path):
329
+ shutil.move(output_file_path, final_output_path)
330
+ if os.path.exists(output_mask_path):
331
+ shutil.move(output_mask_path, final_mask_path)
332
+
333
+ # Clean up temporary directory
334
+ shutil.rmtree(temp_input_dir, ignore_errors=True)
335
+
336
+ if not os.path.exists(final_output_path):
337
+ raise RuntimeError(f"HD-BET did not produce output file: {final_output_path}")
338
+
339
+ print(f"Skull stripping complete. Output saved to: {final_output_path}")
340
+ return final_output_path, final_mask_path
341
+
342
+ except Exception as e:
343
+ # Clean up on error
344
+ shutil.rmtree(temp_input_dir, ignore_errors=True)
345
+ raise RuntimeError(f"HD-BET skull stripping failed: {str(e)}")
346
+
347
+
348
+ # ---------------- Saliency Generation ----------------
349
+
350
+ def extract_attention_map(vit_model, image, layer_idx=-1, img_size=(96, 96, 96), patch_size=16):
351
+ """
352
+ Extracts the attention map from a Vision Transformer (ViT) model.
353
+
354
+ This function wraps the attention blocks of the ViT to capture the attention
355
+ weights during a forward pass. It then processes these weights to generate
356
+ a 3D saliency map corresponding to the model's focus on the input image.
357
+ """
358
+ attention_maps = {}
359
+ original_attns = {}
360
+
361
+ # A wrapper class to intercept and store attention weights from a ViT block.
362
+ class AttentionWithWeights(torch.nn.Module):
363
+ def __init__(self, original_attn_module):
364
+ super().__init__()
365
+ self.original_attn_module = original_attn_module
366
+ self.attn_weights = None
367
+
368
+ def forward(self, x):
369
+ # The original implementation of the attention module may not return
370
+ # the attention weights. This wrapper recalculates them to ensure they
371
+ # are captured. This is based on the standard ViT attention mechanism.
372
+ output = self.original_attn_module(x)
373
+ if hasattr(self.original_attn_module, 'qkv'):
374
+ qkv = self.original_attn_module.qkv(x)
375
+ batch_size, seq_len, _ = x.shape
376
+ # Assuming qkv has been fused and has shape (batch_size, seq_len, 3 * num_heads * head_dim)
377
+ qkv = qkv.reshape(batch_size, seq_len, 3, self.original_attn_module.num_heads, -1)
378
+ qkv = qkv.permute(2, 0, 3, 1, 4)
379
+ q, k, v = qkv[0], qkv[1], qkv[2]
380
+ attn = (q @ k.transpose(-2, -1)) * self.original_attn_module.scale
381
+ self.attn_weights = attn.softmax(dim=-1)
382
+ return output
383
+
384
+ # Store original attention modules and replace with wrappers
385
+ for i, block in enumerate(vit_model.blocks):
386
+ if hasattr(block, 'attn'):
387
+ original_attns[i] = block.attn
388
+ block.attn = AttentionWithWeights(block.attn)
389
+
390
+ try:
391
+ # Perform a forward pass to execute the wrapped modules and capture weights
392
+ with torch.no_grad():
393
+ _ = vit_model(image)
394
+
395
+ # Collect the captured attention weights from each block
396
+ for i, block in enumerate(vit_model.blocks):
397
+ if hasattr(block.attn, 'attn_weights') and block.attn.attn_weights is not None:
398
+ attention_maps[f"layer_{i}"] = block.attn.attn_weights.detach()
399
+
400
+ finally:
401
+ # Restore original attention modules
402
+ for i, original_attn in original_attns.items():
403
+ vit_model.blocks[i].attn = original_attn
404
+
405
+ if not attention_maps:
406
+ raise RuntimeError("Could not extract any attention maps. Please check the ViT model structure.")
407
+
408
+ # Select the attention map from the specified layer
409
+ if layer_idx < 0:
410
+ layer_idx = len(attention_maps) + layer_idx
411
+ layer_name = f"layer_{layer_idx}"
412
+ if layer_name not in attention_maps:
413
+ raise ValueError(f"Layer {layer_idx} not found. Available layers: {list(attention_maps.keys())}")
414
+
415
+ layer_attn = attention_maps[layer_name]
416
+ # Average attention across all heads
417
+ head_attn = layer_attn[0].mean(dim=0)
418
+ # Get attention from the [CLS] token to all other image patches
419
+ cls_attn = head_attn[0, 1:]
420
+
421
+ # Reshape the 1D attention vector into a 3D volume
422
+ patches_per_dim = img_size[0] // patch_size
423
+ total_patches = patches_per_dim ** 3
424
+
425
+ # Pad or truncate if the number of patches doesn't align
426
+ if cls_attn.shape[0] != total_patches:
427
+ if cls_attn.shape[0] > total_patches:
428
+ cls_attn = cls_attn[:total_patches]
429
+ else:
430
+ padded = torch.zeros(total_patches, device=cls_attn.device)
431
+ padded[:cls_attn.shape[0]] = cls_attn
432
+ cls_attn = padded
433
+
434
+ cls_attn_3d = cls_attn.reshape(patches_per_dim, patches_per_dim, patches_per_dim)
435
+ cls_attn_3d = cls_attn_3d.unsqueeze(0).unsqueeze(0) # Add batch and channel dims
436
+
437
+ # Upsample the attention map to the full image resolution
438
+ upsampled_attn = torch.nn.functional.interpolate(
439
+ cls_attn_3d,
440
+ size=img_size,
441
+ mode='trilinear',
442
+ align_corners=False
443
+ ).squeeze()
444
+
445
+ # Normalize the map to [0, 1] for visualization
446
+ upsampled_attn = upsampled_attn.cpu().numpy()
447
+ upsampled_attn = (upsampled_attn - upsampled_attn.min()) / (upsampled_attn.max() - upsampled_attn.min())
448
+ return upsampled_attn
449
+
450
+
451
+ def generate_saliency_dual(model, input_tensor, layer_idx=-1):
452
+ """
453
+ Generate saliency maps for dual-input IDH model.
454
+
455
+ Args:
456
+ model: The complete IDH model
457
+ input_tensor: Dual input tensor (batch_size, 2, C, D, H, W)
458
+ layer_idx: ViT layer to visualize
459
+
460
+ Returns:
461
+ tuple: (flair_input_3d, t1c_input_3d, flair_saliency_3d)
462
+ """
463
+ print("Generating saliency maps for dual input...")
464
+
465
+ try:
466
+ # Extract individual images from dual input
467
+ # input_tensor shape: [batch_size, 2, C, D, H, W]
468
+ flair_tensor = input_tensor[:, 0] # [batch, C, D, H, W]
469
+ t1c_tensor = input_tensor[:, 1] # [batch, C, D, H, W]
470
+
471
+ # Get the ViT backbone
472
+ vit_model = model.backbone.backbone
473
+
474
+ # Generate attention map only for FLAIR
475
+ flair_attn = extract_attention_map(vit_model, flair_tensor, layer_idx)
476
+
477
+ # Convert input tensors to numpy for visualization
478
+ flair_input_3d = flair_tensor.squeeze().cpu().detach().numpy()
479
+ t1c_input_3d = t1c_tensor.squeeze().cpu().detach().numpy()
480
+
481
+ print("Saliency maps generated successfully.")
482
+ return flair_input_3d, t1c_input_3d, flair_attn
483
+
484
+ except Exception as e:
485
+ print(f"Error during saliency generation: {e}")
486
+ traceback.print_exc()
487
+ return None, None, None
488
+
489
+
490
+ # ---------------- Visualization Functions ----------------
491
+
492
+ def create_slice_plots_dual(flair_data_3d, t1c_data_3d, flair_saliency_3d, slice_index):
493
+ """Create slice plots for simplified dual input visualization: T1c, FLAIR, FLAIR attention."""
494
+ print(f"Generating plots for slice index: {slice_index}")
495
+
496
+ if any(data is None for data in [flair_data_3d, t1c_data_3d, flair_saliency_3d]):
497
+ return None, None, None
498
+
499
+ # Check bounds - using axis 2 for axial slices
500
+ if not (0 <= slice_index < flair_data_3d.shape[2]):
501
+ print(f"Error: Slice index {slice_index} out of bounds (0-{flair_data_3d.shape[2]-1}).")
502
+ return None, None, None
503
+
504
+ def save_plot_to_numpy(fig):
505
+ with io.BytesIO() as buf:
506
+ fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75)
507
+ plt.close(fig)
508
+ buf.seek(0)
509
+ img_arr = plt.imread(buf, format='png')
510
+ return (img_arr * 255).astype(np.uint8)
511
+
512
+ try:
513
+ # Extract axial slices - using axis 2 (last dimension)
514
+ flair_slice = flair_data_3d[:, :, slice_index]
515
+ t1c_slice = t1c_data_3d[:, :, slice_index]
516
+ flair_saliency_slice = flair_saliency_3d[:, :, slice_index]
517
+
518
+ # Normalize input slices
519
+ def normalize_slice(slice_data, volume_data):
520
+ p1, p99 = np.percentile(volume_data, (1, 99))
521
+ denom = max(p99 - p1, 1e-6)
522
+ return np.clip((slice_data - p1) / denom, 0, 1)
523
+
524
+ flair_slice_norm = normalize_slice(flair_slice, flair_data_3d)
525
+ t1c_slice_norm = normalize_slice(t1c_slice, t1c_data_3d)
526
+
527
+ # Process saliency slice
528
+ def process_saliency_slice(saliency_slice, saliency_volume):
529
+ saliency_slice = np.copy(saliency_slice)
530
+ saliency_slice[saliency_slice < 0] = 0
531
+ saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0)
532
+ s_max = max(np.max(saliency_volume[saliency_volume >= 0]), 1e-6)
533
+ saliency_slice_norm = saliency_slice_blurred / s_max
534
+ return np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0)
535
+
536
+ flair_sal_processed = process_saliency_slice(flair_saliency_slice, flair_saliency_3d)
537
+
538
+ # Create plots
539
+ plots = []
540
+
541
+ # T1c Input
542
+ fig1, ax1 = plt.subplots(figsize=(6, 6))
543
+ ax1.imshow(t1c_slice_norm, cmap='gray', interpolation='none', origin='lower')
544
+ ax1.axis('off')
545
+ ax1.set_title('T1c Input', fontsize=14, color='white', pad=10)
546
+ plots.append(save_plot_to_numpy(fig1))
547
+
548
+ # FLAIR Input
549
+ fig2, ax2 = plt.subplots(figsize=(6, 6))
550
+ ax2.imshow(flair_slice_norm, cmap='gray', interpolation='none', origin='lower')
551
+ ax2.axis('off')
552
+ ax2.set_title('FLAIR Input', fontsize=14, color='white', pad=10)
553
+ plots.append(save_plot_to_numpy(fig2))
554
+
555
+ # FLAIR Attention
556
+ fig3, ax3 = plt.subplots(figsize=(6, 6))
557
+ ax3.imshow(flair_sal_processed, cmap='magma', interpolation='none', origin='lower', vmin=0)
558
+ ax3.axis('off')
559
+ ax3.set_title('FLAIR Attention', fontsize=14, color='white', pad=10)
560
+ plots.append(save_plot_to_numpy(fig3))
561
+
562
+ print(f"Generated 3 plots successfully for axial slice {slice_index}.")
563
+ return tuple(plots)
564
+
565
+ except Exception as e:
566
+ print(f"Error generating plots for slice {slice_index}: {e}")
567
+ traceback.print_exc()
568
+ return tuple([None] * 3)
569
+
570
+
571
+ # ---------------- Inference ----------------
572
+
573
+ def get_dual_validation_transform(image_size: Tuple[int, int, int]):
574
+ return Compose([
575
+ LoadImaged(keys=["image1", "image2"]),
576
+ EnsureChannelFirstd(keys=["image1", "image2"]),
577
+ Resized(keys=["image1", "image2"], spatial_size=tuple(image_size), mode="trilinear"),
578
+ NormalizeIntensityd(keys=["image1", "image2"], nonzero=True, channel_wise=True),
579
+ ToTensord(keys=["image1", "image2"]),
580
+ ])
581
+
582
+
583
+ def preprocess_dual_nifti(flair_path: str, t1c_path: str, image_size: Tuple[int, int, int], device: torch.device) -> torch.Tensor:
584
+ transform = get_dual_validation_transform(image_size)
585
+ sample = {"image1": flair_path, "image2": t1c_path}
586
+ sample = transform(sample)
587
+ img1 = sample["image1"] # (C, D, H, W)
588
+ img2 = sample["image2"] # (C, D, H, W)
589
+ images = torch.stack([img1, img2], dim=0).unsqueeze(0).to(device)
590
+ return images
591
+
592
+
593
+ def predict_idh(flair_file, t1c_file, threshold: float, do_preprocess: bool, generate_saliency: bool, cfg: dict, model, device):
594
+ try:
595
+ if flair_file is None or t1c_file is None:
596
+ return {"error": "Please upload both FLAIR and T1c NIfTI files (.nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
597
+
598
+ flair_path = flair_file.name if hasattr(flair_file, 'name') else flair_file
599
+ t1c_path = t1c_file.name if hasattr(t1c_file, 'name') else t1c_file
600
+
601
+ if not (flair_path.endswith(".nii") or flair_path.endswith(".nii.gz")):
602
+ return {"error": "FLAIR must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
603
+ if not (t1c_path.endswith(".nii") or t1c_path.endswith(".nii.gz")):
604
+ return {"error": "T1c must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
605
+
606
+ work_dir = tempfile.mkdtemp()
607
+ flair_final_path, t1c_final_path = flair_path, t1c_path
608
+
609
+ try:
610
+ # Optional preprocessing pipeline
611
+ if do_preprocess:
612
+ # Registration (use modality-specific templates)
613
+ flair_reg = os.path.join(work_dir, "flair_registered.nii.gz")
614
+ t1c_reg = os.path.join(work_dir, "t1c_registered.nii.gz")
615
+ register_image_sitk(flair_path, flair_reg, FLAIR_TEMPLATE_PATH)
616
+ register_image_sitk(t1c_path, t1c_reg, T1C_TEMPLATE_PATH)
617
+ # Enhancement
618
+ flair_enh = os.path.join(work_dir, "flair_enhanced.nii.gz")
619
+ t1c_enh = os.path.join(work_dir, "t1c_enhanced.nii.gz")
620
+ run_enhance_on_file(flair_reg, flair_enh)
621
+ run_enhance_on_file(t1c_reg, t1c_enh)
622
+ # Skull stripping
623
+ skullstrip_dir = os.path.join(work_dir, "skullstripped")
624
+ flair_bet, _ = run_skull_stripping(flair_enh, skullstrip_dir)
625
+ t1c_bet, _ = run_skull_stripping(t1c_enh, skullstrip_dir)
626
+ flair_final_path, t1c_final_path = flair_bet, t1c_bet
627
+
628
+ # Prediction
629
+ image_size = cfg.get("infer", {}).get("image_size", [96, 96, 96])
630
+ input_tensor = preprocess_dual_nifti(flair_final_path, t1c_final_path, image_size, device)
631
+
632
+ with torch.no_grad():
633
+ logits = model(input_tensor)
634
+ prob = torch.sigmoid(logits).cpu().numpy().flatten()[0].item()
635
+ predicted_class = int(prob >= threshold)
636
+
637
+ prediction_result = {
638
+ "IDH_mutant_probability": float(prob),
639
+ "threshold": float(threshold),
640
+ "predicted_class": int(predicted_class),
641
+ "preprocessing": bool(do_preprocess),
642
+ "class_label": "IDH-mutant" if predicted_class == 1 else "IDH-wildtype"
643
+ }
644
+
645
+ # Initialize saliency outputs
646
+ t1c_input_img = flair_input_img = flair_attn_img = None
647
+ slider_update = gr.Slider(visible=False)
648
+ saliency_state = {"input_paths": None, "saliency_paths": None, "num_slices": 0}
649
+
650
+ # Generate saliency maps if requested
651
+ if generate_saliency:
652
+ print("--- Generating Saliency Maps ---")
653
+ try:
654
+ flair_input_3d, t1c_input_3d, flair_saliency_3d = generate_saliency_dual(model, input_tensor, layer_idx=-1)
655
+
656
+ if all(data is not None for data in [flair_input_3d, t1c_input_3d, flair_saliency_3d]):
657
+ num_slices = flair_input_3d.shape[2] # Use axis 2 for axial slices
658
+ center_slice_index = num_slices // 2
659
+
660
+ # Save numpy arrays for slider callback
661
+ unique_id = str(uuid.uuid4())
662
+ temp_paths = []
663
+ for name, data in [("flair_input", flair_input_3d), ("t1c_input", t1c_input_3d),
664
+ ("flair_saliency", flair_saliency_3d)]:
665
+ path = os.path.join(work_dir, f"{unique_id}_{name}.npy")
666
+ np.save(path, data)
667
+ temp_paths.append(path)
668
+
669
+ # Generate initial plots for center slice
670
+ plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, center_slice_index)
671
+ if plots and all(p is not None for p in plots):
672
+ t1c_input_img, flair_input_img, flair_attn_img = plots
673
+
674
+ # Update state and slider
675
+ saliency_state = {
676
+ "input_paths": temp_paths[:2], # [flair_input, t1c_input]
677
+ "saliency_paths": temp_paths[2:], # [flair_saliency]
678
+ "num_slices": num_slices
679
+ }
680
+ slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices-1, step=1, label="Select Slice", visible=True)
681
+ print("--- Saliency Generation Complete ---")
682
+ else:
683
+ print("Warning: Saliency generation failed - some outputs were None")
684
+
685
+ except Exception as e:
686
+ print(f"Error during saliency generation: {e}")
687
+ traceback.print_exc()
688
+
689
+ return (prediction_result, t1c_input_img, flair_input_img, flair_attn_img, slider_update, saliency_state)
690
+
691
+ except Exception as e:
692
+ shutil.rmtree(work_dir, ignore_errors=True)
693
+ return {"error": f"Processing failed: {str(e)}"}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
694
+
695
+ except Exception as e:
696
+ return {"error": str(e)}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
697
+
698
+
699
+ def update_slice_viewer_dual(slice_index, current_state):
700
+ """Update slice viewer for dual input saliency visualization."""
701
+ input_paths = current_state.get("input_paths", [])
702
+ saliency_paths = current_state.get("saliency_paths", [])
703
+
704
+ if not input_paths or not saliency_paths or len(input_paths) != 2 or len(saliency_paths) != 1:
705
+ print(f"Warning: Invalid state for slice viewer update: {current_state}")
706
+ return None, None, None
707
+
708
+ try:
709
+ # Load numpy arrays
710
+ flair_input_3d = np.load(input_paths[0])
711
+ t1c_input_3d = np.load(input_paths[1])
712
+ flair_saliency_3d = np.load(saliency_paths[0])
713
+
714
+ # Validate slice index
715
+ slice_index = int(slice_index)
716
+ if not (0 <= slice_index < flair_input_3d.shape[2]): # Use axis 2 for axial slices
717
+ print(f"Warning: Invalid slice index {slice_index}")
718
+ return None, None, None
719
+
720
+ # Generate new plots
721
+ plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, slice_index)
722
+ return plots if plots else tuple([None] * 3)
723
+
724
+ except Exception as e:
725
+ print(f"Error updating slice viewer for index {slice_index}: {e}")
726
+ traceback.print_exc()
727
+ return tuple([None] * 3)
728
+
729
+
730
+ def build_interface():
731
+ cfg = load_config()
732
+ model, device = build_model(cfg)
733
+ default_threshold = float(cfg.get("infer", {}).get("threshold", 0.5))
734
+
735
+ with gr.Blocks(title="BrainIAC: IDH Classification", css="""
736
+ #header-row {
737
+ min-height: 150px;
738
+ align-items: center;
739
+ }
740
+ .logo-img img {
741
+ height: 150px;
742
+ object-fit: contain;
743
+ }
744
+ """) as demo:
745
+ # --- Header with Logos ---
746
+ with gr.Row(elem_id="header-row"):
747
+ with gr.Column(scale=1):
748
+ gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"),
749
+ show_label=False, interactive=False,
750
+ show_download_button=False,
751
+ container=False,
752
+ elem_classes=["logo-img"])
753
+ with gr.Column(scale=3):
754
+ gr.Markdown(
755
+ "<h1 style='text-align: center; margin-bottom: 2.5rem'>"
756
+ "BrainIAC: IDH Classification"
757
+ "</h1>"
758
+ )
759
+ with gr.Column(scale=1):
760
+ gr.Image(os.path.join(APP_DIR, "static/images/brainiac.jpeg"),
761
+ show_label=False, interactive=False,
762
+ show_download_button=False,
763
+ container=False,
764
+ elem_classes=["logo-img"])
765
+
766
+ # --- Add model description section ---
767
+ with gr.Accordion("ℹ️ Model Details and Usage Guide", open=False):
768
+ gr.Markdown("""
769
+ ### 🧠 BrainIAC: IDH Classification
770
+
771
+ **Model Description**
772
+ A Vision Transformer (ViT) model with BrainIAC as pre-trained backbone designed to predict IDH mutation status from dual MRI sequences (FLAIR + T1c).
773
+
774
+ **Training Dataset**
775
+ - **Subjects**: Trained on FLAIR and T1c MRI scans from glioma patients from UCSF-PDGM dataset
776
+ - **Imaging Modalities**: FLAIR and T1c (contrast-enhanced T1-weighted)
777
+ - **Preprocessing**: N4 bias correction, MNI registration, and skull stripping (HD-BET)
778
+
779
+ **Input**
780
+ - Format: NIfTI (.nii or .nii.gz)
781
+ - Required sequences: FLAIR and T1c (both required)
782
+ - Image size: Automatically resized to 96×96×96 voxels
783
+
784
+ **Output**
785
+ - Binary classification: IDH-mutant or IDH-wildtype
786
+ - Probability score for IDH mutation
787
+ - Attention map visualization
788
+
789
+ **Intended Use**
790
+ - Research use only!
791
+
792
+ **NOTE**
793
+ - Requires both FLAIR and T1c sequences
794
+ - Not validated on other MRI sequences
795
+ - Not validated for other brain pathologies beyond gliomas
796
+ - Upload PHI data at own risk!
797
+ - The model is hosted on a cloud-based CPU instance
798
+ - The data is not stored, shared or collected for any purpose!
799
+
800
+ **Preprocessing Pipeline**
801
+ When enabled, the preprocessing performs:
802
+ 1. **Registration**: SimpleITK-based registration to template space with mutual information metric and 1mm isotropic resampling
803
+ 2. **N4 Bias Correction**: Applied during registration step
804
+ 3. **Skull Stripping**: Remove non-brain tissue using HD-BET direct integration
805
+
806
+ **Attention Maps**
807
+ When enabled, generates ViT attention maps showing which brain regions the model focuses on for prediction.
808
+
809
+ """)
810
+
811
+ # Use gr.State to store paths to numpy arrays for the slider callback
812
+ saliency_state = gr.State({"input_paths": None, "saliency_paths": None, "num_slices": 0})
813
+
814
+ # Main Content
815
+ gr.Markdown("**Upload FLAIR and T1c NIfTI volumes** — Optional preprocessing performs registration to MNI, enhancement, and skull stripping.")
816
+
817
+ with gr.Row():
818
+ with gr.Column(scale=1):
819
+ with gr.Group():
820
+ gr.Markdown("### Controls")
821
+ flair_input = gr.File(label="FLAIR (.nii or .nii.gz)")
822
+ t1c_input = gr.File(label="T1c (.nii or .nii.gz)")
823
+ preprocess_checkbox = gr.Checkbox(value=False, label="Preprocess NIfTI (registration + enhancement + skull stripping)")
824
+ generate_saliency_checkbox = gr.Checkbox(value=True, label="Generate Attention Maps")
825
+ threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, step=0.01, label="Decision Threshold")
826
+ predict_btn = gr.Button("Predict IDH Status", variant="primary")
827
+
828
+ with gr.Column(scale=2):
829
+ with gr.Group():
830
+ gr.Markdown("### Classification Result")
831
+ output_json = gr.JSON(label="Prediction")
832
+
833
+ # Saliency visualization section
834
+ with gr.Group():
835
+ gr.Markdown("### Attention Map Viewer (Axial Slice)")
836
+ slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False)
837
+
838
+ with gr.Row():
839
+ with gr.Column():
840
+ gr.Markdown("<p style='text-align: center;'>T1c Input</p>")
841
+ t1c_input_img = gr.Image(label="T1c Input", type="numpy", show_label=False)
842
+ with gr.Column():
843
+ gr.Markdown("<p style='text-align: center;'>FLAIR Input</p>")
844
+ flair_input_img = gr.Image(label="FLAIR Input", type="numpy", show_label=False)
845
+ with gr.Column():
846
+ gr.Markdown("<p style='text-align: center;'>FLAIR Attention</p>")
847
+ flair_attn_img = gr.Image(label="Attention Mask", type="numpy", show_label=False)
848
+
849
+ # Wire components
850
+ predict_btn.click(
851
+ fn=lambda f, t, prep, gen_sal, thr: predict_idh(f, t, thr, prep, gen_sal, cfg, model, device),
852
+ inputs=[flair_input, t1c_input, preprocess_checkbox, generate_saliency_checkbox, threshold_input],
853
+ outputs=[output_json, t1c_input_img, flair_input_img, flair_attn_img, slice_slider, saliency_state],
854
+ )
855
+
856
+ slice_slider.change(
857
+ fn=update_slice_viewer_dual,
858
+ inputs=[slice_slider, saliency_state],
859
+ outputs=[t1c_input_img, flair_input_img, flair_attn_img]
860
+ )
861
+
862
+ return demo
863
+
864
+
865
+ if __name__ == "__main__":
866
+ iface = build_interface()
867
+ iface.launch(server_name="0.0.0.0", server_port=7860)
src/IDH/checkpoints/idh_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4fbbddda1f0a1c38dd3c6241725b3ed4806705085e1b812b19464867518b5bf
3
+ size 353461323
src/IDH/config.yml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gpu:
2
+ device: cpu
3
+ infer:
4
+ checkpoints: ./checkpoints/idh_model.ckpt
5
+ simclr_checkpoint: ./checkpoints/simclr_vitb.ckpt
6
+ threshold: 0.5
7
+ image_size: [96, 96, 96]
src/IDH/golden_image/mni_templates/Parameters_Rigid.txt ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Example parameter file for rotation registration
2
+ // C-style comments: //
3
+
4
+ // The internal pixel type, used for internal computations
5
+ // Leave to float in general.
6
+ // NB: this is not the type of the input images! The pixel
7
+ // type of the input images is automatically read from the
8
+ // images themselves.
9
+ // This setting can be changed to "short" to save some memory
10
+ // in case of very large 3D images.
11
+ (FixedInternalImagePixelType "float")
12
+ (MovingInternalImagePixelType "float")
13
+
14
+ // **************** Main Components **************************
15
+
16
+ // The following components should usually be left as they are:
17
+ (Registration "MultiResolutionRegistration")
18
+ (Interpolator "BSplineInterpolator")
19
+ (ResampleInterpolator "FinalBSplineInterpolator")
20
+ (Resampler "DefaultResampler")
21
+
22
+ // These may be changed to Fixed/MovingSmoothingImagePyramid.
23
+ // See the manual.
24
+ (FixedImagePyramid "FixedRecursiveImagePyramid")
25
+ (MovingImagePyramid "MovingRecursiveImagePyramid")
26
+
27
+ // The following components are most important:
28
+ // The optimizer AdaptiveStochasticGradientDescent (ASGD) works
29
+ // quite ok in general. The Transform and Metric are important
30
+ // and need to be chosen careful for each application. See manual.
31
+ (Optimizer "AdaptiveStochasticGradientDescent")
32
+ (Transform "EulerTransform")
33
+ (Metric "AdvancedMattesMutualInformation")
34
+
35
+ // ***************** Transformation **************************
36
+
37
+ // Scales the rotations compared to the translations, to make
38
+ // sure they are in the same range. In general, it's best to
39
+ // use automatic scales estimation:
40
+ (AutomaticScalesEstimation "true")
41
+
42
+ // Automatically guess an initial translation by aligning the
43
+ // geometric centers of the fixed and moving.
44
+ (AutomaticTransformInitialization "true")
45
+
46
+ // Whether transforms are combined by composition or by addition.
47
+ // In generally, Compose is the best option in most cases.
48
+ // It does not influence the results very much.
49
+ (HowToCombineTransforms "Compose")
50
+
51
+ // ******************* Similarity measure *********************
52
+
53
+ // Number of grey level bins in each resolution level,
54
+ // for the mutual information. 16 or 32 usually works fine.
55
+ // You could also employ a hierarchical strategy:
56
+ //(NumberOfHistogramBins 16 32 64)
57
+ (NumberOfHistogramBins 32)
58
+
59
+ // If you use a mask, this option is important.
60
+ // If the mask serves as region of interest, set it to false.
61
+ // If the mask indicates which pixels are valid, then set it to true.
62
+ // If you do not use a mask, the option doesn't matter.
63
+ (ErodeMask "false")
64
+
65
+ // ******************** Multiresolution **********************
66
+
67
+ // The number of resolutions. 1 Is only enough if the expected
68
+ // deformations are small. 3 or 4 mostly works fine. For large
69
+ // images and large deformations, 5 or 6 may even be useful.
70
+ (NumberOfResolutions 4)
71
+
72
+ // The downsampling/blurring factors for the image pyramids.
73
+ // By default, the images are downsampled by a factor of 2
74
+ // compared to the next resolution.
75
+ // So, in 2D, with 4 resolutions, the following schedule is used:
76
+ //(ImagePyramidSchedule 8 8 4 4 2 2 1 1 )
77
+ // And in 3D:
78
+ //(ImagePyramidSchedule 8 8 8 4 4 4 2 2 2 1 1 1 )
79
+ // You can specify any schedule, for example:
80
+ //(ImagePyramidSchedule 4 4 4 3 2 1 1 1 )
81
+ // Make sure that the number of elements equals the number
82
+ // of resolutions times the image dimension.
83
+
84
+ // ******************* Optimizer ****************************
85
+
86
+ // Maximum number of iterations in each resolution level:
87
+ // 200-500 works usually fine for rigid registration.
88
+ // For more robustness, you may increase this to 1000-2000.
89
+ (MaximumNumberOfIterations 250)
90
+
91
+ // The step size of the optimizer, in mm. By default the voxel size is used.
92
+ // which usually works well. In case of unusual high-resolution images
93
+ // (eg histology) it is necessary to increase this value a bit, to the size
94
+ // of the "smallest visible structure" in the image:
95
+ //(MaximumStepLength 1.0)
96
+
97
+ // **************** Image sampling **********************
98
+
99
+ // Number of spatial samples used to compute the mutual
100
+ // information (and its derivative) in each iteration.
101
+ // With an AdaptiveStochasticGradientDescent optimizer,
102
+ // in combination with the two options below, around 2000
103
+ // samples may already suffice.
104
+ (NumberOfSpatialSamples 2048)
105
+
106
+ // Refresh these spatial samples in every iteration, and select
107
+ // them randomly. See the manual for information on other sampling
108
+ // strategies.
109
+ (NewSamplesEveryIteration "true")
110
+ (ImageSampler "Random")
111
+
112
+ // ************* Interpolation and Resampling ****************
113
+
114
+ // Order of B-Spline interpolation used during registration/optimisation.
115
+ // It may improve accuracy if you set this to 3. Never use 0.
116
+ // An order of 1 gives linear interpolation. This is in most
117
+ // applications a good choice.
118
+ (BSplineInterpolationOrder 1)
119
+
120
+ // Order of B-Spline interpolation used for applying the final
121
+ // deformation.
122
+ // 3 gives good accuracy; recommended in most cases.
123
+ // 1 gives worse accuracy (linear interpolation)
124
+ // 0 gives worst accuracy, but is appropriate for binary images
125
+ // (masks, segmentations); equivalent to nearest neighbor interpolation.
126
+ (FinalBSplineInterpolationOrder 3)
127
+
128
+ //Default pixel value for pixels that come from outside the picture:
129
+ (DefaultPixelValue 0)
130
+
131
+ // Choose whether to generate the deformed moving image.
132
+ // You can save some time by setting this to false, if you are
133
+ // only interested in the final (nonrigidly) deformed moving image
134
+ // for example.
135
+ (WriteResultImage "true")
136
+
137
+ // The pixel type and format of the resulting deformed moving image
138
+ (ResultImagePixelType "short")
139
+ (ResultImageFormat "mhd")
140
+
141
+
src/IDH/golden_image/mni_templates/nihpd_asym_04.5-18.5_t2w.nii ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bd55fefd8a74deadca777c362bbdb7db9ae15ee135dd9fe045c8838af58d3ee
3
+ size 17350930
src/IDH/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f10804664000688f0ddc124b39ce3ae27f2c339a40583bd6ff916727e97b77d0
3
+ size 17350930
src/IDH/hdbet_model/0.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c
3
+ size 65443735
src/IDH/hdbet_model/hdbet_model/0.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c
3
+ size 65443735
src/IDH/model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from monai.networks.nets import ViT
4
+ import os
5
+
6
+
7
+ class ViTBackboneNet(nn.Module):
8
+ def __init__(self, simclr_ckpt_path: str):
9
+ super().__init__()
10
+ self.backbone = ViT(
11
+ in_channels=1,
12
+ img_size=(96, 96, 96),
13
+ patch_size=(16, 16, 16),
14
+ hidden_size=768,
15
+ mlp_dim=3072,
16
+ num_layers=12,
17
+ num_heads=12,
18
+ save_attn=True,
19
+ )
20
+ # Load pretrained weights from SimCLR checkpoint if provided
21
+ if simclr_ckpt_path and os.path.exists(simclr_ckpt_path):
22
+ ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False)
23
+ state_dict = ckpt.get("state_dict", ckpt)
24
+ backbone_state_dict = {}
25
+ for key, value in state_dict.items():
26
+ if key.startswith("backbone."):
27
+ new_key = key[len("backbone."):]
28
+ backbone_state_dict[new_key] = value
29
+ missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False)
30
+ print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
31
+ else:
32
+ print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.")
33
+
34
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
35
+ features = self.backbone(x)
36
+ cls_token = features[0][:, 0]
37
+ return cls_token
38
+
39
+
40
+ class Classifier(nn.Module):
41
+ def __init__(self, d_model: int = 768, num_classes: int = 1):
42
+ super().__init__()
43
+ self.fc = nn.Linear(d_model, num_classes)
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ return self.fc(x)
47
+
48
+
49
+ class SingleScanModelBP(nn.Module):
50
+ def __init__(self, backbone: nn.Module, classifier: nn.Module):
51
+ super().__init__()
52
+ self.backbone = backbone
53
+ self.classifier = classifier
54
+ self.dropout = nn.Dropout(p=0.2)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ # x shape: (batch_size, 2, C, D, H, W)
58
+ scan_features_list = []
59
+ for scan_tensor_with_extra_dim in x.split(1, dim=1):
60
+ squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1)
61
+ feature = self.backbone(squeezed_scan_tensor)
62
+ scan_features_list.append(feature)
63
+ stacked_features = torch.stack(scan_features_list, dim=1)
64
+ merged_features = torch.mean(stacked_features, dim=1)
65
+ merged_features = self.dropout(merged_features)
66
+ output = self.classifier(merged_features)
67
+ return output
src/IDH/static/images/brainage.jpeg ADDED

Git LFS Details

  • SHA256: 4b844af61b1dea2e772edfddcd8f8adb0453721f7972684b7b580e85ae2addf5
  • Pointer size: 130 Bytes
  • Size of remote file: 33.5 kB
src/IDH/static/images/brainiac.jpeg ADDED

Git LFS Details

  • SHA256: 4766658a13c4901d134196b1991b6d5707083edcb3cb8d5e77d2e459a82c2dd7
  • Pointer size: 131 Bytes
  • Size of remote file: 121 kB
src/IDH/static/images/kannlab.png ADDED

Git LFS Details

  • SHA256: ec1df75e65c6acda7654b0e8dfa27672a702f2c4e78609c48f59db037b07548e
  • Pointer size: 130 Bytes
  • Size of remote file: 56.1 kB