File size: 1,345 Bytes
5a169ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from model import ResNet50_3D
import argparse

def load_brainiac(checkpoint_path, device='cuda'):
    """
    Load the ResNet50 model and BrainIAC checkpoint.
    
    Args:
        checkpoint_path (str): Path to the model checkpoint
        device (str): Device to load the model on ('cuda' or 'cpu')
    
    Returns:
        model: Loaded model with checkpoint weights
    """
    # spinup the model 
    model = ResNet50_3D()
    
    # Load brainiac weights 
    checkpoint = torch.load(checkpoint_path, map_location=device)
    state_dict = checkpoint["state_dict"]
    filtered_state_dict = {key: value for key, value in state_dict.items() if 'backbone' in key}
    model.load_state_dict(filtered_state_dict)
    print("BrainIAC Loaded!!")
    
    return model

if __name__ == "__main__":
    # Parse args
    parser = argparse.ArgumentParser(description='Load backbone model from checkpoint')
    parser.add_argument('--checkpoint', type=str, required=True,
                      help='Path to the model checkpoint')
    parser.add_argument('--device', type=str, default='cuda',
                      help='Device to load the model on (cuda or cpu)')
    args = parser.parse_args()
    
    # Load model
    model = load_brainiac(args.checkpoint, args.device)
    print(f"Model loaded successfully from {args.checkpoint}!")