DELCAP / app.py
hackergeek's picture
Update app.py
6c80c52 verified
raw
history blame
6.99 kB
# ============================================================
# DELCAP β€” Medical Image Captioning (Hugging Face Space)
# ============================================================
# ------------------------------
# Install dependencies (if needed)
# ------------------------------
#!pip install torch torchvision --quiet
#!pip install huggingface_hub --quiet
#!pip install nltk --quiet
#!pip install gradio --quiet
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
import json
import nltk
from PIL import Image
from collections import Counter
from huggingface_hub import hf_hub_download
import gradio as gr
# Ensure punkt tokenizer is available
nltk.download("punkt")
# ============================================================
# Configuration
# ============================================================
class Config:
IMG_SIZE = 224
EMBED_SIZE = 256
HIDDEN_SIZE = 512
NUM_LSTM_LAYERS = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_CAPTION_LENGTH = 50
config = Config()
# ============================================================
# Tokenization
# ============================================================
def tokenize_caption(text):
return nltk.word_tokenize(text.lower())
# ============================================================
# Vocabulary
# ============================================================
class Vocabulary:
def __init__(self, freq_threshold=1):
self.itos = {
0: "<pad>",
1: "<unk>",
2: "<sos>",
3: "<eos>"
}
self.stoi = {v: k for k, v in self.itos.items()}
self.freq_threshold = freq_threshold
self.vocab_size = len(self.itos)
def __len__(self):
return self.vocab_size
@classmethod
def from_json(cls, json_data):
vocab_obj = cls()
vocab_obj.stoi = json_data['stoi']
vocab_obj.itos = {int(k): v for k, v in json_data['itos'].items()}
vocab_obj.vocab_size = len(vocab_obj.stoi)
return vocab_obj
def idx_to_word(self, idx):
return self.itos.get(idx, "<unk>")
# ============================================================
# Encoder
# ============================================================
class EncoderCNN(nn.Module):
def __init__(self, embed_size):
super().__init__()
densenet = models.densenet121(weights=models.DenseNet121_Weights.DEFAULT)
self.densenet_features = densenet.features
for param in self.densenet_features.parameters():
param.requires_grad_(False)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.embed = nn.Linear(1024, embed_size)
def forward(self, images):
features = self.densenet_features(images)
features = self.avgpool(features)
features = features.view(features.size(0), -1)
features = self.embed(features)
return features
# ============================================================
# Decoder
# ============================================================
class DecoderRNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
super().__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.dropout = nn.Dropout(0.5)
self.num_layers = num_layers
self.hidden_size = hidden_size
self.feature_to_hidden_state = nn.Linear(embed_size, hidden_size)
def sample(self, features, max_len=20, vocab=None):
self.eval()
with torch.no_grad():
sampled_ids = []
initial_hidden = self.feature_to_hidden_state(features)
h = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)
c = initial_hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)
hidden = (h, c)
start_token = torch.tensor([vocab.stoi["<sos>"]], device=features.device)
inputs = self.embed(start_token).unsqueeze(1)
for _ in range(max_len):
output, hidden = self.lstm(inputs, hidden)
logits = self.linear(self.dropout(output.squeeze(1)))
_, predicted = logits.max(1)
sampled_ids.append(predicted)
if predicted.item() == vocab.stoi["<eos>"]:
break
inputs = self.embed(predicted).unsqueeze(1)
return torch.stack(sampled_ids)
# ============================================================
# Load Vocabulary & Models
# ============================================================
vocab_path = hf_hub_download("hackergeek/delcap", "vocab.json")
with open(vocab_path, "r") as f:
vocab_data = json.load(f)
vocab = Vocabulary.from_json(vocab_data)
encoder_path = hf_hub_download("hackergeek/delcap", "encoder.pth")
decoder_path = hf_hub_download("hackergeek/delcap", "decoder.pth")
encoder = EncoderCNN(config.EMBED_SIZE).to(config.DEVICE)
encoder.load_state_dict(torch.load(encoder_path, map_location=config.DEVICE))
decoder_state = torch.load(decoder_path, map_location=config.DEVICE)
vocab_size = decoder_state["linear.weight"].shape[0]
decoder = DecoderRNN(config.EMBED_SIZE, config.HIDDEN_SIZE, vocab_size).to(config.DEVICE)
decoder.load_state_dict(decoder_state)
encoder.eval()
decoder.eval()
# ============================================================
# Image Preprocessing
# ============================================================
transform = transforms.Compose([
transforms.Resize((config.IMG_SIZE, config.IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
# ============================================================
# Caption Generation
# ============================================================
def generate_caption(image: Image.Image):
image_tensor = transform(image).unsqueeze(0).to(config.DEVICE)
with torch.no_grad():
features = encoder(image_tensor)
sampled_ids = decoder.sample(features, max_len=config.MAX_CAPTION_LENGTH, vocab=vocab)
caption = []
for token in sampled_ids.cpu().numpy():
word = vocab.idx_to_word(token.item())
if word in ["<sos>", "<pad>"]:
continue
if word == "<eos>":
break
caption.append(word)
return " ".join(caption)
# ============================================================
# Gradio Interface
# ============================================================
iface = gr.Interface(
fn=generate_caption,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(label="Generated Caption"),
title="DELCAP β€” Medical Image Captioning",
description="Upload a medical image and get a generated caption."
)
iface.launch()