File size: 12,556 Bytes
9884884
 
 
2ed2bd7
29ed661
2ed2bd7
9884884
2ed2bd7
29ed661
 
 
 
9884884
 
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
 
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
 
 
 
2ed2bd7
 
9884884
 
 
 
 
 
 
 
2ed2bd7
9884884
 
 
 
 
 
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
2ed2bd7
 
 
 
 
9884884
2ed2bd7
9884884
2ed2bd7
9884884
2ed2bd7
 
 
 
 
9884884
 
2ed2bd7
 
 
 
 
9884884
 
2ed2bd7
9884884
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
2ed2bd7
9884884
 
2ed2bd7
9884884
 
 
 
 
2ed2bd7
9884884
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
 
 
 
 
2ed2bd7
9884884
 
 
2ed2bd7
29ed661
 
 
 
 
 
2ed2bd7
29ed661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed2bd7
9884884
 
 
 
 
 
 
2ed2bd7
9884884
 
 
2ed2bd7
29ed661
 
 
 
 
 
2ed2bd7
29ed661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
2ed2bd7
9884884
2ed2bd7
9884884
 
2ed2bd7
 
 
 
 
9884884
 
2ed2bd7
9884884
 
 
 
 
 
2ed2bd7
9884884
 
 
 
 
2ed2bd7
9884884
 
2ed2bd7
 
 
 
 
9884884
 
2ed2bd7
9884884
 
 
 
 
 
 
 
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
2ed2bd7
9884884
 
2ed2bd7
9884884
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
 
 
2ed2bd7
9884884
 
 
 
 
 
2ed2bd7
9884884
 
 
2ed2bd7
29ed661
 
 
 
 
 
2ed2bd7
29ed661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
"""
Tests for HuggingFace streaming summarizer improvements.
"""

from unittest.mock import MagicMock, patch

import pytest

from app.services.hf_streaming_summarizer import (
    HFStreamingSummarizer,
    _split_into_chunks,
)


class TestSplitIntoChunks:
    """Test the text chunking utility function."""

    def test_split_short_text(self):
        """Test splitting short text that doesn't need chunking."""
        text = "This is a short text."
        chunks = _split_into_chunks(text, chunk_chars=100, overlap=20)

        assert len(chunks) == 1
        assert chunks[0] == text

    def test_split_long_text(self):
        """Test splitting long text into multiple chunks."""
        text = "This is a longer text. " * 50  # ~1000 chars
        chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)

        assert len(chunks) > 1
        # All chunks should be within reasonable size
        for chunk in chunks:
            assert len(chunk) <= 200
            assert len(chunk) > 0

    def test_chunk_overlap(self):
        """Test that chunks have proper overlap."""
        text = "This is a test text for overlap testing. " * 20  # ~800 chars
        chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)

        if len(chunks) > 1:
            # Check that consecutive chunks share some content
            for i in range(len(chunks) - 1):
                # There should be some overlap between consecutive chunks
                assert len(chunks[i]) > 0
                assert len(chunks[i + 1]) > 0

    def test_empty_text(self):
        """Test splitting empty text."""
        chunks = _split_into_chunks("", chunk_chars=100, overlap=20)
        assert len(chunks) == 0  # Empty text returns empty list


class TestHFStreamingSummarizerImprovements:
    """Test improvements to HFStreamingSummarizer."""

    @pytest.fixture
    def mock_summarizer(self):
        """Create a mock HFStreamingSummarizer for testing."""
        summarizer = HFStreamingSummarizer()
        summarizer.model = MagicMock()
        summarizer.tokenizer = MagicMock()
        return summarizer

    @pytest.mark.asyncio
    async def test_recursive_summarization_long_text(self, mock_summarizer):
        """Test recursive summarization for long text."""

        # Mock the _single_chunk_summarize method
        async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
            yield {
                "content": f"Summary of: {text[:50]}...",
                "done": False,
                "tokens_used": 10,
            }
            yield {"content": "", "done": True, "tokens_used": 10}

        mock_summarizer._single_chunk_summarize = mock_single_chunk

        # Long text (>1500 chars)
        long_text = (
            "This is a very long text that should trigger recursive summarization. "
            * 30
        )  # ~2000+ chars

        results = []
        async for chunk in mock_summarizer._recursive_summarize(
            long_text,
            max_new_tokens=100,
            temperature=0.3,
            top_p=0.9,
            prompt="Test prompt",
        ):
            results.append(chunk)

        # Should have multiple chunks (one for each text chunk + final summary)
        assert len(results) > 2  # At least 2 chunks + final done signal

        # Check that we get proper streaming format
        content_chunks = [r for r in results if r.get("content") and not r.get("done")]
        assert len(content_chunks) > 0

        # Should end with done signal
        final_chunk = results[-1]
        assert final_chunk.get("done") is True

    @pytest.mark.asyncio
    async def test_recursive_summarization_single_chunk(self, mock_summarizer):
        """Test recursive summarization when text fits in single chunk."""

        # Mock the _single_chunk_summarize method
        async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
            yield {"content": "Single chunk summary", "done": False, "tokens_used": 5}
            yield {"content": "", "done": True, "tokens_used": 5}

        mock_summarizer._single_chunk_summarize = mock_single_chunk

        # Text that would fit in single chunk after splitting
        text = "This is a medium length text. " * 20  # ~600 chars

        results = []
        async for chunk in mock_summarizer._recursive_summarize(
            text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
        ):
            results.append(chunk)

        # Should have at least 2 chunks (content + done)
        assert len(results) >= 2

        # Should end with done signal
        final_chunk = results[-1]
        assert final_chunk.get("done") is True

    @pytest.mark.asyncio
    async def test_single_chunk_summarize_parameters(self, mock_summarizer):
        """Test that _single_chunk_summarize uses correct parameters."""
        # Mock the tokenizer and model
        mock_summarizer.tokenizer.model_max_length = 1024
        mock_summarizer.tokenizer.pad_token_id = 0
        mock_summarizer.tokenizer.eos_token_id = 1

        # Mock the model generation
        mock_streamer = MagicMock()
        mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))

        with (
            patch(
                "app.services.hf_streaming_summarizer.TextIteratorStreamer",
                return_value=mock_streamer,
            ),
            patch("app.services.hf_streaming_summarizer.settings") as mock_settings,
        ):
            mock_settings.hf_model_id = "test-model"

            results = []
            async for chunk in mock_summarizer._single_chunk_summarize(
                "Test text",
                max_new_tokens=80,
                temperature=0.3,
                top_p=0.9,
                prompt="Test prompt",
            ):
                results.append(chunk)

            # Should have content chunks + final done
            assert len(results) >= 2

            # Check that generation was called with correct parameters
            mock_summarizer.model.generate.assert_called_once()
            call_kwargs = mock_summarizer.model.generate.call_args[1]

            assert call_kwargs["max_new_tokens"] == 80
            assert call_kwargs["temperature"] == 0.3
            assert call_kwargs["top_p"] == 0.9
            assert call_kwargs["length_penalty"] == 1.0  # Should be neutral
            assert call_kwargs["min_new_tokens"] <= 50  # Should be conservative

    @pytest.mark.asyncio
    async def test_single_chunk_summarize_defaults(self, mock_summarizer):
        """Test that _single_chunk_summarize uses correct defaults."""
        # Mock the tokenizer and model
        mock_summarizer.tokenizer.model_max_length = 1024
        mock_summarizer.tokenizer.pad_token_id = 0
        mock_summarizer.tokenizer.eos_token_id = 1

        # Mock the model generation
        mock_streamer = MagicMock()
        mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))

        with (
            patch(
                "app.services.hf_streaming_summarizer.TextIteratorStreamer",
                return_value=mock_streamer,
            ),
            patch("app.services.hf_streaming_summarizer.settings") as mock_settings,
        ):
            mock_settings.hf_model_id = "test-model"

            results = []
            async for chunk in mock_summarizer._single_chunk_summarize(
                "Test text",
                max_new_tokens=None,
                temperature=None,
                top_p=None,
                prompt="Test prompt",
            ):
                results.append(chunk)

            # Check that generation was called with correct defaults
            mock_summarizer.model.generate.assert_called_once()
            call_kwargs = mock_summarizer.model.generate.call_args[1]

            assert call_kwargs["max_new_tokens"] == 80  # Default
            assert call_kwargs["temperature"] == 0.3  # Default
            assert call_kwargs["top_p"] == 0.9  # Default

    @pytest.mark.asyncio
    async def test_recursive_summarization_error_handling(self, mock_summarizer):
        """Test error handling in recursive summarization."""

        # Mock _single_chunk_summarize to raise an exception
        async def mock_single_chunk_error(text, max_tokens, temp, top_p, prompt):
            raise Exception("Test error")
            yield  # This line will never be reached, but makes it an async generator

        mock_summarizer._single_chunk_summarize = mock_single_chunk_error

        long_text = "This is a long text. " * 30

        results = []
        async for chunk in mock_summarizer._recursive_summarize(
            long_text,
            max_new_tokens=100,
            temperature=0.3,
            top_p=0.9,
            prompt="Test prompt",
        ):
            results.append(chunk)

        # Should have error chunk
        assert len(results) == 1
        error_chunk = results[0]
        assert error_chunk.get("done") is True
        assert "error" in error_chunk
        assert "Test error" in error_chunk["error"]

    @pytest.mark.asyncio
    async def test_single_chunk_summarize_error_handling(self, mock_summarizer):
        """Test error handling in single chunk summarization."""
        # Mock model to raise exception
        mock_summarizer.model.generate.side_effect = Exception("Generation error")

        results = []
        async for chunk in mock_summarizer._single_chunk_summarize(
            "Test text",
            max_new_tokens=80,
            temperature=0.3,
            top_p=0.9,
            prompt="Test prompt",
        ):
            results.append(chunk)

        # Should have error chunk
        assert len(results) == 1
        error_chunk = results[0]
        assert error_chunk.get("done") is True
        assert "error" in error_chunk
        assert "Generation error" in error_chunk["error"]


class TestHFStreamingSummarizerIntegration:
    """Integration tests for HFStreamingSummarizer improvements."""

    @pytest.mark.asyncio
    async def test_summarize_text_stream_long_text_detection(self):
        """Test that summarize_text_stream detects long text and uses recursive summarization."""
        summarizer = HFStreamingSummarizer()

        # Mock the recursive summarization method
        async def mock_recursive(text, max_tokens, temp, top_p, prompt):
            yield {"content": "Recursive summary", "done": False, "tokens_used": 10}
            yield {"content": "", "done": True, "tokens_used": 10}

        summarizer._recursive_summarize = mock_recursive

        # Long text (>1500 chars)
        long_text = "This is a very long text. " * 60  # ~1500+ chars

        results = []
        async for chunk in summarizer.summarize_text_stream(long_text):
            results.append(chunk)

        # Should have used recursive summarization
        assert len(results) >= 2
        assert results[0]["content"] == "Recursive summary"
        assert results[-1]["done"] is True

    @pytest.mark.asyncio
    async def test_summarize_text_stream_short_text_normal_flow(self):
        """Test that summarize_text_stream uses normal flow for short text."""
        summarizer = HFStreamingSummarizer()

        # Mock model and tokenizer
        summarizer.model = MagicMock()
        summarizer.tokenizer = MagicMock()
        summarizer.tokenizer.model_max_length = 1024
        summarizer.tokenizer.pad_token_id = 0
        summarizer.tokenizer.eos_token_id = 1

        # Mock the streamer
        mock_streamer = MagicMock()
        mock_streamer.__iter__ = MagicMock(return_value=iter(["short", "summary"]))

        with (
            patch(
                "app.services.hf_streaming_summarizer.TextIteratorStreamer",
                return_value=mock_streamer,
            ),
            patch("app.services.hf_streaming_summarizer.settings") as mock_settings,
        ):
            mock_settings.hf_model_id = "test-model"
            mock_settings.hf_temperature = 0.3
            mock_settings.hf_top_p = 0.9

            # Short text (<1500 chars)
            short_text = "This is a short text."

            results = []
            async for chunk in summarizer.summarize_text_stream(short_text):
                results.append(chunk)

            # Should have used normal flow (not recursive)
            assert len(results) >= 2
            assert results[0]["content"] == "short"
            assert results[1]["content"] == "summary"
            assert results[-1]["done"] is True