lukeingawesome commited on
Commit
50d2262
·
verified ·
1 Parent(s): 6f83848

Fix embed_mask for texts without separator (use entire text, not empty)

Browse files
Files changed (1) hide show
  1. modeling_llm2vec4cxr.py +6 -1
modeling_llm2vec4cxr.py CHANGED
@@ -88,11 +88,13 @@ class LLM2Vec4CXRModel(PreTrainedModel):
88
  def _build_separator_inputs(self, texts, max_length: int, separator: str):
89
  tok = self._get_tokenizer()
90
  # Split into [instruction | text]; we embed only the trailing "text" part.
 
91
  parts_after_sep = []
92
  original = []
93
  for t in texts:
94
  parts = t.split(separator)
95
- parts_after_sep.append(parts[1] if len(parts) > 1 else "")
 
96
  original.append("".join(parts))
97
 
98
  tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
@@ -103,6 +105,9 @@ class LLM2Vec4CXRModel(PreTrainedModel):
103
  m = torch.zeros_like(tokenized["attention_mask"][i])
104
  if len(sub["input_ids"][0]) > 0:
105
  m[-len(sub["input_ids"][0]):] = 1
 
 
 
106
  embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
107
 
108
  tokenized["embed_mask"] = embed_mask
 
88
  def _build_separator_inputs(self, texts, max_length: int, separator: str):
89
  tok = self._get_tokenizer()
90
  # Split into [instruction | text]; we embed only the trailing "text" part.
91
+ # If no separator, embed the entire text.
92
  parts_after_sep = []
93
  original = []
94
  for t in texts:
95
  parts = t.split(separator)
96
+ # If no separator found, use the entire text (not empty string)
97
+ parts_after_sep.append(parts[1] if len(parts) > 1 else parts[0])
98
  original.append("".join(parts))
99
 
100
  tokenized = tok(original, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
 
105
  m = torch.zeros_like(tokenized["attention_mask"][i])
106
  if len(sub["input_ids"][0]) > 0:
107
  m[-len(sub["input_ids"][0]):] = 1
108
+ else:
109
+ # If tokenization resulted in 0 tokens, use attention_mask (embed everything)
110
+ m = tokenized["attention_mask"][i].clone()
111
  embed_mask = m.unsqueeze(0) if embed_mask is None else torch.cat([embed_mask, m.unsqueeze(0)], dim=0)
112
 
113
  tokenized["embed_mask"] = embed_mask