|
|
from llm2vec import LLM2Vec |
|
|
from peft import PeftModel |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
PretrainedConfig, |
|
|
AutoTokenizer, |
|
|
|
|
|
) |
|
|
import torch |
|
|
import logging |
|
|
import json |
|
|
import os |
|
|
logger = logging.getLogger(__name__) |
|
|
class LLM2VecWrapper(LLM2Vec): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super(LLM2VecWrapper, self).__init__(*args, **kwargs) |
|
|
|
|
|
def to(self, device_or_dtype): |
|
|
"""Override to method to ensure all modules are properly moved.""" |
|
|
result = super().to(device_or_dtype) |
|
|
|
|
|
|
|
|
if hasattr(result, 'latent_attn') and result.latent_attn is not None: |
|
|
result.latent_attn = result.latent_attn.to(device_or_dtype) |
|
|
|
|
|
return result |
|
|
|
|
|
def prepare_for_tokenization(self, text): |
|
|
text = ( |
|
|
"<|start_header_id|>user<|end_header_id|>\n\n" |
|
|
+ text.strip() |
|
|
+ "<|eot_id|>" |
|
|
) |
|
|
return text |
|
|
|
|
|
def encode_text(self, text, max_length=None): |
|
|
""" |
|
|
Encode text to embeddings with proper embed_mask handling. |
|
|
|
|
|
Args: |
|
|
text (str or list): Text(s) to encode |
|
|
max_length (int, optional): Maximum sequence length |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Text embeddings |
|
|
""" |
|
|
if max_length is None: |
|
|
max_length = getattr(self, 'max_length', 512) |
|
|
|
|
|
inputs = self.tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length |
|
|
) |
|
|
|
|
|
|
|
|
inputs["embed_mask"] = inputs["attention_mask"].clone() |
|
|
|
|
|
|
|
|
import torch |
|
|
model_device = next(self.parameters()).device |
|
|
inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self(inputs) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def tokenize_with_separator(self, texts, max_length=None, separator='!@#$%^&*()'): |
|
|
""" |
|
|
Tokenize texts with special handling for separator-based splitting. |
|
|
This is useful for instruction-following tasks. |
|
|
|
|
|
Args: |
|
|
texts (list): List of texts to tokenize |
|
|
max_length (int, optional): Maximum sequence length |
|
|
separator (str): Separator to split instruction from text |
|
|
|
|
|
Returns: |
|
|
dict: Tokenized inputs with attention masks and embed masks |
|
|
""" |
|
|
if max_length is None: |
|
|
max_length = getattr(self, 'max_length', 512) |
|
|
|
|
|
texts_2 = [] |
|
|
original_texts = [] |
|
|
|
|
|
for text in texts: |
|
|
parts = text.split(separator) |
|
|
texts_2.append(parts[1] if len(parts) > 1 else "") |
|
|
original_texts.append("".join(parts)) |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer( |
|
|
original_texts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
) |
|
|
|
|
|
|
|
|
import torch |
|
|
embed_mask = None |
|
|
for t_i, t in enumerate(texts_2): |
|
|
ids = self.tokenizer( |
|
|
[t], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
e_m = torch.zeros_like(tokenized["attention_mask"][t_i]) |
|
|
if len(ids["input_ids"][0]) > 0: |
|
|
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) |
|
|
|
|
|
if embed_mask is None: |
|
|
embed_mask = e_m.unsqueeze(0) |
|
|
else: |
|
|
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) |
|
|
|
|
|
tokenized["embed_mask"] = embed_mask |
|
|
return tokenized |
|
|
|
|
|
def encode_with_instruction(self, texts, max_length=None, separator='!@#$%^&*()'): |
|
|
""" |
|
|
Encode texts with instruction-following using separator-based processing. |
|
|
|
|
|
Args: |
|
|
texts (list): List of texts with instructions separated by separator |
|
|
max_length (int, optional): Maximum sequence length |
|
|
separator (str): Separator between instruction and text |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Text embeddings |
|
|
""" |
|
|
tokenized = self.tokenize_with_separator(texts, max_length, separator) |
|
|
|
|
|
|
|
|
import torch |
|
|
model_device = next(self.parameters()).device |
|
|
tokenized = {k: v.to(model_device) for k, v in tokenized.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self(tokenized) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def encode_with_separator(self, texts, device=None, max_length=None, separator='!@#$%^&*()'): |
|
|
""" |
|
|
Encode texts with special separator-based handling for instruction/text pairs. |
|
|
|
|
|
Args: |
|
|
texts (list): List of texts to encode (with separator for instruction/text pairs) |
|
|
device: Device to run on (auto-detect if None) |
|
|
max_length: Maximum sequence length (use model default if None) |
|
|
separator: Separator string for instruction/text pairs |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Embeddings for the texts |
|
|
""" |
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
if max_length is None: |
|
|
max_length = 512 |
|
|
|
|
|
|
|
|
self = self.to(device) |
|
|
|
|
|
|
|
|
texts_2 = [] |
|
|
original_texts = [] |
|
|
|
|
|
for text in texts: |
|
|
parts = text.split(separator) |
|
|
texts_2.append(parts[1] if len(parts) > 1 else "") |
|
|
original_texts.append("".join(parts)) |
|
|
|
|
|
|
|
|
tokenized = self.tokenizer( |
|
|
original_texts, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
) |
|
|
|
|
|
|
|
|
embed_mask = None |
|
|
for t_i, t in enumerate(texts_2): |
|
|
ids = self.tokenizer( |
|
|
[t], |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=max_length, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
e_m = torch.zeros_like(tokenized["attention_mask"][t_i]) |
|
|
if len(ids["input_ids"][0]) > 0: |
|
|
e_m[-len(ids["input_ids"][0]):] = torch.ones(len(ids["input_ids"][0])) |
|
|
|
|
|
if embed_mask is None: |
|
|
embed_mask = e_m.unsqueeze(0) |
|
|
else: |
|
|
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) |
|
|
|
|
|
tokenized["embed_mask"] = embed_mask |
|
|
|
|
|
|
|
|
tokenized = {k: v.to(device) for k, v in tokenized.items()} |
|
|
tokenized = {k: v.to(self.model.dtype) if v.dtype.is_floating_point else v |
|
|
for k, v in tokenized.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self(tokenized) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def compute_similarities(self, query_text, candidate_texts, device=None, separator='!@#$%^&*()'): |
|
|
""" |
|
|
Compute similarity scores between a query text and candidate texts. |
|
|
|
|
|
Args: |
|
|
query_text (str): The query text (with separator for instruction/text pairs) |
|
|
candidate_texts (list): List of candidate texts to compare against |
|
|
device: Device to run on (auto-detect if None) |
|
|
separator: Separator string for instruction/text pairs |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Similarity scores for each candidate |
|
|
""" |
|
|
import torch.nn.functional as F |
|
|
|
|
|
if device is None: |
|
|
device = next(self.parameters()).device |
|
|
|
|
|
|
|
|
all_texts = [query_text] + candidate_texts |
|
|
|
|
|
|
|
|
embeddings = self.encode_with_separator(all_texts, device=device, separator=separator) |
|
|
|
|
|
|
|
|
similarities = F.cosine_similarity(embeddings[0], embeddings[1:], dim=1) |
|
|
|
|
|
return similarities |
|
|
|
|
|
def _load_latent_attention_weights(self, model_path, use_safetensors=True): |
|
|
""" |
|
|
Automatically load latent attention weights from model files. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model (local directory or HuggingFace repo) |
|
|
use_safetensors: Whether to use safetensors format |
|
|
""" |
|
|
import os |
|
|
|
|
|
if os.path.isdir(model_path): |
|
|
|
|
|
pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") |
|
|
if os.path.exists(pytorch_model_path): |
|
|
print(f"Loading latent attention weights from {pytorch_model_path}") |
|
|
try: |
|
|
import torch |
|
|
state_dict = torch.load(pytorch_model_path, weights_only=True) |
|
|
latent_attn_weights = {k: v for k, v in state_dict.items() if k.startswith('latent_attn.')} |
|
|
|
|
|
if latent_attn_weights: |
|
|
missing_keys, unexpected_keys = self.latent_attn.load_state_dict( |
|
|
{k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()}, |
|
|
strict=False |
|
|
) |
|
|
if not missing_keys and not unexpected_keys: |
|
|
print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights") |
|
|
else: |
|
|
print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}") |
|
|
else: |
|
|
print("⚠️ No latent attention weights found in the model file") |
|
|
except Exception as e: |
|
|
print(f"❌ Error loading latent attention weights: {e}") |
|
|
else: |
|
|
|
|
|
if use_safetensors: |
|
|
print("Loading latent attention weights from HuggingFace safetensors...") |
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
safetensors_path = hf_hub_download(repo_id=model_path, filename="model.safetensors") |
|
|
|
|
|
|
|
|
safetensors_weights = load_file(safetensors_path) |
|
|
|
|
|
|
|
|
latent_attn_weights = {k: v for k, v in safetensors_weights.items() if k.startswith('latent_attn.')} |
|
|
|
|
|
if latent_attn_weights: |
|
|
print(f"Found {len(latent_attn_weights)} latent attention weights in safetensors") |
|
|
|
|
|
|
|
|
missing_keys, unexpected_keys = self.latent_attn.load_state_dict( |
|
|
{k.replace('latent_attn.', ''): v for k, v in latent_attn_weights.items()}, |
|
|
strict=False |
|
|
) |
|
|
|
|
|
if not missing_keys and not unexpected_keys: |
|
|
print(f"✅ Successfully loaded {len(latent_attn_weights)} latent attention weights from safetensors") |
|
|
else: |
|
|
print(f"⚠️ Partial loading: missing={missing_keys}, unexpected={unexpected_keys}") |
|
|
else: |
|
|
print("⚠️ No latent attention weights found in safetensors file") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error loading latent attention weights from safetensors: {e}") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
base_model_name_or_path, |
|
|
peft_model_name_or_path=None, |
|
|
merge_peft=False, |
|
|
enable_bidirectional=True, |
|
|
extra_model_name_or_path=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] |
|
|
encoder_args = { |
|
|
key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None |
|
|
} |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
tokenizer.padding_side = "left" |
|
|
|
|
|
config = AutoConfig.from_pretrained(base_model_name_or_path) |
|
|
config_class_name = config.__class__.__name__ |
|
|
|
|
|
model_class = cls._get_model_class( |
|
|
config_class_name, enable_bidirectional=enable_bidirectional |
|
|
) |
|
|
model = model_class.from_pretrained(base_model_name_or_path, **kwargs) |
|
|
|
|
|
if os.path.isdir(base_model_name_or_path) and os.path.exists( |
|
|
f"{base_model_name_or_path}/config.json" |
|
|
): |
|
|
with open(f"{base_model_name_or_path}/config.json", "r") as fIn: |
|
|
config_dict = json.load(fIn) |
|
|
config = PretrainedConfig.from_dict(config_dict) |
|
|
model.config._name_or_path = config._name_or_path |
|
|
|
|
|
|
|
|
if hasattr(model, "peft_config"): |
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
base_model_name_or_path, |
|
|
) |
|
|
model = model.merge_and_unload() |
|
|
|
|
|
if peft_model_name_or_path is not None: |
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
peft_model_name_or_path, |
|
|
) |
|
|
if merge_peft: |
|
|
model = model.merge_and_unload() |
|
|
if extra_model_name_or_path is not None: |
|
|
logger.info(f"Loading extra model from {extra_model_name_or_path}") |
|
|
if not merge_peft: |
|
|
model = model.merge_and_unload() |
|
|
if isinstance(extra_model_name_or_path, str): |
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
extra_model_name_or_path, |
|
|
) |
|
|
model = model.merge_and_unload() |
|
|
elif isinstance(extra_model_name_or_path, list): |
|
|
for extra_model in extra_model_name_or_path: |
|
|
model = PeftModel.from_pretrained( |
|
|
model, |
|
|
extra_model, |
|
|
) |
|
|
peft_model_name_or_path = extra_model |
|
|
model = model.merge_and_unload() |
|
|
else: |
|
|
raise ValueError( |
|
|
f"extra_model_name_or_path should be a string or a list of strings." |
|
|
) |
|
|
config = {} |
|
|
config_addr = ( |
|
|
peft_model_name_or_path |
|
|
if peft_model_name_or_path is not None |
|
|
else base_model_name_or_path |
|
|
) |
|
|
if os.path.exists(f"{config_addr}/llm2vec_config.json"): |
|
|
with open(f"{config_addr}/llm2vec_config.json", "r") as fIn: |
|
|
llm2vec_config = json.load(fIn) |
|
|
config.update(llm2vec_config) |
|
|
|
|
|
for key, value in encoder_args.items(): |
|
|
config[key] = value |
|
|
|
|
|
llm2vec_model = cls(model=model, tokenizer=tokenizer, **config) |
|
|
|
|
|
|
|
|
if (hasattr(llm2vec_model, 'latent_attn') and |
|
|
llm2vec_model.latent_attn is not None and |
|
|
llm2vec_model.pooling_mode == "latent_attention"): |
|
|
|
|
|
llm2vec_model._load_latent_attention_weights(base_model_name_or_path, kwargs.get('use_safetensors', True)) |
|
|
|
|
|
|
|
|
if 'torch_dtype' in kwargs and kwargs['torch_dtype'] is not None: |
|
|
llm2vec_model = llm2vec_model.to(kwargs['torch_dtype']) |
|
|
|
|
|
return llm2vec_model |