Spaces:
Running
Running
ming
commited on
Commit
·
6e01ea3
1
Parent(s):
8aac05c
feat: Add text streaming support with SSE
Browse files- Add summarize_text_stream() async generator method to OllamaService
- Implement /api/v1/summarize/stream endpoint with Server-Sent Events
- Add StreamChunk schema for streaming response documentation
- Comprehensive test coverage with TDD approach (11 new tests)
- Android-friendly SSE format for real-time text streaming
- Maintains backward compatibility with existing non-streaming endpoint
- Proper error handling with SSE error events
- Manual verification with curl confirms working implementation
Closes: Streaming text summarization feature request
- app/api/v1/schemas.py +8 -0
- app/api/v1/summarize.py +59 -0
- app/services/summarizer.py +97 -1
- tests/test_api.py +202 -2
- tests/test_services.py +255 -0
app/api/v1/schemas.py
CHANGED
|
@@ -42,6 +42,14 @@ class HealthResponse(BaseModel):
|
|
| 42 |
ollama: Optional[str] = Field(None, description="Ollama service status")
|
| 43 |
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
class ErrorResponse(BaseModel):
|
| 46 |
"""Error response schema."""
|
| 47 |
|
|
|
|
| 42 |
ollama: Optional[str] = Field(None, description="Ollama service status")
|
| 43 |
|
| 44 |
|
| 45 |
+
class StreamChunk(BaseModel):
|
| 46 |
+
"""Schema for streaming response chunks."""
|
| 47 |
+
|
| 48 |
+
content: str = Field(..., description="Content chunk from the stream")
|
| 49 |
+
done: bool = Field(..., description="Whether this is the final chunk")
|
| 50 |
+
tokens_used: Optional[int] = Field(None, description="Number of tokens used so far")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
class ErrorResponse(BaseModel):
|
| 54 |
"""Error response schema."""
|
| 55 |
|
app/api/v1/summarize.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
Summarization endpoints.
|
| 3 |
"""
|
|
|
|
| 4 |
from fastapi import APIRouter, HTTPException
|
|
|
|
| 5 |
import httpx
|
| 6 |
from app.api.v1.schemas import SummarizeRequest, SummarizeResponse
|
| 7 |
from app.services.summarizer import ollama_service
|
|
@@ -33,3 +35,60 @@ async def summarize(payload: SummarizeRequest) -> SummarizeResponse:
|
|
| 33 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Summarization endpoints.
|
| 3 |
"""
|
| 4 |
+
import json
|
| 5 |
from fastapi import APIRouter, HTTPException
|
| 6 |
+
from fastapi.responses import StreamingResponse
|
| 7 |
import httpx
|
| 8 |
from app.api.v1.schemas import SummarizeRequest, SummarizeResponse
|
| 9 |
from app.services.summarizer import ollama_service
|
|
|
|
| 35 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 36 |
|
| 37 |
|
| 38 |
+
async def _stream_generator(payload: SummarizeRequest):
|
| 39 |
+
"""Generator function for streaming SSE responses."""
|
| 40 |
+
try:
|
| 41 |
+
async for chunk in ollama_service.summarize_text_stream(
|
| 42 |
+
text=payload.text,
|
| 43 |
+
max_tokens=payload.max_tokens or 256,
|
| 44 |
+
prompt=payload.prompt or "Summarize the following text concisely:",
|
| 45 |
+
):
|
| 46 |
+
# Format as SSE event
|
| 47 |
+
sse_data = json.dumps(chunk)
|
| 48 |
+
yield f"data: {sse_data}\n\n"
|
| 49 |
+
|
| 50 |
+
except httpx.TimeoutException as e:
|
| 51 |
+
# Send error event in SSE format
|
| 52 |
+
error_chunk = {
|
| 53 |
+
"content": "",
|
| 54 |
+
"done": True,
|
| 55 |
+
"error": "Request timeout. The text may be too long or complex. Try reducing the text length or max_tokens."
|
| 56 |
+
}
|
| 57 |
+
sse_data = json.dumps(error_chunk)
|
| 58 |
+
yield f"data: {sse_data}\n\n"
|
| 59 |
+
return # Don't raise exception in streaming context
|
| 60 |
+
except httpx.HTTPError as e:
|
| 61 |
+
# Send error event in SSE format
|
| 62 |
+
error_chunk = {
|
| 63 |
+
"content": "",
|
| 64 |
+
"done": True,
|
| 65 |
+
"error": f"Summarization failed: {str(e)}"
|
| 66 |
+
}
|
| 67 |
+
sse_data = json.dumps(error_chunk)
|
| 68 |
+
yield f"data: {sse_data}\n\n"
|
| 69 |
+
return # Don't raise exception in streaming context
|
| 70 |
+
except Exception as e:
|
| 71 |
+
# Send error event in SSE format
|
| 72 |
+
error_chunk = {
|
| 73 |
+
"content": "",
|
| 74 |
+
"done": True,
|
| 75 |
+
"error": f"Internal server error: {str(e)}"
|
| 76 |
+
}
|
| 77 |
+
sse_data = json.dumps(error_chunk)
|
| 78 |
+
yield f"data: {sse_data}\n\n"
|
| 79 |
+
return # Don't raise exception in streaming context
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
@router.post("/stream")
|
| 83 |
+
async def summarize_stream(payload: SummarizeRequest):
|
| 84 |
+
"""Stream text summarization using Server-Sent Events (SSE)."""
|
| 85 |
+
return StreamingResponse(
|
| 86 |
+
_stream_generator(payload),
|
| 87 |
+
media_type="text/event-stream",
|
| 88 |
+
headers={
|
| 89 |
+
"Cache-Control": "no-cache",
|
| 90 |
+
"Connection": "keep-alive",
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
app/services/summarizer.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
Ollama service integration for text summarization.
|
| 3 |
"""
|
|
|
|
| 4 |
import time
|
| 5 |
-
from typing import Dict, Any
|
| 6 |
from urllib.parse import urljoin
|
| 7 |
|
| 8 |
import httpx
|
|
@@ -123,6 +124,101 @@ class OllamaService:
|
|
| 123 |
# Present a consistent error type to callers
|
| 124 |
raise httpx.HTTPError(f"Ollama API error: {e}") from e
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
async def check_health(self) -> bool:
|
| 127 |
"""
|
| 128 |
Verify Ollama is reachable and (optionally) that the model exists.
|
|
|
|
| 1 |
"""
|
| 2 |
Ollama service integration for text summarization.
|
| 3 |
"""
|
| 4 |
+
import json
|
| 5 |
import time
|
| 6 |
+
from typing import Dict, Any, AsyncGenerator
|
| 7 |
from urllib.parse import urljoin
|
| 8 |
|
| 9 |
import httpx
|
|
|
|
| 124 |
# Present a consistent error type to callers
|
| 125 |
raise httpx.HTTPError(f"Ollama API error: {e}") from e
|
| 126 |
|
| 127 |
+
async def summarize_text_stream(
|
| 128 |
+
self,
|
| 129 |
+
text: str,
|
| 130 |
+
max_tokens: int = 100,
|
| 131 |
+
prompt: str = "Summarize concisely:",
|
| 132 |
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
| 133 |
+
"""
|
| 134 |
+
Stream text summarization using Ollama.
|
| 135 |
+
Yields chunks as they arrive from Ollama.
|
| 136 |
+
Raises httpx.HTTPError (and subclasses) on failure.
|
| 137 |
+
"""
|
| 138 |
+
start_time = time.time()
|
| 139 |
+
|
| 140 |
+
# Optimized timeout: base + 3s per extra 1000 chars (cap 90s)
|
| 141 |
+
text_length = len(text)
|
| 142 |
+
dynamic_timeout = min(self.timeout + max(0, (text_length - 1000) // 1000 * 3), 90)
|
| 143 |
+
|
| 144 |
+
# Preprocess text to reduce input size for faster processing
|
| 145 |
+
if text_length > 4000:
|
| 146 |
+
# Truncate very long texts and add note
|
| 147 |
+
text = text[:4000] + "\n\n[Text truncated for faster processing]"
|
| 148 |
+
text_length = len(text)
|
| 149 |
+
logger.info(f"Text truncated from {len(text)} to {text_length} chars for faster processing")
|
| 150 |
+
|
| 151 |
+
logger.info(f"Processing text of {text_length} chars with timeout {dynamic_timeout}s")
|
| 152 |
+
|
| 153 |
+
full_prompt = f"{prompt}\n\n{text}"
|
| 154 |
+
|
| 155 |
+
payload = {
|
| 156 |
+
"model": self.model,
|
| 157 |
+
"prompt": full_prompt,
|
| 158 |
+
"stream": True, # Enable streaming
|
| 159 |
+
"options": {
|
| 160 |
+
"num_predict": max_tokens,
|
| 161 |
+
"temperature": 0.1, # Lower temperature for faster, more focused output
|
| 162 |
+
"top_p": 0.9, # Nucleus sampling for efficiency
|
| 163 |
+
"top_k": 40, # Limit vocabulary for speed
|
| 164 |
+
"repeat_penalty": 1.1, # Prevent repetition
|
| 165 |
+
"num_ctx": 2048, # Limit context window for speed
|
| 166 |
+
},
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
generate_url = urljoin(self.base_url, "api/generate")
|
| 170 |
+
logger.info(f"POST {generate_url} (streaming)")
|
| 171 |
+
|
| 172 |
+
try:
|
| 173 |
+
async with httpx.AsyncClient(timeout=dynamic_timeout) as client:
|
| 174 |
+
async with client.stream("POST", generate_url, json=payload) as response:
|
| 175 |
+
response.raise_for_status()
|
| 176 |
+
|
| 177 |
+
async for line in response.aiter_lines():
|
| 178 |
+
line = line.strip()
|
| 179 |
+
if not line:
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
data = json.loads(line)
|
| 184 |
+
chunk = {
|
| 185 |
+
"content": data.get("response", ""),
|
| 186 |
+
"done": data.get("done", False),
|
| 187 |
+
"tokens_used": data.get("eval_count", 0),
|
| 188 |
+
}
|
| 189 |
+
yield chunk
|
| 190 |
+
|
| 191 |
+
# Break if this is the final chunk
|
| 192 |
+
if data.get("done", False):
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
except json.JSONDecodeError:
|
| 196 |
+
# Skip malformed JSON lines
|
| 197 |
+
logger.warning(f"Skipping malformed JSON line: {line[:100]}")
|
| 198 |
+
continue
|
| 199 |
+
|
| 200 |
+
except httpx.TimeoutException:
|
| 201 |
+
logger.error(
|
| 202 |
+
f"Timeout calling Ollama after {dynamic_timeout}s "
|
| 203 |
+
f"(chars={text_length}, url={generate_url})"
|
| 204 |
+
)
|
| 205 |
+
raise
|
| 206 |
+
except httpx.RequestError as e:
|
| 207 |
+
# Network / connection errors (DNS, refused, TLS, etc.)
|
| 208 |
+
logger.error(f"Request error calling Ollama at {generate_url}: {e}")
|
| 209 |
+
raise
|
| 210 |
+
except httpx.HTTPStatusError as e:
|
| 211 |
+
# Non-2xx responses
|
| 212 |
+
body = e.response.text if e.response is not None else ""
|
| 213 |
+
logger.error(
|
| 214 |
+
f"HTTP {e.response.status_code if e.response else '??'} from Ollama at {generate_url}: {body[:400]}"
|
| 215 |
+
)
|
| 216 |
+
raise
|
| 217 |
+
except Exception as e:
|
| 218 |
+
logger.error(f"Unexpected error calling Ollama at {generate_url}: {e}")
|
| 219 |
+
# Present a consistent error type to callers
|
| 220 |
+
raise httpx.HTTPError(f"Ollama API error: {e}") from e
|
| 221 |
+
|
| 222 |
async def check_health(self) -> bool:
|
| 223 |
"""
|
| 224 |
Verify Ollama is reachable and (optionally) that the model exists.
|
tests/test_api.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
Integration tests for API endpoints.
|
| 3 |
"""
|
|
|
|
| 4 |
import pytest
|
| 5 |
-
from unittest.mock import patch
|
| 6 |
from starlette.testclient import TestClient
|
| 7 |
from app.main import app
|
| 8 |
|
|
@@ -96,4 +97,203 @@ def test_summarize_endpoint_large_text_handling():
|
|
| 96 |
mock_client.assert_called_once()
|
| 97 |
call_args = mock_client.call_args
|
| 98 |
expected_timeout = 60 + (5000 - 1000) // 1000 * 5 # 80 seconds
|
| 99 |
-
assert call_args[1]['timeout'] == expected_timeout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Integration tests for API endpoints.
|
| 3 |
"""
|
| 4 |
+
import json
|
| 5 |
import pytest
|
| 6 |
+
from unittest.mock import patch, MagicMock
|
| 7 |
from starlette.testclient import TestClient
|
| 8 |
from app.main import app
|
| 9 |
|
|
|
|
| 97 |
mock_client.assert_called_once()
|
| 98 |
call_args = mock_client.call_args
|
| 99 |
expected_timeout = 60 + (5000 - 1000) // 1000 * 5 # 80 seconds
|
| 100 |
+
assert call_args[1]['timeout'] == expected_timeout
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Tests for Streaming Endpoint
|
| 104 |
+
@pytest.mark.integration
|
| 105 |
+
def test_summarize_stream_endpoint_success(sample_text):
|
| 106 |
+
"""Test successful streaming summarization via API endpoint."""
|
| 107 |
+
# Mock streaming response data
|
| 108 |
+
mock_stream_data = [
|
| 109 |
+
'{"response": "This", "done": false, "eval_count": 1}\n',
|
| 110 |
+
'{"response": " is", "done": false, "eval_count": 2}\n',
|
| 111 |
+
'{"response": " a", "done": false, "eval_count": 3}\n',
|
| 112 |
+
'{"response": " test", "done": true, "eval_count": 4}\n'
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
class MockStreamResponse:
|
| 116 |
+
def __init__(self, data):
|
| 117 |
+
self.data = data
|
| 118 |
+
|
| 119 |
+
async def aiter_lines(self):
|
| 120 |
+
for line in self.data:
|
| 121 |
+
yield line
|
| 122 |
+
|
| 123 |
+
def raise_for_status(self):
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
class MockStreamContextManager:
|
| 127 |
+
def __init__(self, response):
|
| 128 |
+
self.response = response
|
| 129 |
+
|
| 130 |
+
async def __aenter__(self):
|
| 131 |
+
return self.response
|
| 132 |
+
|
| 133 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 134 |
+
return False
|
| 135 |
+
|
| 136 |
+
class MockStreamClient:
|
| 137 |
+
async def __aenter__(self):
|
| 138 |
+
return self
|
| 139 |
+
|
| 140 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 141 |
+
return False
|
| 142 |
+
|
| 143 |
+
def stream(self, method, url, **kwargs):
|
| 144 |
+
return MockStreamContextManager(MockStreamResponse(mock_stream_data))
|
| 145 |
+
|
| 146 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 147 |
+
resp = client.post(
|
| 148 |
+
"/api/v1/summarize/stream",
|
| 149 |
+
json={"text": sample_text, "max_tokens": 128}
|
| 150 |
+
)
|
| 151 |
+
assert resp.status_code == 200
|
| 152 |
+
assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
|
| 153 |
+
|
| 154 |
+
# Parse SSE response
|
| 155 |
+
lines = resp.text.strip().split('\n')
|
| 156 |
+
data_lines = [line for line in lines if line.startswith('data: ')]
|
| 157 |
+
|
| 158 |
+
assert len(data_lines) == 4
|
| 159 |
+
|
| 160 |
+
# Parse first chunk
|
| 161 |
+
first_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
|
| 162 |
+
assert first_chunk["content"] == "This"
|
| 163 |
+
assert first_chunk["done"] is False
|
| 164 |
+
assert first_chunk["tokens_used"] == 1
|
| 165 |
+
|
| 166 |
+
# Parse last chunk
|
| 167 |
+
last_chunk = json.loads(data_lines[-1][6:]) # Remove 'data: ' prefix
|
| 168 |
+
assert last_chunk["content"] == " test"
|
| 169 |
+
assert last_chunk["done"] is True
|
| 170 |
+
assert last_chunk["tokens_used"] == 4
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
@pytest.mark.integration
|
| 174 |
+
def test_summarize_stream_endpoint_validation_error():
|
| 175 |
+
"""Test validation error for empty text in streaming endpoint."""
|
| 176 |
+
resp = client.post(
|
| 177 |
+
"/api/v1/summarize/stream",
|
| 178 |
+
json={"text": ""}
|
| 179 |
+
)
|
| 180 |
+
assert resp.status_code == 422
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@pytest.mark.integration
|
| 184 |
+
def test_summarize_stream_endpoint_timeout_error():
|
| 185 |
+
"""Test that timeout errors in streaming return proper error."""
|
| 186 |
+
import httpx
|
| 187 |
+
|
| 188 |
+
class MockStreamClient:
|
| 189 |
+
async def __aenter__(self):
|
| 190 |
+
return self
|
| 191 |
+
|
| 192 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
def stream(self, method, url, **kwargs):
|
| 196 |
+
raise httpx.TimeoutException("Timeout")
|
| 197 |
+
|
| 198 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 199 |
+
resp = client.post(
|
| 200 |
+
"/api/v1/summarize/stream",
|
| 201 |
+
json={"text": "Test text that will timeout"}
|
| 202 |
+
)
|
| 203 |
+
assert resp.status_code == 200 # SSE returns 200 even with errors
|
| 204 |
+
assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
|
| 205 |
+
|
| 206 |
+
# Parse SSE response
|
| 207 |
+
lines = resp.text.strip().split('\n')
|
| 208 |
+
data_lines = [line for line in lines if line.startswith('data: ')]
|
| 209 |
+
|
| 210 |
+
assert len(data_lines) == 1
|
| 211 |
+
error_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
|
| 212 |
+
assert error_chunk["done"] is True
|
| 213 |
+
assert "timeout" in error_chunk["error"].lower()
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
@pytest.mark.integration
|
| 217 |
+
def test_summarize_stream_endpoint_http_error():
|
| 218 |
+
"""Test that HTTP errors in streaming return proper error."""
|
| 219 |
+
import httpx
|
| 220 |
+
|
| 221 |
+
http_error = httpx.HTTPStatusError("Bad Request", request=MagicMock(), response=MagicMock())
|
| 222 |
+
|
| 223 |
+
class MockStreamClient:
|
| 224 |
+
async def __aenter__(self):
|
| 225 |
+
return self
|
| 226 |
+
|
| 227 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 228 |
+
return False
|
| 229 |
+
|
| 230 |
+
def stream(self, method, url, **kwargs):
|
| 231 |
+
raise http_error
|
| 232 |
+
|
| 233 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 234 |
+
resp = client.post(
|
| 235 |
+
"/api/v1/summarize/stream",
|
| 236 |
+
json={"text": "Test text"}
|
| 237 |
+
)
|
| 238 |
+
assert resp.status_code == 200 # SSE returns 200 even with errors
|
| 239 |
+
assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
|
| 240 |
+
|
| 241 |
+
# Parse SSE response
|
| 242 |
+
lines = resp.text.strip().split('\n')
|
| 243 |
+
data_lines = [line for line in lines if line.startswith('data: ')]
|
| 244 |
+
|
| 245 |
+
assert len(data_lines) == 1
|
| 246 |
+
error_chunk = json.loads(data_lines[0][6:]) # Remove 'data: ' prefix
|
| 247 |
+
assert error_chunk["done"] is True
|
| 248 |
+
assert "Summarization failed" in error_chunk["error"]
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@pytest.mark.integration
|
| 252 |
+
def test_summarize_stream_endpoint_sse_format():
|
| 253 |
+
"""Test that streaming endpoint returns proper SSE format."""
|
| 254 |
+
mock_stream_data = ['{"response": "Summary", "done": true, "eval_count": 1}\n']
|
| 255 |
+
|
| 256 |
+
class MockStreamResponse:
|
| 257 |
+
def __init__(self, data):
|
| 258 |
+
self.data = data
|
| 259 |
+
|
| 260 |
+
async def aiter_lines(self):
|
| 261 |
+
for line in self.data:
|
| 262 |
+
yield line
|
| 263 |
+
|
| 264 |
+
def raise_for_status(self):
|
| 265 |
+
pass
|
| 266 |
+
|
| 267 |
+
class MockStreamContextManager:
|
| 268 |
+
def __init__(self, response):
|
| 269 |
+
self.response = response
|
| 270 |
+
|
| 271 |
+
async def __aenter__(self):
|
| 272 |
+
return self.response
|
| 273 |
+
|
| 274 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 275 |
+
return False
|
| 276 |
+
|
| 277 |
+
class MockStreamClient:
|
| 278 |
+
async def __aenter__(self):
|
| 279 |
+
return self
|
| 280 |
+
|
| 281 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 282 |
+
return False
|
| 283 |
+
|
| 284 |
+
def stream(self, method, url, **kwargs):
|
| 285 |
+
return MockStreamContextManager(MockStreamResponse(mock_stream_data))
|
| 286 |
+
|
| 287 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 288 |
+
resp = client.post(
|
| 289 |
+
"/api/v1/summarize/stream",
|
| 290 |
+
json={"text": "Test text"}
|
| 291 |
+
)
|
| 292 |
+
assert resp.status_code == 200
|
| 293 |
+
assert resp.headers["content-type"] == "text/event-stream; charset=utf-8"
|
| 294 |
+
assert resp.headers["cache-control"] == "no-cache"
|
| 295 |
+
assert resp.headers["connection"] == "keep-alive"
|
| 296 |
+
|
| 297 |
+
# Check SSE format
|
| 298 |
+
lines = resp.text.strip().split('\n')
|
| 299 |
+
assert any(line.startswith('data: ') for line in lines)
|
tests/test_services.py
CHANGED
|
@@ -224,3 +224,258 @@ class TestOllamaService:
|
|
| 224 |
error_message = str(exc_info.value)
|
| 225 |
assert f"timeout after {expected_timeout}s" in error_message
|
| 226 |
assert "Text may be too long or complex" in error_message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
error_message = str(exc_info.value)
|
| 225 |
assert f"timeout after {expected_timeout}s" in error_message
|
| 226 |
assert "Text may be too long or complex" in error_message
|
| 227 |
+
|
| 228 |
+
# Tests for Streaming Functionality
|
| 229 |
+
@pytest.mark.asyncio
|
| 230 |
+
async def test_summarize_text_stream_success(self, ollama_service):
|
| 231 |
+
"""Test successful text streaming."""
|
| 232 |
+
# Mock streaming response data
|
| 233 |
+
mock_stream_data = [
|
| 234 |
+
'{"response": "This", "done": false, "eval_count": 1}\n',
|
| 235 |
+
'{"response": " is", "done": false, "eval_count": 2}\n',
|
| 236 |
+
'{"response": " a", "done": false, "eval_count": 3}\n',
|
| 237 |
+
'{"response": " test", "done": true, "eval_count": 4}\n'
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
class MockStreamResponse:
|
| 241 |
+
def __init__(self, data):
|
| 242 |
+
self.data = data
|
| 243 |
+
self._index = 0
|
| 244 |
+
|
| 245 |
+
async def aiter_lines(self):
|
| 246 |
+
for line in self.data:
|
| 247 |
+
yield line
|
| 248 |
+
|
| 249 |
+
def raise_for_status(self):
|
| 250 |
+
# Mock successful response
|
| 251 |
+
pass
|
| 252 |
+
|
| 253 |
+
mock_response = MockStreamResponse(mock_stream_data)
|
| 254 |
+
|
| 255 |
+
class MockStreamContextManager:
|
| 256 |
+
def __init__(self, response):
|
| 257 |
+
self.response = response
|
| 258 |
+
|
| 259 |
+
async def __aenter__(self):
|
| 260 |
+
return self.response
|
| 261 |
+
|
| 262 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
class MockStreamClient:
|
| 266 |
+
async def __aenter__(self):
|
| 267 |
+
return self
|
| 268 |
+
|
| 269 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
def stream(self, method, url, **kwargs):
|
| 273 |
+
# Return an async context manager
|
| 274 |
+
return MockStreamContextManager(mock_response)
|
| 275 |
+
|
| 276 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 277 |
+
chunks = []
|
| 278 |
+
async for chunk in ollama_service.summarize_text_stream("Test text"):
|
| 279 |
+
chunks.append(chunk)
|
| 280 |
+
|
| 281 |
+
assert len(chunks) == 4
|
| 282 |
+
assert chunks[0]["content"] == "This"
|
| 283 |
+
assert chunks[0]["done"] is False
|
| 284 |
+
assert chunks[0]["tokens_used"] == 1
|
| 285 |
+
assert chunks[-1]["content"] == " test"
|
| 286 |
+
assert chunks[-1]["done"] is True
|
| 287 |
+
assert chunks[-1]["tokens_used"] == 4
|
| 288 |
+
|
| 289 |
+
@pytest.mark.asyncio
|
| 290 |
+
async def test_summarize_text_stream_with_custom_params(self, ollama_service):
|
| 291 |
+
"""Test streaming with custom parameters."""
|
| 292 |
+
mock_stream_data = ['{"response": "Summary", "done": true, "eval_count": 1}\n']
|
| 293 |
+
|
| 294 |
+
class MockStreamResponse:
|
| 295 |
+
def __init__(self, data):
|
| 296 |
+
self.data = data
|
| 297 |
+
|
| 298 |
+
async def aiter_lines(self):
|
| 299 |
+
for line in self.data:
|
| 300 |
+
yield line
|
| 301 |
+
|
| 302 |
+
def raise_for_status(self):
|
| 303 |
+
# Mock successful response
|
| 304 |
+
pass
|
| 305 |
+
|
| 306 |
+
mock_response = MockStreamResponse(mock_stream_data)
|
| 307 |
+
captured_payload = {}
|
| 308 |
+
|
| 309 |
+
class MockStreamContextManager:
|
| 310 |
+
def __init__(self, response):
|
| 311 |
+
self.response = response
|
| 312 |
+
|
| 313 |
+
async def __aenter__(self):
|
| 314 |
+
return self.response
|
| 315 |
+
|
| 316 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 317 |
+
return False
|
| 318 |
+
|
| 319 |
+
class MockStreamClient:
|
| 320 |
+
async def __aenter__(self):
|
| 321 |
+
return self
|
| 322 |
+
|
| 323 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 324 |
+
return False
|
| 325 |
+
|
| 326 |
+
def stream(self, method, url, **kwargs):
|
| 327 |
+
captured_payload.update(kwargs.get('json', {}))
|
| 328 |
+
return MockStreamContextManager(mock_response)
|
| 329 |
+
|
| 330 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 331 |
+
chunks = []
|
| 332 |
+
async for chunk in ollama_service.summarize_text_stream(
|
| 333 |
+
"Test text",
|
| 334 |
+
max_tokens=512,
|
| 335 |
+
prompt="Custom prompt"
|
| 336 |
+
):
|
| 337 |
+
chunks.append(chunk)
|
| 338 |
+
|
| 339 |
+
# Verify captured payload
|
| 340 |
+
assert captured_payload["stream"] is True
|
| 341 |
+
assert captured_payload["options"]["num_predict"] == 512
|
| 342 |
+
assert "Custom prompt" in captured_payload["prompt"]
|
| 343 |
+
|
| 344 |
+
@pytest.mark.asyncio
|
| 345 |
+
async def test_summarize_text_stream_timeout(self, ollama_service):
|
| 346 |
+
"""Test streaming timeout handling."""
|
| 347 |
+
class MockStreamClient:
|
| 348 |
+
async def __aenter__(self):
|
| 349 |
+
return self
|
| 350 |
+
|
| 351 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 352 |
+
return False
|
| 353 |
+
|
| 354 |
+
def stream(self, method, url, **kwargs):
|
| 355 |
+
raise httpx.TimeoutException("Timeout")
|
| 356 |
+
|
| 357 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 358 |
+
with pytest.raises(httpx.TimeoutException):
|
| 359 |
+
chunks = []
|
| 360 |
+
async for chunk in ollama_service.summarize_text_stream("Test text"):
|
| 361 |
+
chunks.append(chunk)
|
| 362 |
+
|
| 363 |
+
@pytest.mark.asyncio
|
| 364 |
+
async def test_summarize_text_stream_http_error(self, ollama_service):
|
| 365 |
+
"""Test streaming HTTP error handling."""
|
| 366 |
+
http_error = httpx.HTTPStatusError("Bad Request", request=MagicMock(), response=MagicMock())
|
| 367 |
+
|
| 368 |
+
class MockStreamClient:
|
| 369 |
+
async def __aenter__(self):
|
| 370 |
+
return self
|
| 371 |
+
|
| 372 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 373 |
+
return False
|
| 374 |
+
|
| 375 |
+
def stream(self, method, url, **kwargs):
|
| 376 |
+
raise http_error
|
| 377 |
+
|
| 378 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 379 |
+
with pytest.raises(httpx.HTTPStatusError):
|
| 380 |
+
chunks = []
|
| 381 |
+
async for chunk in ollama_service.summarize_text_stream("Test text"):
|
| 382 |
+
chunks.append(chunk)
|
| 383 |
+
|
| 384 |
+
@pytest.mark.asyncio
|
| 385 |
+
async def test_summarize_text_stream_empty_response(self, ollama_service):
|
| 386 |
+
"""Test streaming with empty response."""
|
| 387 |
+
mock_stream_data = []
|
| 388 |
+
|
| 389 |
+
class MockStreamResponse:
|
| 390 |
+
def __init__(self, data):
|
| 391 |
+
self.data = data
|
| 392 |
+
|
| 393 |
+
async def aiter_lines(self):
|
| 394 |
+
for line in self.data:
|
| 395 |
+
yield line
|
| 396 |
+
|
| 397 |
+
def raise_for_status(self):
|
| 398 |
+
# Mock successful response
|
| 399 |
+
pass
|
| 400 |
+
|
| 401 |
+
mock_response = MockStreamResponse(mock_stream_data)
|
| 402 |
+
|
| 403 |
+
class MockStreamContextManager:
|
| 404 |
+
def __init__(self, response):
|
| 405 |
+
self.response = response
|
| 406 |
+
|
| 407 |
+
async def __aenter__(self):
|
| 408 |
+
return self.response
|
| 409 |
+
|
| 410 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 411 |
+
return False
|
| 412 |
+
|
| 413 |
+
class MockStreamClient:
|
| 414 |
+
async def __aenter__(self):
|
| 415 |
+
return self
|
| 416 |
+
|
| 417 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 418 |
+
return False
|
| 419 |
+
|
| 420 |
+
def stream(self, method, url, **kwargs):
|
| 421 |
+
return MockStreamContextManager(mock_response)
|
| 422 |
+
|
| 423 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 424 |
+
chunks = []
|
| 425 |
+
async for chunk in ollama_service.summarize_text_stream("Test text"):
|
| 426 |
+
chunks.append(chunk)
|
| 427 |
+
|
| 428 |
+
assert len(chunks) == 0
|
| 429 |
+
|
| 430 |
+
@pytest.mark.asyncio
|
| 431 |
+
async def test_summarize_text_stream_malformed_json(self, ollama_service):
|
| 432 |
+
"""Test streaming with malformed JSON response."""
|
| 433 |
+
mock_stream_data = [
|
| 434 |
+
'{"response": "Valid", "done": false, "eval_count": 1}\n',
|
| 435 |
+
'invalid json line\n',
|
| 436 |
+
'{"response": "End", "done": true, "eval_count": 2}\n'
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
class MockStreamResponse:
|
| 440 |
+
def __init__(self, data):
|
| 441 |
+
self.data = data
|
| 442 |
+
|
| 443 |
+
async def aiter_lines(self):
|
| 444 |
+
for line in self.data:
|
| 445 |
+
yield line
|
| 446 |
+
|
| 447 |
+
def raise_for_status(self):
|
| 448 |
+
# Mock successful response
|
| 449 |
+
pass
|
| 450 |
+
|
| 451 |
+
mock_response = MockStreamResponse(mock_stream_data)
|
| 452 |
+
|
| 453 |
+
class MockStreamContextManager:
|
| 454 |
+
def __init__(self, response):
|
| 455 |
+
self.response = response
|
| 456 |
+
|
| 457 |
+
async def __aenter__(self):
|
| 458 |
+
return self.response
|
| 459 |
+
|
| 460 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 461 |
+
return False
|
| 462 |
+
|
| 463 |
+
class MockStreamClient:
|
| 464 |
+
async def __aenter__(self):
|
| 465 |
+
return self
|
| 466 |
+
|
| 467 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 468 |
+
return False
|
| 469 |
+
|
| 470 |
+
def stream(self, method, url, **kwargs):
|
| 471 |
+
return MockStreamContextManager(mock_response)
|
| 472 |
+
|
| 473 |
+
with patch('httpx.AsyncClient', return_value=MockStreamClient()):
|
| 474 |
+
chunks = []
|
| 475 |
+
async for chunk in ollama_service.summarize_text_stream("Test text"):
|
| 476 |
+
chunks.append(chunk)
|
| 477 |
+
|
| 478 |
+
# Should skip malformed JSON and continue with valid chunks
|
| 479 |
+
assert len(chunks) == 2
|
| 480 |
+
assert chunks[0]["content"] == "Valid"
|
| 481 |
+
assert chunks[1]["content"] == "End"
|