ming commited on
Commit
3570bfd
·
1 Parent(s): 87d9e3a

Fix HF streaming crash: enforce batch size = 1 for TextIteratorStreamer

Browse files

- Normalize tokenizer outputs to dict format (handles tensor returns from apply_chat_template)
- Enforce batch size == 1 for all input tensors (add batch dim if 1D, trim if > 1)
- Add num_return_sequences=1 to gen_kwargs for streamer safety
- Add regression test to verify batch size enforcement

Fixes crash: 'TextStreamer only supports batch size 1'

app/services/hf_streaming_summarizer.py CHANGED
@@ -172,57 +172,40 @@ class HFStreamingSummarizer:
172
  temperature = temperature or settings.hf_temperature
173
  top_p = top_p or settings.hf_top_p
174
 
175
- # Check if model is t5 (doesn't use chat templates)
176
  if "t5" in settings.hf_model_id.lower():
177
- # t5 models use simple prompt format for summarization
178
  full_prompt = f"summarize: {text}"
179
- inputs = self.tokenizer(
180
- full_prompt,
181
- return_tensors="pt",
182
- max_length=512,
183
- truncation=True
184
- )
185
  elif "bart" in settings.hf_model_id.lower():
186
- # BART models (including DistilBART) expect direct text input
187
- # No prefixes or chat templates needed
188
- inputs = self.tokenizer(
189
- text,
190
- return_tensors="pt",
191
- max_length=1024,
192
- truncation=True
193
- )
194
  else:
195
- # Other models use chat template
196
  messages = [
197
  {"role": "system", "content": prompt},
198
- {"role": "user", "content": text}
199
  ]
200
-
201
- # Apply chat template if available, otherwise use simple prompt
202
  if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
203
- inputs = self.tokenizer.apply_chat_template(
204
- messages,
205
- tokenize=True,
206
- add_generation_prompt=True,
207
- return_tensors="pt"
208
  )
209
  else:
210
- # Fallback to simple prompt format
211
  full_prompt = f"{prompt}\n\n{text}"
212
- inputs = self.tokenizer(full_prompt, return_tensors="pt")
213
-
214
- inputs = inputs.to(self.model.device)
215
-
216
- # CRITICAL FIX: Ensure batch size is 1 for TextIteratorStreamer
217
- # The streamer only works with batch size 1, so we need to ensure
218
- # that all input tensors have batch dimension of 1
219
- for key, tensor in inputs.items():
220
- if tensor.dim() > 1 and tensor.size(0) > 1:
221
- # If batch size > 1, take only the first sample
222
- inputs[key] = tensor[:1]
223
- elif tensor.dim() == 1:
224
- # If tensor is 1D, add batch dimension
225
- inputs[key] = tensor.unsqueeze(0)
 
 
 
226
 
227
  # Create streamer for token-by-token output
228
  streamer = TextIteratorStreamer(
@@ -241,6 +224,8 @@ class HFStreamingSummarizer:
241
  "top_p": top_p,
242
  "pad_token_id": self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
243
  }
 
 
244
 
245
  # Run generation in background thread
246
  generation_thread = threading.Thread(
 
172
  temperature = temperature or settings.hf_temperature
173
  top_p = top_p or settings.hf_top_p
174
 
175
+ # --- Build tokenized inputs robustly ---
176
  if "t5" in settings.hf_model_id.lower():
 
177
  full_prompt = f"summarize: {text}"
178
+ inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
 
179
  elif "bart" in settings.hf_model_id.lower():
180
+ inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
 
181
  else:
 
182
  messages = [
183
  {"role": "system", "content": prompt},
184
+ {"role": "user", "content": text},
185
  ]
 
 
186
  if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
187
+ inputs_raw = self.tokenizer.apply_chat_template(
188
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
 
 
 
189
  )
190
  else:
 
191
  full_prompt = f"{prompt}\n\n{text}"
192
+ inputs_raw = self.tokenizer(full_prompt, return_tensors="pt")
193
+
194
+ # Normalize to dict regardless of tokenizer return type
195
+ if isinstance(inputs_raw, dict):
196
+ inputs = inputs_raw
197
+ else:
198
+ inputs = {"input_ids": inputs_raw}
199
+
200
+ # Move to model device
201
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
202
+
203
+ # Enforce batch size == 1 for streamer safety
204
+ for k, v in list(inputs.items()):
205
+ if v.dim() == 1:
206
+ inputs[k] = v.unsqueeze(0) # [seq] -> [1, seq]
207
+ elif v.dim() >= 2 and v.size(0) > 1:
208
+ inputs[k] = v[:1] # [B, ...] -> [1, ...]
209
 
210
  # Create streamer for token-by-token output
211
  streamer = TextIteratorStreamer(
 
224
  "top_p": top_p,
225
  "pad_token_id": self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
226
  }
227
+ # Streamer only supports a single sequence
228
+ gen_kwargs["num_return_sequences"] = 1
229
 
230
  # Run generation in background thread
231
  generation_thread = threading.Thread(
tests/test_hf_streaming.py CHANGED
@@ -119,6 +119,30 @@ class TestHFStreamingSummarizer:
119
  # Expected when torch is not available
120
  pass
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  class TestHFStreamingServiceIntegration:
124
  """Test the global HF streaming service instance."""
 
119
  # Expected when torch is not available
120
  pass
121
 
122
+ @pytest.mark.asyncio
123
+ async def test_streaming_single_batch(self):
124
+ """Test that streaming enforces batch size = 1 and completes successfully."""
125
+ service = HFStreamingSummarizer()
126
+
127
+ # Skip if model not initialized (transformers not available)
128
+ if not service.model or not service.tokenizer:
129
+ pytest.skip("Transformers not available")
130
+
131
+ chunks = []
132
+ async for chunk in service.summarize_text_stream(
133
+ text="This is a short test article about New Zealand tech news.",
134
+ max_new_tokens=32,
135
+ temperature=0.7,
136
+ top_p=0.9,
137
+ prompt="Summarize:"
138
+ ):
139
+ chunks.append(chunk)
140
+
141
+ # Should complete without ValueError and have a final done=True
142
+ assert len(chunks) > 0
143
+ assert any(c.get("done") for c in chunks)
144
+ assert all("error" not in c or c.get("error") is None for c in chunks if not c.get("done"))
145
+
146
 
147
  class TestHFStreamingServiceIntegration:
148
  """Test the global HF streaming service instance."""