Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from monai.networks.nets import ViT | |
| import os | |
| class ViTBackboneNet(nn.Module): | |
| def __init__(self, simclr_ckpt_path: str): | |
| super().__init__() | |
| self.backbone = ViT( | |
| in_channels=1, | |
| img_size=(96, 96, 96), | |
| patch_size=(16, 16, 16), | |
| hidden_size=768, | |
| mlp_dim=3072, | |
| num_layers=12, | |
| num_heads=12, | |
| save_attn=True, | |
| ) | |
| # Load pretrained weights from SimCLR checkpoint if provided | |
| if simclr_ckpt_path and os.path.exists(simclr_ckpt_path): | |
| ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False) | |
| state_dict = ckpt.get("state_dict", ckpt) | |
| backbone_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if key.startswith("backbone."): | |
| new_key = key[len("backbone."):] | |
| backbone_state_dict[new_key] = value | |
| missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False) | |
| print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}") | |
| else: | |
| print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.") | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| features = self.backbone(x) | |
| cls_token = features[0][:, 0] | |
| return cls_token | |
| class Classifier(nn.Module): | |
| def __init__(self, d_model: int = 768, num_classes: int = 1): | |
| super().__init__() | |
| self.fc = nn.Linear(d_model, num_classes) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.fc(x) | |
| class SingleScanModelBP(nn.Module): | |
| def __init__(self, backbone: nn.Module, classifier: nn.Module): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.classifier = classifier | |
| self.dropout = nn.Dropout(p=0.2) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x shape: (batch_size, 2, C, D, H, W) | |
| scan_features_list = [] | |
| for scan_tensor_with_extra_dim in x.split(1, dim=1): | |
| squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1) | |
| feature = self.backbone(squeezed_scan_tensor) | |
| scan_features_list.append(feature) | |
| stacked_features = torch.stack(scan_features_list, dim=1) | |
| merged_features = torch.mean(stacked_features, dim=1) | |
| merged_features = self.dropout(merged_features) | |
| output = self.classifier(merged_features) | |
| return output |