Spaces:
Running
Running
ming
commited on
Commit
·
87d9e3a
1
Parent(s):
6e48ad3
Fix TextStreamer batch size error in V2 API
Browse files- Add batch size validation to ensure TextIteratorStreamer receives batch size 1
- Handle 1D tensors by adding batch dimension
- Handle oversized batches by taking first sample only
- Maintains compatibility with all model types (T5, BART, etc.)
- Fixes 'TextStreamer only supports batch size 1' error
app/services/hf_streaming_summarizer.py
CHANGED
|
@@ -213,6 +213,17 @@ class HFStreamingSummarizer:
|
|
| 213 |
|
| 214 |
inputs = inputs.to(self.model.device)
|
| 215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
# Create streamer for token-by-token output
|
| 217 |
streamer = TextIteratorStreamer(
|
| 218 |
self.tokenizer,
|
|
|
|
| 213 |
|
| 214 |
inputs = inputs.to(self.model.device)
|
| 215 |
|
| 216 |
+
# CRITICAL FIX: Ensure batch size is 1 for TextIteratorStreamer
|
| 217 |
+
# The streamer only works with batch size 1, so we need to ensure
|
| 218 |
+
# that all input tensors have batch dimension of 1
|
| 219 |
+
for key, tensor in inputs.items():
|
| 220 |
+
if tensor.dim() > 1 and tensor.size(0) > 1:
|
| 221 |
+
# If batch size > 1, take only the first sample
|
| 222 |
+
inputs[key] = tensor[:1]
|
| 223 |
+
elif tensor.dim() == 1:
|
| 224 |
+
# If tensor is 1D, add batch dimension
|
| 225 |
+
inputs[key] = tensor.unsqueeze(0)
|
| 226 |
+
|
| 227 |
# Create streamer for token-by-token output
|
| 228 |
streamer = TextIteratorStreamer(
|
| 229 |
self.tokenizer,
|