ming commited on
Commit
6e01ea3
·
1 Parent(s): 8aac05c

feat: Add text streaming support with SSE

Browse files

- Add summarize_text_stream() async generator method to OllamaService
- Implement /api/v1/summarize/stream endpoint with Server-Sent Events
- Add StreamChunk schema for streaming response documentation
- Comprehensive test coverage with TDD approach (11 new tests)
- Android-friendly SSE format for real-time text streaming
- Maintains backward compatibility with existing non-streaming endpoint
- Proper error handling with SSE error events
- Manual verification with curl confirms working implementation

Closes: Streaming text summarization feature request

app/api/v1/schemas.py CHANGED
@@ -42,6 +42,14 @@ class HealthResponse(BaseModel):
42
  ollama: Optional[str] = Field(None, description="Ollama service status")
43
 
44
 
 
 
 
 
 
 
 
 
45
  class ErrorResponse(BaseModel):
46
  """Error response schema."""
47
 
 
42
  ollama: Optional[str] = Field(None, description="Ollama service status")
43
 
44
 
45
+ class StreamChunk(BaseModel):
46
+ """Schema for streaming response chunks."""
47
+
48
+ content: str = Field(..., description="Content chunk from the stream")
49
+ done: bool = Field(..., description="Whether this is the final chunk")
50
+ tokens_used: Optional[int] = Field(None, description="Number of tokens used so far")
51
+
52
+
53
  class ErrorResponse(BaseModel):
54
  """Error response schema."""
55
 
app/api/v1/summarize.py CHANGED
@@ -1,7 +1,9 @@
1
  """
2
  Summarization endpoints.
3
  """
 
4
  from fastapi import APIRouter, HTTPException
 
5
  import httpx
6
  from app.api.v1.schemas import SummarizeRequest, SummarizeResponse
7
  from app.services.summarizer import ollama_service
@@ -33,3 +35,60 @@ async def summarize(payload: SummarizeRequest) -> SummarizeResponse:
33
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Summarization endpoints.
3
  """
4
+ import json
5
  from fastapi import APIRouter, HTTPException
6
+ from fastapi.responses import StreamingResponse
7
  import httpx
8
  from app.api.v1.schemas import SummarizeRequest, SummarizeResponse
9
  from app.services.summarizer import ollama_service
 
35
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
36
 
37
 
38
+ async def _stream_generator(payload: SummarizeRequest):
39
+ """Generator function for streaming SSE responses."""
40
+ try:
41
+ async for chunk in ollama_service.summarize_text_stream(
42
+ text=payload.text,
43
+ max_tokens=payload.max_tokens or 256,
44
+ prompt=payload.prompt or "Summarize the following text concisely:",
45
+ ):
46
+ # Format as SSE event
47
+ sse_data = json.dumps(chunk)
48
+ yield f"data: {sse_data}\n\n"
49
+
50
+ except httpx.TimeoutException as e:
51
+ # Send error event in SSE format
52
+ error_chunk = {
53
+ "content": "",
54
+ "done": True,
55
+ "error": "Request timeout. The text may be too long or complex. Try reducing the text length or max_tokens."
56
+ }
57
+ sse_data = json.dumps(error_chunk)
58
+ yield f"data: {sse_data}\n\n"
59
+ return # Don't raise exception in streaming context
60
+ except httpx.HTTPError as e:
61
+ # Send error event in SSE format
62
+ error_chunk = {
63
+ "content": "",
64
+ "done": True,
65
+ "error": f"Summarization failed: {str(e)}"
66
+ }
67
+ sse_data = json.dumps(error_chunk)
68
+ yield f"data: {sse_data}\n\n"
69
+ return # Don't raise exception in streaming context
70
+ except Exception as e:
71
+ # Send error event in SSE format
72
+ error_chunk = {
73
+ "content": "",
74
+ "done": True,
75
+ "error": f"Internal server error: {str(e)}"
76
+ }
77
+ sse_data = json.dumps(error_chunk)
78
+ yield f"data: {sse_data}\n\n"
79
+ return # Don't raise exception in streaming context
80
+
81
+
82
+ @router.post("/stream")
83
+ async def summarize_stream(payload: SummarizeRequest):
84
+ """Stream text summarization using Server-Sent Events (SSE)."""
85
+ return StreamingResponse(
86
+ _stream_generator(payload),
87
+ media_type="text/event-stream",
88
+ headers={
89
+ "Cache-Control": "no-cache",
90
+ "Connection": "keep-alive",
91
+ }
92
+ )
93
+
94
+
app/services/summarizer.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  Ollama service integration for text summarization.
3
  """
 
4
  import time
5
- from typing import Dict, Any
6
  from urllib.parse import urljoin
7
 
8
  import httpx
@@ -123,6 +124,101 @@ class OllamaService:
123
  # Present a consistent error type to callers
124
  raise httpx.HTTPError(f"Ollama API error: {e}") from e
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  async def check_health(self) -> bool:
127
  """
128
  Verify Ollama is reachable and (optionally) that the model exists.
 
1
  """
2
  Ollama service integration for text summarization.
3
  """
4
+ import json
5
  import time
6
+ from typing import Dict, Any, AsyncGenerator
7
  from urllib.parse import urljoin
8
 
9
  import httpx
 
124
  # Present a consistent error type to callers
125
  raise httpx.HTTPError(f"Ollama API error: {e}") from e
126
 
127
+ async def summarize_text_stream(
128
+ self,
129
+ text: str,
130
+ max_tokens: int = 100,
131
+ prompt: str = "Summarize concisely:",
132
+ ) -> AsyncGenerator[Dict[str, Any], None]:
133
+ """
134
+ Stream text summarization using Ollama.
135
+ Yields chunks as they arrive from Ollama.
136
+ Raises httpx.HTTPError (and subclasses) on failure.
137
+ """
138
+ start_time = time.time()
139
+
140
+ # Optimized timeout: base + 3s per extra 1000 chars (cap 90s)
141
+ text_length = len(text)
142
+ dynamic_timeout = min(self.timeout + max(0, (text_length - 1000) // 1000 * 3), 90)
143
+
144
+ # Preprocess text to reduce input size for faster processing
145
+ if text_length > 4000:
146
+ # Truncate very long texts and add note
147
+ text = text[:4000] + "\n\n[Text truncated for faster processing]"
148
+ text_length = len(text)
149
+ logger.info(f"Text truncated from {len(text)} to {text_length} chars for faster processing")
150
+
151
+ logger.info(f"Processing text of {text_length} chars with timeout {dynamic_timeout}s")
152
+
153
+ full_prompt = f"{prompt}\n\n{text}"
154
+
155
+ payload = {
156
+ "model": self.model,
157
+ "prompt": full_prompt,
158
+ "stream": True, # Enable streaming
159
+ "options": {
160
+ "num_predict": max_tokens,
161
+ "temperature": 0.1, # Lower temperature for faster, more focused output
162
+ "top_p": 0.9, # Nucleus sampling for efficiency
163
+ "top_k": 40, # Limit vocabulary for speed
164
+ "repeat_penalty": 1.1, # Prevent repetition
165
+ "num_ctx": 2048, # Limit context window for speed
166
+ },
167
+ }
168
+
169
+ generate_url = urljoin(self.base_url, "api/generate")
170
+ logger.info(f"POST {generate_url} (streaming)")
171
+
172
+ try:
173
+ async with httpx.AsyncClient(timeout=dynamic_timeout) as client:
174
+ async with client.stream("POST", generate_url, json=payload) as response:
175
+ response.raise_for_status()
176
+
177
+ async for line in response.aiter_lines():
178
+ line = line.strip()
179
+ if not line:
180
+ continue
181
+
182
+ try:
183
+ data = json.loads(line)
184
+ chunk = {
185
+ "content": data.get("response", ""),
186
+ "done": data.get("done", False),
187
+ "tokens_used": data.get("eval_count", 0),
188
+ }
189
+ yield chunk
190
+
191
+ # Break if this is the final chunk
192
+ if data.get("done", False):
193
+ break
194
+
195
+ except json.JSONDecodeError:
196
+ # Skip malformed JSON lines
197
+ logger.warning(f"Skipping malformed JSON line: {line[:100]}")
198
+ continue
199
+
200
+ except httpx.TimeoutException:
201
+ logger.error(
202
+ f"Timeout calling Ollama after {dynamic_timeout}s "
203
+ f"(chars={text_length}, url={generate_url})"
204
+ )
205
+ raise
206
+ except httpx.RequestError as e:
207
+ # Network / connection errors (DNS, refused, TLS, etc.)
208
+ logger.error(f"Request error calling Ollama at {generate_url}: {e}")
209
+ raise
210
+ except httpx.HTTPStatusError as e:
211
+ # Non-2xx responses
212
+ body = e.response.text if e.response is not None else ""
213
+ logger.error(
214
+ f"HTTP {e.response.status_code if e.response else '??'} from Ollama at {generate_url}: {body[:400]}"
215
+ )
216
+ raise
217
+ except Exception as e:
218
+ logger.error(f"Unexpected error calling Ollama at {generate_url}: {e}")
219
+ # Present a consistent error type to callers
220
+ raise httpx.HTTPError(f"Ollama API error: {e}") from e
221
+
222
  async def check_health(self) -> bool:
223
  """
224
  Verify Ollama is reachable and (optionally) that the model exists.
tests/test_api.py CHANGED
@@ -1,8 +1,9 @@
1
  """
2
  Integration tests for API endpoints.
3
  """
 
4
  import pytest
5
- from unittest.mock import patch
6
  from starlette.testclient import TestClient
7
  from app.main import app
8
 
@@ -96,4 +97,203 @@ def test_summarize_endpoint_large_text_handling():
96
  mock_client.assert_called_once()
97
  call_args = mock_client.call_args
98
  expected_timeout = 60 + (5000 - 1000) // 1000 * 5 # 80 seconds
99
- assert call_args[1]['timeout'] == expected_timeout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Integration tests for API endpoints.
3
  """
4
+ import json
5
  import pytest
6
+ from unittest.mock import patch, MagicMock
7
  from starlette.testclient import TestClient
8
  from app.main import app
9
 
 
97
  mock_client.assert_called_once()
98
  call_args = mock_client.call_args
99
  expected_timeout = 60 + (5000 - 1000) // 1000 * 5 # 80 seconds
100
+ assert call_args[1]['timeout'] == expected_timeout
101
+
102
+
103
+ # Tests for Streaming Endpoint
104
+ @pytest.mark.integration
105
+ def test_summarize_stream_endpoint_success(sample_text):
106
+ """Test successful streaming summarization via API endpoint."""
107
+ # Mock streaming response data
108
+ mock_stream_data = [
109
+ '{"response": "This", "done": false, "eval_count": 1}\n',
110
+ '{"response": " is", "done": false, "eval_count": 2}\n',
111
+ '{"response": " a", "done": false, "eval_count": 3}\n',
112
+ '{"response": " test", "done": true, "eval_count": 4}\n'
113
+ ]
114
+
115
+ class MockStreamResponse:
116
+ def __init__(self, data):
117
+ self.data = data
118
+
119
+ async def aiter_lines(self):
120
+ for line in self.data:
121
+ yield line
122
+
123
+ def raise_for_status(self):
124
+ pass
125
+
126
+ class MockStreamContextManager:
127
+ def __init__(self, response):
128
+ self.response = response
129
+
130
+ async def __aenter__(self):
131
+ return self.response
132
+
133
+ async def __aexit__(self, exc_type, exc, tb):
134
+ return False
135
+
136
+ class MockStreamClient:
137
+ async def __aenter__(self):
138
+ return self
139
+
140
+ async def __aexit__(self, exc_type, exc, tb):
141
+ return False
142
+
143
+ def stream(self, method, url, **kwargs):
144
+ return MockStreamContextManager(MockStreamResponse(mock_stream_data))
145
+
146
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
147
+ resp = client.post(
148
+ "/api/v1/summarize/stream",
149
+ json={"text": sample_text, "max_tokens": 128}
150
+ )
151
+ assert resp.status_code == 200
152
+ assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
153
+
154
+ # Parse SSE response
155
+ lines = resp.text.strip().split('\n')
156
+ data_lines = [line for line in lines if line.startswith('data: ')]
157
+
158
+ assert len(data_lines) == 4
159
+
160
+ # Parse first chunk
161
+ first_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
162
+ assert first_chunk["content"] == "This"
163
+ assert first_chunk["done"] is False
164
+ assert first_chunk["tokens_used"] == 1
165
+
166
+ # Parse last chunk
167
+ last_chunk = json.loads(data_lines[-1][6:]) # Remove 'data: ' prefix
168
+ assert last_chunk["content"] == " test"
169
+ assert last_chunk["done"] is True
170
+ assert last_chunk["tokens_used"] == 4
171
+
172
+
173
+ @pytest.mark.integration
174
+ def test_summarize_stream_endpoint_validation_error():
175
+ """Test validation error for empty text in streaming endpoint."""
176
+ resp = client.post(
177
+ "/api/v1/summarize/stream",
178
+ json={"text": ""}
179
+ )
180
+ assert resp.status_code == 422
181
+
182
+
183
+ @pytest.mark.integration
184
+ def test_summarize_stream_endpoint_timeout_error():
185
+ """Test that timeout errors in streaming return proper error."""
186
+ import httpx
187
+
188
+ class MockStreamClient:
189
+ async def __aenter__(self):
190
+ return self
191
+
192
+ async def __aexit__(self, exc_type, exc, tb):
193
+ return False
194
+
195
+ def stream(self, method, url, **kwargs):
196
+ raise httpx.TimeoutException("Timeout")
197
+
198
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
199
+ resp = client.post(
200
+ "/api/v1/summarize/stream",
201
+ json={"text": "Test text that will timeout"}
202
+ )
203
+ assert resp.status_code == 200 # SSE returns 200 even with errors
204
+ assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
205
+
206
+ # Parse SSE response
207
+ lines = resp.text.strip().split('\n')
208
+ data_lines = [line for line in lines if line.startswith('data: ')]
209
+
210
+ assert len(data_lines) == 1
211
+ error_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
212
+ assert error_chunk["done"] is True
213
+ assert "timeout" in error_chunk["error"].lower()
214
+
215
+
216
+ @pytest.mark.integration
217
+ def test_summarize_stream_endpoint_http_error():
218
+ """Test that HTTP errors in streaming return proper error."""
219
+ import httpx
220
+
221
+ http_error = httpx.HTTPStatusError("Bad Request", request=MagicMock(), response=MagicMock())
222
+
223
+ class MockStreamClient:
224
+ async def __aenter__(self):
225
+ return self
226
+
227
+ async def __aexit__(self, exc_type, exc, tb):
228
+ return False
229
+
230
+ def stream(self, method, url, **kwargs):
231
+ raise http_error
232
+
233
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
234
+ resp = client.post(
235
+ "/api/v1/summarize/stream",
236
+ json={"text": "Test text"}
237
+ )
238
+ assert resp.status_code == 200 # SSE returns 200 even with errors
239
+ assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
240
+
241
+ # Parse SSE response
242
+ lines = resp.text.strip().split('\n')
243
+ data_lines = [line for line in lines if line.startswith('data: ')]
244
+
245
+ assert len(data_lines) == 1
246
+ error_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
247
+ assert error_chunk["done"] is True
248
+ assert "Summarization failed" in error_chunk["error"]
249
+
250
+
251
+ @pytest.mark.integration
252
+ def test_summarize_stream_endpoint_sse_format():
253
+ """Test that streaming endpoint returns proper SSE format."""
254
+ mock_stream_data = ['{"response": "Summary", "done": true, "eval_count": 1}\n']
255
+
256
+ class MockStreamResponse:
257
+ def __init__(self, data):
258
+ self.data = data
259
+
260
+ async def aiter_lines(self):
261
+ for line in self.data:
262
+ yield line
263
+
264
+ def raise_for_status(self):
265
+ pass
266
+
267
+ class MockStreamContextManager:
268
+ def __init__(self, response):
269
+ self.response = response
270
+
271
+ async def __aenter__(self):
272
+ return self.response
273
+
274
+ async def __aexit__(self, exc_type, exc, tb):
275
+ return False
276
+
277
+ class MockStreamClient:
278
+ async def __aenter__(self):
279
+ return self
280
+
281
+ async def __aexit__(self, exc_type, exc, tb):
282
+ return False
283
+
284
+ def stream(self, method, url, **kwargs):
285
+ return MockStreamContextManager(MockStreamResponse(mock_stream_data))
286
+
287
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
288
+ resp = client.post(
289
+ "/api/v1/summarize/stream",
290
+ json={"text": "Test text"}
291
+ )
292
+ assert resp.status_code == 200
293
+ assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
294
+ assert resp.headers["cache-control"] == "no-cache"
295
+ assert resp.headers["connection"] == "keep-alive"
296
+
297
+ # Check SSE format
298
+ lines = resp.text.strip().split('\n')
299
+ assert any(line.startswith('data: ') for line in lines)
tests/test_services.py CHANGED
@@ -224,3 +224,258 @@ class TestOllamaService:
224
  error_message = str(exc_info.value)
225
  assert f"timeout after {expected_timeout}s" in error_message
226
  assert "Text may be too long or complex" in error_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  error_message = str(exc_info.value)
225
  assert f"timeout after {expected_timeout}s" in error_message
226
  assert "Text may be too long or complex" in error_message
227
+
228
+ # Tests for Streaming Functionality
229
+ @pytest.mark.asyncio
230
+ async def test_summarize_text_stream_success(self, ollama_service):
231
+ """Test successful text streaming."""
232
+ # Mock streaming response data
233
+ mock_stream_data = [
234
+ '{"response": "This", "done": false, "eval_count": 1}\n',
235
+ '{"response": " is", "done": false, "eval_count": 2}\n',
236
+ '{"response": " a", "done": false, "eval_count": 3}\n',
237
+ '{"response": " test", "done": true, "eval_count": 4}\n'
238
+ ]
239
+
240
+ class MockStreamResponse:
241
+ def __init__(self, data):
242
+ self.data = data
243
+ self._index = 0
244
+
245
+ async def aiter_lines(self):
246
+ for line in self.data:
247
+ yield line
248
+
249
+ def raise_for_status(self):
250
+ # Mock successful response
251
+ pass
252
+
253
+ mock_response = MockStreamResponse(mock_stream_data)
254
+
255
+ class MockStreamContextManager:
256
+ def __init__(self, response):
257
+ self.response = response
258
+
259
+ async def __aenter__(self):
260
+ return self.response
261
+
262
+ async def __aexit__(self, exc_type, exc, tb):
263
+ return False
264
+
265
+ class MockStreamClient:
266
+ async def __aenter__(self):
267
+ return self
268
+
269
+ async def __aexit__(self, exc_type, exc, tb):
270
+ return False
271
+
272
+ def stream(self, method, url, **kwargs):
273
+ # Return an async context manager
274
+ return MockStreamContextManager(mock_response)
275
+
276
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
277
+ chunks = []
278
+ async for chunk in ollama_service.summarize_text_stream("Test text"):
279
+ chunks.append(chunk)
280
+
281
+ assert len(chunks) == 4
282
+ assert chunks[0]["content"] == "This"
283
+ assert chunks[0]["done"] is False
284
+ assert chunks[0]["tokens_used"] == 1
285
+ assert chunks[-1]["content"] == " test"
286
+ assert chunks[-1]["done"] is True
287
+ assert chunks[-1]["tokens_used"] == 4
288
+
289
+ @pytest.mark.asyncio
290
+ async def test_summarize_text_stream_with_custom_params(self, ollama_service):
291
+ """Test streaming with custom parameters."""
292
+ mock_stream_data = ['{"response": "Summary", "done": true, "eval_count": 1}\n']
293
+
294
+ class MockStreamResponse:
295
+ def __init__(self, data):
296
+ self.data = data
297
+
298
+ async def aiter_lines(self):
299
+ for line in self.data:
300
+ yield line
301
+
302
+ def raise_for_status(self):
303
+ # Mock successful response
304
+ pass
305
+
306
+ mock_response = MockStreamResponse(mock_stream_data)
307
+ captured_payload = {}
308
+
309
+ class MockStreamContextManager:
310
+ def __init__(self, response):
311
+ self.response = response
312
+
313
+ async def __aenter__(self):
314
+ return self.response
315
+
316
+ async def __aexit__(self, exc_type, exc, tb):
317
+ return False
318
+
319
+ class MockStreamClient:
320
+ async def __aenter__(self):
321
+ return self
322
+
323
+ async def __aexit__(self, exc_type, exc, tb):
324
+ return False
325
+
326
+ def stream(self, method, url, **kwargs):
327
+ captured_payload.update(kwargs.get('json', {}))
328
+ return MockStreamContextManager(mock_response)
329
+
330
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
331
+ chunks = []
332
+ async for chunk in ollama_service.summarize_text_stream(
333
+ "Test text",
334
+ max_tokens=512,
335
+ prompt="Custom prompt"
336
+ ):
337
+ chunks.append(chunk)
338
+
339
+ # Verify captured payload
340
+ assert captured_payload["stream"] is True
341
+ assert captured_payload["options"]["num_predict"] == 512
342
+ assert "Custom prompt" in captured_payload["prompt"]
343
+
344
+ @pytest.mark.asyncio
345
+ async def test_summarize_text_stream_timeout(self, ollama_service):
346
+ """Test streaming timeout handling."""
347
+ class MockStreamClient:
348
+ async def __aenter__(self):
349
+ return self
350
+
351
+ async def __aexit__(self, exc_type, exc, tb):
352
+ return False
353
+
354
+ def stream(self, method, url, **kwargs):
355
+ raise httpx.TimeoutException("Timeout")
356
+
357
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
358
+ with pytest.raises(httpx.TimeoutException):
359
+ chunks = []
360
+ async for chunk in ollama_service.summarize_text_stream("Test text"):
361
+ chunks.append(chunk)
362
+
363
+ @pytest.mark.asyncio
364
+ async def test_summarize_text_stream_http_error(self, ollama_service):
365
+ """Test streaming HTTP error handling."""
366
+ http_error = httpx.HTTPStatusError("Bad Request", request=MagicMock(), response=MagicMock())
367
+
368
+ class MockStreamClient:
369
+ async def __aenter__(self):
370
+ return self
371
+
372
+ async def __aexit__(self, exc_type, exc, tb):
373
+ return False
374
+
375
+ def stream(self, method, url, **kwargs):
376
+ raise http_error
377
+
378
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
379
+ with pytest.raises(httpx.HTTPStatusError):
380
+ chunks = []
381
+ async for chunk in ollama_service.summarize_text_stream("Test text"):
382
+ chunks.append(chunk)
383
+
384
+ @pytest.mark.asyncio
385
+ async def test_summarize_text_stream_empty_response(self, ollama_service):
386
+ """Test streaming with empty response."""
387
+ mock_stream_data = []
388
+
389
+ class MockStreamResponse:
390
+ def __init__(self, data):
391
+ self.data = data
392
+
393
+ async def aiter_lines(self):
394
+ for line in self.data:
395
+ yield line
396
+
397
+ def raise_for_status(self):
398
+ # Mock successful response
399
+ pass
400
+
401
+ mock_response = MockStreamResponse(mock_stream_data)
402
+
403
+ class MockStreamContextManager:
404
+ def __init__(self, response):
405
+ self.response = response
406
+
407
+ async def __aenter__(self):
408
+ return self.response
409
+
410
+ async def __aexit__(self, exc_type, exc, tb):
411
+ return False
412
+
413
+ class MockStreamClient:
414
+ async def __aenter__(self):
415
+ return self
416
+
417
+ async def __aexit__(self, exc_type, exc, tb):
418
+ return False
419
+
420
+ def stream(self, method, url, **kwargs):
421
+ return MockStreamContextManager(mock_response)
422
+
423
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
424
+ chunks = []
425
+ async for chunk in ollama_service.summarize_text_stream("Test text"):
426
+ chunks.append(chunk)
427
+
428
+ assert len(chunks) == 0
429
+
430
+ @pytest.mark.asyncio
431
+ async def test_summarize_text_stream_malformed_json(self, ollama_service):
432
+ """Test streaming with malformed JSON response."""
433
+ mock_stream_data = [
434
+ '{"response": "Valid", "done": false, "eval_count": 1}\n',
435
+ 'invalid json line\n',
436
+ '{"response": "End", "done": true, "eval_count": 2}\n'
437
+ ]
438
+
439
+ class MockStreamResponse:
440
+ def __init__(self, data):
441
+ self.data = data
442
+
443
+ async def aiter_lines(self):
444
+ for line in self.data:
445
+ yield line
446
+
447
+ def raise_for_status(self):
448
+ # Mock successful response
449
+ pass
450
+
451
+ mock_response = MockStreamResponse(mock_stream_data)
452
+
453
+ class MockStreamContextManager:
454
+ def __init__(self, response):
455
+ self.response = response
456
+
457
+ async def __aenter__(self):
458
+ return self.response
459
+
460
+ async def __aexit__(self, exc_type, exc, tb):
461
+ return False
462
+
463
+ class MockStreamClient:
464
+ async def __aenter__(self):
465
+ return self
466
+
467
+ async def __aexit__(self, exc_type, exc, tb):
468
+ return False
469
+
470
+ def stream(self, method, url, **kwargs):
471
+ return MockStreamContextManager(mock_response)
472
+
473
+ with patch('httpx.AsyncClient', return_value=MockStreamClient()):
474
+ chunks = []
475
+ async for chunk in ollama_service.summarize_text_stream("Test text"):
476
+ chunks.append(chunk)
477
+
478
+ # Should skip malformed JSON and continue with valid chunks
479
+ assert len(chunks) == 2
480
+ assert chunks[0]["content"] == "Valid"
481
+ assert chunks[1]["content"] == "End"