Spaces:
Sleeping
Sleeping
Divyanshu Tak
commited on
Commit
·
65bee5d
1
Parent(s):
20cb642
Add BrainIAC IDH Classification app with Vision Transformer model
Browse files- .gitignore +3 -0
- Dockerfile +37 -0
- README.md +24 -6
- requirements.txt +18 -0
- src/IDH/HD_BET/HD_BET/config.py +121 -0
- src/IDH/HD_BET/HD_BET/data_loading.py +121 -0
- src/IDH/HD_BET/HD_BET/hd_bet.py +119 -0
- src/IDH/HD_BET/HD_BET/network_architecture.py +213 -0
- src/IDH/HD_BET/HD_BET/paths.py +6 -0
- src/IDH/HD_BET/HD_BET/predict_case.py +126 -0
- src/IDH/HD_BET/HD_BET/run.py +117 -0
- src/IDH/HD_BET/HD_BET/utils.py +115 -0
- src/IDH/HD_BET/config.py +121 -0
- src/IDH/HD_BET/data_loading.py +121 -0
- src/IDH/HD_BET/hd_bet.py +119 -0
- src/IDH/HD_BET/network_architecture.py +213 -0
- src/IDH/HD_BET/paths.py +6 -0
- src/IDH/HD_BET/predict_case.py +126 -0
- src/IDH/HD_BET/run.py +117 -0
- src/IDH/HD_BET/utils.py +115 -0
- src/IDH/app_gradio.py +867 -0
- src/IDH/checkpoints/idh_model.ckpt +3 -0
- src/IDH/config.yml +7 -0
- src/IDH/golden_image/mni_templates/Parameters_Rigid.txt +141 -0
- src/IDH/golden_image/mni_templates/nihpd_asym_04.5-18.5_t2w.nii +3 -0
- src/IDH/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii +3 -0
- src/IDH/hdbet_model/0.model +3 -0
- src/IDH/hdbet_model/hdbet_model/0.model +3 -0
- src/IDH/model.py +67 -0
- src/IDH/static/images/brainage.jpeg +3 -0
- src/IDH/static/images/brainiac.jpeg +3 -0
- src/IDH/static/images/kannlab.png +3 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use an official Python runtime as a parent image
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set the working directory in the container
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Install necessary system dependencies (kept minimal)
|
| 8 |
+
RUN apt-get update && \
|
| 9 |
+
apt-get install -y --no-install-recommends \
|
| 10 |
+
git \
|
| 11 |
+
libgl1 \
|
| 12 |
+
libglib2.0-0 \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy the requirements file first to leverage Docker cache
|
| 16 |
+
COPY requirements.txt ./
|
| 17 |
+
|
| 18 |
+
# Install Python packages
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy the entire project
|
| 22 |
+
COPY . /app/
|
| 23 |
+
|
| 24 |
+
# Create a non-root user (HF Spaces requirement)
|
| 25 |
+
RUN useradd -m -u 1000 user
|
| 26 |
+
USER user
|
| 27 |
+
|
| 28 |
+
# Make sure the user owns the app directory
|
| 29 |
+
COPY --chown=user:user . /app/
|
| 30 |
+
|
| 31 |
+
# Expose Gradio default port
|
| 32 |
+
EXPOSE 7860
|
| 33 |
+
|
| 34 |
+
ENV PYTHONUNBUFFERED=1
|
| 35 |
+
|
| 36 |
+
# Run the app from the src/IDH directory
|
| 37 |
+
CMD ["python", "src/IDH/app_gradio.py"]
|
README.md
CHANGED
|
@@ -1,11 +1,29 @@
|
|
| 1 |
---
|
| 2 |
-
title: IDH Classification
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
license:
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: BrainIAC IDH Classification
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# BrainIAC: IDH Classification
|
| 12 |
+
|
| 13 |
+
A Vision Transformer model for predicting IDH mutation status from dual MRI sequences (FLAIR + T1c).
|
| 14 |
+
|
| 15 |
+
## Features
|
| 16 |
+
- Upload FLAIR and T1c MRI NIfTI files
|
| 17 |
+
- Optional preprocessing (registration + enhancement + skull stripping)
|
| 18 |
+
- Vision Transformer attention map visualization
|
| 19 |
+
- Interactive slice-by-slice attention viewing
|
| 20 |
+
- Real-time IDH mutation prediction with confidence scores
|
| 21 |
+
|
| 22 |
+
## Usage
|
| 23 |
+
1. Upload FLAIR and T1c MRI scans (.nii or .nii.gz)
|
| 24 |
+
2. Optionally enable preprocessing
|
| 25 |
+
3. Enable saliency maps for attention visualization
|
| 26 |
+
4. Adjust prediction threshold
|
| 27 |
+
5. View results and attention maps
|
| 28 |
+
|
| 29 |
+
*Research use only*
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
monai==1.3.2
|
| 2 |
+
nibabel==5.2.1
|
| 3 |
+
numpy==1.23.5
|
| 4 |
+
pydicom
|
| 5 |
+
PyYAML
|
| 6 |
+
pytorch-lightning==2.3.3
|
| 7 |
+
scipy==1.10.1
|
| 8 |
+
SimpleITK==2.4.0
|
| 9 |
+
torch==2.6.0
|
| 10 |
+
tqdm
|
| 11 |
+
gradio
|
| 12 |
+
pandas
|
| 13 |
+
scikit-image==0.21.0
|
| 14 |
+
opencv-python
|
| 15 |
+
itk-elastix
|
| 16 |
+
dicom2nifti
|
| 17 |
+
einops
|
| 18 |
+
matplotlib
|
src/IDH/HD_BET/HD_BET/config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from HD_BET.utils import SetNetworkToVal, softmax_helper
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from HD_BET.network_architecture import Network
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseConfig(object):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def get_split(self, fold, random_state=12345):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def get_network(self, mode="train"):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def get_basic_generators(self, fold):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def get_data_generators(self, fold):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def preprocess(self, data):
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
def __repr__(self):
|
| 32 |
+
res = ""
|
| 33 |
+
for v in vars(self):
|
| 34 |
+
if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
|
| 35 |
+
res += (v + ": " + str(self.__getattribute__(v)) + "\n")
|
| 36 |
+
return res
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HD_BET_Config(BaseConfig):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super(HD_BET_Config, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
|
| 44 |
+
|
| 45 |
+
# network parameters
|
| 46 |
+
self.net_base_num_layers = 21
|
| 47 |
+
self.BATCH_SIZE = 2
|
| 48 |
+
self.net_do_DS = True
|
| 49 |
+
self.net_dropout_p = 0.0
|
| 50 |
+
self.net_use_inst_norm = True
|
| 51 |
+
self.net_conv_use_bias = True
|
| 52 |
+
self.net_norm_use_affine = True
|
| 53 |
+
self.net_leaky_relu_slope = 1e-1
|
| 54 |
+
|
| 55 |
+
# hyperparameters
|
| 56 |
+
self.INPUT_PATCH_SIZE = (128, 128, 128)
|
| 57 |
+
self.num_classes = 2
|
| 58 |
+
self.selected_data_channels = range(1)
|
| 59 |
+
|
| 60 |
+
# data augmentation
|
| 61 |
+
self.da_mirror_axes = (2, 3, 4)
|
| 62 |
+
|
| 63 |
+
# validation
|
| 64 |
+
self.val_use_DO = False
|
| 65 |
+
self.val_use_train_mode = False # for dropout sampling
|
| 66 |
+
self.val_num_repeats = 1 # only useful if dropout sampling
|
| 67 |
+
self.val_batch_size = 1 # only useful if dropout sampling
|
| 68 |
+
self.val_save_npz = True
|
| 69 |
+
self.val_do_mirroring = True # test time data augmentation via mirroring
|
| 70 |
+
self.val_write_images = True
|
| 71 |
+
self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
|
| 72 |
+
self.val_min_size = self.INPUT_PATCH_SIZE
|
| 73 |
+
self.val_fn = None
|
| 74 |
+
|
| 75 |
+
# CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
|
| 76 |
+
# stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
|
| 77 |
+
# to false in 0.4)
|
| 78 |
+
self.val_use_moving_averages = False
|
| 79 |
+
|
| 80 |
+
def get_network(self, train=True, pretrained_weights=None):
|
| 81 |
+
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
|
| 82 |
+
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
|
| 83 |
+
self.net_norm_use_affine, True, self.net_do_DS)
|
| 84 |
+
|
| 85 |
+
if pretrained_weights is not None:
|
| 86 |
+
net.load_state_dict(
|
| 87 |
+
torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
|
| 88 |
+
|
| 89 |
+
if train:
|
| 90 |
+
net.train(True)
|
| 91 |
+
else:
|
| 92 |
+
net.train(False)
|
| 93 |
+
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
|
| 94 |
+
net.do_ds = False
|
| 95 |
+
|
| 96 |
+
optimizer = None
|
| 97 |
+
self.lr_scheduler = None
|
| 98 |
+
return net, optimizer
|
| 99 |
+
|
| 100 |
+
def get_data_generators(self, fold):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_split(self, fold, random_state=12345):
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
def get_basic_generators(self, fold):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
def on_epoch_end(self, epoch):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def preprocess(self, data):
|
| 113 |
+
data = np.copy(data)
|
| 114 |
+
for c in range(data.shape[0]):
|
| 115 |
+
data[c] -= data[c].mean()
|
| 116 |
+
data[c] /= data[c].std()
|
| 117 |
+
return data
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
config = HD_BET_Config
|
| 121 |
+
|
src/IDH/HD_BET/HD_BET/data_loading.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage.transform import resize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def resize_image(image, old_spacing, new_spacing, order=3):
|
| 7 |
+
new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
|
| 8 |
+
int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
|
| 9 |
+
int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
|
| 10 |
+
return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
|
| 14 |
+
spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
|
| 15 |
+
image = sitk.GetArrayFromImage(itk_image).astype(float)
|
| 16 |
+
|
| 17 |
+
assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
|
| 18 |
+
|
| 19 |
+
if not is_seg:
|
| 20 |
+
if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
|
| 21 |
+
image = resize_image(image, spacing, spacing_target).astype(np.float32)
|
| 22 |
+
|
| 23 |
+
image -= image.mean()
|
| 24 |
+
image /= image.std()
|
| 25 |
+
else:
|
| 26 |
+
new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
|
| 27 |
+
int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
|
| 28 |
+
int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
|
| 29 |
+
image = resize_segmentation(image, new_shape, 1)
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_and_preprocess(mri_file):
|
| 34 |
+
images = {}
|
| 35 |
+
# t1
|
| 36 |
+
images["T1"] = sitk.ReadImage(mri_file)
|
| 37 |
+
|
| 38 |
+
properties_dict = {
|
| 39 |
+
"spacing": images["T1"].GetSpacing(),
|
| 40 |
+
"direction": images["T1"].GetDirection(),
|
| 41 |
+
"size": images["T1"].GetSize(),
|
| 42 |
+
"origin": images["T1"].GetOrigin()
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
for k in images.keys():
|
| 46 |
+
images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
|
| 47 |
+
|
| 48 |
+
properties_dict['size_before_cropping'] = images["T1"].shape
|
| 49 |
+
|
| 50 |
+
imgs = []
|
| 51 |
+
for seq in ['T1']:
|
| 52 |
+
imgs.append(images[seq][None])
|
| 53 |
+
all_data = np.vstack(imgs)
|
| 54 |
+
print("image shape after preprocessing: ", str(all_data[0].shape))
|
| 55 |
+
return all_data, properties_dict
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
|
| 59 |
+
'''
|
| 60 |
+
segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
|
| 61 |
+
of the original image
|
| 62 |
+
|
| 63 |
+
dct:
|
| 64 |
+
size_before_cropping
|
| 65 |
+
brain_bbox
|
| 66 |
+
size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
|
| 67 |
+
spacing
|
| 68 |
+
origin
|
| 69 |
+
direction
|
| 70 |
+
|
| 71 |
+
:param segmentation:
|
| 72 |
+
:param dct:
|
| 73 |
+
:param out_fname:
|
| 74 |
+
:return:
|
| 75 |
+
'''
|
| 76 |
+
old_size = dct.get('size_before_cropping')
|
| 77 |
+
bbox = dct.get('brain_bbox')
|
| 78 |
+
if bbox is not None:
|
| 79 |
+
seg_old_size = np.zeros(old_size)
|
| 80 |
+
for c in range(3):
|
| 81 |
+
bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
|
| 82 |
+
seg_old_size[bbox[0][0]:bbox[0][1],
|
| 83 |
+
bbox[1][0]:bbox[1][1],
|
| 84 |
+
bbox[2][0]:bbox[2][1]] = segmentation
|
| 85 |
+
else:
|
| 86 |
+
seg_old_size = segmentation
|
| 87 |
+
if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
|
| 88 |
+
seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
|
| 89 |
+
else:
|
| 90 |
+
seg_old_spacing = seg_old_size
|
| 91 |
+
seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
|
| 92 |
+
seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
|
| 93 |
+
seg_resized_itk.SetOrigin(dct['origin'])
|
| 94 |
+
seg_resized_itk.SetDirection(dct['direction'])
|
| 95 |
+
sitk.WriteImage(seg_resized_itk, out_fname)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def resize_segmentation(segmentation, new_shape, order=3, cval=0):
|
| 99 |
+
'''
|
| 100 |
+
Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
|
| 101 |
+
|
| 102 |
+
Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
|
| 103 |
+
hot encoding which is resized and transformed back to a segmentation map.
|
| 104 |
+
This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
|
| 105 |
+
:param segmentation:
|
| 106 |
+
:param new_shape:
|
| 107 |
+
:param order:
|
| 108 |
+
:return:
|
| 109 |
+
'''
|
| 110 |
+
tpe = segmentation.dtype
|
| 111 |
+
unique_labels = np.unique(segmentation)
|
| 112 |
+
assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
|
| 113 |
+
if order == 0:
|
| 114 |
+
return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
|
| 115 |
+
else:
|
| 116 |
+
reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
|
| 117 |
+
|
| 118 |
+
for i, c in enumerate(unique_labels):
|
| 119 |
+
reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
|
| 120 |
+
reshaped[reshaped_multihot >= 0.5] = c
|
| 121 |
+
return reshaped
|
src/IDH/HD_BET/HD_BET/hd_bet.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
|
| 6 |
+
from HD_BET.run import run_hd_bet
|
| 7 |
+
from HD_BET.utils import maybe_mkdir_p, subfiles
|
| 8 |
+
import HD_BET
|
| 9 |
+
|
| 10 |
+
def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
|
| 11 |
+
|
| 12 |
+
if output_file_or_dir is None:
|
| 13 |
+
output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
|
| 14 |
+
os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
|
| 18 |
+
config_file = os.path.join(HD_BET.__path__[0], "config.py")
|
| 19 |
+
|
| 20 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 21 |
+
|
| 22 |
+
if device == 'cpu':
|
| 23 |
+
pass
|
| 24 |
+
else:
|
| 25 |
+
device = int(device)
|
| 26 |
+
|
| 27 |
+
if os.path.isdir(input_file_or_dir):
|
| 28 |
+
maybe_mkdir_p(output_file_or_dir)
|
| 29 |
+
input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
|
| 30 |
+
|
| 31 |
+
if len(input_files) == 0:
|
| 32 |
+
raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
|
| 33 |
+
|
| 34 |
+
output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
|
| 35 |
+
input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
|
| 36 |
+
else:
|
| 37 |
+
if not output_file_or_dir.endswith('.nii.gz'):
|
| 38 |
+
output_file_or_dir += '.nii.gz'
|
| 39 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 40 |
+
|
| 41 |
+
output_files = [output_file_or_dir]
|
| 42 |
+
input_files = [input_file_or_dir]
|
| 43 |
+
|
| 44 |
+
if tta == 0:
|
| 45 |
+
tta = False
|
| 46 |
+
elif tta == 1:
|
| 47 |
+
tta = True
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
|
| 50 |
+
|
| 51 |
+
if overwrite_existing == 0:
|
| 52 |
+
overwrite_existing = False
|
| 53 |
+
elif overwrite_existing == 1:
|
| 54 |
+
overwrite_existing = True
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
|
| 57 |
+
|
| 58 |
+
if pp == 0:
|
| 59 |
+
pp = False
|
| 60 |
+
elif pp == 1:
|
| 61 |
+
pp = True
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 64 |
+
|
| 65 |
+
if save_mask == 0:
|
| 66 |
+
save_mask = False
|
| 67 |
+
elif save_mask == 1:
|
| 68 |
+
save_mask = True
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 71 |
+
|
| 72 |
+
run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
print("\n########################")
|
| 77 |
+
print("If you are using hd-bet, please cite the following paper:")
|
| 78 |
+
print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
|
| 79 |
+
"Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
|
| 80 |
+
"neural networks. arXiv preprint arXiv:1901.11341, 2019.")
|
| 81 |
+
print("########################\n")
|
| 82 |
+
|
| 83 |
+
import argparse
|
| 84 |
+
parser = argparse.ArgumentParser()
|
| 85 |
+
parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
|
| 86 |
+
'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
|
| 87 |
+
'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
|
| 88 |
+
'within that folder will be brain extracted.', required=True, type=str)
|
| 89 |
+
parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
|
| 90 |
+
' will be created', required=False, type=str)
|
| 91 |
+
parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
|
| 92 |
+
'use only one set of parameters whereas accurate will '
|
| 93 |
+
'use the five sets of parameters that resulted from '
|
| 94 |
+
'our cross-validation as an ensemble. Default: '
|
| 95 |
+
'accurate',
|
| 96 |
+
required=False)
|
| 97 |
+
parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
|
| 98 |
+
'Must be either int or str. Use int for GPU id or '
|
| 99 |
+
'\'cpu\' to run on CPU. When using CPU you should '
|
| 100 |
+
'consider disabling tta. Default for -device is: 0',
|
| 101 |
+
required=False)
|
| 102 |
+
parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
|
| 103 |
+
'(mirroring). 1= True, 0=False. Disable this '
|
| 104 |
+
'if you are using CPU to speed things up! '
|
| 105 |
+
'Default: 1')
|
| 106 |
+
parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
|
| 107 |
+
' but the largest connected component in '
|
| 108 |
+
'the prediction. Default: 1')
|
| 109 |
+
parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
|
| 110 |
+
'mask will not be '
|
| 111 |
+
'saved')
|
| 112 |
+
parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
|
| 113 |
+
"want to overwrite existing "
|
| 114 |
+
"predictions")
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
|
| 119 |
+
|
src/IDH/HD_BET/HD_BET/network_architecture.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from HD_BET.utils import softmax_helper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EncodingModule(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
|
| 9 |
+
inst_norm_affine=True, lrelu_inplace=True):
|
| 10 |
+
nn.Module.__init__(self)
|
| 11 |
+
self.dropout_p = dropout_p
|
| 12 |
+
self.lrelu_inplace = lrelu_inplace
|
| 13 |
+
self.inst_norm_affine = inst_norm_affine
|
| 14 |
+
self.conv_bias = conv_bias
|
| 15 |
+
self.leakiness = leakiness
|
| 16 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 17 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 18 |
+
self.dropout = nn.Dropout3d(dropout_p)
|
| 19 |
+
self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 20 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
skip = x
|
| 24 |
+
x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 25 |
+
x = self.conv1(x)
|
| 26 |
+
if self.dropout_p is not None and self.dropout_p > 0:
|
| 27 |
+
x = self.dropout(x)
|
| 28 |
+
x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 29 |
+
x = self.conv2(x)
|
| 30 |
+
x = x + skip
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Upsample(nn.Module):
|
| 35 |
+
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
|
| 36 |
+
super(Upsample, self).__init__()
|
| 37 |
+
self.align_corners = align_corners
|
| 38 |
+
self.mode = mode
|
| 39 |
+
self.scale_factor = scale_factor
|
| 40 |
+
self.size = size
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
|
| 44 |
+
align_corners=self.align_corners)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LocalizationModule(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 49 |
+
lrelu_inplace=True):
|
| 50 |
+
nn.Module.__init__(self)
|
| 51 |
+
self.lrelu_inplace = lrelu_inplace
|
| 52 |
+
self.inst_norm_affine = inst_norm_affine
|
| 53 |
+
self.conv_bias = conv_bias
|
| 54 |
+
self.leakiness = leakiness
|
| 55 |
+
self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
|
| 56 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 57 |
+
self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
|
| 58 |
+
self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 62 |
+
x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class UpsamplingModule(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 68 |
+
lrelu_inplace=True):
|
| 69 |
+
nn.Module.__init__(self)
|
| 70 |
+
self.lrelu_inplace = lrelu_inplace
|
| 71 |
+
self.inst_norm_affine = inst_norm_affine
|
| 72 |
+
self.conv_bias = conv_bias
|
| 73 |
+
self.leakiness = leakiness
|
| 74 |
+
self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
|
| 75 |
+
self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
|
| 76 |
+
self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
|
| 80 |
+
inplace=self.lrelu_inplace)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DownsamplingModule(nn.Module):
|
| 85 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 86 |
+
lrelu_inplace=True):
|
| 87 |
+
nn.Module.__init__(self)
|
| 88 |
+
self.lrelu_inplace = lrelu_inplace
|
| 89 |
+
self.inst_norm_affine = inst_norm_affine
|
| 90 |
+
self.conv_bias = conv_bias
|
| 91 |
+
self.leakiness = leakiness
|
| 92 |
+
self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 93 |
+
self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 97 |
+
b = self.downsample(x)
|
| 98 |
+
return x, b
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Network(nn.Module):
|
| 102 |
+
def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
|
| 103 |
+
final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 104 |
+
lrelu_inplace=True, do_ds=True):
|
| 105 |
+
super(Network, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.do_ds = do_ds
|
| 108 |
+
self.lrelu_inplace = lrelu_inplace
|
| 109 |
+
self.inst_norm_affine = inst_norm_affine
|
| 110 |
+
self.conv_bias = conv_bias
|
| 111 |
+
self.leakiness = leakiness
|
| 112 |
+
self.final_nonlin = final_nonlin
|
| 113 |
+
self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 114 |
+
|
| 115 |
+
self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 116 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 117 |
+
self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
|
| 118 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 119 |
+
|
| 120 |
+
self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 121 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 122 |
+
self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
|
| 123 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 124 |
+
|
| 125 |
+
self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 126 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 127 |
+
self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
|
| 128 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 129 |
+
|
| 130 |
+
self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 131 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 132 |
+
self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
|
| 133 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 134 |
+
|
| 135 |
+
self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
|
| 136 |
+
conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
|
| 137 |
+
|
| 138 |
+
self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 139 |
+
self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 140 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 141 |
+
|
| 142 |
+
self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 143 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 144 |
+
self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 145 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 146 |
+
|
| 147 |
+
self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 148 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 149 |
+
self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 150 |
+
self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 151 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 152 |
+
|
| 153 |
+
self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 154 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 155 |
+
self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 156 |
+
self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 157 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 158 |
+
|
| 159 |
+
self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 160 |
+
self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 161 |
+
self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 162 |
+
self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 163 |
+
self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
seg_outputs = []
|
| 167 |
+
|
| 168 |
+
x = self.init_conv(x)
|
| 169 |
+
x = self.context1(x)
|
| 170 |
+
|
| 171 |
+
skip1, x = self.down1(x)
|
| 172 |
+
x = self.context2(x)
|
| 173 |
+
|
| 174 |
+
skip2, x = self.down2(x)
|
| 175 |
+
x = self.context3(x)
|
| 176 |
+
|
| 177 |
+
skip3, x = self.down3(x)
|
| 178 |
+
x = self.context4(x)
|
| 179 |
+
|
| 180 |
+
skip4, x = self.down4(x)
|
| 181 |
+
x = self.context5(x)
|
| 182 |
+
|
| 183 |
+
x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 184 |
+
x = self.up1(x)
|
| 185 |
+
|
| 186 |
+
x = torch.cat((skip4, x), dim=1)
|
| 187 |
+
x = self.loc1(x)
|
| 188 |
+
x = self.up2(x)
|
| 189 |
+
|
| 190 |
+
x = torch.cat((skip3, x), dim=1)
|
| 191 |
+
x = self.loc2(x)
|
| 192 |
+
loc2_seg = self.final_nonlin(self.loc2_seg(x))
|
| 193 |
+
seg_outputs.append(loc2_seg)
|
| 194 |
+
x = self.up3(x)
|
| 195 |
+
|
| 196 |
+
x = torch.cat((skip2, x), dim=1)
|
| 197 |
+
x = self.loc3(x)
|
| 198 |
+
loc3_seg = self.final_nonlin(self.loc3_seg(x))
|
| 199 |
+
seg_outputs.append(loc3_seg)
|
| 200 |
+
x = self.up4(x)
|
| 201 |
+
|
| 202 |
+
x = torch.cat((skip1, x), dim=1)
|
| 203 |
+
x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
|
| 204 |
+
inplace=self.lrelu_inplace)
|
| 205 |
+
x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
|
| 206 |
+
inplace=self.lrelu_inplace)
|
| 207 |
+
x = self.final_nonlin(self.seg_layer(x))
|
| 208 |
+
seg_outputs.append(x)
|
| 209 |
+
|
| 210 |
+
if self.do_ds:
|
| 211 |
+
return seg_outputs[::-1]
|
| 212 |
+
else:
|
| 213 |
+
return seg_outputs[-1]
|
src/IDH/HD_BET/HD_BET/paths.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# please refer to the readme on where to get the parameters. Save them in this folder:
|
| 4 |
+
# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
|
| 5 |
+
# Updated path for Docker container:
|
| 6 |
+
folder_with_parameter_files = "/app/IDH/hdbet_model"
|
src/IDH/HD_BET/HD_BET/predict_case.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
|
| 6 |
+
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
|
| 7 |
+
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
|
| 8 |
+
shp = patient.shape
|
| 9 |
+
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
|
| 10 |
+
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
|
| 11 |
+
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
|
| 12 |
+
for i in range(len(shp)):
|
| 13 |
+
if shp[i] % shape_must_be_divisible_by[i] == 0:
|
| 14 |
+
new_shp[i] -= shape_must_be_divisible_by[i]
|
| 15 |
+
if min_size is not None:
|
| 16 |
+
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
|
| 17 |
+
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
|
| 21 |
+
shape = tuple(list(image.shape))
|
| 22 |
+
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
|
| 23 |
+
if pad_value is None:
|
| 24 |
+
if len(shape) == 2:
|
| 25 |
+
pad_value = image[0,0]
|
| 26 |
+
elif len(shape) == 3:
|
| 27 |
+
pad_value = image[0, 0, 0]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Image must be either 2 or 3 dimensional")
|
| 30 |
+
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
|
| 31 |
+
if len(shape) == 2:
|
| 32 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
|
| 33 |
+
elif len(shape) == 3:
|
| 34 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
|
| 39 |
+
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
pad_res = []
|
| 42 |
+
for i in range(patient_data.shape[0]):
|
| 43 |
+
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
|
| 44 |
+
pad_res.append(t[None])
|
| 45 |
+
|
| 46 |
+
patient_data = np.vstack(pad_res)
|
| 47 |
+
|
| 48 |
+
new_shp = patient_data.shape
|
| 49 |
+
|
| 50 |
+
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
data[0] = patient_data
|
| 53 |
+
|
| 54 |
+
if BATCH_SIZE is not None:
|
| 55 |
+
data = np.vstack([data] * BATCH_SIZE)
|
| 56 |
+
|
| 57 |
+
a = torch.rand(data.shape).float()
|
| 58 |
+
|
| 59 |
+
if main_device == 'cpu':
|
| 60 |
+
pass
|
| 61 |
+
else:
|
| 62 |
+
a = a.cuda(main_device)
|
| 63 |
+
|
| 64 |
+
if do_mirroring:
|
| 65 |
+
x = 8
|
| 66 |
+
else:
|
| 67 |
+
x = 1
|
| 68 |
+
all_preds = []
|
| 69 |
+
for i in range(num_repeats):
|
| 70 |
+
for m in range(x):
|
| 71 |
+
data_for_net = np.array(data)
|
| 72 |
+
do_stuff = False
|
| 73 |
+
if m == 0:
|
| 74 |
+
do_stuff = True
|
| 75 |
+
pass
|
| 76 |
+
if m == 1 and (4 in mirror_axes):
|
| 77 |
+
do_stuff = True
|
| 78 |
+
data_for_net = data_for_net[:, :, :, :, ::-1]
|
| 79 |
+
if m == 2 and (3 in mirror_axes):
|
| 80 |
+
do_stuff = True
|
| 81 |
+
data_for_net = data_for_net[:, :, :, ::-1, :]
|
| 82 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 83 |
+
do_stuff = True
|
| 84 |
+
data_for_net = data_for_net[:, :, :, ::-1, ::-1]
|
| 85 |
+
if m == 4 and (2 in mirror_axes):
|
| 86 |
+
do_stuff = True
|
| 87 |
+
data_for_net = data_for_net[:, :, ::-1, :, :]
|
| 88 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 89 |
+
do_stuff = True
|
| 90 |
+
data_for_net = data_for_net[:, :, ::-1, :, ::-1]
|
| 91 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 92 |
+
do_stuff = True
|
| 93 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, :]
|
| 94 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 95 |
+
do_stuff = True
|
| 96 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
|
| 97 |
+
|
| 98 |
+
if do_stuff:
|
| 99 |
+
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
|
| 100 |
+
p = net(a) # np.copy is necessary because ::-1 creates just a view i think
|
| 101 |
+
p = p.data.cpu().numpy()
|
| 102 |
+
|
| 103 |
+
if m == 0:
|
| 104 |
+
pass
|
| 105 |
+
if m == 1 and (4 in mirror_axes):
|
| 106 |
+
p = p[:, :, :, :, ::-1]
|
| 107 |
+
if m == 2 and (3 in mirror_axes):
|
| 108 |
+
p = p[:, :, :, ::-1, :]
|
| 109 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 110 |
+
p = p[:, :, :, ::-1, ::-1]
|
| 111 |
+
if m == 4 and (2 in mirror_axes):
|
| 112 |
+
p = p[:, :, ::-1, :, :]
|
| 113 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 114 |
+
p = p[:, :, ::-1, :, ::-1]
|
| 115 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 116 |
+
p = p[:, :, ::-1, ::-1, :]
|
| 117 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 118 |
+
p = p[:, :, ::-1, ::-1, ::-1]
|
| 119 |
+
all_preds.append(p)
|
| 120 |
+
|
| 121 |
+
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
|
| 122 |
+
predicted_segmentation = stacked.mean(0).argmax(0)
|
| 123 |
+
uncertainty = stacked.var(0)
|
| 124 |
+
bayesian_predictions = stacked
|
| 125 |
+
softmax_pred = stacked.mean(0)
|
| 126 |
+
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
|
src/IDH/HD_BET/HD_BET/run.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import SimpleITK as sitk
|
| 4 |
+
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
|
| 5 |
+
from HD_BET.predict_case import predict_case_3D_net
|
| 6 |
+
import imp
|
| 7 |
+
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
|
| 8 |
+
import os
|
| 9 |
+
import HD_BET
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def apply_bet(img, bet, out_fname):
|
| 13 |
+
img_itk = sitk.ReadImage(img)
|
| 14 |
+
img_npy = sitk.GetArrayFromImage(img_itk)
|
| 15 |
+
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
|
| 16 |
+
img_npy[img_bet == 0] = 0
|
| 17 |
+
out = sitk.GetImageFromArray(img_npy)
|
| 18 |
+
out.CopyInformation(img_itk)
|
| 19 |
+
sitk.WriteImage(out, out_fname)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
|
| 23 |
+
postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
:param mri_fnames: str or list/tuple of str
|
| 27 |
+
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
|
| 28 |
+
:param mode: fast or accurate
|
| 29 |
+
:param config_file: config.py
|
| 30 |
+
:param device: either int (for device id) or 'cpu'
|
| 31 |
+
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
|
| 32 |
+
but the largest predicted connected component. Default False
|
| 33 |
+
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
|
| 34 |
+
CPU you may want to turn that off to speed things up
|
| 35 |
+
:return:
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
list_of_param_files = []
|
| 39 |
+
|
| 40 |
+
if mode == 'fast':
|
| 41 |
+
params_file = get_params_fname(0)
|
| 42 |
+
maybe_download_parameters(0)
|
| 43 |
+
|
| 44 |
+
list_of_param_files.append(params_file)
|
| 45 |
+
elif mode == 'accurate':
|
| 46 |
+
for i in range(5):
|
| 47 |
+
params_file = get_params_fname(i)
|
| 48 |
+
maybe_download_parameters(i)
|
| 49 |
+
|
| 50 |
+
list_of_param_files.append(params_file)
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
|
| 53 |
+
|
| 54 |
+
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
|
| 55 |
+
|
| 56 |
+
cf = imp.load_source('cf', config_file)
|
| 57 |
+
cf = cf.config()
|
| 58 |
+
|
| 59 |
+
net, _ = cf.get_network(cf.val_use_train_mode, None)
|
| 60 |
+
if device == "cpu":
|
| 61 |
+
net = net.cpu()
|
| 62 |
+
else:
|
| 63 |
+
net.cuda(device)
|
| 64 |
+
|
| 65 |
+
if not isinstance(mri_fnames, (list, tuple)):
|
| 66 |
+
mri_fnames = [mri_fnames]
|
| 67 |
+
|
| 68 |
+
if not isinstance(output_fnames, (list, tuple)):
|
| 69 |
+
output_fnames = [output_fnames]
|
| 70 |
+
|
| 71 |
+
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
|
| 72 |
+
|
| 73 |
+
params = []
|
| 74 |
+
for p in list_of_param_files:
|
| 75 |
+
params.append(torch.load(p, map_location=lambda storage, loc: storage))
|
| 76 |
+
|
| 77 |
+
for in_fname, out_fname in zip(mri_fnames, output_fnames):
|
| 78 |
+
mask_fname = out_fname[:-7] + "_mask.nii.gz"
|
| 79 |
+
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
|
| 80 |
+
print("File:", in_fname)
|
| 81 |
+
print("preprocessing...")
|
| 82 |
+
try:
|
| 83 |
+
data, data_dict = load_and_preprocess(in_fname)
|
| 84 |
+
except RuntimeError:
|
| 85 |
+
print("\nERROR\nCould not read file", in_fname, "\n")
|
| 86 |
+
continue
|
| 87 |
+
except AssertionError as e:
|
| 88 |
+
print(e)
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
softmax_preds = []
|
| 92 |
+
|
| 93 |
+
print("prediction (CNN id)...")
|
| 94 |
+
for i, p in enumerate(params):
|
| 95 |
+
print(i)
|
| 96 |
+
net.load_state_dict(p)
|
| 97 |
+
net.eval()
|
| 98 |
+
net.apply(SetNetworkToVal(False, False))
|
| 99 |
+
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
|
| 100 |
+
cf.val_batch_size, cf.net_input_must_be_divisible_by,
|
| 101 |
+
cf.val_min_size, device, cf.da_mirror_axes)
|
| 102 |
+
softmax_preds.append(softmax_pred[None])
|
| 103 |
+
|
| 104 |
+
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
|
| 105 |
+
|
| 106 |
+
if postprocess:
|
| 107 |
+
seg = postprocess_prediction(seg)
|
| 108 |
+
|
| 109 |
+
print("exporting segmentation...")
|
| 110 |
+
save_segmentation_nifti(seg, data_dict, mask_fname)
|
| 111 |
+
|
| 112 |
+
apply_bet(in_fname, mask_fname, out_fname)
|
| 113 |
+
|
| 114 |
+
if not keep_mask:
|
| 115 |
+
os.remove(mask_fname)
|
| 116 |
+
|
| 117 |
+
|
src/IDH/HD_BET/HD_BET/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from urllib.request import urlopen
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.morphology import label
|
| 6 |
+
import os
|
| 7 |
+
from HD_BET.paths import folder_with_parameter_files
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_params_fname(fold):
|
| 11 |
+
return os.path.join(folder_with_parameter_files, "%d.model" % fold)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def maybe_download_parameters(fold=0, force_overwrite=False):
|
| 15 |
+
"""
|
| 16 |
+
Downloads the parameters for some fold if it is not present yet.
|
| 17 |
+
:param fold:
|
| 18 |
+
:param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
|
| 19 |
+
:return:
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
assert 0 <= fold <= 4, "fold must be between 0 and 4"
|
| 23 |
+
|
| 24 |
+
if not os.path.isdir(folder_with_parameter_files):
|
| 25 |
+
maybe_mkdir_p(folder_with_parameter_files)
|
| 26 |
+
|
| 27 |
+
out_filename = get_params_fname(fold)
|
| 28 |
+
|
| 29 |
+
if force_overwrite and os.path.isfile(out_filename):
|
| 30 |
+
os.remove(out_filename)
|
| 31 |
+
|
| 32 |
+
if not os.path.isfile(out_filename):
|
| 33 |
+
url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
|
| 34 |
+
print("Downloading", url, "...")
|
| 35 |
+
data = urlopen(url).read()
|
| 36 |
+
#out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
|
| 37 |
+
with open(out_filename, 'wb') as f:
|
| 38 |
+
f.write(data)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_weights(module):
|
| 42 |
+
if isinstance(module, nn.Conv3d):
|
| 43 |
+
module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
|
| 44 |
+
if module.bias is not None:
|
| 45 |
+
module.bias = nn.init.constant(module.bias, 0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def softmax_helper(x):
|
| 49 |
+
rpt = [1 for _ in range(len(x.size()))]
|
| 50 |
+
rpt[1] = x.size(1)
|
| 51 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
| 52 |
+
e_x = torch.exp(x - x_max)
|
| 53 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SetNetworkToVal(object):
|
| 57 |
+
def __init__(self, use_dropout_sampling=False, norm_use_average=True):
|
| 58 |
+
self.norm_use_average = norm_use_average
|
| 59 |
+
self.use_dropout_sampling = use_dropout_sampling
|
| 60 |
+
|
| 61 |
+
def __call__(self, module):
|
| 62 |
+
if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
|
| 63 |
+
module.train(self.use_dropout_sampling)
|
| 64 |
+
elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
|
| 65 |
+
isinstance(module, nn.InstanceNorm1d) \
|
| 66 |
+
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
|
| 67 |
+
isinstance(module, nn.BatchNorm1d):
|
| 68 |
+
module.train(not self.norm_use_average)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def postprocess_prediction(seg):
|
| 72 |
+
# basically look for connected components and choose the largest one, delete everything else
|
| 73 |
+
print("running postprocessing... ")
|
| 74 |
+
mask = seg != 0
|
| 75 |
+
lbls = label(mask, connectivity=mask.ndim)
|
| 76 |
+
lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
|
| 77 |
+
largest_region = np.argmax(lbls_sizes[1:]) + 1
|
| 78 |
+
seg[lbls != largest_region] = 0
|
| 79 |
+
return seg
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 83 |
+
if join:
|
| 84 |
+
l = os.path.join
|
| 85 |
+
else:
|
| 86 |
+
l = lambda x, y: y
|
| 87 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
|
| 88 |
+
and (prefix is None or i.startswith(prefix))
|
| 89 |
+
and (suffix is None or i.endswith(suffix))]
|
| 90 |
+
if sort:
|
| 91 |
+
res.sort()
|
| 92 |
+
return res
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 96 |
+
if join:
|
| 97 |
+
l = os.path.join
|
| 98 |
+
else:
|
| 99 |
+
l = lambda x, y: y
|
| 100 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
|
| 101 |
+
and (prefix is None or i.startswith(prefix))
|
| 102 |
+
and (suffix is None or i.endswith(suffix))]
|
| 103 |
+
if sort:
|
| 104 |
+
res.sort()
|
| 105 |
+
return res
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
subfolders = subdirs # I am tired of confusing those
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def maybe_mkdir_p(directory):
|
| 112 |
+
splits = directory.split("/")[1:]
|
| 113 |
+
for i in range(0, len(splits)):
|
| 114 |
+
if not os.path.isdir(os.path.join("", *splits[:i+1])):
|
| 115 |
+
os.mkdir(os.path.join("", *splits[:i+1]))
|
src/IDH/HD_BET/config.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from HD_BET.utils import SetNetworkToVal, softmax_helper
|
| 4 |
+
from abc import abstractmethod
|
| 5 |
+
from HD_BET.network_architecture import Network
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class BaseConfig(object):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def get_split(self, fold, random_state=12345):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def get_network(self, mode="train"):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def get_basic_generators(self, fold):
|
| 22 |
+
pass
|
| 23 |
+
|
| 24 |
+
@abstractmethod
|
| 25 |
+
def get_data_generators(self, fold):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def preprocess(self, data):
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
def __repr__(self):
|
| 32 |
+
res = ""
|
| 33 |
+
for v in vars(self):
|
| 34 |
+
if not v.startswith("__") and not v.startswith("_") and v != 'dataset':
|
| 35 |
+
res += (v + ": " + str(self.__getattribute__(v)) + "\n")
|
| 36 |
+
return res
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class HD_BET_Config(BaseConfig):
|
| 40 |
+
def __init__(self):
|
| 41 |
+
super(HD_BET_Config, self).__init__()
|
| 42 |
+
|
| 43 |
+
self.EXPERIMENT_NAME = self.__class__.__name__ # just a generic experiment name
|
| 44 |
+
|
| 45 |
+
# network parameters
|
| 46 |
+
self.net_base_num_layers = 21
|
| 47 |
+
self.BATCH_SIZE = 2
|
| 48 |
+
self.net_do_DS = True
|
| 49 |
+
self.net_dropout_p = 0.0
|
| 50 |
+
self.net_use_inst_norm = True
|
| 51 |
+
self.net_conv_use_bias = True
|
| 52 |
+
self.net_norm_use_affine = True
|
| 53 |
+
self.net_leaky_relu_slope = 1e-1
|
| 54 |
+
|
| 55 |
+
# hyperparameters
|
| 56 |
+
self.INPUT_PATCH_SIZE = (128, 128, 128)
|
| 57 |
+
self.num_classes = 2
|
| 58 |
+
self.selected_data_channels = range(1)
|
| 59 |
+
|
| 60 |
+
# data augmentation
|
| 61 |
+
self.da_mirror_axes = (2, 3, 4)
|
| 62 |
+
|
| 63 |
+
# validation
|
| 64 |
+
self.val_use_DO = False
|
| 65 |
+
self.val_use_train_mode = False # for dropout sampling
|
| 66 |
+
self.val_num_repeats = 1 # only useful if dropout sampling
|
| 67 |
+
self.val_batch_size = 1 # only useful if dropout sampling
|
| 68 |
+
self.val_save_npz = True
|
| 69 |
+
self.val_do_mirroring = True # test time data augmentation via mirroring
|
| 70 |
+
self.val_write_images = True
|
| 71 |
+
self.net_input_must_be_divisible_by = 16 # we could make a network class that has this as a property
|
| 72 |
+
self.val_min_size = self.INPUT_PATCH_SIZE
|
| 73 |
+
self.val_fn = None
|
| 74 |
+
|
| 75 |
+
# CAREFUL! THIS IS A HACK TO MAKE PYTORCH 0.3 STATE DICTS COMPATIBLE WITH PYTORCH 0.4 (setting keep_runnings_
|
| 76 |
+
# stats=True but not using them in validation. keep_runnings_stats was True before 0.3 but unused and defaults
|
| 77 |
+
# to false in 0.4)
|
| 78 |
+
self.val_use_moving_averages = False
|
| 79 |
+
|
| 80 |
+
def get_network(self, train=True, pretrained_weights=None):
|
| 81 |
+
net = Network(self.num_classes, len(self.selected_data_channels), self.net_base_num_layers,
|
| 82 |
+
self.net_dropout_p, softmax_helper, self.net_leaky_relu_slope, self.net_conv_use_bias,
|
| 83 |
+
self.net_norm_use_affine, True, self.net_do_DS)
|
| 84 |
+
|
| 85 |
+
if pretrained_weights is not None:
|
| 86 |
+
net.load_state_dict(
|
| 87 |
+
torch.load(pretrained_weights, map_location=lambda storage, loc: storage))
|
| 88 |
+
|
| 89 |
+
if train:
|
| 90 |
+
net.train(True)
|
| 91 |
+
else:
|
| 92 |
+
net.train(False)
|
| 93 |
+
net.apply(SetNetworkToVal(self.val_use_DO, self.val_use_moving_averages))
|
| 94 |
+
net.do_ds = False
|
| 95 |
+
|
| 96 |
+
optimizer = None
|
| 97 |
+
self.lr_scheduler = None
|
| 98 |
+
return net, optimizer
|
| 99 |
+
|
| 100 |
+
def get_data_generators(self, fold):
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
def get_split(self, fold, random_state=12345):
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
def get_basic_generators(self, fold):
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
def on_epoch_end(self, epoch):
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def preprocess(self, data):
|
| 113 |
+
data = np.copy(data)
|
| 114 |
+
for c in range(data.shape[0]):
|
| 115 |
+
data[c] -= data[c].mean()
|
| 116 |
+
data[c] /= data[c].std()
|
| 117 |
+
return data
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
config = HD_BET_Config
|
| 121 |
+
|
src/IDH/HD_BET/data_loading.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import SimpleITK as sitk
|
| 2 |
+
import numpy as np
|
| 3 |
+
from skimage.transform import resize
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def resize_image(image, old_spacing, new_spacing, order=3):
|
| 7 |
+
new_shape = (int(np.round(old_spacing[0]/new_spacing[0]*float(image.shape[0]))),
|
| 8 |
+
int(np.round(old_spacing[1]/new_spacing[1]*float(image.shape[1]))),
|
| 9 |
+
int(np.round(old_spacing[2]/new_spacing[2]*float(image.shape[2]))))
|
| 10 |
+
return resize(image, new_shape, order=order, mode='edge', cval=0, anti_aliasing=False)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def preprocess_image(itk_image, is_seg=False, spacing_target=(1, 0.5, 0.5)):
|
| 14 |
+
spacing = np.array(itk_image.GetSpacing())[[2, 1, 0]]
|
| 15 |
+
image = sitk.GetArrayFromImage(itk_image).astype(float)
|
| 16 |
+
|
| 17 |
+
assert len(image.shape) == 3, "The image has unsupported number of dimensions. Only 3D images are allowed"
|
| 18 |
+
|
| 19 |
+
if not is_seg:
|
| 20 |
+
if np.any([[i != j] for i, j in zip(spacing, spacing_target)]):
|
| 21 |
+
image = resize_image(image, spacing, spacing_target).astype(np.float32)
|
| 22 |
+
|
| 23 |
+
image -= image.mean()
|
| 24 |
+
image /= image.std()
|
| 25 |
+
else:
|
| 26 |
+
new_shape = (int(np.round(spacing[0] / spacing_target[0] * float(image.shape[0]))),
|
| 27 |
+
int(np.round(spacing[1] / spacing_target[1] * float(image.shape[1]))),
|
| 28 |
+
int(np.round(spacing[2] / spacing_target[2] * float(image.shape[2]))))
|
| 29 |
+
image = resize_segmentation(image, new_shape, 1)
|
| 30 |
+
return image
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def load_and_preprocess(mri_file):
|
| 34 |
+
images = {}
|
| 35 |
+
# t1
|
| 36 |
+
images["T1"] = sitk.ReadImage(mri_file)
|
| 37 |
+
|
| 38 |
+
properties_dict = {
|
| 39 |
+
"spacing": images["T1"].GetSpacing(),
|
| 40 |
+
"direction": images["T1"].GetDirection(),
|
| 41 |
+
"size": images["T1"].GetSize(),
|
| 42 |
+
"origin": images["T1"].GetOrigin()
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
for k in images.keys():
|
| 46 |
+
images[k] = preprocess_image(images[k], is_seg=False, spacing_target=(1.5, 1.5, 1.5))
|
| 47 |
+
|
| 48 |
+
properties_dict['size_before_cropping'] = images["T1"].shape
|
| 49 |
+
|
| 50 |
+
imgs = []
|
| 51 |
+
for seq in ['T1']:
|
| 52 |
+
imgs.append(images[seq][None])
|
| 53 |
+
all_data = np.vstack(imgs)
|
| 54 |
+
print("image shape after preprocessing: ", str(all_data[0].shape))
|
| 55 |
+
return all_data, properties_dict
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def save_segmentation_nifti(segmentation, dct, out_fname, order=1):
|
| 59 |
+
'''
|
| 60 |
+
segmentation must have the same spacing as the original nifti (for now). segmentation may have been cropped out
|
| 61 |
+
of the original image
|
| 62 |
+
|
| 63 |
+
dct:
|
| 64 |
+
size_before_cropping
|
| 65 |
+
brain_bbox
|
| 66 |
+
size -> this is the original size of the dataset, if the image was not resampled, this is the same as size_before_cropping
|
| 67 |
+
spacing
|
| 68 |
+
origin
|
| 69 |
+
direction
|
| 70 |
+
|
| 71 |
+
:param segmentation:
|
| 72 |
+
:param dct:
|
| 73 |
+
:param out_fname:
|
| 74 |
+
:return:
|
| 75 |
+
'''
|
| 76 |
+
old_size = dct.get('size_before_cropping')
|
| 77 |
+
bbox = dct.get('brain_bbox')
|
| 78 |
+
if bbox is not None:
|
| 79 |
+
seg_old_size = np.zeros(old_size)
|
| 80 |
+
for c in range(3):
|
| 81 |
+
bbox[c][1] = np.min((bbox[c][0] + segmentation.shape[c], old_size[c]))
|
| 82 |
+
seg_old_size[bbox[0][0]:bbox[0][1],
|
| 83 |
+
bbox[1][0]:bbox[1][1],
|
| 84 |
+
bbox[2][0]:bbox[2][1]] = segmentation
|
| 85 |
+
else:
|
| 86 |
+
seg_old_size = segmentation
|
| 87 |
+
if np.any(np.array(seg_old_size) != np.array(dct['size'])[[2, 1, 0]]):
|
| 88 |
+
seg_old_spacing = resize_segmentation(seg_old_size, np.array(dct['size'])[[2, 1, 0]], order=order)
|
| 89 |
+
else:
|
| 90 |
+
seg_old_spacing = seg_old_size
|
| 91 |
+
seg_resized_itk = sitk.GetImageFromArray(seg_old_spacing.astype(np.int32))
|
| 92 |
+
seg_resized_itk.SetSpacing(np.array(dct['spacing'])[[0, 1, 2]])
|
| 93 |
+
seg_resized_itk.SetOrigin(dct['origin'])
|
| 94 |
+
seg_resized_itk.SetDirection(dct['direction'])
|
| 95 |
+
sitk.WriteImage(seg_resized_itk, out_fname)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def resize_segmentation(segmentation, new_shape, order=3, cval=0):
|
| 99 |
+
'''
|
| 100 |
+
Taken from batchgenerators (https://github.com/MIC-DKFZ/batchgenerators) to prevent dependency
|
| 101 |
+
|
| 102 |
+
Resizes a segmentation map. Supports all orders (see skimage documentation). Will transform segmentation map to one
|
| 103 |
+
hot encoding which is resized and transformed back to a segmentation map.
|
| 104 |
+
This prevents interpolation artifacts ([0, 0, 2] -> [0, 1, 2])
|
| 105 |
+
:param segmentation:
|
| 106 |
+
:param new_shape:
|
| 107 |
+
:param order:
|
| 108 |
+
:return:
|
| 109 |
+
'''
|
| 110 |
+
tpe = segmentation.dtype
|
| 111 |
+
unique_labels = np.unique(segmentation)
|
| 112 |
+
assert len(segmentation.shape) == len(new_shape), "new shape must have same dimensionality as segmentation"
|
| 113 |
+
if order == 0:
|
| 114 |
+
return resize(segmentation, new_shape, order, mode="constant", cval=cval, clip=True, anti_aliasing=False).astype(tpe)
|
| 115 |
+
else:
|
| 116 |
+
reshaped = np.zeros(new_shape, dtype=segmentation.dtype)
|
| 117 |
+
|
| 118 |
+
for i, c in enumerate(unique_labels):
|
| 119 |
+
reshaped_multihot = resize((segmentation == c).astype(float), new_shape, order, mode="edge", clip=True, anti_aliasing=False)
|
| 120 |
+
reshaped[reshaped_multihot >= 0.5] = c
|
| 121 |
+
return reshaped
|
src/IDH/HD_BET/hd_bet.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append("/mnt/93E8-0534/AIDAN/HDBET/")
|
| 6 |
+
from HD_BET.run import run_hd_bet
|
| 7 |
+
from HD_BET.utils import maybe_mkdir_p, subfiles
|
| 8 |
+
import HD_BET
|
| 9 |
+
|
| 10 |
+
def hd_bet(input_file_or_dir,output_file_or_dir,mode,device,tta,pp=1,save_mask=0,overwrite_existing=1):
|
| 11 |
+
|
| 12 |
+
if output_file_or_dir is None:
|
| 13 |
+
output_file_or_dir = os.path.join(os.path.dirname(input_file_or_dir),
|
| 14 |
+
os.path.basename(input_file_or_dir).split(".")[0] + "_bet")
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
params_file = os.path.join(HD_BET.__path__[0], "model_final.py")
|
| 18 |
+
config_file = os.path.join(HD_BET.__path__[0], "config.py")
|
| 19 |
+
|
| 20 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 21 |
+
|
| 22 |
+
if device == 'cpu':
|
| 23 |
+
pass
|
| 24 |
+
else:
|
| 25 |
+
device = int(device)
|
| 26 |
+
|
| 27 |
+
if os.path.isdir(input_file_or_dir):
|
| 28 |
+
maybe_mkdir_p(output_file_or_dir)
|
| 29 |
+
input_files = subfiles(input_file_or_dir, suffix='_0000.nii.gz', join=False)
|
| 30 |
+
|
| 31 |
+
if len(input_files) == 0:
|
| 32 |
+
raise RuntimeError("input is a folder but no nifti files (.nii.gz) were found in here")
|
| 33 |
+
|
| 34 |
+
output_files = [os.path.join(output_file_or_dir, i) for i in input_files]
|
| 35 |
+
input_files = [os.path.join(input_file_or_dir, i) for i in input_files]
|
| 36 |
+
else:
|
| 37 |
+
if not output_file_or_dir.endswith('.nii.gz'):
|
| 38 |
+
output_file_or_dir += '.nii.gz'
|
| 39 |
+
assert os.path.abspath(input_file_or_dir) != os.path.abspath(output_file_or_dir), "output must be different from input"
|
| 40 |
+
|
| 41 |
+
output_files = [output_file_or_dir]
|
| 42 |
+
input_files = [input_file_or_dir]
|
| 43 |
+
|
| 44 |
+
if tta == 0:
|
| 45 |
+
tta = False
|
| 46 |
+
elif tta == 1:
|
| 47 |
+
tta = True
|
| 48 |
+
else:
|
| 49 |
+
raise ValueError("Unknown value for tta: %s. Expected: 0 or 1" % str(tta))
|
| 50 |
+
|
| 51 |
+
if overwrite_existing == 0:
|
| 52 |
+
overwrite_existing = False
|
| 53 |
+
elif overwrite_existing == 1:
|
| 54 |
+
overwrite_existing = True
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError("Unknown value for overwrite_existing: %s. Expected: 0 or 1" % str(overwrite_existing))
|
| 57 |
+
|
| 58 |
+
if pp == 0:
|
| 59 |
+
pp = False
|
| 60 |
+
elif pp == 1:
|
| 61 |
+
pp = True
|
| 62 |
+
else:
|
| 63 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 64 |
+
|
| 65 |
+
if save_mask == 0:
|
| 66 |
+
save_mask = False
|
| 67 |
+
elif save_mask == 1:
|
| 68 |
+
save_mask = True
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError("Unknown value for pp: %s. Expected: 0 or 1" % str(pp))
|
| 71 |
+
|
| 72 |
+
run_hd_bet(input_files, output_files, mode, config_file, device, pp, tta, save_mask, overwrite_existing)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
print("\n########################")
|
| 77 |
+
print("If you are using hd-bet, please cite the following paper:")
|
| 78 |
+
print("Isensee F, Schell M, Tursunova I, Brugnara G, Bonekamp D, Neuberger U, Wick A, Schlemmer HP, Heiland S, Wick W,"
|
| 79 |
+
"Bendszus M, Maier-Hein KH, Kickingereder P. Automated brain extraction of multi-sequence MRI using artificial"
|
| 80 |
+
"neural networks. arXiv preprint arXiv:1901.11341, 2019.")
|
| 81 |
+
print("########################\n")
|
| 82 |
+
|
| 83 |
+
import argparse
|
| 84 |
+
parser = argparse.ArgumentParser()
|
| 85 |
+
parser.add_argument('-i', '--input', help='input. Can be either a single file name or an input folder. If file: must be '
|
| 86 |
+
'nifti (.nii.gz) and can only be 3D. No support for 4d images, use fslsplit to '
|
| 87 |
+
'split 4d sequences into 3d images. If folder: all files ending with .nii.gz '
|
| 88 |
+
'within that folder will be brain extracted.', required=True, type=str)
|
| 89 |
+
parser.add_argument('-o', '--output', help='output. Can be either a filename or a folder. If it does not exist, the folder'
|
| 90 |
+
' will be created', required=False, type=str)
|
| 91 |
+
parser.add_argument('-mode', type=str, default='accurate', help='can be either \'fast\' or \'accurate\'. Fast will '
|
| 92 |
+
'use only one set of parameters whereas accurate will '
|
| 93 |
+
'use the five sets of parameters that resulted from '
|
| 94 |
+
'our cross-validation as an ensemble. Default: '
|
| 95 |
+
'accurate',
|
| 96 |
+
required=False)
|
| 97 |
+
parser.add_argument('-device', default='0', type=str, help='used to set on which device the prediction will run. '
|
| 98 |
+
'Must be either int or str. Use int for GPU id or '
|
| 99 |
+
'\'cpu\' to run on CPU. When using CPU you should '
|
| 100 |
+
'consider disabling tta. Default for -device is: 0',
|
| 101 |
+
required=False)
|
| 102 |
+
parser.add_argument('-tta', default=1, required=False, type=int, help='whether to use test time data augmentation '
|
| 103 |
+
'(mirroring). 1= True, 0=False. Disable this '
|
| 104 |
+
'if you are using CPU to speed things up! '
|
| 105 |
+
'Default: 1')
|
| 106 |
+
parser.add_argument('-pp', default=1, type=int, required=False, help='set to 0 to disabe postprocessing (remove all'
|
| 107 |
+
' but the largest connected component in '
|
| 108 |
+
'the prediction. Default: 1')
|
| 109 |
+
parser.add_argument('-s', '--save_mask', default=1, type=int, required=False, help='if set to 0 the segmentation '
|
| 110 |
+
'mask will not be '
|
| 111 |
+
'saved')
|
| 112 |
+
parser.add_argument('--overwrite_existing', default=1, type=int, required=False, help="set this to 0 if you don't "
|
| 113 |
+
"want to overwrite existing "
|
| 114 |
+
"predictions")
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
hd_bet(args.input,args.output,args.mode,args.device,args.tta,args.pp,args.save_mask,args.overwrite_existing)
|
| 119 |
+
|
src/IDH/HD_BET/network_architecture.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from HD_BET.utils import softmax_helper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class EncodingModule(nn.Module):
|
| 8 |
+
def __init__(self, in_channels, out_channels, filter_size=3, dropout_p=0.3, leakiness=1e-2, conv_bias=True,
|
| 9 |
+
inst_norm_affine=True, lrelu_inplace=True):
|
| 10 |
+
nn.Module.__init__(self)
|
| 11 |
+
self.dropout_p = dropout_p
|
| 12 |
+
self.lrelu_inplace = lrelu_inplace
|
| 13 |
+
self.inst_norm_affine = inst_norm_affine
|
| 14 |
+
self.conv_bias = conv_bias
|
| 15 |
+
self.leakiness = leakiness
|
| 16 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 17 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 18 |
+
self.dropout = nn.Dropout3d(dropout_p)
|
| 19 |
+
self.bn_2 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 20 |
+
self.conv2 = nn.Conv3d(out_channels, out_channels, filter_size, 1, (filter_size - 1) // 2, bias=self.conv_bias)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
skip = x
|
| 24 |
+
x = F.leaky_relu(self.bn_1(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 25 |
+
x = self.conv1(x)
|
| 26 |
+
if self.dropout_p is not None and self.dropout_p > 0:
|
| 27 |
+
x = self.dropout(x)
|
| 28 |
+
x = F.leaky_relu(self.bn_2(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 29 |
+
x = self.conv2(x)
|
| 30 |
+
x = x + skip
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class Upsample(nn.Module):
|
| 35 |
+
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=True):
|
| 36 |
+
super(Upsample, self).__init__()
|
| 37 |
+
self.align_corners = align_corners
|
| 38 |
+
self.mode = mode
|
| 39 |
+
self.scale_factor = scale_factor
|
| 40 |
+
self.size = size
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
|
| 44 |
+
align_corners=self.align_corners)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LocalizationModule(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 49 |
+
lrelu_inplace=True):
|
| 50 |
+
nn.Module.__init__(self)
|
| 51 |
+
self.lrelu_inplace = lrelu_inplace
|
| 52 |
+
self.inst_norm_affine = inst_norm_affine
|
| 53 |
+
self.conv_bias = conv_bias
|
| 54 |
+
self.leakiness = leakiness
|
| 55 |
+
self.conv1 = nn.Conv3d(in_channels, in_channels, 3, 1, 1, bias=self.conv_bias)
|
| 56 |
+
self.bn_1 = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 57 |
+
self.conv2 = nn.Conv3d(in_channels, out_channels, 1, 1, 0, bias=self.conv_bias)
|
| 58 |
+
self.bn_2 = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x = F.leaky_relu(self.bn_1(self.conv1(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 62 |
+
x = F.leaky_relu(self.bn_2(self.conv2(x)), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class UpsamplingModule(nn.Module):
|
| 67 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 68 |
+
lrelu_inplace=True):
|
| 69 |
+
nn.Module.__init__(self)
|
| 70 |
+
self.lrelu_inplace = lrelu_inplace
|
| 71 |
+
self.inst_norm_affine = inst_norm_affine
|
| 72 |
+
self.conv_bias = conv_bias
|
| 73 |
+
self.leakiness = leakiness
|
| 74 |
+
self.upsample = Upsample(scale_factor=2, mode="trilinear", align_corners=True)
|
| 75 |
+
self.upsample_conv = nn.Conv3d(in_channels, out_channels, 3, 1, 1, bias=self.conv_bias)
|
| 76 |
+
self.bn = nn.InstanceNorm3d(out_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
x = F.leaky_relu(self.bn(self.upsample_conv(self.upsample(x))), negative_slope=self.leakiness,
|
| 80 |
+
inplace=self.lrelu_inplace)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DownsamplingModule(nn.Module):
|
| 85 |
+
def __init__(self, in_channels, out_channels, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 86 |
+
lrelu_inplace=True):
|
| 87 |
+
nn.Module.__init__(self)
|
| 88 |
+
self.lrelu_inplace = lrelu_inplace
|
| 89 |
+
self.inst_norm_affine = inst_norm_affine
|
| 90 |
+
self.conv_bias = conv_bias
|
| 91 |
+
self.leakiness = leakiness
|
| 92 |
+
self.bn = nn.InstanceNorm3d(in_channels, affine=self.inst_norm_affine, track_running_stats=True)
|
| 93 |
+
self.downsample = nn.Conv3d(in_channels, out_channels, 3, 2, 1, bias=self.conv_bias)
|
| 94 |
+
|
| 95 |
+
def forward(self, x):
|
| 96 |
+
x = F.leaky_relu(self.bn(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 97 |
+
b = self.downsample(x)
|
| 98 |
+
return x, b
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class Network(nn.Module):
|
| 102 |
+
def __init__(self, num_classes=4, num_input_channels=4, base_filters=16, dropout_p=0.3,
|
| 103 |
+
final_nonlin=softmax_helper, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
|
| 104 |
+
lrelu_inplace=True, do_ds=True):
|
| 105 |
+
super(Network, self).__init__()
|
| 106 |
+
|
| 107 |
+
self.do_ds = do_ds
|
| 108 |
+
self.lrelu_inplace = lrelu_inplace
|
| 109 |
+
self.inst_norm_affine = inst_norm_affine
|
| 110 |
+
self.conv_bias = conv_bias
|
| 111 |
+
self.leakiness = leakiness
|
| 112 |
+
self.final_nonlin = final_nonlin
|
| 113 |
+
self.init_conv = nn.Conv3d(num_input_channels, base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 114 |
+
|
| 115 |
+
self.context1 = EncodingModule(base_filters, base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 116 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 117 |
+
self.down1 = DownsamplingModule(base_filters, base_filters * 2, leakiness=1e-2, conv_bias=True,
|
| 118 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 119 |
+
|
| 120 |
+
self.context2 = EncodingModule(2 * base_filters, 2 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 121 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 122 |
+
self.down2 = DownsamplingModule(2 * base_filters, base_filters * 4, leakiness=1e-2, conv_bias=True,
|
| 123 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 124 |
+
|
| 125 |
+
self.context3 = EncodingModule(4 * base_filters, 4 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 126 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 127 |
+
self.down3 = DownsamplingModule(4 * base_filters, base_filters * 8, leakiness=1e-2, conv_bias=True,
|
| 128 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 129 |
+
|
| 130 |
+
self.context4 = EncodingModule(8 * base_filters, 8 * base_filters, 3, dropout_p, leakiness=1e-2, conv_bias=True,
|
| 131 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 132 |
+
self.down4 = DownsamplingModule(8 * base_filters, base_filters * 16, leakiness=1e-2, conv_bias=True,
|
| 133 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 134 |
+
|
| 135 |
+
self.context5 = EncodingModule(16 * base_filters, 16 * base_filters, 3, dropout_p, leakiness=1e-2,
|
| 136 |
+
conv_bias=True, inst_norm_affine=True, lrelu_inplace=True)
|
| 137 |
+
|
| 138 |
+
self.bn_after_context5 = nn.InstanceNorm3d(16 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 139 |
+
self.up1 = UpsamplingModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 140 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 141 |
+
|
| 142 |
+
self.loc1 = LocalizationModule(16 * base_filters, 8 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 143 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 144 |
+
self.up2 = UpsamplingModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 145 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 146 |
+
|
| 147 |
+
self.loc2 = LocalizationModule(8 * base_filters, 4 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 148 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 149 |
+
self.loc2_seg = nn.Conv3d(4 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 150 |
+
self.up3 = UpsamplingModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 151 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 152 |
+
|
| 153 |
+
self.loc3 = LocalizationModule(4 * base_filters, 2 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 154 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 155 |
+
self.loc3_seg = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 156 |
+
self.up4 = UpsamplingModule(2 * base_filters, 1 * base_filters, leakiness=1e-2, conv_bias=True,
|
| 157 |
+
inst_norm_affine=True, lrelu_inplace=True)
|
| 158 |
+
|
| 159 |
+
self.end_conv_1 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 160 |
+
self.end_conv_1_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 161 |
+
self.end_conv_2 = nn.Conv3d(2 * base_filters, 2 * base_filters, 3, 1, 1, bias=self.conv_bias)
|
| 162 |
+
self.end_conv_2_bn = nn.InstanceNorm3d(2 * base_filters, affine=self.inst_norm_affine, track_running_stats=True)
|
| 163 |
+
self.seg_layer = nn.Conv3d(2 * base_filters, num_classes, 1, 1, 0, bias=False)
|
| 164 |
+
|
| 165 |
+
def forward(self, x):
|
| 166 |
+
seg_outputs = []
|
| 167 |
+
|
| 168 |
+
x = self.init_conv(x)
|
| 169 |
+
x = self.context1(x)
|
| 170 |
+
|
| 171 |
+
skip1, x = self.down1(x)
|
| 172 |
+
x = self.context2(x)
|
| 173 |
+
|
| 174 |
+
skip2, x = self.down2(x)
|
| 175 |
+
x = self.context3(x)
|
| 176 |
+
|
| 177 |
+
skip3, x = self.down3(x)
|
| 178 |
+
x = self.context4(x)
|
| 179 |
+
|
| 180 |
+
skip4, x = self.down4(x)
|
| 181 |
+
x = self.context5(x)
|
| 182 |
+
|
| 183 |
+
x = F.leaky_relu(self.bn_after_context5(x), negative_slope=self.leakiness, inplace=self.lrelu_inplace)
|
| 184 |
+
x = self.up1(x)
|
| 185 |
+
|
| 186 |
+
x = torch.cat((skip4, x), dim=1)
|
| 187 |
+
x = self.loc1(x)
|
| 188 |
+
x = self.up2(x)
|
| 189 |
+
|
| 190 |
+
x = torch.cat((skip3, x), dim=1)
|
| 191 |
+
x = self.loc2(x)
|
| 192 |
+
loc2_seg = self.final_nonlin(self.loc2_seg(x))
|
| 193 |
+
seg_outputs.append(loc2_seg)
|
| 194 |
+
x = self.up3(x)
|
| 195 |
+
|
| 196 |
+
x = torch.cat((skip2, x), dim=1)
|
| 197 |
+
x = self.loc3(x)
|
| 198 |
+
loc3_seg = self.final_nonlin(self.loc3_seg(x))
|
| 199 |
+
seg_outputs.append(loc3_seg)
|
| 200 |
+
x = self.up4(x)
|
| 201 |
+
|
| 202 |
+
x = torch.cat((skip1, x), dim=1)
|
| 203 |
+
x = F.leaky_relu(self.end_conv_1_bn(self.end_conv_1(x)), negative_slope=self.leakiness,
|
| 204 |
+
inplace=self.lrelu_inplace)
|
| 205 |
+
x = F.leaky_relu(self.end_conv_2_bn(self.end_conv_2(x)), negative_slope=self.leakiness,
|
| 206 |
+
inplace=self.lrelu_inplace)
|
| 207 |
+
x = self.final_nonlin(self.seg_layer(x))
|
| 208 |
+
seg_outputs.append(x)
|
| 209 |
+
|
| 210 |
+
if self.do_ds:
|
| 211 |
+
return seg_outputs[::-1]
|
| 212 |
+
else:
|
| 213 |
+
return seg_outputs[-1]
|
src/IDH/HD_BET/paths.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
# please refer to the readme on where to get the parameters. Save them in this folder:
|
| 4 |
+
# Original Path: "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params"
|
| 5 |
+
# Updated path for Docker container:
|
| 6 |
+
folder_with_parameter_files = "/app/IDH/hdbet_model"
|
src/IDH/HD_BET/predict_case.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def pad_patient_3D(patient, shape_must_be_divisible_by=16, min_size=None):
|
| 6 |
+
if not (isinstance(shape_must_be_divisible_by, list) or isinstance(shape_must_be_divisible_by, tuple)):
|
| 7 |
+
shape_must_be_divisible_by = [shape_must_be_divisible_by] * 3
|
| 8 |
+
shp = patient.shape
|
| 9 |
+
new_shp = [shp[0] + shape_must_be_divisible_by[0] - shp[0] % shape_must_be_divisible_by[0],
|
| 10 |
+
shp[1] + shape_must_be_divisible_by[1] - shp[1] % shape_must_be_divisible_by[1],
|
| 11 |
+
shp[2] + shape_must_be_divisible_by[2] - shp[2] % shape_must_be_divisible_by[2]]
|
| 12 |
+
for i in range(len(shp)):
|
| 13 |
+
if shp[i] % shape_must_be_divisible_by[i] == 0:
|
| 14 |
+
new_shp[i] -= shape_must_be_divisible_by[i]
|
| 15 |
+
if min_size is not None:
|
| 16 |
+
new_shp = np.max(np.vstack((np.array(new_shp), np.array(min_size))), 0)
|
| 17 |
+
return reshape_by_padding_upper_coords(patient, new_shp, 0), shp
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def reshape_by_padding_upper_coords(image, new_shape, pad_value=None):
|
| 21 |
+
shape = tuple(list(image.shape))
|
| 22 |
+
new_shape = tuple(np.max(np.concatenate((shape, new_shape)).reshape((2,len(shape))), axis=0))
|
| 23 |
+
if pad_value is None:
|
| 24 |
+
if len(shape) == 2:
|
| 25 |
+
pad_value = image[0,0]
|
| 26 |
+
elif len(shape) == 3:
|
| 27 |
+
pad_value = image[0, 0, 0]
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Image must be either 2 or 3 dimensional")
|
| 30 |
+
res = np.ones(list(new_shape), dtype=image.dtype) * pad_value
|
| 31 |
+
if len(shape) == 2:
|
| 32 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1])] = image
|
| 33 |
+
elif len(shape) == 3:
|
| 34 |
+
res[0:0+int(shape[0]), 0:0+int(shape[1]), 0:0+int(shape[2])] = image
|
| 35 |
+
return res
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def predict_case_3D_net(net, patient_data, do_mirroring, num_repeats, BATCH_SIZE=None,
|
| 39 |
+
new_shape_must_be_divisible_by=16, min_size=None, main_device=0, mirror_axes=(2, 3, 4)):
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
pad_res = []
|
| 42 |
+
for i in range(patient_data.shape[0]):
|
| 43 |
+
t, old_shape = pad_patient_3D(patient_data[i], new_shape_must_be_divisible_by, min_size)
|
| 44 |
+
pad_res.append(t[None])
|
| 45 |
+
|
| 46 |
+
patient_data = np.vstack(pad_res)
|
| 47 |
+
|
| 48 |
+
new_shp = patient_data.shape
|
| 49 |
+
|
| 50 |
+
data = np.zeros(tuple([1] + list(new_shp)), dtype=np.float32)
|
| 51 |
+
|
| 52 |
+
data[0] = patient_data
|
| 53 |
+
|
| 54 |
+
if BATCH_SIZE is not None:
|
| 55 |
+
data = np.vstack([data] * BATCH_SIZE)
|
| 56 |
+
|
| 57 |
+
a = torch.rand(data.shape).float()
|
| 58 |
+
|
| 59 |
+
if main_device == 'cpu':
|
| 60 |
+
pass
|
| 61 |
+
else:
|
| 62 |
+
a = a.cuda(main_device)
|
| 63 |
+
|
| 64 |
+
if do_mirroring:
|
| 65 |
+
x = 8
|
| 66 |
+
else:
|
| 67 |
+
x = 1
|
| 68 |
+
all_preds = []
|
| 69 |
+
for i in range(num_repeats):
|
| 70 |
+
for m in range(x):
|
| 71 |
+
data_for_net = np.array(data)
|
| 72 |
+
do_stuff = False
|
| 73 |
+
if m == 0:
|
| 74 |
+
do_stuff = True
|
| 75 |
+
pass
|
| 76 |
+
if m == 1 and (4 in mirror_axes):
|
| 77 |
+
do_stuff = True
|
| 78 |
+
data_for_net = data_for_net[:, :, :, :, ::-1]
|
| 79 |
+
if m == 2 and (3 in mirror_axes):
|
| 80 |
+
do_stuff = True
|
| 81 |
+
data_for_net = data_for_net[:, :, :, ::-1, :]
|
| 82 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 83 |
+
do_stuff = True
|
| 84 |
+
data_for_net = data_for_net[:, :, :, ::-1, ::-1]
|
| 85 |
+
if m == 4 and (2 in mirror_axes):
|
| 86 |
+
do_stuff = True
|
| 87 |
+
data_for_net = data_for_net[:, :, ::-1, :, :]
|
| 88 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 89 |
+
do_stuff = True
|
| 90 |
+
data_for_net = data_for_net[:, :, ::-1, :, ::-1]
|
| 91 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 92 |
+
do_stuff = True
|
| 93 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, :]
|
| 94 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 95 |
+
do_stuff = True
|
| 96 |
+
data_for_net = data_for_net[:, :, ::-1, ::-1, ::-1]
|
| 97 |
+
|
| 98 |
+
if do_stuff:
|
| 99 |
+
_ = a.data.copy_(torch.from_numpy(np.copy(data_for_net)))
|
| 100 |
+
p = net(a) # np.copy is necessary because ::-1 creates just a view i think
|
| 101 |
+
p = p.data.cpu().numpy()
|
| 102 |
+
|
| 103 |
+
if m == 0:
|
| 104 |
+
pass
|
| 105 |
+
if m == 1 and (4 in mirror_axes):
|
| 106 |
+
p = p[:, :, :, :, ::-1]
|
| 107 |
+
if m == 2 and (3 in mirror_axes):
|
| 108 |
+
p = p[:, :, :, ::-1, :]
|
| 109 |
+
if m == 3 and (4 in mirror_axes) and (3 in mirror_axes):
|
| 110 |
+
p = p[:, :, :, ::-1, ::-1]
|
| 111 |
+
if m == 4 and (2 in mirror_axes):
|
| 112 |
+
p = p[:, :, ::-1, :, :]
|
| 113 |
+
if m == 5 and (2 in mirror_axes) and (4 in mirror_axes):
|
| 114 |
+
p = p[:, :, ::-1, :, ::-1]
|
| 115 |
+
if m == 6 and (2 in mirror_axes) and (3 in mirror_axes):
|
| 116 |
+
p = p[:, :, ::-1, ::-1, :]
|
| 117 |
+
if m == 7 and (2 in mirror_axes) and (3 in mirror_axes) and (4 in mirror_axes):
|
| 118 |
+
p = p[:, :, ::-1, ::-1, ::-1]
|
| 119 |
+
all_preds.append(p)
|
| 120 |
+
|
| 121 |
+
stacked = np.vstack(all_preds)[:, :, :old_shape[0], :old_shape[1], :old_shape[2]]
|
| 122 |
+
predicted_segmentation = stacked.mean(0).argmax(0)
|
| 123 |
+
uncertainty = stacked.var(0)
|
| 124 |
+
bayesian_predictions = stacked
|
| 125 |
+
softmax_pred = stacked.mean(0)
|
| 126 |
+
return predicted_segmentation, bayesian_predictions, softmax_pred, uncertainty
|
src/IDH/HD_BET/run.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import SimpleITK as sitk
|
| 4 |
+
from HD_BET.data_loading import load_and_preprocess, save_segmentation_nifti
|
| 5 |
+
from HD_BET.predict_case import predict_case_3D_net
|
| 6 |
+
import imp
|
| 7 |
+
from HD_BET.utils import postprocess_prediction, SetNetworkToVal, get_params_fname, maybe_download_parameters
|
| 8 |
+
import os
|
| 9 |
+
import HD_BET
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def apply_bet(img, bet, out_fname):
|
| 13 |
+
img_itk = sitk.ReadImage(img)
|
| 14 |
+
img_npy = sitk.GetArrayFromImage(img_itk)
|
| 15 |
+
img_bet = sitk.GetArrayFromImage(sitk.ReadImage(bet))
|
| 16 |
+
img_npy[img_bet == 0] = 0
|
| 17 |
+
out = sitk.GetImageFromArray(img_npy)
|
| 18 |
+
out.CopyInformation(img_itk)
|
| 19 |
+
sitk.WriteImage(out, out_fname)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run_hd_bet(mri_fnames, output_fnames, mode="accurate", config_file=os.path.join(HD_BET.__path__[0], "config.py"), device=0,
|
| 23 |
+
postprocess=False, do_tta=True, keep_mask=True, overwrite=True):
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
:param mri_fnames: str or list/tuple of str
|
| 27 |
+
:param output_fnames: str or list/tuple of str. If list: must have the same length as output_fnames
|
| 28 |
+
:param mode: fast or accurate
|
| 29 |
+
:param config_file: config.py
|
| 30 |
+
:param device: either int (for device id) or 'cpu'
|
| 31 |
+
:param postprocess: whether to do postprocessing or not. Postprocessing here consists of simply discarding all
|
| 32 |
+
but the largest predicted connected component. Default False
|
| 33 |
+
:param do_tta: whether to do test time data augmentation by mirroring along all axes. Default: True. If you use
|
| 34 |
+
CPU you may want to turn that off to speed things up
|
| 35 |
+
:return:
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
list_of_param_files = []
|
| 39 |
+
|
| 40 |
+
if mode == 'fast':
|
| 41 |
+
params_file = get_params_fname(0)
|
| 42 |
+
maybe_download_parameters(0)
|
| 43 |
+
|
| 44 |
+
list_of_param_files.append(params_file)
|
| 45 |
+
elif mode == 'accurate':
|
| 46 |
+
for i in range(5):
|
| 47 |
+
params_file = get_params_fname(i)
|
| 48 |
+
maybe_download_parameters(i)
|
| 49 |
+
|
| 50 |
+
list_of_param_files.append(params_file)
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError("Unknown value for mode: %s. Expected: fast or accurate" % mode)
|
| 53 |
+
|
| 54 |
+
assert all([os.path.isfile(i) for i in list_of_param_files]), "Could not find parameter files"
|
| 55 |
+
|
| 56 |
+
cf = imp.load_source('cf', config_file)
|
| 57 |
+
cf = cf.config()
|
| 58 |
+
|
| 59 |
+
net, _ = cf.get_network(cf.val_use_train_mode, None)
|
| 60 |
+
if device == "cpu":
|
| 61 |
+
net = net.cpu()
|
| 62 |
+
else:
|
| 63 |
+
net.cuda(device)
|
| 64 |
+
|
| 65 |
+
if not isinstance(mri_fnames, (list, tuple)):
|
| 66 |
+
mri_fnames = [mri_fnames]
|
| 67 |
+
|
| 68 |
+
if not isinstance(output_fnames, (list, tuple)):
|
| 69 |
+
output_fnames = [output_fnames]
|
| 70 |
+
|
| 71 |
+
assert len(mri_fnames) == len(output_fnames), "mri_fnames and output_fnames must have the same length"
|
| 72 |
+
|
| 73 |
+
params = []
|
| 74 |
+
for p in list_of_param_files:
|
| 75 |
+
params.append(torch.load(p, map_location=lambda storage, loc: storage))
|
| 76 |
+
|
| 77 |
+
for in_fname, out_fname in zip(mri_fnames, output_fnames):
|
| 78 |
+
mask_fname = out_fname[:-7] + "_mask.nii.gz"
|
| 79 |
+
if overwrite or (not (os.path.isfile(mask_fname) and keep_mask) or not os.path.isfile(out_fname)):
|
| 80 |
+
print("File:", in_fname)
|
| 81 |
+
print("preprocessing...")
|
| 82 |
+
try:
|
| 83 |
+
data, data_dict = load_and_preprocess(in_fname)
|
| 84 |
+
except RuntimeError:
|
| 85 |
+
print("\nERROR\nCould not read file", in_fname, "\n")
|
| 86 |
+
continue
|
| 87 |
+
except AssertionError as e:
|
| 88 |
+
print(e)
|
| 89 |
+
continue
|
| 90 |
+
|
| 91 |
+
softmax_preds = []
|
| 92 |
+
|
| 93 |
+
print("prediction (CNN id)...")
|
| 94 |
+
for i, p in enumerate(params):
|
| 95 |
+
print(i)
|
| 96 |
+
net.load_state_dict(p)
|
| 97 |
+
net.eval()
|
| 98 |
+
net.apply(SetNetworkToVal(False, False))
|
| 99 |
+
_, _, softmax_pred, _ = predict_case_3D_net(net, data, do_tta, cf.val_num_repeats,
|
| 100 |
+
cf.val_batch_size, cf.net_input_must_be_divisible_by,
|
| 101 |
+
cf.val_min_size, device, cf.da_mirror_axes)
|
| 102 |
+
softmax_preds.append(softmax_pred[None])
|
| 103 |
+
|
| 104 |
+
seg = np.argmax(np.vstack(softmax_preds).mean(0), 0)
|
| 105 |
+
|
| 106 |
+
if postprocess:
|
| 107 |
+
seg = postprocess_prediction(seg)
|
| 108 |
+
|
| 109 |
+
print("exporting segmentation...")
|
| 110 |
+
save_segmentation_nifti(seg, data_dict, mask_fname)
|
| 111 |
+
|
| 112 |
+
apply_bet(in_fname, mask_fname, out_fname)
|
| 113 |
+
|
| 114 |
+
if not keep_mask:
|
| 115 |
+
os.remove(mask_fname)
|
| 116 |
+
|
| 117 |
+
|
src/IDH/HD_BET/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from urllib.request import urlopen
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.morphology import label
|
| 6 |
+
import os
|
| 7 |
+
from HD_BET.paths import folder_with_parameter_files
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_params_fname(fold):
|
| 11 |
+
return os.path.join(folder_with_parameter_files, "%d.model" % fold)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def maybe_download_parameters(fold=0, force_overwrite=False):
|
| 15 |
+
"""
|
| 16 |
+
Downloads the parameters for some fold if it is not present yet.
|
| 17 |
+
:param fold:
|
| 18 |
+
:param force_overwrite: if True the old parameter file will be deleted (if present) prior to download
|
| 19 |
+
:return:
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
assert 0 <= fold <= 4, "fold must be between 0 and 4"
|
| 23 |
+
|
| 24 |
+
if not os.path.isdir(folder_with_parameter_files):
|
| 25 |
+
maybe_mkdir_p(folder_with_parameter_files)
|
| 26 |
+
|
| 27 |
+
out_filename = get_params_fname(fold)
|
| 28 |
+
|
| 29 |
+
if force_overwrite and os.path.isfile(out_filename):
|
| 30 |
+
os.remove(out_filename)
|
| 31 |
+
|
| 32 |
+
if not os.path.isfile(out_filename):
|
| 33 |
+
url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold
|
| 34 |
+
print("Downloading", url, "...")
|
| 35 |
+
data = urlopen(url).read()
|
| 36 |
+
#out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model"
|
| 37 |
+
with open(out_filename, 'wb') as f:
|
| 38 |
+
f.write(data)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def init_weights(module):
|
| 42 |
+
if isinstance(module, nn.Conv3d):
|
| 43 |
+
module.weight = nn.init.kaiming_normal(module.weight, a=1e-2)
|
| 44 |
+
if module.bias is not None:
|
| 45 |
+
module.bias = nn.init.constant(module.bias, 0)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def softmax_helper(x):
|
| 49 |
+
rpt = [1 for _ in range(len(x.size()))]
|
| 50 |
+
rpt[1] = x.size(1)
|
| 51 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
| 52 |
+
e_x = torch.exp(x - x_max)
|
| 53 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SetNetworkToVal(object):
|
| 57 |
+
def __init__(self, use_dropout_sampling=False, norm_use_average=True):
|
| 58 |
+
self.norm_use_average = norm_use_average
|
| 59 |
+
self.use_dropout_sampling = use_dropout_sampling
|
| 60 |
+
|
| 61 |
+
def __call__(self, module):
|
| 62 |
+
if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout):
|
| 63 |
+
module.train(self.use_dropout_sampling)
|
| 64 |
+
elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \
|
| 65 |
+
isinstance(module, nn.InstanceNorm1d) \
|
| 66 |
+
or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \
|
| 67 |
+
isinstance(module, nn.BatchNorm1d):
|
| 68 |
+
module.train(not self.norm_use_average)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def postprocess_prediction(seg):
|
| 72 |
+
# basically look for connected components and choose the largest one, delete everything else
|
| 73 |
+
print("running postprocessing... ")
|
| 74 |
+
mask = seg != 0
|
| 75 |
+
lbls = label(mask, connectivity=mask.ndim)
|
| 76 |
+
lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)]
|
| 77 |
+
largest_region = np.argmax(lbls_sizes[1:]) + 1
|
| 78 |
+
seg[lbls != largest_region] = 0
|
| 79 |
+
return seg
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def subdirs(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 83 |
+
if join:
|
| 84 |
+
l = os.path.join
|
| 85 |
+
else:
|
| 86 |
+
l = lambda x, y: y
|
| 87 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))
|
| 88 |
+
and (prefix is None or i.startswith(prefix))
|
| 89 |
+
and (suffix is None or i.endswith(suffix))]
|
| 90 |
+
if sort:
|
| 91 |
+
res.sort()
|
| 92 |
+
return res
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def subfiles(folder, join=True, prefix=None, suffix=None, sort=True):
|
| 96 |
+
if join:
|
| 97 |
+
l = os.path.join
|
| 98 |
+
else:
|
| 99 |
+
l = lambda x, y: y
|
| 100 |
+
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
|
| 101 |
+
and (prefix is None or i.startswith(prefix))
|
| 102 |
+
and (suffix is None or i.endswith(suffix))]
|
| 103 |
+
if sort:
|
| 104 |
+
res.sort()
|
| 105 |
+
return res
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
subfolders = subdirs # I am tired of confusing those
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def maybe_mkdir_p(directory):
|
| 112 |
+
splits = directory.split("/")[1:]
|
| 113 |
+
for i in range(0, len(splits)):
|
| 114 |
+
if not os.path.isdir(os.path.join("", *splits[:i+1])):
|
| 115 |
+
os.mkdir(os.path.join("", *splits[:i+1]))
|
src/IDH/app_gradio.py
ADDED
|
@@ -0,0 +1,867 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import nibabel as nib
|
| 5 |
+
import numpy as np
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from typing import Tuple
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib
|
| 12 |
+
matplotlib.use('Agg') # Use non-interactive backend
|
| 13 |
+
import cv2 # For Gaussian Blur
|
| 14 |
+
import io # For saving plots to memory
|
| 15 |
+
import base64 # For encoding plots
|
| 16 |
+
import uuid # For unique IDs
|
| 17 |
+
import traceback # For detailed error printing
|
| 18 |
+
|
| 19 |
+
import SimpleITK as sitk
|
| 20 |
+
import itk
|
| 21 |
+
from scipy.signal import medfilt
|
| 22 |
+
import skimage.filters
|
| 23 |
+
|
| 24 |
+
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Resized, NormalizeIntensityd, ToTensord
|
| 25 |
+
|
| 26 |
+
from model import ViTBackboneNet, Classifier, SingleScanModelBP
|
| 27 |
+
|
| 28 |
+
# Optional HD-BET import (packaged locally like in MCI app)
|
| 29 |
+
try:
|
| 30 |
+
from HD_BET.run import run_hd_bet
|
| 31 |
+
from HD_BET.hd_bet import hd_bet
|
| 32 |
+
except Exception as e:
|
| 33 |
+
print(f"Warning: HD_BET not available: {e}")
|
| 34 |
+
run_hd_bet = None
|
| 35 |
+
hd_bet = None
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
APP_DIR = os.path.dirname(__file__)
|
| 39 |
+
TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates")
|
| 40 |
+
PARAMS_RIGID_PATH = os.path.join(TEMPLATE_DIR, "Parameters_Rigid.txt")
|
| 41 |
+
DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "temp_head.nii.gz")
|
| 42 |
+
FLAIR_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_04.5-18.5_t2w.nii")
|
| 43 |
+
T1C_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii")
|
| 44 |
+
HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py")
|
| 45 |
+
HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_config() -> dict:
|
| 49 |
+
cfg_path = os.path.join(APP_DIR, "config.yml")
|
| 50 |
+
if os.path.exists(cfg_path):
|
| 51 |
+
with open(cfg_path, "r") as f:
|
| 52 |
+
return yaml.safe_load(f)
|
| 53 |
+
# Defaults
|
| 54 |
+
return {
|
| 55 |
+
"gpu": {"device": "cpu"},
|
| 56 |
+
"infer": {
|
| 57 |
+
"checkpoints": "./checkpoints/idh_model.ckpt",
|
| 58 |
+
"simclr_checkpoint": "./checkpoints/simclr_vitb.ckpt",
|
| 59 |
+
"threshold": 0.5,
|
| 60 |
+
"image_size": [96, 96, 96],
|
| 61 |
+
},
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def build_model(cfg: dict):
|
| 66 |
+
device = torch.device(cfg.get("gpu", {}).get("device", "cpu"))
|
| 67 |
+
infer_cfg = cfg.get("infer", {})
|
| 68 |
+
simclr_path = os.path.join(APP_DIR, infer_cfg.get("simclr_checkpoint", ""))
|
| 69 |
+
ckpt_path = os.path.join(APP_DIR, infer_cfg.get("checkpoints", ""))
|
| 70 |
+
|
| 71 |
+
backbone = ViTBackboneNet(simclr_ckpt_path=simclr_path)
|
| 72 |
+
classifier = Classifier(d_model=768, num_classes=1)
|
| 73 |
+
model = SingleScanModelBP(backbone, classifier)
|
| 74 |
+
|
| 75 |
+
# Load finetuned checkpoint (Lightning or plain state_dict)
|
| 76 |
+
if os.path.exists(ckpt_path):
|
| 77 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 78 |
+
if "state_dict" in checkpoint:
|
| 79 |
+
state_dict = checkpoint["state_dict"]
|
| 80 |
+
new_state_dict = {}
|
| 81 |
+
for key, value in state_dict.items():
|
| 82 |
+
if key.startswith("model."):
|
| 83 |
+
new_state_dict[key[len("model."):]] = value
|
| 84 |
+
else:
|
| 85 |
+
new_state_dict[key] = value
|
| 86 |
+
else:
|
| 87 |
+
new_state_dict = checkpoint
|
| 88 |
+
model.load_state_dict(new_state_dict, strict=False)
|
| 89 |
+
else:
|
| 90 |
+
print(f"Warning: Finetuned checkpoint not found at {ckpt_path}. Model will use backbone-only weights.")
|
| 91 |
+
|
| 92 |
+
model.to(device)
|
| 93 |
+
model.eval()
|
| 94 |
+
return model, device
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ---------------- Preprocessing (Registration + Enhancement + Skull Stripping) ----------------
|
| 98 |
+
|
| 99 |
+
def bias_field_correction(img_array: np.ndarray) -> np.ndarray:
|
| 100 |
+
image = sitk.GetImageFromArray(img_array.astype(np.float32))
|
| 101 |
+
if image.GetPixelID() != sitk.sitkFloat32:
|
| 102 |
+
image = sitk.Cast(image, sitk.sitkFloat32)
|
| 103 |
+
maskImage = sitk.OtsuThreshold(image, 0, 1, 200)
|
| 104 |
+
corrector = sitk.N4BiasFieldCorrectionImageFilter()
|
| 105 |
+
numberFittingLevels = 4
|
| 106 |
+
max_iters = [min(50 * (2 ** i), 200) for i in range(numberFittingLevels)]
|
| 107 |
+
corrector.SetMaximumNumberOfIterations(max_iters)
|
| 108 |
+
corrected_image = corrector.Execute(image, maskImage)
|
| 109 |
+
return sitk.GetArrayFromImage(corrected_image)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def denoise(volume: np.ndarray, kernel_size: int = 3) -> np.ndarray:
|
| 113 |
+
return medfilt(volume, kernel_size)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def rescale_intensity(volume: np.ndarray, percentils=[0.5, 99.5], bins_num=256) -> np.ndarray:
|
| 117 |
+
volume_float = volume.astype(np.float32)
|
| 118 |
+
try:
|
| 119 |
+
t = skimage.filters.threshold_otsu(volume_float, nbins=256)
|
| 120 |
+
volume_masked = np.copy(volume_float)
|
| 121 |
+
volume_masked[volume_masked < t] = 0
|
| 122 |
+
obj_volume = volume_masked[np.where(volume_masked > 0)]
|
| 123 |
+
except ValueError:
|
| 124 |
+
obj_volume = volume_float.flatten()
|
| 125 |
+
if obj_volume.size == 0:
|
| 126 |
+
obj_volume = volume_float.flatten()
|
| 127 |
+
min_value = np.min(obj_volume)
|
| 128 |
+
max_value = np.max(obj_volume)
|
| 129 |
+
else:
|
| 130 |
+
min_value = np.percentile(obj_volume, percentils[0])
|
| 131 |
+
max_value = np.percentile(obj_volume, percentils[1])
|
| 132 |
+
denom = max_value - min_value
|
| 133 |
+
if denom < 1e-6:
|
| 134 |
+
denom = 1e-6
|
| 135 |
+
if bins_num == 0:
|
| 136 |
+
output_volume = (volume_float - min_value) / denom
|
| 137 |
+
output_volume = np.clip(output_volume, 0.0, 1.0)
|
| 138 |
+
else:
|
| 139 |
+
output_volume = np.round((volume_float - min_value) / denom * (bins_num - 1))
|
| 140 |
+
output_volume = np.clip(output_volume, 0, bins_num - 1)
|
| 141 |
+
return output_volume.astype(np.float32)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def equalize_hist(volume: np.ndarray, bins_num=256) -> np.ndarray:
|
| 145 |
+
mask = volume > 1e-6
|
| 146 |
+
obj_volume = volume[mask]
|
| 147 |
+
if obj_volume.size == 0:
|
| 148 |
+
return volume
|
| 149 |
+
hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max()))
|
| 150 |
+
cdf = hist.cumsum()
|
| 151 |
+
cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1])
|
| 152 |
+
equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized)
|
| 153 |
+
equalized_volume = np.copy(volume)
|
| 154 |
+
equalized_volume[mask] = equalized_obj_volume
|
| 155 |
+
return equalized_volume.astype(np.float32)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def run_enhance_on_file(input_nifti_path: str, output_nifti_path: str):
|
| 159 |
+
"""
|
| 160 |
+
Simplified enhancement - just copy the file since N4 is now done in registration.
|
| 161 |
+
This maintains compatibility with the existing preprocessing pipeline.
|
| 162 |
+
"""
|
| 163 |
+
print(f"Enhancement step (N4 already applied during registration): {input_nifti_path}")
|
| 164 |
+
# Since N4 bias correction is now handled in registration, just copy the file
|
| 165 |
+
import shutil
|
| 166 |
+
shutil.copy2(input_nifti_path, output_nifti_path)
|
| 167 |
+
print(f"Enhancement complete (passthrough): {output_nifti_path}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def register_image_sitk(input_nifti_path: str, output_nifti_path: str, template_path: str, interp_type='linear'):
|
| 171 |
+
"""
|
| 172 |
+
MRI registration with SimpleITK matching the provided script approach.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
input_nifti_path: Path to input NIfTI file
|
| 176 |
+
output_nifti_path: Path to save registered output
|
| 177 |
+
template_path: Path to template image
|
| 178 |
+
interp_type: Interpolation type ('linear', 'bspline', 'nearest_neighbor')
|
| 179 |
+
"""
|
| 180 |
+
print(f"Registering {input_nifti_path} to template {template_path}")
|
| 181 |
+
|
| 182 |
+
# Read template and moving images
|
| 183 |
+
fixed_img = sitk.ReadImage(template_path, sitk.sitkFloat32)
|
| 184 |
+
moving_img = sitk.ReadImage(input_nifti_path, sitk.sitkFloat32)
|
| 185 |
+
|
| 186 |
+
# Apply N4 bias correction to moving image
|
| 187 |
+
moving_img = sitk.N4BiasFieldCorrection(moving_img)
|
| 188 |
+
|
| 189 |
+
# Resample fixed image to 1mm isotropic
|
| 190 |
+
old_size = fixed_img.GetSize()
|
| 191 |
+
old_spacing = fixed_img.GetSpacing()
|
| 192 |
+
new_spacing = (1, 1, 1)
|
| 193 |
+
new_size = [
|
| 194 |
+
int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))),
|
| 195 |
+
int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))),
|
| 196 |
+
int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2])))
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
# Set interpolation type
|
| 200 |
+
if interp_type == 'linear':
|
| 201 |
+
interp_type = sitk.sitkLinear
|
| 202 |
+
elif interp_type == 'bspline':
|
| 203 |
+
interp_type = sitk.sitkBSpline
|
| 204 |
+
elif interp_type == 'nearest_neighbor':
|
| 205 |
+
interp_type = sitk.sitkNearestNeighbor
|
| 206 |
+
else:
|
| 207 |
+
interp_type = sitk.sitkLinear
|
| 208 |
+
|
| 209 |
+
# Resample fixed image
|
| 210 |
+
resample = sitk.ResampleImageFilter()
|
| 211 |
+
resample.SetOutputSpacing(new_spacing)
|
| 212 |
+
resample.SetSize(new_size)
|
| 213 |
+
resample.SetOutputOrigin(fixed_img.GetOrigin())
|
| 214 |
+
resample.SetOutputDirection(fixed_img.GetDirection())
|
| 215 |
+
resample.SetInterpolator(interp_type)
|
| 216 |
+
resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue())
|
| 217 |
+
resample.SetOutputPixelType(sitk.sitkFloat32)
|
| 218 |
+
fixed_img = resample.Execute(fixed_img)
|
| 219 |
+
|
| 220 |
+
# Initialize transform
|
| 221 |
+
transform = sitk.CenteredTransformInitializer(
|
| 222 |
+
fixed_img,
|
| 223 |
+
moving_img,
|
| 224 |
+
sitk.Euler3DTransform(),
|
| 225 |
+
sitk.CenteredTransformInitializerFilter.GEOMETRY)
|
| 226 |
+
|
| 227 |
+
# Set up registration method
|
| 228 |
+
registration_method = sitk.ImageRegistrationMethod()
|
| 229 |
+
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
|
| 230 |
+
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
|
| 231 |
+
registration_method.SetMetricSamplingPercentage(0.01)
|
| 232 |
+
registration_method.SetInterpolator(sitk.sitkLinear)
|
| 233 |
+
registration_method.SetOptimizerAsGradientDescent(
|
| 234 |
+
learningRate=1.0,
|
| 235 |
+
numberOfIterations=100,
|
| 236 |
+
convergenceMinimumValue=1e-6,
|
| 237 |
+
convergenceWindowSize=10)
|
| 238 |
+
registration_method.SetOptimizerScalesFromPhysicalShift()
|
| 239 |
+
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
|
| 240 |
+
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
|
| 241 |
+
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
|
| 242 |
+
registration_method.SetInitialTransform(transform)
|
| 243 |
+
|
| 244 |
+
# Execute registration
|
| 245 |
+
final_transform = registration_method.Execute(fixed_img, moving_img)
|
| 246 |
+
|
| 247 |
+
# Apply transform and save registered image
|
| 248 |
+
moving_img_resampled = sitk.Resample(
|
| 249 |
+
moving_img,
|
| 250 |
+
fixed_img,
|
| 251 |
+
final_transform,
|
| 252 |
+
sitk.sitkLinear,
|
| 253 |
+
0.0,
|
| 254 |
+
moving_img.GetPixelID())
|
| 255 |
+
|
| 256 |
+
sitk.WriteImage(moving_img_resampled, output_nifti_path)
|
| 257 |
+
print(f"Registration complete. Saved to: {output_nifti_path}")
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def register_image(input_nifti_path: str, output_nifti_path: str):
|
| 261 |
+
"""Wrapper to maintain compatibility - now uses SimpleITK registration."""
|
| 262 |
+
if not os.path.exists(DEFAULT_TEMPLATE_PATH):
|
| 263 |
+
raise FileNotFoundError(f"Template file missing: {DEFAULT_TEMPLATE_PATH}")
|
| 264 |
+
register_image_sitk(input_nifti_path, output_nifti_path, DEFAULT_TEMPLATE_PATH)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def run_skull_stripping(input_nifti_path: str, output_dir: str):
|
| 268 |
+
"""
|
| 269 |
+
Brain extraction using HD-BET direct integration matching the script approach.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
input_nifti_path: Path to input NIfTI file
|
| 273 |
+
output_dir: Directory to save skull-stripped output
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
tuple: (output_file_path, output_mask_path)
|
| 277 |
+
"""
|
| 278 |
+
print(f"Running HD-BET skull stripping on {input_nifti_path}")
|
| 279 |
+
|
| 280 |
+
if hd_bet is None:
|
| 281 |
+
raise RuntimeError("HD-BET not available. Please include HD_BET and hdbet_model in src/IDH.")
|
| 282 |
+
|
| 283 |
+
if not os.path.exists(HD_BET_MODEL_DIR):
|
| 284 |
+
raise FileNotFoundError(f"HD-BET models not found at {HD_BET_MODEL_DIR}")
|
| 285 |
+
|
| 286 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
# Get base filename and prepare HD-BET compatible naming
|
| 289 |
+
base_name = os.path.basename(input_nifti_path).replace('.nii.gz', '').replace('.nii', '')
|
| 290 |
+
|
| 291 |
+
# HD-BET expects files with _0000 suffix - create temporary file if needed
|
| 292 |
+
temp_input_dir = os.path.join(output_dir, "temp_input")
|
| 293 |
+
os.makedirs(temp_input_dir, exist_ok=True)
|
| 294 |
+
|
| 295 |
+
# Copy input file with _0000 suffix for HD-BET
|
| 296 |
+
temp_input_path = os.path.join(temp_input_dir, f"{base_name}_0000.nii.gz")
|
| 297 |
+
shutil.copy2(input_nifti_path, temp_input_path)
|
| 298 |
+
|
| 299 |
+
# Set device
|
| 300 |
+
device = "0" if torch.cuda.is_available() else "cpu"
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
# Also try setting the specific model file path
|
| 304 |
+
model_file = os.path.join(HD_BET_MODEL_DIR, '0.model')
|
| 305 |
+
|
| 306 |
+
if os.path.exists(model_file):
|
| 307 |
+
print(f"Local model file exists at: {model_file}")
|
| 308 |
+
else:
|
| 309 |
+
print(f"Warning: Model file not found at: {model_file}")
|
| 310 |
+
# List directory contents for debugging
|
| 311 |
+
if os.path.exists(HD_BET_MODEL_DIR):
|
| 312 |
+
print(f"Contents of {HD_BET_MODEL_DIR}: {os.listdir(HD_BET_MODEL_DIR)}")
|
| 313 |
+
else:
|
| 314 |
+
print(f"Directory {HD_BET_MODEL_DIR} does not exist")
|
| 315 |
+
|
| 316 |
+
# Run HD-BET directly on the temporary directory
|
| 317 |
+
print(f"Running hd_bet with input_dir: {temp_input_dir}, output_dir: {output_dir}")
|
| 318 |
+
hd_bet(temp_input_dir, output_dir, device=device, mode='fast', tta=0)
|
| 319 |
+
|
| 320 |
+
# HD-BET outputs files with original naming convention
|
| 321 |
+
output_file_path = os.path.join(output_dir, f"{base_name}_0000.nii.gz")
|
| 322 |
+
output_mask_path = os.path.join(output_dir, f"{base_name}_0000_mask.nii.gz")
|
| 323 |
+
|
| 324 |
+
# Rename to expected format for compatibility
|
| 325 |
+
final_output_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz")
|
| 326 |
+
final_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz")
|
| 327 |
+
|
| 328 |
+
if os.path.exists(output_file_path):
|
| 329 |
+
shutil.move(output_file_path, final_output_path)
|
| 330 |
+
if os.path.exists(output_mask_path):
|
| 331 |
+
shutil.move(output_mask_path, final_mask_path)
|
| 332 |
+
|
| 333 |
+
# Clean up temporary directory
|
| 334 |
+
shutil.rmtree(temp_input_dir, ignore_errors=True)
|
| 335 |
+
|
| 336 |
+
if not os.path.exists(final_output_path):
|
| 337 |
+
raise RuntimeError(f"HD-BET did not produce output file: {final_output_path}")
|
| 338 |
+
|
| 339 |
+
print(f"Skull stripping complete. Output saved to: {final_output_path}")
|
| 340 |
+
return final_output_path, final_mask_path
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
# Clean up on error
|
| 344 |
+
shutil.rmtree(temp_input_dir, ignore_errors=True)
|
| 345 |
+
raise RuntimeError(f"HD-BET skull stripping failed: {str(e)}")
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# ---------------- Saliency Generation ----------------
|
| 349 |
+
|
| 350 |
+
def extract_attention_map(vit_model, image, layer_idx=-1, img_size=(96, 96, 96), patch_size=16):
|
| 351 |
+
"""
|
| 352 |
+
Extracts the attention map from a Vision Transformer (ViT) model.
|
| 353 |
+
|
| 354 |
+
This function wraps the attention blocks of the ViT to capture the attention
|
| 355 |
+
weights during a forward pass. It then processes these weights to generate
|
| 356 |
+
a 3D saliency map corresponding to the model's focus on the input image.
|
| 357 |
+
"""
|
| 358 |
+
attention_maps = {}
|
| 359 |
+
original_attns = {}
|
| 360 |
+
|
| 361 |
+
# A wrapper class to intercept and store attention weights from a ViT block.
|
| 362 |
+
class AttentionWithWeights(torch.nn.Module):
|
| 363 |
+
def __init__(self, original_attn_module):
|
| 364 |
+
super().__init__()
|
| 365 |
+
self.original_attn_module = original_attn_module
|
| 366 |
+
self.attn_weights = None
|
| 367 |
+
|
| 368 |
+
def forward(self, x):
|
| 369 |
+
# The original implementation of the attention module may not return
|
| 370 |
+
# the attention weights. This wrapper recalculates them to ensure they
|
| 371 |
+
# are captured. This is based on the standard ViT attention mechanism.
|
| 372 |
+
output = self.original_attn_module(x)
|
| 373 |
+
if hasattr(self.original_attn_module, 'qkv'):
|
| 374 |
+
qkv = self.original_attn_module.qkv(x)
|
| 375 |
+
batch_size, seq_len, _ = x.shape
|
| 376 |
+
# Assuming qkv has been fused and has shape (batch_size, seq_len, 3 * num_heads * head_dim)
|
| 377 |
+
qkv = qkv.reshape(batch_size, seq_len, 3, self.original_attn_module.num_heads, -1)
|
| 378 |
+
qkv = qkv.permute(2, 0, 3, 1, 4)
|
| 379 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 380 |
+
attn = (q @ k.transpose(-2, -1)) * self.original_attn_module.scale
|
| 381 |
+
self.attn_weights = attn.softmax(dim=-1)
|
| 382 |
+
return output
|
| 383 |
+
|
| 384 |
+
# Store original attention modules and replace with wrappers
|
| 385 |
+
for i, block in enumerate(vit_model.blocks):
|
| 386 |
+
if hasattr(block, 'attn'):
|
| 387 |
+
original_attns[i] = block.attn
|
| 388 |
+
block.attn = AttentionWithWeights(block.attn)
|
| 389 |
+
|
| 390 |
+
try:
|
| 391 |
+
# Perform a forward pass to execute the wrapped modules and capture weights
|
| 392 |
+
with torch.no_grad():
|
| 393 |
+
_ = vit_model(image)
|
| 394 |
+
|
| 395 |
+
# Collect the captured attention weights from each block
|
| 396 |
+
for i, block in enumerate(vit_model.blocks):
|
| 397 |
+
if hasattr(block.attn, 'attn_weights') and block.attn.attn_weights is not None:
|
| 398 |
+
attention_maps[f"layer_{i}"] = block.attn.attn_weights.detach()
|
| 399 |
+
|
| 400 |
+
finally:
|
| 401 |
+
# Restore original attention modules
|
| 402 |
+
for i, original_attn in original_attns.items():
|
| 403 |
+
vit_model.blocks[i].attn = original_attn
|
| 404 |
+
|
| 405 |
+
if not attention_maps:
|
| 406 |
+
raise RuntimeError("Could not extract any attention maps. Please check the ViT model structure.")
|
| 407 |
+
|
| 408 |
+
# Select the attention map from the specified layer
|
| 409 |
+
if layer_idx < 0:
|
| 410 |
+
layer_idx = len(attention_maps) + layer_idx
|
| 411 |
+
layer_name = f"layer_{layer_idx}"
|
| 412 |
+
if layer_name not in attention_maps:
|
| 413 |
+
raise ValueError(f"Layer {layer_idx} not found. Available layers: {list(attention_maps.keys())}")
|
| 414 |
+
|
| 415 |
+
layer_attn = attention_maps[layer_name]
|
| 416 |
+
# Average attention across all heads
|
| 417 |
+
head_attn = layer_attn[0].mean(dim=0)
|
| 418 |
+
# Get attention from the [CLS] token to all other image patches
|
| 419 |
+
cls_attn = head_attn[0, 1:]
|
| 420 |
+
|
| 421 |
+
# Reshape the 1D attention vector into a 3D volume
|
| 422 |
+
patches_per_dim = img_size[0] // patch_size
|
| 423 |
+
total_patches = patches_per_dim ** 3
|
| 424 |
+
|
| 425 |
+
# Pad or truncate if the number of patches doesn't align
|
| 426 |
+
if cls_attn.shape[0] != total_patches:
|
| 427 |
+
if cls_attn.shape[0] > total_patches:
|
| 428 |
+
cls_attn = cls_attn[:total_patches]
|
| 429 |
+
else:
|
| 430 |
+
padded = torch.zeros(total_patches, device=cls_attn.device)
|
| 431 |
+
padded[:cls_attn.shape[0]] = cls_attn
|
| 432 |
+
cls_attn = padded
|
| 433 |
+
|
| 434 |
+
cls_attn_3d = cls_attn.reshape(patches_per_dim, patches_per_dim, patches_per_dim)
|
| 435 |
+
cls_attn_3d = cls_attn_3d.unsqueeze(0).unsqueeze(0) # Add batch and channel dims
|
| 436 |
+
|
| 437 |
+
# Upsample the attention map to the full image resolution
|
| 438 |
+
upsampled_attn = torch.nn.functional.interpolate(
|
| 439 |
+
cls_attn_3d,
|
| 440 |
+
size=img_size,
|
| 441 |
+
mode='trilinear',
|
| 442 |
+
align_corners=False
|
| 443 |
+
).squeeze()
|
| 444 |
+
|
| 445 |
+
# Normalize the map to [0, 1] for visualization
|
| 446 |
+
upsampled_attn = upsampled_attn.cpu().numpy()
|
| 447 |
+
upsampled_attn = (upsampled_attn - upsampled_attn.min()) / (upsampled_attn.max() - upsampled_attn.min())
|
| 448 |
+
return upsampled_attn
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def generate_saliency_dual(model, input_tensor, layer_idx=-1):
|
| 452 |
+
"""
|
| 453 |
+
Generate saliency maps for dual-input IDH model.
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
model: The complete IDH model
|
| 457 |
+
input_tensor: Dual input tensor (batch_size, 2, C, D, H, W)
|
| 458 |
+
layer_idx: ViT layer to visualize
|
| 459 |
+
|
| 460 |
+
Returns:
|
| 461 |
+
tuple: (flair_input_3d, t1c_input_3d, flair_saliency_3d)
|
| 462 |
+
"""
|
| 463 |
+
print("Generating saliency maps for dual input...")
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
# Extract individual images from dual input
|
| 467 |
+
# input_tensor shape: [batch_size, 2, C, D, H, W]
|
| 468 |
+
flair_tensor = input_tensor[:, 0] # [batch, C, D, H, W]
|
| 469 |
+
t1c_tensor = input_tensor[:, 1] # [batch, C, D, H, W]
|
| 470 |
+
|
| 471 |
+
# Get the ViT backbone
|
| 472 |
+
vit_model = model.backbone.backbone
|
| 473 |
+
|
| 474 |
+
# Generate attention map only for FLAIR
|
| 475 |
+
flair_attn = extract_attention_map(vit_model, flair_tensor, layer_idx)
|
| 476 |
+
|
| 477 |
+
# Convert input tensors to numpy for visualization
|
| 478 |
+
flair_input_3d = flair_tensor.squeeze().cpu().detach().numpy()
|
| 479 |
+
t1c_input_3d = t1c_tensor.squeeze().cpu().detach().numpy()
|
| 480 |
+
|
| 481 |
+
print("Saliency maps generated successfully.")
|
| 482 |
+
return flair_input_3d, t1c_input_3d, flair_attn
|
| 483 |
+
|
| 484 |
+
except Exception as e:
|
| 485 |
+
print(f"Error during saliency generation: {e}")
|
| 486 |
+
traceback.print_exc()
|
| 487 |
+
return None, None, None
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
# ---------------- Visualization Functions ----------------
|
| 491 |
+
|
| 492 |
+
def create_slice_plots_dual(flair_data_3d, t1c_data_3d, flair_saliency_3d, slice_index):
|
| 493 |
+
"""Create slice plots for simplified dual input visualization: T1c, FLAIR, FLAIR attention."""
|
| 494 |
+
print(f"Generating plots for slice index: {slice_index}")
|
| 495 |
+
|
| 496 |
+
if any(data is None for data in [flair_data_3d, t1c_data_3d, flair_saliency_3d]):
|
| 497 |
+
return None, None, None
|
| 498 |
+
|
| 499 |
+
# Check bounds - using axis 2 for axial slices
|
| 500 |
+
if not (0 <= slice_index < flair_data_3d.shape[2]):
|
| 501 |
+
print(f"Error: Slice index {slice_index} out of bounds (0-{flair_data_3d.shape[2]-1}).")
|
| 502 |
+
return None, None, None
|
| 503 |
+
|
| 504 |
+
def save_plot_to_numpy(fig):
|
| 505 |
+
with io.BytesIO() as buf:
|
| 506 |
+
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75)
|
| 507 |
+
plt.close(fig)
|
| 508 |
+
buf.seek(0)
|
| 509 |
+
img_arr = plt.imread(buf, format='png')
|
| 510 |
+
return (img_arr * 255).astype(np.uint8)
|
| 511 |
+
|
| 512 |
+
try:
|
| 513 |
+
# Extract axial slices - using axis 2 (last dimension)
|
| 514 |
+
flair_slice = flair_data_3d[:, :, slice_index]
|
| 515 |
+
t1c_slice = t1c_data_3d[:, :, slice_index]
|
| 516 |
+
flair_saliency_slice = flair_saliency_3d[:, :, slice_index]
|
| 517 |
+
|
| 518 |
+
# Normalize input slices
|
| 519 |
+
def normalize_slice(slice_data, volume_data):
|
| 520 |
+
p1, p99 = np.percentile(volume_data, (1, 99))
|
| 521 |
+
denom = max(p99 - p1, 1e-6)
|
| 522 |
+
return np.clip((slice_data - p1) / denom, 0, 1)
|
| 523 |
+
|
| 524 |
+
flair_slice_norm = normalize_slice(flair_slice, flair_data_3d)
|
| 525 |
+
t1c_slice_norm = normalize_slice(t1c_slice, t1c_data_3d)
|
| 526 |
+
|
| 527 |
+
# Process saliency slice
|
| 528 |
+
def process_saliency_slice(saliency_slice, saliency_volume):
|
| 529 |
+
saliency_slice = np.copy(saliency_slice)
|
| 530 |
+
saliency_slice[saliency_slice < 0] = 0
|
| 531 |
+
saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0)
|
| 532 |
+
s_max = max(np.max(saliency_volume[saliency_volume >= 0]), 1e-6)
|
| 533 |
+
saliency_slice_norm = saliency_slice_blurred / s_max
|
| 534 |
+
return np.where(saliency_slice_norm > 0.0, saliency_slice_norm, 0)
|
| 535 |
+
|
| 536 |
+
flair_sal_processed = process_saliency_slice(flair_saliency_slice, flair_saliency_3d)
|
| 537 |
+
|
| 538 |
+
# Create plots
|
| 539 |
+
plots = []
|
| 540 |
+
|
| 541 |
+
# T1c Input
|
| 542 |
+
fig1, ax1 = plt.subplots(figsize=(6, 6))
|
| 543 |
+
ax1.imshow(t1c_slice_norm, cmap='gray', interpolation='none', origin='lower')
|
| 544 |
+
ax1.axis('off')
|
| 545 |
+
ax1.set_title('T1c Input', fontsize=14, color='white', pad=10)
|
| 546 |
+
plots.append(save_plot_to_numpy(fig1))
|
| 547 |
+
|
| 548 |
+
# FLAIR Input
|
| 549 |
+
fig2, ax2 = plt.subplots(figsize=(6, 6))
|
| 550 |
+
ax2.imshow(flair_slice_norm, cmap='gray', interpolation='none', origin='lower')
|
| 551 |
+
ax2.axis('off')
|
| 552 |
+
ax2.set_title('FLAIR Input', fontsize=14, color='white', pad=10)
|
| 553 |
+
plots.append(save_plot_to_numpy(fig2))
|
| 554 |
+
|
| 555 |
+
# FLAIR Attention
|
| 556 |
+
fig3, ax3 = plt.subplots(figsize=(6, 6))
|
| 557 |
+
ax3.imshow(flair_sal_processed, cmap='magma', interpolation='none', origin='lower', vmin=0)
|
| 558 |
+
ax3.axis('off')
|
| 559 |
+
ax3.set_title('FLAIR Attention', fontsize=14, color='white', pad=10)
|
| 560 |
+
plots.append(save_plot_to_numpy(fig3))
|
| 561 |
+
|
| 562 |
+
print(f"Generated 3 plots successfully for axial slice {slice_index}.")
|
| 563 |
+
return tuple(plots)
|
| 564 |
+
|
| 565 |
+
except Exception as e:
|
| 566 |
+
print(f"Error generating plots for slice {slice_index}: {e}")
|
| 567 |
+
traceback.print_exc()
|
| 568 |
+
return tuple([None] * 3)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# ---------------- Inference ----------------
|
| 572 |
+
|
| 573 |
+
def get_dual_validation_transform(image_size: Tuple[int, int, int]):
|
| 574 |
+
return Compose([
|
| 575 |
+
LoadImaged(keys=["image1", "image2"]),
|
| 576 |
+
EnsureChannelFirstd(keys=["image1", "image2"]),
|
| 577 |
+
Resized(keys=["image1", "image2"], spatial_size=tuple(image_size), mode="trilinear"),
|
| 578 |
+
NormalizeIntensityd(keys=["image1", "image2"], nonzero=True, channel_wise=True),
|
| 579 |
+
ToTensord(keys=["image1", "image2"]),
|
| 580 |
+
])
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def preprocess_dual_nifti(flair_path: str, t1c_path: str, image_size: Tuple[int, int, int], device: torch.device) -> torch.Tensor:
|
| 584 |
+
transform = get_dual_validation_transform(image_size)
|
| 585 |
+
sample = {"image1": flair_path, "image2": t1c_path}
|
| 586 |
+
sample = transform(sample)
|
| 587 |
+
img1 = sample["image1"] # (C, D, H, W)
|
| 588 |
+
img2 = sample["image2"] # (C, D, H, W)
|
| 589 |
+
images = torch.stack([img1, img2], dim=0).unsqueeze(0).to(device)
|
| 590 |
+
return images
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
def predict_idh(flair_file, t1c_file, threshold: float, do_preprocess: bool, generate_saliency: bool, cfg: dict, model, device):
|
| 594 |
+
try:
|
| 595 |
+
if flair_file is None or t1c_file is None:
|
| 596 |
+
return {"error": "Please upload both FLAIR and T1c NIfTI files (.nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 597 |
+
|
| 598 |
+
flair_path = flair_file.name if hasattr(flair_file, 'name') else flair_file
|
| 599 |
+
t1c_path = t1c_file.name if hasattr(t1c_file, 'name') else t1c_file
|
| 600 |
+
|
| 601 |
+
if not (flair_path.endswith(".nii") or flair_path.endswith(".nii.gz")):
|
| 602 |
+
return {"error": "FLAIR must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 603 |
+
if not (t1c_path.endswith(".nii") or t1c_path.endswith(".nii.gz")):
|
| 604 |
+
return {"error": "T1c must be a NIfTI file (.nii or .nii.gz)."}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 605 |
+
|
| 606 |
+
work_dir = tempfile.mkdtemp()
|
| 607 |
+
flair_final_path, t1c_final_path = flair_path, t1c_path
|
| 608 |
+
|
| 609 |
+
try:
|
| 610 |
+
# Optional preprocessing pipeline
|
| 611 |
+
if do_preprocess:
|
| 612 |
+
# Registration (use modality-specific templates)
|
| 613 |
+
flair_reg = os.path.join(work_dir, "flair_registered.nii.gz")
|
| 614 |
+
t1c_reg = os.path.join(work_dir, "t1c_registered.nii.gz")
|
| 615 |
+
register_image_sitk(flair_path, flair_reg, FLAIR_TEMPLATE_PATH)
|
| 616 |
+
register_image_sitk(t1c_path, t1c_reg, T1C_TEMPLATE_PATH)
|
| 617 |
+
# Enhancement
|
| 618 |
+
flair_enh = os.path.join(work_dir, "flair_enhanced.nii.gz")
|
| 619 |
+
t1c_enh = os.path.join(work_dir, "t1c_enhanced.nii.gz")
|
| 620 |
+
run_enhance_on_file(flair_reg, flair_enh)
|
| 621 |
+
run_enhance_on_file(t1c_reg, t1c_enh)
|
| 622 |
+
# Skull stripping
|
| 623 |
+
skullstrip_dir = os.path.join(work_dir, "skullstripped")
|
| 624 |
+
flair_bet, _ = run_skull_stripping(flair_enh, skullstrip_dir)
|
| 625 |
+
t1c_bet, _ = run_skull_stripping(t1c_enh, skullstrip_dir)
|
| 626 |
+
flair_final_path, t1c_final_path = flair_bet, t1c_bet
|
| 627 |
+
|
| 628 |
+
# Prediction
|
| 629 |
+
image_size = cfg.get("infer", {}).get("image_size", [96, 96, 96])
|
| 630 |
+
input_tensor = preprocess_dual_nifti(flair_final_path, t1c_final_path, image_size, device)
|
| 631 |
+
|
| 632 |
+
with torch.no_grad():
|
| 633 |
+
logits = model(input_tensor)
|
| 634 |
+
prob = torch.sigmoid(logits).cpu().numpy().flatten()[0].item()
|
| 635 |
+
predicted_class = int(prob >= threshold)
|
| 636 |
+
|
| 637 |
+
prediction_result = {
|
| 638 |
+
"IDH_mutant_probability": float(prob),
|
| 639 |
+
"threshold": float(threshold),
|
| 640 |
+
"predicted_class": int(predicted_class),
|
| 641 |
+
"preprocessing": bool(do_preprocess),
|
| 642 |
+
"class_label": "IDH-mutant" if predicted_class == 1 else "IDH-wildtype"
|
| 643 |
+
}
|
| 644 |
+
|
| 645 |
+
# Initialize saliency outputs
|
| 646 |
+
t1c_input_img = flair_input_img = flair_attn_img = None
|
| 647 |
+
slider_update = gr.Slider(visible=False)
|
| 648 |
+
saliency_state = {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 649 |
+
|
| 650 |
+
# Generate saliency maps if requested
|
| 651 |
+
if generate_saliency:
|
| 652 |
+
print("--- Generating Saliency Maps ---")
|
| 653 |
+
try:
|
| 654 |
+
flair_input_3d, t1c_input_3d, flair_saliency_3d = generate_saliency_dual(model, input_tensor, layer_idx=-1)
|
| 655 |
+
|
| 656 |
+
if all(data is not None for data in [flair_input_3d, t1c_input_3d, flair_saliency_3d]):
|
| 657 |
+
num_slices = flair_input_3d.shape[2] # Use axis 2 for axial slices
|
| 658 |
+
center_slice_index = num_slices // 2
|
| 659 |
+
|
| 660 |
+
# Save numpy arrays for slider callback
|
| 661 |
+
unique_id = str(uuid.uuid4())
|
| 662 |
+
temp_paths = []
|
| 663 |
+
for name, data in [("flair_input", flair_input_3d), ("t1c_input", t1c_input_3d),
|
| 664 |
+
("flair_saliency", flair_saliency_3d)]:
|
| 665 |
+
path = os.path.join(work_dir, f"{unique_id}_{name}.npy")
|
| 666 |
+
np.save(path, data)
|
| 667 |
+
temp_paths.append(path)
|
| 668 |
+
|
| 669 |
+
# Generate initial plots for center slice
|
| 670 |
+
plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, center_slice_index)
|
| 671 |
+
if plots and all(p is not None for p in plots):
|
| 672 |
+
t1c_input_img, flair_input_img, flair_attn_img = plots
|
| 673 |
+
|
| 674 |
+
# Update state and slider
|
| 675 |
+
saliency_state = {
|
| 676 |
+
"input_paths": temp_paths[:2], # [flair_input, t1c_input]
|
| 677 |
+
"saliency_paths": temp_paths[2:], # [flair_saliency]
|
| 678 |
+
"num_slices": num_slices
|
| 679 |
+
}
|
| 680 |
+
slider_update = gr.Slider(value=center_slice_index, minimum=0, maximum=num_slices-1, step=1, label="Select Slice", visible=True)
|
| 681 |
+
print("--- Saliency Generation Complete ---")
|
| 682 |
+
else:
|
| 683 |
+
print("Warning: Saliency generation failed - some outputs were None")
|
| 684 |
+
|
| 685 |
+
except Exception as e:
|
| 686 |
+
print(f"Error during saliency generation: {e}")
|
| 687 |
+
traceback.print_exc()
|
| 688 |
+
|
| 689 |
+
return (prediction_result, t1c_input_img, flair_input_img, flair_attn_img, slider_update, saliency_state)
|
| 690 |
+
|
| 691 |
+
except Exception as e:
|
| 692 |
+
shutil.rmtree(work_dir, ignore_errors=True)
|
| 693 |
+
return {"error": f"Processing failed: {str(e)}"}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 694 |
+
|
| 695 |
+
except Exception as e:
|
| 696 |
+
return {"error": str(e)}, None, None, None, gr.Slider(visible=False), {"input_paths": None, "saliency_paths": None, "num_slices": 0}
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def update_slice_viewer_dual(slice_index, current_state):
|
| 700 |
+
"""Update slice viewer for dual input saliency visualization."""
|
| 701 |
+
input_paths = current_state.get("input_paths", [])
|
| 702 |
+
saliency_paths = current_state.get("saliency_paths", [])
|
| 703 |
+
|
| 704 |
+
if not input_paths or not saliency_paths or len(input_paths) != 2 or len(saliency_paths) != 1:
|
| 705 |
+
print(f"Warning: Invalid state for slice viewer update: {current_state}")
|
| 706 |
+
return None, None, None
|
| 707 |
+
|
| 708 |
+
try:
|
| 709 |
+
# Load numpy arrays
|
| 710 |
+
flair_input_3d = np.load(input_paths[0])
|
| 711 |
+
t1c_input_3d = np.load(input_paths[1])
|
| 712 |
+
flair_saliency_3d = np.load(saliency_paths[0])
|
| 713 |
+
|
| 714 |
+
# Validate slice index
|
| 715 |
+
slice_index = int(slice_index)
|
| 716 |
+
if not (0 <= slice_index < flair_input_3d.shape[2]): # Use axis 2 for axial slices
|
| 717 |
+
print(f"Warning: Invalid slice index {slice_index}")
|
| 718 |
+
return None, None, None
|
| 719 |
+
|
| 720 |
+
# Generate new plots
|
| 721 |
+
plots = create_slice_plots_dual(flair_input_3d, t1c_input_3d, flair_saliency_3d, slice_index)
|
| 722 |
+
return plots if plots else tuple([None] * 3)
|
| 723 |
+
|
| 724 |
+
except Exception as e:
|
| 725 |
+
print(f"Error updating slice viewer for index {slice_index}: {e}")
|
| 726 |
+
traceback.print_exc()
|
| 727 |
+
return tuple([None] * 3)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
def build_interface():
|
| 731 |
+
cfg = load_config()
|
| 732 |
+
model, device = build_model(cfg)
|
| 733 |
+
default_threshold = float(cfg.get("infer", {}).get("threshold", 0.5))
|
| 734 |
+
|
| 735 |
+
with gr.Blocks(title="BrainIAC: IDH Classification", css="""
|
| 736 |
+
#header-row {
|
| 737 |
+
min-height: 150px;
|
| 738 |
+
align-items: center;
|
| 739 |
+
}
|
| 740 |
+
.logo-img img {
|
| 741 |
+
height: 150px;
|
| 742 |
+
object-fit: contain;
|
| 743 |
+
}
|
| 744 |
+
""") as demo:
|
| 745 |
+
# --- Header with Logos ---
|
| 746 |
+
with gr.Row(elem_id="header-row"):
|
| 747 |
+
with gr.Column(scale=1):
|
| 748 |
+
gr.Image(os.path.join(APP_DIR, "static/images/kannlab.png"),
|
| 749 |
+
show_label=False, interactive=False,
|
| 750 |
+
show_download_button=False,
|
| 751 |
+
container=False,
|
| 752 |
+
elem_classes=["logo-img"])
|
| 753 |
+
with gr.Column(scale=3):
|
| 754 |
+
gr.Markdown(
|
| 755 |
+
"<h1 style='text-align: center; margin-bottom: 2.5rem'>"
|
| 756 |
+
"BrainIAC: IDH Classification"
|
| 757 |
+
"</h1>"
|
| 758 |
+
)
|
| 759 |
+
with gr.Column(scale=1):
|
| 760 |
+
gr.Image(os.path.join(APP_DIR, "static/images/brainiac.jpeg"),
|
| 761 |
+
show_label=False, interactive=False,
|
| 762 |
+
show_download_button=False,
|
| 763 |
+
container=False,
|
| 764 |
+
elem_classes=["logo-img"])
|
| 765 |
+
|
| 766 |
+
# --- Add model description section ---
|
| 767 |
+
with gr.Accordion("ℹ️ Model Details and Usage Guide", open=False):
|
| 768 |
+
gr.Markdown("""
|
| 769 |
+
### 🧠 BrainIAC: IDH Classification
|
| 770 |
+
|
| 771 |
+
**Model Description**
|
| 772 |
+
A Vision Transformer (ViT) model with BrainIAC as pre-trained backbone designed to predict IDH mutation status from dual MRI sequences (FLAIR + T1c).
|
| 773 |
+
|
| 774 |
+
**Training Dataset**
|
| 775 |
+
- **Subjects**: Trained on FLAIR and T1c MRI scans from glioma patients from UCSF-PDGM dataset
|
| 776 |
+
- **Imaging Modalities**: FLAIR and T1c (contrast-enhanced T1-weighted)
|
| 777 |
+
- **Preprocessing**: N4 bias correction, MNI registration, and skull stripping (HD-BET)
|
| 778 |
+
|
| 779 |
+
**Input**
|
| 780 |
+
- Format: NIfTI (.nii or .nii.gz)
|
| 781 |
+
- Required sequences: FLAIR and T1c (both required)
|
| 782 |
+
- Image size: Automatically resized to 96×96×96 voxels
|
| 783 |
+
|
| 784 |
+
**Output**
|
| 785 |
+
- Binary classification: IDH-mutant or IDH-wildtype
|
| 786 |
+
- Probability score for IDH mutation
|
| 787 |
+
- Attention map visualization
|
| 788 |
+
|
| 789 |
+
**Intended Use**
|
| 790 |
+
- Research use only!
|
| 791 |
+
|
| 792 |
+
**NOTE**
|
| 793 |
+
- Requires both FLAIR and T1c sequences
|
| 794 |
+
- Not validated on other MRI sequences
|
| 795 |
+
- Not validated for other brain pathologies beyond gliomas
|
| 796 |
+
- Upload PHI data at own risk!
|
| 797 |
+
- The model is hosted on a cloud-based CPU instance
|
| 798 |
+
- The data is not stored, shared or collected for any purpose!
|
| 799 |
+
|
| 800 |
+
**Preprocessing Pipeline**
|
| 801 |
+
When enabled, the preprocessing performs:
|
| 802 |
+
1. **Registration**: SimpleITK-based registration to template space with mutual information metric and 1mm isotropic resampling
|
| 803 |
+
2. **N4 Bias Correction**: Applied during registration step
|
| 804 |
+
3. **Skull Stripping**: Remove non-brain tissue using HD-BET direct integration
|
| 805 |
+
|
| 806 |
+
**Attention Maps**
|
| 807 |
+
When enabled, generates ViT attention maps showing which brain regions the model focuses on for prediction.
|
| 808 |
+
|
| 809 |
+
""")
|
| 810 |
+
|
| 811 |
+
# Use gr.State to store paths to numpy arrays for the slider callback
|
| 812 |
+
saliency_state = gr.State({"input_paths": None, "saliency_paths": None, "num_slices": 0})
|
| 813 |
+
|
| 814 |
+
# Main Content
|
| 815 |
+
gr.Markdown("**Upload FLAIR and T1c NIfTI volumes** — Optional preprocessing performs registration to MNI, enhancement, and skull stripping.")
|
| 816 |
+
|
| 817 |
+
with gr.Row():
|
| 818 |
+
with gr.Column(scale=1):
|
| 819 |
+
with gr.Group():
|
| 820 |
+
gr.Markdown("### Controls")
|
| 821 |
+
flair_input = gr.File(label="FLAIR (.nii or .nii.gz)")
|
| 822 |
+
t1c_input = gr.File(label="T1c (.nii or .nii.gz)")
|
| 823 |
+
preprocess_checkbox = gr.Checkbox(value=False, label="Preprocess NIfTI (registration + enhancement + skull stripping)")
|
| 824 |
+
generate_saliency_checkbox = gr.Checkbox(value=True, label="Generate Attention Maps")
|
| 825 |
+
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=default_threshold, step=0.01, label="Decision Threshold")
|
| 826 |
+
predict_btn = gr.Button("Predict IDH Status", variant="primary")
|
| 827 |
+
|
| 828 |
+
with gr.Column(scale=2):
|
| 829 |
+
with gr.Group():
|
| 830 |
+
gr.Markdown("### Classification Result")
|
| 831 |
+
output_json = gr.JSON(label="Prediction")
|
| 832 |
+
|
| 833 |
+
# Saliency visualization section
|
| 834 |
+
with gr.Group():
|
| 835 |
+
gr.Markdown("### Attention Map Viewer (Axial Slice)")
|
| 836 |
+
slice_slider = gr.Slider(label="Select Slice", minimum=0, maximum=0, step=1, value=0, visible=False)
|
| 837 |
+
|
| 838 |
+
with gr.Row():
|
| 839 |
+
with gr.Column():
|
| 840 |
+
gr.Markdown("<p style='text-align: center;'>T1c Input</p>")
|
| 841 |
+
t1c_input_img = gr.Image(label="T1c Input", type="numpy", show_label=False)
|
| 842 |
+
with gr.Column():
|
| 843 |
+
gr.Markdown("<p style='text-align: center;'>FLAIR Input</p>")
|
| 844 |
+
flair_input_img = gr.Image(label="FLAIR Input", type="numpy", show_label=False)
|
| 845 |
+
with gr.Column():
|
| 846 |
+
gr.Markdown("<p style='text-align: center;'>FLAIR Attention</p>")
|
| 847 |
+
flair_attn_img = gr.Image(label="Attention Mask", type="numpy", show_label=False)
|
| 848 |
+
|
| 849 |
+
# Wire components
|
| 850 |
+
predict_btn.click(
|
| 851 |
+
fn=lambda f, t, prep, gen_sal, thr: predict_idh(f, t, thr, prep, gen_sal, cfg, model, device),
|
| 852 |
+
inputs=[flair_input, t1c_input, preprocess_checkbox, generate_saliency_checkbox, threshold_input],
|
| 853 |
+
outputs=[output_json, t1c_input_img, flair_input_img, flair_attn_img, slice_slider, saliency_state],
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
slice_slider.change(
|
| 857 |
+
fn=update_slice_viewer_dual,
|
| 858 |
+
inputs=[slice_slider, saliency_state],
|
| 859 |
+
outputs=[t1c_input_img, flair_input_img, flair_attn_img]
|
| 860 |
+
)
|
| 861 |
+
|
| 862 |
+
return demo
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
if __name__ == "__main__":
|
| 866 |
+
iface = build_interface()
|
| 867 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|
src/IDH/checkpoints/idh_model.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4fbbddda1f0a1c38dd3c6241725b3ed4806705085e1b812b19464867518b5bf
|
| 3 |
+
size 353461323
|
src/IDH/config.yml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu:
|
| 2 |
+
device: cpu
|
| 3 |
+
infer:
|
| 4 |
+
checkpoints: ./checkpoints/idh_model.ckpt
|
| 5 |
+
simclr_checkpoint: ./checkpoints/simclr_vitb.ckpt
|
| 6 |
+
threshold: 0.5
|
| 7 |
+
image_size: [96, 96, 96]
|
src/IDH/golden_image/mni_templates/Parameters_Rigid.txt
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Example parameter file for rotation registration
|
| 2 |
+
// C-style comments: //
|
| 3 |
+
|
| 4 |
+
// The internal pixel type, used for internal computations
|
| 5 |
+
// Leave to float in general.
|
| 6 |
+
// NB: this is not the type of the input images! The pixel
|
| 7 |
+
// type of the input images is automatically read from the
|
| 8 |
+
// images themselves.
|
| 9 |
+
// This setting can be changed to "short" to save some memory
|
| 10 |
+
// in case of very large 3D images.
|
| 11 |
+
(FixedInternalImagePixelType "float")
|
| 12 |
+
(MovingInternalImagePixelType "float")
|
| 13 |
+
|
| 14 |
+
// **************** Main Components **************************
|
| 15 |
+
|
| 16 |
+
// The following components should usually be left as they are:
|
| 17 |
+
(Registration "MultiResolutionRegistration")
|
| 18 |
+
(Interpolator "BSplineInterpolator")
|
| 19 |
+
(ResampleInterpolator "FinalBSplineInterpolator")
|
| 20 |
+
(Resampler "DefaultResampler")
|
| 21 |
+
|
| 22 |
+
// These may be changed to Fixed/MovingSmoothingImagePyramid.
|
| 23 |
+
// See the manual.
|
| 24 |
+
(FixedImagePyramid "FixedRecursiveImagePyramid")
|
| 25 |
+
(MovingImagePyramid "MovingRecursiveImagePyramid")
|
| 26 |
+
|
| 27 |
+
// The following components are most important:
|
| 28 |
+
// The optimizer AdaptiveStochasticGradientDescent (ASGD) works
|
| 29 |
+
// quite ok in general. The Transform and Metric are important
|
| 30 |
+
// and need to be chosen careful for each application. See manual.
|
| 31 |
+
(Optimizer "AdaptiveStochasticGradientDescent")
|
| 32 |
+
(Transform "EulerTransform")
|
| 33 |
+
(Metric "AdvancedMattesMutualInformation")
|
| 34 |
+
|
| 35 |
+
// ***************** Transformation **************************
|
| 36 |
+
|
| 37 |
+
// Scales the rotations compared to the translations, to make
|
| 38 |
+
// sure they are in the same range. In general, it's best to
|
| 39 |
+
// use automatic scales estimation:
|
| 40 |
+
(AutomaticScalesEstimation "true")
|
| 41 |
+
|
| 42 |
+
// Automatically guess an initial translation by aligning the
|
| 43 |
+
// geometric centers of the fixed and moving.
|
| 44 |
+
(AutomaticTransformInitialization "true")
|
| 45 |
+
|
| 46 |
+
// Whether transforms are combined by composition or by addition.
|
| 47 |
+
// In generally, Compose is the best option in most cases.
|
| 48 |
+
// It does not influence the results very much.
|
| 49 |
+
(HowToCombineTransforms "Compose")
|
| 50 |
+
|
| 51 |
+
// ******************* Similarity measure *********************
|
| 52 |
+
|
| 53 |
+
// Number of grey level bins in each resolution level,
|
| 54 |
+
// for the mutual information. 16 or 32 usually works fine.
|
| 55 |
+
// You could also employ a hierarchical strategy:
|
| 56 |
+
//(NumberOfHistogramBins 16 32 64)
|
| 57 |
+
(NumberOfHistogramBins 32)
|
| 58 |
+
|
| 59 |
+
// If you use a mask, this option is important.
|
| 60 |
+
// If the mask serves as region of interest, set it to false.
|
| 61 |
+
// If the mask indicates which pixels are valid, then set it to true.
|
| 62 |
+
// If you do not use a mask, the option doesn't matter.
|
| 63 |
+
(ErodeMask "false")
|
| 64 |
+
|
| 65 |
+
// ******************** Multiresolution **********************
|
| 66 |
+
|
| 67 |
+
// The number of resolutions. 1 Is only enough if the expected
|
| 68 |
+
// deformations are small. 3 or 4 mostly works fine. For large
|
| 69 |
+
// images and large deformations, 5 or 6 may even be useful.
|
| 70 |
+
(NumberOfResolutions 4)
|
| 71 |
+
|
| 72 |
+
// The downsampling/blurring factors for the image pyramids.
|
| 73 |
+
// By default, the images are downsampled by a factor of 2
|
| 74 |
+
// compared to the next resolution.
|
| 75 |
+
// So, in 2D, with 4 resolutions, the following schedule is used:
|
| 76 |
+
//(ImagePyramidSchedule 8 8 4 4 2 2 1 1 )
|
| 77 |
+
// And in 3D:
|
| 78 |
+
//(ImagePyramidSchedule 8 8 8 4 4 4 2 2 2 1 1 1 )
|
| 79 |
+
// You can specify any schedule, for example:
|
| 80 |
+
//(ImagePyramidSchedule 4 4 4 3 2 1 1 1 )
|
| 81 |
+
// Make sure that the number of elements equals the number
|
| 82 |
+
// of resolutions times the image dimension.
|
| 83 |
+
|
| 84 |
+
// ******************* Optimizer ****************************
|
| 85 |
+
|
| 86 |
+
// Maximum number of iterations in each resolution level:
|
| 87 |
+
// 200-500 works usually fine for rigid registration.
|
| 88 |
+
// For more robustness, you may increase this to 1000-2000.
|
| 89 |
+
(MaximumNumberOfIterations 250)
|
| 90 |
+
|
| 91 |
+
// The step size of the optimizer, in mm. By default the voxel size is used.
|
| 92 |
+
// which usually works well. In case of unusual high-resolution images
|
| 93 |
+
// (eg histology) it is necessary to increase this value a bit, to the size
|
| 94 |
+
// of the "smallest visible structure" in the image:
|
| 95 |
+
//(MaximumStepLength 1.0)
|
| 96 |
+
|
| 97 |
+
// **************** Image sampling **********************
|
| 98 |
+
|
| 99 |
+
// Number of spatial samples used to compute the mutual
|
| 100 |
+
// information (and its derivative) in each iteration.
|
| 101 |
+
// With an AdaptiveStochasticGradientDescent optimizer,
|
| 102 |
+
// in combination with the two options below, around 2000
|
| 103 |
+
// samples may already suffice.
|
| 104 |
+
(NumberOfSpatialSamples 2048)
|
| 105 |
+
|
| 106 |
+
// Refresh these spatial samples in every iteration, and select
|
| 107 |
+
// them randomly. See the manual for information on other sampling
|
| 108 |
+
// strategies.
|
| 109 |
+
(NewSamplesEveryIteration "true")
|
| 110 |
+
(ImageSampler "Random")
|
| 111 |
+
|
| 112 |
+
// ************* Interpolation and Resampling ****************
|
| 113 |
+
|
| 114 |
+
// Order of B-Spline interpolation used during registration/optimisation.
|
| 115 |
+
// It may improve accuracy if you set this to 3. Never use 0.
|
| 116 |
+
// An order of 1 gives linear interpolation. This is in most
|
| 117 |
+
// applications a good choice.
|
| 118 |
+
(BSplineInterpolationOrder 1)
|
| 119 |
+
|
| 120 |
+
// Order of B-Spline interpolation used for applying the final
|
| 121 |
+
// deformation.
|
| 122 |
+
// 3 gives good accuracy; recommended in most cases.
|
| 123 |
+
// 1 gives worse accuracy (linear interpolation)
|
| 124 |
+
// 0 gives worst accuracy, but is appropriate for binary images
|
| 125 |
+
// (masks, segmentations); equivalent to nearest neighbor interpolation.
|
| 126 |
+
(FinalBSplineInterpolationOrder 3)
|
| 127 |
+
|
| 128 |
+
//Default pixel value for pixels that come from outside the picture:
|
| 129 |
+
(DefaultPixelValue 0)
|
| 130 |
+
|
| 131 |
+
// Choose whether to generate the deformed moving image.
|
| 132 |
+
// You can save some time by setting this to false, if you are
|
| 133 |
+
// only interested in the final (nonrigidly) deformed moving image
|
| 134 |
+
// for example.
|
| 135 |
+
(WriteResultImage "true")
|
| 136 |
+
|
| 137 |
+
// The pixel type and format of the resulting deformed moving image
|
| 138 |
+
(ResultImagePixelType "short")
|
| 139 |
+
(ResultImageFormat "mhd")
|
| 140 |
+
|
| 141 |
+
|
src/IDH/golden_image/mni_templates/nihpd_asym_04.5-18.5_t2w.nii
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0bd55fefd8a74deadca777c362bbdb7db9ae15ee135dd9fe045c8838af58d3ee
|
| 3 |
+
size 17350930
|
src/IDH/golden_image/mni_templates/nihpd_asym_13.0-18.5_t1w.nii
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f10804664000688f0ddc124b39ce3ae27f2c339a40583bd6ff916727e97b77d0
|
| 3 |
+
size 17350930
|
src/IDH/hdbet_model/0.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c
|
| 3 |
+
size 65443735
|
src/IDH/hdbet_model/hdbet_model/0.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f75233753c4750672815e2b7a86db754995ae44b8f1cd77bccfc37becd2d83c
|
| 3 |
+
size 65443735
|
src/IDH/model.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from monai.networks.nets import ViT
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ViTBackboneNet(nn.Module):
|
| 8 |
+
def __init__(self, simclr_ckpt_path: str):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.backbone = ViT(
|
| 11 |
+
in_channels=1,
|
| 12 |
+
img_size=(96, 96, 96),
|
| 13 |
+
patch_size=(16, 16, 16),
|
| 14 |
+
hidden_size=768,
|
| 15 |
+
mlp_dim=3072,
|
| 16 |
+
num_layers=12,
|
| 17 |
+
num_heads=12,
|
| 18 |
+
save_attn=True,
|
| 19 |
+
)
|
| 20 |
+
# Load pretrained weights from SimCLR checkpoint if provided
|
| 21 |
+
if simclr_ckpt_path and os.path.exists(simclr_ckpt_path):
|
| 22 |
+
ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False)
|
| 23 |
+
state_dict = ckpt.get("state_dict", ckpt)
|
| 24 |
+
backbone_state_dict = {}
|
| 25 |
+
for key, value in state_dict.items():
|
| 26 |
+
if key.startswith("backbone."):
|
| 27 |
+
new_key = key[len("backbone."):]
|
| 28 |
+
backbone_state_dict[new_key] = value
|
| 29 |
+
missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False)
|
| 30 |
+
print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
|
| 31 |
+
else:
|
| 32 |
+
print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.")
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 35 |
+
features = self.backbone(x)
|
| 36 |
+
cls_token = features[0][:, 0]
|
| 37 |
+
return cls_token
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class Classifier(nn.Module):
|
| 41 |
+
def __init__(self, d_model: int = 768, num_classes: int = 1):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.fc = nn.Linear(d_model, num_classes)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
return self.fc(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SingleScanModelBP(nn.Module):
|
| 50 |
+
def __init__(self, backbone: nn.Module, classifier: nn.Module):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.backbone = backbone
|
| 53 |
+
self.classifier = classifier
|
| 54 |
+
self.dropout = nn.Dropout(p=0.2)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 57 |
+
# x shape: (batch_size, 2, C, D, H, W)
|
| 58 |
+
scan_features_list = []
|
| 59 |
+
for scan_tensor_with_extra_dim in x.split(1, dim=1):
|
| 60 |
+
squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1)
|
| 61 |
+
feature = self.backbone(squeezed_scan_tensor)
|
| 62 |
+
scan_features_list.append(feature)
|
| 63 |
+
stacked_features = torch.stack(scan_features_list, dim=1)
|
| 64 |
+
merged_features = torch.mean(stacked_features, dim=1)
|
| 65 |
+
merged_features = self.dropout(merged_features)
|
| 66 |
+
output = self.classifier(merged_features)
|
| 67 |
+
return output
|
src/IDH/static/images/brainage.jpeg
ADDED
|
Git LFS Details
|
src/IDH/static/images/brainiac.jpeg
ADDED
|
Git LFS Details
|
src/IDH/static/images/kannlab.png
ADDED
|
Git LFS Details
|