BrainIAC-Brainage-V0 / src /BrainIAC /load_brainiac.py
Divyanshu Tak
Initial commit of BrainIAC Docker application
f5288df
import torch
from model import ResNet50_3D
import argparse
def load_brainiac(checkpoint_path, device='cuda'):
"""
Load the ResNet50 model and BrainIAC checkpoint.
Args:
checkpoint_path (str): Path to the model checkpoint
device (str): Device to load the model on ('cuda' or 'cpu')
Returns:
model: Loaded model with checkpoint weights
"""
# spinup the model
model = ResNet50_3D()
# Load brainiac weights
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint["state_dict"]
filtered_state_dict = {key: value for key, value in state_dict.items() if 'backbone' in key}
model.load_state_dict(filtered_state_dict)
print("BrainIAC Loaded!!")
return model
if __name__ == "__main__":
# Parse args
parser = argparse.ArgumentParser(description='Load backbone model from checkpoint')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to the model checkpoint')
parser.add_argument('--device', type=str, default='cuda',
help='Device to load the model on (cuda or cpu)')
args = parser.parse_args()
# Load model
model = load_brainiac(args.checkpoint, args.device)
print(f"Model loaded successfully from {args.checkpoint}!")