File size: 6,076 Bytes
c589499
 
 
 
 
6f83848
 
f1a89e8
 
 
c589499
 
f1a89e8
c589499
 
6f83848
c589499
6f83848
 
c589499
6f83848
c589499
 
 
 
6f83848
 
 
c589499
 
 
 
 
 
 
 
 
 
 
 
6f83848
c589499
6f83848
 
 
 
 
 
 
c589499
6f83848
c589499
f1a89e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50d2262
f1a89e8
 
 
 
50d2262
 
f1a89e8
 
 
 
 
 
 
 
 
 
50d2262
 
 
f1a89e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Custom model class for LLM2Vec4CXR that properly handles latent attention pooling.
"""

from llm2vec.models.bidirectional_llama import LlamaBiModel
from transformers import PreTrainedModel
from transformers.models.llama.configuration_llama import LlamaConfig
# from llm2vec.pooling import LatentAttentionPooling
from .pooling_latent import LatentAttentionPooling
from transformers import AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F


class LLM2Vec4CXRModel(PreTrainedModel):
    """
    Wrapper model that includes LlamaBiModel and latent attention pooling.
    Structure matches the saved checkpoint: self.model + self.latent_attn
    """
    config_class = LlamaConfig
    
    def __init__(self, config, **kwargs):
        super().__init__(config, **kwargs)
        
        # Wrap the LlamaBiModel
        self.model = LlamaBiModel(config)
        
        # Initialize latent attention pooling
        self.latent_attn = LatentAttentionPooling(
            d_model=config.hidden_size,
            num_heads=8,  # Standard for this model size
            num_latents=512  # Standard for LLM2Vec
        )
    
    def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs):
        """
        Forward pass that properly handles latent attention pooling.
        """
        # Get base model output
        outputs = self.model(input_ids, attention_mask=attention_mask, **kwargs)
        
        # Apply latent attention pooling
        if embed_mask is not None:
            # Use embed_mask for instruction-following tasks
            pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask)
        else:
            # Use attention_mask for simple encoding
            pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask)
        
        return pooled_output

    # --- Convenience tokenizer (lazy) -------------------------------------
    def _get_tokenizer(self):
        if not hasattr(self, "_hf_tokenizer"):
            tok = AutoTokenizer.from_pretrained(getattr(self.config, "_name_or_path", "lukeingawesome/llm2vec4cxr"))
            if tok.pad_token is None:
                tok.pad_token = tok.eos_token
            tok.padding_side = "left"
            self._hf_tokenizer = tok
        return self._hf_tokenizer

    # --- Ensure latent_attn follows .to(device/dtype) ----------------------
    def to(self, *args, **kwargs):
        m = super().to(*args, **kwargs)
        if hasattr(self, "latent_attn") and self.latent_attn is not None:
            # Align latent_attn with the base weights' device & dtype
            try:
                device = next(p.device for p in self.parameters() if p is not None)
                dtype  = next((p.dtype for p in self.parameters() if p.is_floating_point()), None)
                self.latent_attn = self.latent_attn.to(device=device, dtype=dtype)
            except StopIteration:
                pass
        return m

    # --- Simple text encoding (no instruction) ----------------------------
    @torch.no_grad()
    def encode_text(self, texts, max_length: int = 512):
        tok = self._get_tokenizer()
        enc = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        # For simple encoding we embed over all non‑pad tokens
        enc["embed_mask"] = enc["attention_mask"].clone()
        dev = next(self.parameters()).device
        enc = {k: v.to(dev) for k, v in enc.items()}
        return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])

    # --- Instruction/text encoding with separator -------------------------
    def _build_separator_inputs(self, texts, max_length: int, separator: str):
        tok = self._get_tokenizer()
        # Split into [instruction | text]; we embed only the trailing "text" part.
        # If no separator, embed the entire text.
        parts_after_sep = []
        original = []
        for t in texts:
            parts = t.split(separator)
            # If no separator found, use the entire text (not empty string)
            parts_after_sep.append(parts[1] if len(parts) > 1 else parts[0])
            original.append("".join(parts))

        tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        # Build an embed_mask that lights up only the trailing "text" span
        embed_mask = None
        for i, t in enumerate(parts_after_sep):
            sub = tok([t], return_tensors="pt", padding=True, truncation=True, max_length=max_length, add_special_tokens=False)
            m = torch.zeros_like(tokenized["attention_mask"][i])
            if len(sub["input_ids"][0]) > 0:
                m[-len(sub["input_ids"][0]):] = 1
            else:
                # If tokenization resulted in 0 tokens, use attention_mask (embed everything)
                m = tokenized["attention_mask"][i].clone()
            embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)

        tokenized["embed_mask"] = embed_mask
        return tokenized

    @torch.no_grad()
    def encode_with_separator(self, texts, separator: str = "!@#$%^&*()", max_length: int = 512):
        enc = self._build_separator_inputs(texts, max_length=max_length, separator=separator)
        dev = next(self.parameters()).device
        enc = {k: v.to(dev) for k, v in enc.items()}
        return self(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], embed_mask=enc["embed_mask"])

    # --- One‑liner cosine similarity over instruction+text ----------------
    @torch.no_grad()
    def compute_similarities(self, query_text: str, candidate_texts, separator: str = "!@#$%^&*()", max_length: int = 512):
        all_texts = [query_text] + list(candidate_texts)
        embs = self.encode_with_separator(all_texts, separator=separator, max_length=max_length)
        # embs: [N, 2048]; compare query vs candidates
        return F.cosine_similarity(embs[0], embs[1:], dim=1)