ming commited on
Commit
87d9e3a
·
1 Parent(s): 6e48ad3

Fix TextStreamer batch size error in V2 API

Browse files

- Add batch size validation to ensure TextIteratorStreamer receives batch size 1
- Handle 1D tensors by adding batch dimension
- Handle oversized batches by taking first sample only
- Maintains compatibility with all model types (T5, BART, etc.)
- Fixes 'TextStreamer only supports batch size 1' error

app/services/hf_streaming_summarizer.py CHANGED
@@ -213,6 +213,17 @@ class HFStreamingSummarizer:
213
 
214
  inputs = inputs.to(self.model.device)
215
 
 
 
 
 
 
 
 
 
 
 
 
216
  # Create streamer for token-by-token output
217
  streamer = TextIteratorStreamer(
218
  self.tokenizer,
 
213
 
214
  inputs = inputs.to(self.model.device)
215
 
216
+ # CRITICAL FIX: Ensure batch size is 1 for TextIteratorStreamer
217
+ # The streamer only works with batch size 1, so we need to ensure
218
+ # that all input tensors have batch dimension of 1
219
+ for key, tensor in inputs.items():
220
+ if tensor.dim() > 1 and tensor.size(0) > 1:
221
+ # If batch size > 1, take only the first sample
222
+ inputs[key] = tensor[:1]
223
+ elif tensor.dim() == 1:
224
+ # If tensor is 1D, add batch dimension
225
+ inputs[key] = tensor.unsqueeze(0)
226
+
227
  # Create streamer for token-by-token output
228
  streamer = TextIteratorStreamer(
229
  self.tokenizer,