|
|
import torch |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import random |
|
|
import yaml |
|
|
import os |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
from torch.utils.data import DataLoader |
|
|
from dataset2 import MedicalImageDatasetBalancedIntensity3D |
|
|
from load_brainiac import load_brainiac |
|
|
|
|
|
|
|
|
seed = 42 |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = "0" |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
|
|
|
def custom_collate(batch): |
|
|
images = [item['image'] for item in batch] |
|
|
labels = [item['label'] for item in batch] |
|
|
|
|
|
max_len = 1 |
|
|
padded_images = [] |
|
|
|
|
|
for img in images: |
|
|
pad_size = max_len - img.shape[0] |
|
|
if pad_size > 0: |
|
|
padding = torch.zeros((pad_size,) + img.shape[1:]) |
|
|
img_padded = torch.cat([img, padding], dim=0) |
|
|
padded_images.append(img_padded) |
|
|
else: |
|
|
padded_images.append(img) |
|
|
|
|
|
return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(model, test_loader): |
|
|
features_df = None |
|
|
model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
for sample in tqdm(test_loader, desc="Inference", unit="batch"): |
|
|
inputs = sample['image'].to(device) |
|
|
class_labels = sample['label'].float().to(device) |
|
|
|
|
|
|
|
|
features = model(inputs) |
|
|
features_numpy = features.cpu().numpy() |
|
|
|
|
|
|
|
|
feature_columns = [f'Feature_{i}' for i in range(features_numpy.shape[1])] |
|
|
batch_features = pd.DataFrame( |
|
|
features_numpy, |
|
|
columns=feature_columns |
|
|
) |
|
|
batch_features['GroundTruthClassLabel'] = class_labels.cpu().numpy().flatten() |
|
|
|
|
|
|
|
|
if features_df is None: |
|
|
features_df = batch_features |
|
|
else: |
|
|
features_df = pd.concat([features_df, batch_features], ignore_index=True) |
|
|
|
|
|
return features_df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Extract BrainIAC features from images') |
|
|
parser.add_argument('--checkpoint', type=str, required=True, |
|
|
help='Path to the BrainIAC model checkpoint') |
|
|
parser.add_argument('--input_csv', type=str, required=True, |
|
|
help='Path to the input CSV file containing image paths') |
|
|
parser.add_argument('--output_csv', type=str, required=True, |
|
|
help='Path to save the output features CSV') |
|
|
parser.add_argument('--root_dir', type=str, required=True, |
|
|
help='Root directory containing the image data') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
test_dataset = MedicalImageDatasetBalancedIntensity3D( |
|
|
csv_path=args.input_csv, |
|
|
root_dir=args.root_dir |
|
|
) |
|
|
test_loader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=1, |
|
|
shuffle=False, |
|
|
collate_fn=custom_collate, |
|
|
num_workers=1 |
|
|
) |
|
|
|
|
|
|
|
|
model = load_brainiac(args.checkpoint, device) |
|
|
model = model.to(device) |
|
|
|
|
|
features_df = infer(model, test_loader) |
|
|
|
|
|
|
|
|
features_df.to_csv(args.output_csv, index=False) |
|
|
print(f"Features saved to {args.output_csv}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |