Spaces:
Sleeping
Sleeping
| import torch | |
| from model import ResNet50_3D | |
| import argparse | |
| def load_brainiac(checkpoint_path, device='cuda'): | |
| """ | |
| Load the ResNet50 model and BrainIAC checkpoint. | |
| Args: | |
| checkpoint_path (str): Path to the model checkpoint | |
| device (str): Device to load the model on ('cuda' or 'cpu') | |
| Returns: | |
| model: Loaded model with checkpoint weights | |
| """ | |
| # spinup the model | |
| model = ResNet50_3D() | |
| # Load brainiac weights | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| state_dict = checkpoint["state_dict"] | |
| filtered_state_dict = {key: value for key, value in state_dict.items() if 'backbone' in key} | |
| model.load_state_dict(filtered_state_dict) | |
| print("BrainIAC Loaded!!") | |
| return model | |
| if __name__ == "__main__": | |
| # Parse args | |
| parser = argparse.ArgumentParser(description='Load backbone model from checkpoint') | |
| parser.add_argument('--checkpoint', type=str, required=True, | |
| help='Path to the model checkpoint') | |
| parser.add_argument('--device', type=str, default='cuda', | |
| help='Device to load the model on (cuda or cpu)') | |
| args = parser.parse_args() | |
| # Load model | |
| model = load_brainiac(args.checkpoint, args.device) | |
| print(f"Model loaded successfully from {args.checkpoint}!") |