lukeingawesome commited on
Commit
1880e8e
·
verified ·
1 Parent(s): b5f02cc

Create llm2vec_wrapper.py

Browse files
Files changed (1) hide show
  1. llm2vec_wrapper.py +423 -0
llm2vec_wrapper.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm2vec import LLM2Vec
2
+ from peft import PeftModel
3
+ from transformers import (
4
+ AutoConfig,
5
+ PretrainedConfig,
6
+ AutoTokenizer,
7
+
8
+ )
9
+ import torch
10
+ import logging
11
+ import json
12
+ import os
13
+ logger = logging.getLogger(__name__)
14
+ class LLM2VecWrapper(LLM2Vec):
15
+ def __init__(self, *args, **kwargs):
16
+ super(LLM2VecWrapper, self).__init__(*args, **kwargs)
17
+
18
+ def to(self, device_or_dtype):
19
+ """Override to method to ensure all modules are properly moved."""
20
+ result = super().to(device_or_dtype)
21
+
22
+ # Ensure latent attention pooling is also moved
23
+ if hasattr(result, 'latent_attn') and result.latent_attn is not None:
24
+ result.latent_attn = result.latent_attn.to(device_or_dtype)
25
+
26
+ return result
27
+
28
+ def prepare_for_tokenization(self, text):
29
+ text = (
30
+ "<|start_header_id|>user<|end_header_id|>\n\n"
31
+ + text.strip()
32
+ + "<|eot_id|>"
33
+ )
34
+ return text
35
+
36
+ def encode_text(self, text, max_length=None):
37
+ """
38
+ Encode text to embeddings with proper embed_mask handling.
39
+
40
+ Args:
41
+ text (str or list): Text(s) to encode
42
+ max_length (int, optional): Maximum sequence length
43
+
44
+ Returns:
45
+ torch.Tensor: Text embeddings
46
+ """
47
+ if max_length is None:
48
+ max_length = getattr(self, 'max_length', 512)
49
+
50
+ inputs = self.tokenizer(
51
+ text,
52
+ return_tensors="pt",
53
+ padding=True,
54
+ truncation=True,
55
+ max_length=max_length
56
+ )
57
+
58
+ # Add embed_mask (same as attention_mask for simple text encoding)
59
+ inputs["embed_mask"] = inputs["attention_mask"].clone()
60
+
61
+ # Move to same device as model
62
+ import torch
63
+ model_device = next(self.parameters()).device
64
+ inputs = {k: v.to(model_device) for k, v in inputs.items()}
65
+
66
+ with torch.no_grad():
67
+ embeddings = self(inputs)
68
+
69
+ return embeddings
70
+
71
+ def tokenize_with_separator(self, texts, max_length=None, separator='!@#$%^&*()'):
72
+ """
73
+ Tokenize texts with special handling for separator-based splitting.
74
+ This is useful for instruction-following tasks.
75
+
76
+ Args:
77
+ texts (list): List of texts to tokenize
78
+ max_length (int, optional): Maximum sequence length
79
+ separator (str): Separator to split instruction from text
80
+
81
+ Returns:
82
+ dict: Tokenized inputs with attention masks and embed masks
83
+ """
84
+ if max_length is None:
85
+ max_length = getattr(self, 'max_length', 512)
86
+
87
+ texts_2 = []
88
+ original_texts = []
89
+
90
+ for text in texts:
91
+ parts = text.split(separator)
92
+ texts_2.append(parts[1] if len(parts) > 1 else "")
93
+ original_texts.append("".join(parts))
94
+
95
+ # Tokenize original texts
96
+ tokenized = self.tokenizer(
97
+ original_texts,
98
+ return_tensors="pt",
99
+ padding=True,
100
+ truncation=True,
101
+ max_length=max_length,
102
+ )
103
+
104
+ # Create embedding masks for the separated parts
105
+ import torch
106
+ embed_mask = None
107
+ for t_i, t in enumerate(texts_2):
108
+ ids = self.tokenizer(
109
+ [t],
110
+ return_tensors="pt",
111
+ padding=True,
112
+ truncation=True,
113
+ max_length=max_length,
114
+ add_special_tokens=False,
115
+ )
116
+
117
+ e_m = torch.zeros_like(tokenized["attention_mask"][t_i])
118
+ if len(ids["input_ids"][0]) > 0:
119
+ e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
120
+
121
+ if embed_mask is None:
122
+ embed_mask = e_m.unsqueeze(0)
123
+ else:
124
+ embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
125
+
126
+ tokenized["embed_mask"] = embed_mask
127
+ return tokenized
128
+
129
+ def encode_with_instruction(self, texts, max_length=None, separator='!@#$%^&*()'):
130
+ """
131
+ Encode texts with instruction-following using separator-based processing.
132
+
133
+ Args:
134
+ texts (list): List of texts with instructions separated by separator
135
+ max_length (int, optional): Maximum sequence length
136
+ separator (str): Separator between instruction and text
137
+
138
+ Returns:
139
+ torch.Tensor: Text embeddings
140
+ """
141
+ tokenized = self.tokenize_with_separator(texts, max_length, separator)
142
+
143
+ # Move to same device as model
144
+ import torch
145
+ model_device = next(self.parameters()).device
146
+ tokenized = {k: v.to(model_device) for k, v in tokenized.items()}
147
+
148
+ with torch.no_grad():
149
+ embeddings = self(tokenized)
150
+
151
+ return embeddings
152
+
153
+ def encode_with_separator(self, texts, device=None, max_length=None, separator='!@#$%^&*()'):
154
+ """
155
+ Encode texts with special separator-based handling for instruction/text pairs.
156
+
157
+ Args:
158
+ texts (list): List of texts to encode (with separator for instruction/text pairs)
159
+ device: Device to run on (auto-detect if None)
160
+ max_length: Maximum sequence length (use model default if None)
161
+ separator: Separator string for instruction/text pairs
162
+
163
+ Returns:
164
+ torch.Tensor: Embeddings for the texts
165
+ """
166
+ if device is None:
167
+ device = next(self.parameters()).device
168
+ if max_length is None:
169
+ max_length = 512
170
+
171
+ # Ensure model is on the right device
172
+ self = self.to(device)
173
+
174
+ # Process texts with separator
175
+ texts_2 = []
176
+ original_texts = []
177
+
178
+ for text in texts:
179
+ parts = text.split(separator)
180
+ texts_2.append(parts[1] if len(parts) > 1 else "")
181
+ original_texts.append("".join(parts))
182
+
183
+ # Tokenize original texts
184
+ tokenized = self.tokenizer(
185
+ original_texts,
186
+ return_tensors="pt",
187
+ padding=True,
188
+ truncation=True,
189
+ max_length=max_length,
190
+ )
191
+
192
+ # Create embedding masks
193
+ embed_mask = None
194
+ for t_i, t in enumerate(texts_2):
195
+ ids = self.tokenizer(
196
+ [t],
197
+ return_tensors="pt",
198
+ padding=True,
199
+ truncation=True,
200
+ max_length=max_length,
201
+ add_special_tokens=False,
202
+ )
203
+
204
+ e_m = torch.zeros_like(tokenized["attention_mask"][t_i])
205
+ if len(ids["input_ids"][0]) > 0:
206
+ e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0]))
207
+
208
+ if embed_mask is None:
209
+ embed_mask = e_m.unsqueeze(0)
210
+ else:
211
+ embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
212
+
213
+ tokenized["embed_mask"] = embed_mask
214
+
215
+ # Move to device and compute embeddings
216
+ tokenized = {k: v.to(device) for k, v in tokenized.items()}
217
+ tokenized = {k: v.to(self.model.dtype) if v.dtype.is_floating_point else v
218
+ for k, v in tokenized.items()}
219
+
220
+ with torch.no_grad():
221
+ embeddings = self(tokenized)
222
+
223
+ return embeddings
224
+
225
+ def compute_similarities(self, query_text, candidate_texts, device=None, separator='!@#$%^&*()'):
226
+ """
227
+ Compute similarity scores between a query text and candidate texts.
228
+
229
+ Args:
230
+ query_text (str): The query text (with separator for instruction/text pairs)
231
+ candidate_texts (list): List of candidate texts to compare against
232
+ device: Device to run on (auto-detect if None)
233
+ separator: Separator string for instruction/text pairs
234
+
235
+ Returns:
236
+ torch.Tensor: Similarity scores for each candidate
237
+ """
238
+ import torch.nn.functional as F
239
+
240
+ if device is None:
241
+ device = next(self.parameters()).device
242
+
243
+ # Combine query and candidates
244
+ all_texts = [query_text] + candidate_texts
245
+
246
+ # Get embeddings
247
+ embeddings = self.encode_with_separator(all_texts, device=device, separator=separator)
248
+
249
+ # Compute similarities between query (first embedding) and candidates
250
+ similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1)
251
+
252
+ return similarities
253
+
254
+ def _load_latent_attention_weights(self, model_path, use_safetensors=True):
255
+ """
256
+ Automatically load latent attention weights from model files.
257
+
258
+ Args:
259
+ model_path: Path to model (local directory or HuggingFace repo)
260
+ use_safetensors: Whether to use safetensors format
261
+ """
262
+ import os
263
+
264
+ if os.path.isdir(model_path):
265
+ # Local directory - try pytorch_model.bin first
266
+ pytorch_model_path = os.path.join(model_path, "pytorch_model.bin")
267
+ if os.path.exists(pytorch_model_path):
268
+ print(f"Loading latent attention weights from {pytorch_model_path}")
269
+ try:
270
+ import torch
271
+ state_dict = torch.load(pytorch_model_path, weights_only=True)
272
+ latent_attn_weights = {k: v for k, v in state_dict.items() if k.startswith('latent_attn.')}
273
+
274
+ if latent_attn_weights:
275
+ missing_keys, unexpected_keys = self.latent_attn.load_state_dict(
276
+ {k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()},
277
+ strict=False
278
+ )
279
+ if not missing_keys and not unexpected_keys:
280
+ print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights")
281
+ else:
282
+ print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}")
283
+ else:
284
+ print("⚠️ No latent attention weights found in the model file")
285
+ except Exception as e:
286
+ print(f"❌ Error loading latent attention weights: {e}")
287
+ else:
288
+ # HuggingFace repository - load from safetensors
289
+ if use_safetensors:
290
+ print("Loading latent attention weights from HuggingFace safetensors...")
291
+ try:
292
+ from safetensors.torch import load_file
293
+ from huggingface_hub import hf_hub_download
294
+
295
+ # Download the safetensors file
296
+ safetensors_path = hf_hub_download(repo_id=model_path, filename="model.safetensors")
297
+
298
+ # Load weights from safetensors
299
+ safetensors_weights = load_file(safetensors_path)
300
+
301
+ # Extract latent attention weights
302
+ latent_attn_weights = {k: v for k, v in safetensors_weights.items() if k.startswith('latent_attn.')}
303
+
304
+ if latent_attn_weights:
305
+ print(f"Found {len(latent_attn_weights)} latent attention weights in safetensors")
306
+
307
+ # Load the weights into the latent attention module
308
+ missing_keys, unexpected_keys = self.latent_attn.load_state_dict(
309
+ {k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()},
310
+ strict=False
311
+ )
312
+
313
+ if not missing_keys and not unexpected_keys:
314
+ print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights from safetensors")
315
+ else:
316
+ print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}")
317
+ else:
318
+ print("⚠️ No latent attention weights found in safetensors file")
319
+
320
+ except Exception as e:
321
+ print(f"❌ Error loading latent attention weights from safetensors: {e}")
322
+
323
+ @classmethod
324
+ def from_pretrained(
325
+ cls,
326
+ base_model_name_or_path,
327
+ peft_model_name_or_path=None,
328
+ merge_peft=False,
329
+ enable_bidirectional=True,
330
+ extra_model_name_or_path=None,
331
+ **kwargs,
332
+ ):
333
+ # pop out encoder args
334
+ keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"]
335
+ encoder_args = {
336
+ key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None
337
+ }
338
+
339
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
340
+ tokenizer.pad_token = tokenizer.eos_token
341
+ tokenizer.padding_side = "left"
342
+
343
+ config = AutoConfig.from_pretrained(base_model_name_or_path)
344
+ config_class_name = config.__class__.__name__
345
+
346
+ model_class = cls._get_model_class(
347
+ config_class_name, enable_bidirectional=enable_bidirectional
348
+ )
349
+ model = model_class.from_pretrained(base_model_name_or_path, **kwargs)
350
+
351
+ if os.path.isdir(base_model_name_or_path) and os.path.exists(
352
+ f"{base_model_name_or_path}/config.json"
353
+ ):
354
+ with open(f"{base_model_name_or_path}/config.json", "r") as fIn:
355
+ config_dict = json.load(fIn)
356
+ config = PretrainedConfig.from_dict(config_dict)
357
+ model.config._name_or_path = config._name_or_path
358
+
359
+ # For special case where config.json and adapter weights are in the same directory
360
+ if hasattr(model, "peft_config"):
361
+ model = PeftModel.from_pretrained(
362
+ model,
363
+ base_model_name_or_path,
364
+ )
365
+ model = model.merge_and_unload()
366
+
367
+ if peft_model_name_or_path is not None:
368
+ model = PeftModel.from_pretrained(
369
+ model,
370
+ peft_model_name_or_path,
371
+ )
372
+ if merge_peft:
373
+ model = model.merge_and_unload()
374
+ if extra_model_name_or_path is not None:
375
+ logger.info(f"Loading extra model from {extra_model_name_or_path}")
376
+ if not merge_peft:
377
+ model = model.merge_and_unload()
378
+ if isinstance(extra_model_name_or_path, str):
379
+ model = PeftModel.from_pretrained(
380
+ model,
381
+ extra_model_name_or_path,
382
+ )
383
+ model = model.merge_and_unload()
384
+ elif isinstance(extra_model_name_or_path, list):
385
+ for extra_model in extra_model_name_or_path:
386
+ model = PeftModel.from_pretrained(
387
+ model,
388
+ extra_model,
389
+ )
390
+ peft_model_name_or_path = extra_model
391
+ model = model.merge_and_unload()
392
+ else:
393
+ raise ValueError(
394
+ f"extra_model_name_or_path should be a string or a list of strings."
395
+ )
396
+ config = {}
397
+ config_addr = (
398
+ peft_model_name_or_path
399
+ if peft_model_name_or_path is not None
400
+ else base_model_name_or_path
401
+ )
402
+ if os.path.exists(f"{config_addr}/llm2vec_config.json"):
403
+ with open(f"{config_addr}/llm2vec_config.json", "r") as fIn:
404
+ llm2vec_config = json.load(fIn)
405
+ config.update(llm2vec_config)
406
+
407
+ for key, value in encoder_args.items():
408
+ config[key] = value
409
+
410
+ llm2vec_model = cls(model=model, tokenizer=tokenizer, **config)
411
+
412
+ # Auto-load latent attention weights if using latent_attention pooling
413
+ if (hasattr(llm2vec_model, 'latent_attn') and
414
+ llm2vec_model.latent_attn is not None and
415
+ llm2vec_model.pooling_mode == "latent_attention"):
416
+
417
+ llm2vec_model._load_latent_attention_weights(base_model_name_or_path, kwargs.get('use_safetensors', True))
418
+
419
+ # Ensure the entire model is converted to the requested dtype
420
+ if 'torch_dtype' in kwargs and kwargs['torch_dtype'] is not None:
421
+ llm2vec_model = llm2vec_model.to(kwargs['torch_dtype'])
422
+
423
+ return llm2vec_model