BrainIAC-Brainage-V0 / src /BrainIAC /preprocessing /mri_preprocess_3d_simple.py
Divyanshu Tak
Initial commit of BrainIAC Docker application
f5288df
raw
history blame
8.02 kB
import sys
import os
import glob
import SimpleITK as sitk
from tqdm import tqdm
import random
from HD_BET.hd_bet import hd_bet
import argparse
import torch
def brain_extraction(input_dir, output_dir, device):
"""
Brain extraction using HDBET package (UNet based DL method)
Args:
input_dir {path} -- input directory for registered images
output_dir {path} -- output directory for brain extracted images
Returns:
Brain images
"""
print("Running brain extraction...")
print(f"Input directory: {input_dir}")
print(f"Output directory: {output_dir}")
# Run HD-BET directly with the output directory
hd_bet(input_dir, output_dir, device=device, mode='fast', tta=0)
print('Brain extraction complete!')
print("\nContents of output directory after brain extraction:")
print(os.listdir(output_dir))
def registration(input_dir, output_dir, temp_img, interp_type='linear'):
"""
MRI registration with SimpleITK
Args:
input_dir {path} -- Directory containing input images
output_dir {path} -- Directory to save registered images
temp_img {str} -- Registration image template
Returns:
The sitk image object -- nii.gz
"""
# Read the template image
fixed_img = sitk.ReadImage(temp_img, sitk.sitkFloat32)
# Track problematic files
IDs = []
print("Preloading step...")
for img_dir in tqdm(sorted(glob.glob(input_dir + '/*.nii.gz'))):
ID = img_dir.split('/')[-1].split('.')[0]
try:
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32)
except Exception as e:
IDs.append(ID)
print(f"Error loading {ID}: {e}")
count = 0
print("Registering images...")
list_of_files = sorted(glob.glob(input_dir + '/*.nii.gz'))
for img_dir in tqdm(list_of_files):
ID = img_dir.split('/')[-1].split('.')[0]
if ID in IDs:
print(f'Skipping problematic file: {ID}')
continue
if "_mask" in ID:
continue
print(f"Processing image {count + 1}: {ID}")
try:
# Read and preprocess moving image
moving_img = sitk.ReadImage(img_dir, sitk.sitkFloat32)
moving_img = sitk.N4BiasFieldCorrection(moving_img)
# Resample fixed image to 1mm isotropic
old_size = fixed_img.GetSize()
old_spacing = fixed_img.GetSpacing()
new_spacing = (1, 1, 1)
new_size = [
int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))),
int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))),
int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2])))
]
# Set interpolation type
if interp_type == 'linear':
interp_type = sitk.sitkLinear
elif interp_type == 'bspline':
interp_type = sitk.sitkBSpline
elif interp_type == 'nearest_neighbor':
interp_type = sitk.sitkNearestNeighbor
# Resample fixed image
resample = sitk.ResampleImageFilter()
resample.SetOutputSpacing(new_spacing)
resample.SetSize(new_size)
resample.SetOutputOrigin(fixed_img.GetOrigin())
resample.SetOutputDirection(fixed_img.GetDirection())
resample.SetInterpolator(interp_type)
resample.SetDefaultPixelValue(fixed_img.GetPixelIDValue())
resample.SetOutputPixelType(sitk.sitkFloat32)
fixed_img = resample.Execute(fixed_img)
# Initialize transform
transform = sitk.CenteredTransformInitializer(
fixed_img,
moving_img,
sitk.Euler3DTransform(),
sitk.CenteredTransformInitializerFilter.GEOMETRY)
# Set up registration method
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkLinear)
registration_method.SetOptimizerAsGradientDescent(
learningRate=1.0,
numberOfIterations=100,
convergenceMinimumValue=1e-6,
convergenceWindowSize=10)
registration_method.SetOptimizerScalesFromPhysicalShift()
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
registration_method.SetInitialTransform(transform)
# Execute registration
final_transform = registration_method.Execute(fixed_img, moving_img)
# Apply transform and save registered image
moving_img_resampled = sitk.Resample(
moving_img,
fixed_img,
final_transform,
sitk.sitkLinear,
0.0,
moving_img.GetPixelID())
# Save with _0000 suffix as required by HD-BET
output_filename = os.path.join(output_dir, f"{ID}_0000.nii.gz")
sitk.WriteImage(moving_img_resampled, output_filename)
print(f"Saved registered image to: {output_filename}")
count += 1
except Exception as e:
print(f"Error processing {ID}: {e}")
continue
print(f"Successfully registered {count} images.")
# Debug information
print(f"Contents of output directory {output_dir}:")
print(os.listdir(output_dir))
return count > 0
def main(temp_img, input_dir, output_dir):
"""
Main function to process brain MRI images
Args:
temp_img {str} -- Path to template image
input_dir {str} -- Path to input directory containing images
output_dir {str} -- Path to output directory for results
"""
os.makedirs(output_dir, exist_ok=True)
# set device
device = "0" if torch.cuda.is_available() else "cpu"
# Create temporary directory for intermediate results
temp_reg_dir = os.path.join(output_dir, 'temp_registered')
os.makedirs(temp_reg_dir, exist_ok=True)
print("Starting brain MRI preprocessing...")
# REgistration
print("\nStep 1: Image Registration")
success = registration(
input_dir=input_dir,
output_dir=temp_reg_dir,
temp_img=temp_img
)
if not success:
print("Registration failed! No images were processed successfully.")
return
print("\nChecking temporary directory contents:")
print(os.listdir(temp_reg_dir))
# skullstripping
print("\nStep 2: Brain Extraction")
brain_extraction(
input_dir=temp_reg_dir,
output_dir=output_dir,
device=device
)
# Clean up temporary directory
import shutil
shutil.rmtree(temp_reg_dir)
print("\nPreprocessing complete! Final results saved in:", output_dir)
print("Final preprocessed files:")
print(os.listdir(output_dir))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Process brain MRI registration and skull stripping.")
parser.add_argument("--temp_img", type=str, required=True, help="Path to the atlas template image.")
parser.add_argument("--input_dir", type=str, required=True, help="Path to the input images directory.")
parser.add_argument("--output_dir", type=str, required=True, help="Path to save the processed images.")
args = parser.parse_args()
main(temp_img=args.temp_img, input_dir=args.input_dir, output_dir=args.output_dir)