File size: 6,076 Bytes
c589499 6f83848 f1a89e8 c589499 f1a89e8 c589499 6f83848 c589499 6f83848 c589499 6f83848 c589499 6f83848 c589499 6f83848 c589499 6f83848 c589499 6f83848 c589499 f1a89e8 50d2262 f1a89e8 50d2262 f1a89e8 50d2262 f1a89e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""
Custom model class for LLM2Vec4CXR that properly handles latent attention pooling.
"""
from llm2vec.models.bidirectional_llama import LlamaBiModel
from transformers import PreTrainedModel
from transformers.models.llama.configuration_llama import LlamaConfig
# from llm2vec.pooling import LatentAttentionPooling
from .pooling_latent import LatentAttentionPooling
from transformers import AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
class LLM2Vec4CXRModel(PreTrainedModel):
"""
Wrapper model that includes LlamaBiModel and latent attention pooling.
Structure matches the saved checkpoint: self.model + self.latent_attn
"""
config_class = LlamaConfig
def __init__(self, config, **kwargs):
super().__init__(config, **kwargs)
# Wrap the LlamaBiModel
self.model = LlamaBiModel(config)
# Initialize latent attention pooling
self.latent_attn = LatentAttentionPooling(
d_model=config.hidden_size,
num_heads=8, # Standard for this model size
num_latents=512 # Standard for LLM2Vec
)
def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
"""
Forward pass that properly handles latent attention pooling.
"""
# Get base model output
outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
# Apply latent attention pooling
if embed_mask is not None:
# Use embed_mask for instruction-following tasks
pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
else:
# Use attention_mask for simple encoding
pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
return pooled_output
# --- Convenience tokenizer (lazy) -------------------------------------
def _get_tokenizer(self):
if not hasattr(self, "_hf_tokenizer"):
tok = AutoTokenizer.from_pretrained(getattr(self.config, "_name_or_path", "lukeingawesome/llm2vec4cxr"))
if tok.pad_token is None:
tok.pad_token = tok.eos_token
tok.padding_side = "left"
self._hf_tokenizer = tok
return self._hf_tokenizer
# --- Ensure latent_attn follows .to(device/dtype) ----------------------
def to(self, *args, **kwargs):
m = super().to(*args, **kwargs)
if hasattr(self, "latent_attn") and self.latent_attn is not None:
# Align latent_attn with the base weights' device & dtype
try:
device = next(p.device for p in self.parameters() if p is not None)
dtype = next((p.dtype for p in self.parameters() if p.is_floating_point()), None)
self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
except StopIteration:
pass
return m
# --- Simple text encoding (no instruction) ----------------------------
@torch.no_grad()
def encode_text(self, texts, max_length: int = 512):
tok = self._get_tokenizer()
enc = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
# For simple encoding we embed over all non‑pad tokens
enc["embed_mask"] = enc["attention_mask"].clone()
dev = next(self.parameters()).device
enc = {k: v.to(dev) for k, v in enc.items()}
return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])
# --- Instruction/text encoding with separator -------------------------
def _build_separator_inputs(self, texts, max_length: int, separator: str):
tok = self._get_tokenizer()
# Split into [instruction | text]; we embed only the trailing "text" part.
# If no separator, embed the entire text.
parts_after_sep = []
original = []
for t in texts:
parts = t.split(separator)
# If no separator found, use the entire text (not empty string)
parts_after_sep.append(parts[1] if len(parts) > 1 else parts[0])
original.append("".join(parts))
tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
# Build an embed_mask that lights up only the trailing "text" span
embed_mask = None
for i, t in enumerate(parts_after_sep):
sub = tok([t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False)
m = torch.zeros_like(tokenized["attention_mask"][i])
if len(sub["input_ids"][0]) > 0:
m[-len(sub["input_ids"][0]):] = 1
else:
# If tokenization resulted in 0 tokens, use attention_mask (embed everything)
m = tokenized["attention_mask"][i].clone()
embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
tokenized["embed_mask"] = embed_mask
return tokenized
@torch.no_grad()
def encode_with_separator(self, texts, separator: str = "!@#$%^&*()", max_length: int = 512):
enc = self._build_separator_inputs(texts, max_length=max_length, separator=separator)
dev = next(self.parameters()).device
enc = {k: v.to(dev) for k, v in enc.items()}
return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])
# --- One‑liner cosine similarity over instruction+text ----------------
@torch.no_grad()
def compute_similarities(self, query_text: str, candidate_texts, separator: str = "!@#$%^&*()", max_length: int = 512):
all_texts = [query_text] + list(candidate_texts)
embs = self.encode_with_separator(all_texts, separator=separator, max_length=max_length)
# embs: [N, 2048]; compare query vs candidates
return F.cosine_similarity(embs[0], embs[1:], dim=1)
|