Spaces:
Sleeping
Sleeping
| # ============================================================ | |
| # DELCAP β Medical Image Captioning (Hugging Face Space) | |
| # ============================================================ | |
| # ------------------------------ | |
| # Install dependencies (if needed) | |
| # ------------------------------ | |
| #!pip install torch torchvision --quiet | |
| #!pip install huggingface_hub --quiet | |
| #!pip install nltk --quiet | |
| #!pip install gradio --quiet | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import json | |
| import nltk | |
| from PIL import Image | |
| from collections import Counter | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| # Ensure punkt tokenizer is available | |
| nltk.download("punkt") | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| class Config: | |
| IMG_SIZE = 224 | |
| EMBED_SIZE = 256 | |
| HIDDEN_SIZE = 512 | |
| NUM_LSTM_LAYERS = 1 | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MAX_CAPTION_LENGTH = 50 | |
| config = Config() | |
| # ============================================================ | |
| # Tokenization | |
| # ============================================================ | |
| def tokenize_caption(text): | |
| return nltk.word_tokenize(text.lower()) | |
| # ============================================================ | |
| # Vocabulary | |
| # ============================================================ | |
| class Vocabulary: | |
| def __init__(self, freq_threshold=1): | |
| self.itos = { | |
| 0: "<pad>", | |
| 1: "<unk>", | |
| 2: "<sos>", | |
| 3: "<eos>" | |
| } | |
| self.stoi = {v: k for k, v in self.itos.items()} | |
| self.freq_threshold = freq_threshold | |
| self.vocab_size = len(self.itos) | |
| def __len__(self): | |
| return self.vocab_size | |
| def from_json(cls, json_data): | |
| vocab_obj = cls() | |
| vocab_obj.stoi = json_data['stoi'] | |
| vocab_obj.itos = {int(k): v for k, v in json_data['itos'].items()} | |
| vocab_obj.vocab_size = len(vocab_obj.stoi) | |
| return vocab_obj | |
| def idx_to_word(self, idx): | |
| return self.itos.get(idx, "<unk>") | |
| # ============================================================ | |
| # Encoder | |
| # ============================================================ | |
| class EncoderCNN(nn.Module): | |
| def __init__(self, embed_size): | |
| super().__init__() | |
| densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT) | |
| self.densenet_features = densenet.features | |
| for param in self.densenet_features.parameters(): | |
| param.requires_grad_(False) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.embed = nn.Linear(1024, embed_size) | |
| def forward(self, images): | |
| features = self.densenet_features(images) | |
| features = self.avgpool(features) | |
| features = features.view(features.size(0), -1) | |
| features = self.embed(features) | |
| return features | |
| # ============================================================ | |
| # Decoder | |
| # ============================================================ | |
| class DecoderRNN(nn.Module): | |
| def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1): | |
| super().__init__() | |
| self.embed = nn.Embedding(vocab_size, embed_size) | |
| self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) | |
| self.linear = nn.Linear(hidden_size, vocab_size) | |
| self.dropout = nn.Dropout(0.5) | |
| self.num_layers = num_layers | |
| self.hidden_size = hidden_size | |
| self.feature_to_hidden_state = nn.Linear(embed_size, hidden_size) | |
| def sample(self, features, max_len=20, vocab=None): | |
| self.eval() | |
| with torch.no_grad(): | |
| sampled_ids = [] | |
| initial_hidden = self.feature_to_hidden_state(features) | |
| h = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1) | |
| c = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1) | |
| hidden = (h, c) | |
| start_token = torch.tensor([vocab.stoi["<sos>"]], device=features.device) | |
| inputs = self.embed(start_token).unsqueeze(1) | |
| for _ in range(max_len): | |
| output, hidden = self.lstm(inputs, hidden) | |
| logits = self.linear(self.dropout(output.squeeze(1))) | |
| _, predicted = logits.max(1) | |
| sampled_ids.append(predicted) | |
| if predicted.item() == vocab.stoi["<eos>"]: | |
| break | |
| inputs = self.embed(predicted).unsqueeze(1) | |
| return torch.stack(sampled_ids) | |
| # ============================================================ | |
| # Load Vocabulary & Models | |
| # ============================================================ | |
| vocab_path = hf_hub_download("hackergeek/delcap", "vocab.json") | |
| with open(vocab_path, "r") as f: | |
| vocab_data = json.load(f) | |
| vocab = Vocabulary.from_json(vocab_data) | |
| encoder_path = hf_hub_download("hackergeek/delcap", "encoder.pth") | |
| decoder_path = hf_hub_download("hackergeek/delcap", "decoder.pth") | |
| encoder = EncoderCNN(config.EMBED_SIZE).to(config.DEVICE) | |
| encoder.load_state_dict(torch.load(encoder_path, map_location=config.DEVICE)) | |
| decoder_state = torch.load(decoder_path, map_location=config.DEVICE) | |
| vocab_size = decoder_state["linear.weight"].shape[0] | |
| decoder = DecoderRNN(config.EMBED_SIZE, config.HIDDEN_SIZE, vocab_size).to(config.DEVICE) | |
| decoder.load_state_dict(decoder_state) | |
| encoder.eval() | |
| decoder.eval() | |
| # ============================================================ | |
| # Image Preprocessing | |
| # ============================================================ | |
| transform = transforms.Compose([ | |
| transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # ============================================================ | |
| # Caption Generation | |
| # ============================================================ | |
| def generate_caption(image: Image.Image): | |
| image_tensor = transform(image).unsqueeze(0).to(config.DEVICE) | |
| with torch.no_grad(): | |
| features = encoder(image_tensor) | |
| sampled_ids = decoder.sample(features, max_len=config.MAX_CAPTION_LENGTH, vocab=vocab) | |
| caption = [] | |
| for token in sampled_ids.cpu().numpy(): | |
| word = vocab.idx_to_word(token.item()) | |
| if word in ["<sos>", "<pad>"]: | |
| continue | |
| if word == "<eos>": | |
| break | |
| caption.append(word) | |
| return " ".join(caption) | |
| # ============================================================ | |
| # Gradio Interface | |
| # ============================================================ | |
| iface = gr.Interface( | |
| fn=generate_caption, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Textbox(label="Generated Caption"), | |
| title="DELCAP β Medical Image Captioning", | |
| description="Upload a medical image and get a generated caption." | |
| ) | |
| iface.launch() |