Divyanshu Tak
Add BrainIAC IDH Classification app with Vision Transformer model
65bee5d
import torch
import torch.nn as nn
from monai.networks.nets import ViT
import os
class ViTBackboneNet(nn.Module):
def __init__(self, simclr_ckpt_path: str):
super().__init__()
self.backbone = ViT(
in_channels=1,
img_size=(96, 96, 96),
patch_size=(16, 16, 16),
hidden_size=768,
mlp_dim=3072,
num_layers=12,
num_heads=12,
save_attn=True,
)
# Load pretrained weights from SimCLR checkpoint if provided
if simclr_ckpt_path and os.path.exists(simclr_ckpt_path):
ckpt = torch.load(simclr_ckpt_path, map_location="cpu", weights_only=False)
state_dict = ckpt.get("state_dict", ckpt)
backbone_state_dict = {}
for key, value in state_dict.items():
if key.startswith("backbone."):
new_key = key[len("backbone."):]
backbone_state_dict[new_key] = value
missing, unexpected = self.backbone.load_state_dict(backbone_state_dict, strict=False)
print(f"Loaded SimCLR backbone weights. Missing: {len(missing)}, Unexpected: {len(unexpected)}")
else:
print("Warning: SimCLR checkpoint not found or not provided. Using randomly initialized backbone.")
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
cls_token = features[0][:, 0]
return cls_token
class Classifier(nn.Module):
def __init__(self, d_model: int = 768, num_classes: int = 1):
super().__init__()
self.fc = nn.Linear(d_model, num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc(x)
class SingleScanModelBP(nn.Module):
def __init__(self, backbone: nn.Module, classifier: nn.Module):
super().__init__()
self.backbone = backbone
self.classifier = classifier
self.dropout = nn.Dropout(p=0.2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x shape: (batch_size, 2, C, D, H, W)
scan_features_list = []
for scan_tensor_with_extra_dim in x.split(1, dim=1):
squeezed_scan_tensor = scan_tensor_with_extra_dim.squeeze(1)
feature = self.backbone(squeezed_scan_tensor)
scan_features_list.append(feature)
stacked_features = torch.stack(scan_features_list, dim=1)
merged_features = torch.mean(stacked_features, dim=1)
merged_features = self.dropout(merged_features)
output = self.classifier(merged_features)
return output