import torch import torch.nn as nn from monai.networks.nets import ViT import os class ViTBackboneNet(nn.Module): def __init__(self, simclr_ckpt_path: str): super().__init__() self.backbone = ViT( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), hidden_size=768, mlp_dim=3072, num_layers=12, num_heads=12, save_attn=True, ) # Load pretrained weights from SimCLR checkpoint if provided if simclr_ckpt_path and os.path.exists(simclr_ckpt_path): ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False) state_dict = ckpt.get("state_dict", ckpt) backbone_state_dict = {} for key, value in state_dict.items(): if key.startswith("backbone."): new_key = key[len("backbone."):] backbone_state_dict[new_key] = value missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False) print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}") else: print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.") def forward(self, x: torch.Tensor) -> torch.Tensor: features = self.backbone(x) cls_token = features[0][:, 0] return cls_token class Classifier(nn.Module): def __init__(self, d_model: int = 768, num_classes: int = 1): super().__init__() self.fc = nn.Linear(d_model, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.fc(x) class SingleScanModelBP(nn.Module): def __init__(self, backbone: nn.Module, classifier: nn.Module): super().__init__() self.backbone = backbone self.classifier = classifier self.dropout = nn.Dropout(p=0.2) def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: (batch_size, 2, C, D, H, W) scan_features_list = [] for scan_tensor_with_extra_dim in x.split(1, dim=1): squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1) feature = self.backbone(squeezed_scan_tensor) scan_features_list.append(feature) stacked_features = torch.stack(scan_features_list, dim=1) merged_features = torch.mean(stacked_features, dim=1) merged_features = self.dropout(merged_features) output = self.classifier(merged_features) return output