Fix embed_mask for texts without separator (use entire text, not empty)
Browse files- modeling_llm2vec4cxr.py +6 -1
modeling_llm2vec4cxr.py
CHANGED
|
@@ -88,11 +88,13 @@ class LLM2Vec4CXRModel(PreTrainedModel):
|
|
| 88 |
def _build_separator_inputs(self, texts, max_length: int, separator: str):
|
| 89 |
tok = self._get_tokenizer()
|
| 90 |
# Split into [instruction | text]; we embed only the trailing "text" part.
|
|
|
|
| 91 |
parts_after_sep = []
|
| 92 |
original = []
|
| 93 |
for t in texts:
|
| 94 |
parts = t.split(separator)
|
| 95 |
-
|
|
|
|
| 96 |
original.append("".join(parts))
|
| 97 |
|
| 98 |
tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
@@ -103,6 +105,9 @@ class LLM2Vec4CXRModel(PreTrainedModel):
|
|
| 103 |
m = torch.zeros_like(tokenized["attention_mask"][i])
|
| 104 |
if len(sub["input_ids"][0]) > 0:
|
| 105 |
m[-len(sub["input_ids"][0]):] = 1
|
|
|
|
|
|
|
|
|
|
| 106 |
embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
|
| 107 |
|
| 108 |
tokenized["embed_mask"] = embed_mask
|
|
|
|
| 88 |
def _build_separator_inputs(self, texts, max_length: int, separator: str):
|
| 89 |
tok = self._get_tokenizer()
|
| 90 |
# Split into [instruction | text]; we embed only the trailing "text" part.
|
| 91 |
+
# If no separator, embed the entire text.
|
| 92 |
parts_after_sep = []
|
| 93 |
original = []
|
| 94 |
for t in texts:
|
| 95 |
parts = t.split(separator)
|
| 96 |
+
# If no separator found, use the entire text (not empty string)
|
| 97 |
+
parts_after_sep.append(parts[1] if len(parts) > 1 else parts[0])
|
| 98 |
original.append("".join(parts))
|
| 99 |
|
| 100 |
tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
|
|
|
|
| 105 |
m = torch.zeros_like(tokenized["attention_mask"][i])
|
| 106 |
if len(sub["input_ids"][0]) > 0:
|
| 107 |
m[-len(sub["input_ids"][0]):] = 1
|
| 108 |
+
else:
|
| 109 |
+
# If tokenization resulted in 0 tokens, use attention_mask (embed everything)
|
| 110 |
+
m = tokenized["attention_mask"][i].clone()
|
| 111 |
embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
|
| 112 |
|
| 113 |
tokenized["embed_mask"] = embed_mask
|