File size: 3,733 Bytes
5a169ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import numpy as np
import pandas as pd
import random
import yaml
import os
import argparse
from tqdm import tqdm
from torch.utils.data import DataLoader
from dataset2 import MedicalImageDatasetBalancedIntensity3D
from load_brainiac import load_brainiac

# fix random seed 
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)


# Set GPU
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Define custom collate function for data loading
def custom_collate(batch):
    images = [item['image'] for item in batch]
    labels = [item['label'] for item in batch]

    max_len = 1
    padded_images = []

    for img in images:
        pad_size = max_len - img.shape[0]
        if pad_size > 0:
            padding = torch.zeros((pad_size,) + img.shape[1:])
            img_padded = torch.cat([img, padding], dim=0)
            padded_images.append(img_padded)
        else:
            padded_images.append(img)

    return {"image": torch.stack(padded_images, dim=0), "label": torch.stack(labels)}

#=========================
# Inference function
#=========================

def infer(model, test_loader):
    features_df = None  # Placeholder for feature DataFrame
    model.eval()
    
    with torch.no_grad():
        for sample in tqdm(test_loader, desc="Inference", unit="batch"):
            inputs = sample['image'].to(device)
            class_labels = sample['label'].float().to(device)

            # Get features from the model
            features = model(inputs)
            features_numpy = features.cpu().numpy()

            # Expand features into separate columns
            feature_columns = [f'Feature_{i}' for i in range(features_numpy.shape[1])]
            batch_features = pd.DataFrame(
                features_numpy,
                columns=feature_columns
            )
            batch_features['GroundTruthClassLabel'] = class_labels.cpu().numpy().flatten()

            # Append batch features to features_df
            if features_df is None:
                features_df = batch_features
            else:
                features_df = pd.concat([features_df, batch_features], ignore_index=True)
    
    return features_df

#=========================
# Main inference pipeline
#=========================

def main():
    # argparse
    parser = argparse.ArgumentParser(description='Extract BrainIAC features from images')
    parser.add_argument('--checkpoint', type=str, required=True,
                      help='Path to the BrainIAC model checkpoint')
    parser.add_argument('--input_csv', type=str, required=True,
                      help='Path to the input CSV file containing image paths')
    parser.add_argument('--output_csv', type=str, required=True,
                      help='Path to save the output features CSV')
    parser.add_argument('--root_dir', type=str, required=True,
                      help='Root directory containing the image data')
    args = parser.parse_args()

    # spinup the dataloader
    test_dataset = MedicalImageDatasetBalancedIntensity3D(
        csv_path=args.input_csv,
        root_dir=args.root_dir
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=custom_collate,
        num_workers=1
    )

    # Load brainiac
    model = load_brainiac(args.checkpoint, device)
    model = model.to(device)
    # infer
    features_df = infer(model, test_loader)
    
    # Save features
    features_df.to_csv(args.output_csv, index=False)
    print(f"Features saved to {args.output_csv}")

if __name__ == "__main__":
    main()