|
|
import torch |
|
|
from model import ResNet50_3D |
|
|
import argparse |
|
|
|
|
|
def load_brainiac(checkpoint_path, device='cuda'): |
|
|
""" |
|
|
Load the ResNet50 model and BrainIAC checkpoint. |
|
|
|
|
|
Args: |
|
|
checkpoint_path (str): Path to the model checkpoint |
|
|
device (str): Device to load the model on ('cuda' or 'cpu') |
|
|
|
|
|
Returns: |
|
|
model: Loaded model with checkpoint weights |
|
|
""" |
|
|
|
|
|
model = ResNet50_3D() |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=device) |
|
|
state_dict = checkpoint["state_dict"] |
|
|
filtered_state_dict = {key: value for key, value in state_dict.items() if 'backbone' in key} |
|
|
model.load_state_dict(filtered_state_dict) |
|
|
print("BrainIAC Loaded!!") |
|
|
|
|
|
return model |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Load backbone model from checkpoint') |
|
|
parser.add_argument('--checkpoint', type=str, required=True, |
|
|
help='Path to the model checkpoint') |
|
|
parser.add_argument('--device', type=str, default='cuda', |
|
|
help='Device to load the model on (cuda or cpu)') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model = load_brainiac(args.checkpoint, args.device) |
|
|
print(f"Model loaded successfully from {args.checkpoint}!") |