# ============================================================ # DELCAP — Medical Image Captioning (Hugging Face Space) # ============================================================ # ------------------------------ # Install dependencies (if needed) # ------------------------------ #!pip install torch torchvision --quiet #!pip install huggingface_hub --quiet #!pip install nltk --quiet #!pip install gradio --quiet import torch import torch.nn as nn import torchvision.models as models import torchvision.transforms as transforms import json import nltk from PIL import Image from collections import Counter from huggingface_hub import hf_hub_download import gradio as gr # Ensure punkt tokenizer is available nltk.download("punkt") # ============================================================ # Configuration # ============================================================ class Config: IMG_SIZE = 224 EMBED_SIZE = 256 HIDDEN_SIZE = 512 NUM_LSTM_LAYERS = 1 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_CAPTION_LENGTH = 50 config = Config() # ============================================================ # Tokenization # ============================================================ def tokenize_caption(text): return nltk.word_tokenize(text.lower()) # ============================================================ # Vocabulary # ============================================================ class Vocabulary: def __init__(self, freq_threshold=1): self.itos = { 0: "", 1: "", 2: "", 3: "" } self.stoi = {v: k for k, v in self.itos.items()} self.freq_threshold = freq_threshold self.vocab_size = len(self.itos) def __len__(self): return self.vocab_size @classmethod def from_json(cls, json_data): vocab_obj = cls() vocab_obj.stoi = json_data['stoi'] vocab_obj.itos = {int(k): v for k, v in json_data['itos'].items()} vocab_obj.vocab_size = len(vocab_obj.stoi) return vocab_obj def idx_to_word(self, idx): return self.itos.get(idx, "") # ============================================================ # Encoder # ============================================================ class EncoderCNN(nn.Module): def __init__(self, embed_size): super().__init__() densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT) self.densenet_features = densenet.features for param in self.densenet_features.parameters(): param.requires_grad_(False) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.embed = nn.Linear(1024, embed_size) def forward(self, images): features = self.densenet_features(images) features = self.avgpool(features) features = features.view(features.size(0), -1) features = self.embed(features) return features # ============================================================ # Decoder # ============================================================ class DecoderRNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): super().__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) self.dropout = nn.Dropout(0.5) self.num_layers = num_layers self.hidden_size = hidden_size self.feature_to_hidden_state = nn.Linear(embed_size, hidden_size) def sample(self, features, max_len=20, vocab=None): self.eval() with torch.no_grad(): sampled_ids = [] initial_hidden = self.feature_to_hidden_state(features) h = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1) c = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1) hidden = (h, c) start_token = torch.tensor([vocab.stoi[""]], device=features.device) inputs = self.embed(start_token).unsqueeze(1) for _ in range(max_len): output, hidden = self.lstm(inputs, hidden) logits = self.linear(self.dropout(output.squeeze(1))) _, predicted = logits.max(1) sampled_ids.append(predicted) if predicted.item() == vocab.stoi[""]: break inputs = self.embed(predicted).unsqueeze(1) return torch.stack(sampled_ids) # ============================================================ # Load Vocabulary & Models # ============================================================ vocab_path = hf_hub_download("hackergeek/delcap", "vocab.json") with open(vocab_path, "r") as f: vocab_data = json.load(f) vocab = Vocabulary.from_json(vocab_data) encoder_path = hf_hub_download("hackergeek/delcap", "encoder.pth") decoder_path = hf_hub_download("hackergeek/delcap", "decoder.pth") encoder = EncoderCNN(config.EMBED_SIZE).to(config.DEVICE) encoder.load_state_dict(torch.load(encoder_path, map_location=config.DEVICE)) decoder_state = torch.load(decoder_path, map_location=config.DEVICE) vocab_size = decoder_state["linear.weight"].shape[0] decoder = DecoderRNN(config.EMBED_SIZE, config.HIDDEN_SIZE, vocab_size).to(config.DEVICE) decoder.load_state_dict(decoder_state) encoder.eval() decoder.eval() # ============================================================ # Image Preprocessing # ============================================================ transform = transforms.Compose([ transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # ============================================================ # Caption Generation # ============================================================ def generate_caption(image: Image.Image): image_tensor = transform(image).unsqueeze(0).to(config.DEVICE) with torch.no_grad(): features = encoder(image_tensor) sampled_ids = decoder.sample(features, max_len=config.MAX_CAPTION_LENGTH, vocab=vocab) caption = [] for token in sampled_ids.cpu().numpy(): word = vocab.idx_to_word(token.item()) if word in ["", ""]: continue if word == "": break caption.append(word) return " ".join(caption) # ============================================================ # Gradio Interface # ============================================================ iface = gr.Interface( fn=generate_caption, inputs=gr.Image(type="pil"), outputs=gr.Textbox(label="Generated Caption"), title="DELCAP — Medical Image Captioning", description="Upload a medical image and get a generated caption." ) iface.launch()