ming commited on
Commit
9884884
·
1 Parent(s): 698636a

Improve V2 summarization: adaptive tokens, recursive summarization, better defaults

Browse files

- Add temperature and top_p parameters to SummarizeRequest schema (defaults: 0.3, 0.9)
- Implement adaptive max_new_tokens logic: 60-100 for short texts, proportional for long texts
- Add recursive summarization for texts >1500 chars with chunking and summary-of-summaries
- Fix generation parameters to prevent rambling (reduce min_new_tokens, neutral length_penalty)
- Update default prompt to be more concise
- Add comprehensive unit tests for all improvements (40 tests)
- Fix V2 API to generate concise summaries instead of rambling output

app/api/v1/schemas.py CHANGED
@@ -10,8 +10,10 @@ class SummarizeRequest(BaseModel):
10
 
11
  text: str = Field(..., min_length=1, max_length=32000, description="Text to summarize")
12
  max_tokens: Optional[int] = Field(default=256, ge=1, le=2048, description="Maximum tokens for summary")
 
 
13
  prompt: Optional[str] = Field(
14
- default="Provide a comprehensive summary of the following text, including main arguments, key findings, important details, and specific examples. Structure your response clearly:",
15
  max_length=500,
16
  description="Custom prompt for summarization"
17
  )
 
10
 
11
  text: str = Field(..., min_length=1, max_length=32000, description="Text to summarize")
12
  max_tokens: Optional[int] = Field(default=256, ge=1, le=2048, description="Maximum tokens for summary")
13
+ temperature: Optional[float] = Field(default=0.3, ge=0.0, le=2.0, description="Sampling temperature for generation")
14
+ top_p: Optional[float] = Field(default=0.9, ge=0.0, le=1.0, description="Nucleus sampling parameter")
15
  prompt: Optional[str] = Field(
16
+ default="Summarize the key points concisely:",
17
  max_length=500,
18
  description="Custom prompt for summarization"
19
  )
app/api/v2/summarize.py CHANGED
@@ -27,12 +27,28 @@ async def summarize_stream(payload: SummarizeRequest):
27
  async def _stream_generator(payload: SummarizeRequest):
28
  """Generator function for streaming SSE responses using HuggingFace."""
29
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  async for chunk in hf_streaming_service.summarize_text_stream(
31
  text=payload.text,
32
- max_new_tokens=payload.max_tokens or 128, # Map max_tokens to max_new_tokens
33
- temperature=0.7, # Use default temperature
34
- top_p=0.95, # Use default top_p
35
- prompt=payload.prompt or "Provide a comprehensive summary of the following text, including main arguments, key findings, important details, and specific examples. Structure your response clearly:",
36
  ):
37
  # Format as SSE event (same format as V1)
38
  sse_data = json.dumps(chunk)
 
27
  async def _stream_generator(payload: SummarizeRequest):
28
  """Generator function for streaming SSE responses using HuggingFace."""
29
  try:
30
+ # Calculate adaptive max_new_tokens based on text length
31
+ text_length = len(payload.text)
32
+ if text_length < 1500:
33
+ # Short texts: use 60-100 tokens
34
+ adaptive_max_tokens = min(100, max(60, text_length // 15))
35
+ else:
36
+ # Longer texts: scale proportionally but cap appropriately
37
+ adaptive_max_tokens = min(400, max(100, text_length // 20))
38
+
39
+ # Use adaptive calculation by default, but allow user override
40
+ # Check if max_tokens was explicitly provided (not just the default 256)
41
+ if hasattr(payload, 'model_fields_set') and 'max_tokens' in payload.model_fields_set:
42
+ max_new_tokens = payload.max_tokens
43
+ else:
44
+ max_new_tokens = adaptive_max_tokens
45
+
46
  async for chunk in hf_streaming_service.summarize_text_stream(
47
  text=payload.text,
48
+ max_new_tokens=max_new_tokens,
49
+ temperature=payload.temperature, # Use user-provided temperature
50
+ top_p=payload.top_p, # Use user-provided top_p
51
+ prompt=payload.prompt,
52
  ):
53
  # Format as SSE event (same format as V1)
54
  sse_data = json.dumps(chunk)
app/services/hf_streaming_summarizer.py CHANGED
@@ -164,7 +164,7 @@ class HFStreamingSummarizer:
164
  max_new_tokens: int = None,
165
  temperature: float = None,
166
  top_p: float = None,
167
- prompt: str = "Provide a comprehensive summary of the following text, including main arguments, key findings, important details, and specific examples. Structure your response clearly:",
168
  ) -> AsyncGenerator[Dict[str, Any], None]:
169
  """
170
  Stream text summarization using HuggingFace's TextIteratorStreamer.
@@ -194,13 +194,19 @@ class HFStreamingSummarizer:
194
 
195
  logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
196
 
 
 
 
 
 
 
 
197
  try:
198
  # Use provided parameters or sensible defaults
199
- # Aim for ~200–400 tokens summary by default.
200
- # If settings.hf_max_new_tokens is small, override with 256.
201
- max_new_tokens = max_new_tokens or max(getattr(settings, "hf_max_new_tokens", 0) or 0, 256)
202
- temperature = temperature or settings.hf_temperature
203
- top_p = top_p or settings.hf_top_p
204
 
205
  # Determine a generous encoder max length (respect tokenizer.model_max_length)
206
  model_max = getattr(self.tokenizer, "model_max_length", 1024)
@@ -319,10 +325,10 @@ class HFStreamingSummarizer:
319
  gen_kwargs["num_return_sequences"] = 1
320
  gen_kwargs["num_beams"] = 1
321
  gen_kwargs["num_beam_groups"] = 1
322
- # Ensure we don't stop too early; set a floor and slightly favor longer generations
323
- gen_kwargs["min_new_tokens"] = max(96, min(192, max_new_tokens // 2)) # floor ~100–192
324
- # length_penalty > 1.0 encourages longer outputs on encoder-decoder models
325
- gen_kwargs["length_penalty"] = 1.1
326
  # Reduce premature EOS in some checkpoints (optional)
327
  gen_kwargs["no_repeat_ngram_size"] = 3
328
  gen_kwargs["repetition_penalty"] = 1.05
@@ -376,6 +382,217 @@ class HFStreamingSummarizer:
376
  "error": "HF summarization failed. See server logs for traceback.",
377
  }
378
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  async def check_health(self) -> bool:
380
  """
381
  Check if the HuggingFace model is properly initialized and ready.
 
164
  max_new_tokens: int = None,
165
  temperature: float = None,
166
  top_p: float = None,
167
+ prompt: str = "Summarize the key points concisely:",
168
  ) -> AsyncGenerator[Dict[str, Any], None]:
169
  """
170
  Stream text summarization using HuggingFace's TextIteratorStreamer.
 
194
 
195
  logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
196
 
197
+ # Check if text is long enough to require recursive summarization
198
+ if text_length > 1500:
199
+ logger.info(f"Text is long ({text_length} chars), using recursive summarization")
200
+ async for chunk in self._recursive_summarize(text, max_new_tokens, temperature, top_p, prompt):
201
+ yield chunk
202
+ return
203
+
204
  try:
205
  # Use provided parameters or sensible defaults
206
+ # For short texts, aim for concise summaries (60-100 tokens)
207
+ max_new_tokens = max_new_tokens or max(getattr(settings, "hf_max_new_tokens", 0) or 0, 80)
208
+ temperature = temperature or getattr(settings, "hf_temperature", 0.3)
209
+ top_p = top_p or getattr(settings, "hf_top_p", 0.9)
 
210
 
211
  # Determine a generous encoder max length (respect tokenizer.model_max_length)
212
  model_max = getattr(self.tokenizer, "model_max_length", 1024)
 
325
  gen_kwargs["num_return_sequences"] = 1
326
  gen_kwargs["num_beams"] = 1
327
  gen_kwargs["num_beam_groups"] = 1
328
+ # Set conservative min_new_tokens to prevent rambling
329
+ gen_kwargs["min_new_tokens"] = max(20, min(50, max_new_tokens // 4)) # floor ~20-50
330
+ # Use neutral length_penalty to avoid encouraging longer outputs
331
+ gen_kwargs["length_penalty"] = 1.0
332
  # Reduce premature EOS in some checkpoints (optional)
333
  gen_kwargs["no_repeat_ngram_size"] = 3
334
  gen_kwargs["repetition_penalty"] = 1.05
 
382
  "error": "HF summarization failed. See server logs for traceback.",
383
  }
384
 
385
+ async def _recursive_summarize(
386
+ self,
387
+ text: str,
388
+ max_new_tokens: int,
389
+ temperature: float,
390
+ top_p: float,
391
+ prompt: str,
392
+ ) -> AsyncGenerator[Dict[str, Any], None]:
393
+ """
394
+ Recursively summarize long text by chunking and summarizing each chunk,
395
+ then summarizing the summaries if there are multiple chunks.
396
+ """
397
+ try:
398
+ # Split text into chunks of ~800-1000 tokens
399
+ chunks = _split_into_chunks(text, chunk_chars=4000, overlap=400)
400
+ logger.info(f"Split long text into {len(chunks)} chunks for recursive summarization")
401
+
402
+ chunk_summaries = []
403
+
404
+ # Summarize each chunk
405
+ for i, chunk in enumerate(chunks):
406
+ logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
407
+
408
+ # Use smaller max_new_tokens for individual chunks
409
+ chunk_max_tokens = min(max_new_tokens, 80)
410
+
411
+ chunk_summary = ""
412
+ async for chunk_result in self._single_chunk_summarize(
413
+ chunk, chunk_max_tokens, temperature, top_p, prompt
414
+ ):
415
+ if chunk_result.get("content"):
416
+ chunk_summary += chunk_result["content"]
417
+ yield chunk_result # Stream each chunk's summary
418
+
419
+ chunk_summaries.append(chunk_summary.strip())
420
+
421
+ # If we have multiple chunks, create a final summary of summaries
422
+ if len(chunk_summaries) > 1:
423
+ logger.info("Creating final summary of summaries")
424
+ combined_summaries = "\n\n".join(chunk_summaries)
425
+
426
+ # Use original max_new_tokens for final summary
427
+ async for final_result in self._single_chunk_summarize(
428
+ combined_summaries, max_new_tokens, temperature, top_p,
429
+ "Summarize the key points from these summaries:"
430
+ ):
431
+ yield final_result
432
+ else:
433
+ # Single chunk, just yield the done signal
434
+ yield {
435
+ "content": "",
436
+ "done": True,
437
+ "tokens_used": 0,
438
+ }
439
+
440
+ except Exception as e:
441
+ logger.exception("❌ Recursive summarization failed")
442
+ yield {
443
+ "content": "",
444
+ "done": True,
445
+ "error": f"Recursive summarization failed: {str(e)}",
446
+ }
447
+
448
+ async def _single_chunk_summarize(
449
+ self,
450
+ text: str,
451
+ max_new_tokens: int,
452
+ temperature: float,
453
+ top_p: float,
454
+ prompt: str,
455
+ ) -> AsyncGenerator[Dict[str, Any], None]:
456
+ """
457
+ Summarize a single chunk of text using the same logic as the main method
458
+ but without the recursive check.
459
+ """
460
+ if not self.model or not self.tokenizer:
461
+ error_msg = "HuggingFace model not available. Please check model initialization."
462
+ logger.error(f"❌ {error_msg}")
463
+ yield {
464
+ "content": "",
465
+ "done": True,
466
+ "error": error_msg,
467
+ }
468
+ return
469
+
470
+ try:
471
+ # Use provided parameters or sensible defaults
472
+ max_new_tokens = max_new_tokens or 80
473
+ temperature = temperature or 0.3
474
+ top_p = top_p or 0.9
475
+
476
+ # Determine encoder max length
477
+ model_max = getattr(self.tokenizer, "model_max_length", 1024)
478
+ if not isinstance(model_max, int) or model_max <= 0:
479
+ model_max = 1024
480
+ enc_max_len = min(model_max, 2048)
481
+
482
+ # Build tokenized inputs
483
+ if "t5" in settings.hf_model_id.lower():
484
+ full_prompt = f"summarize: {text}"
485
+ inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=enc_max_len, truncation=True)
486
+ elif "bart" in settings.hf_model_id.lower():
487
+ inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=enc_max_len, truncation=True)
488
+ else:
489
+ messages = [
490
+ {"role": "system", "content": prompt},
491
+ {"role": "user", "content": text}
492
+ ]
493
+
494
+ if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
495
+ inputs_raw = self.tokenizer.apply_chat_template(
496
+ messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
497
+ )
498
+ else:
499
+ full_prompt = f"{prompt}\n\n{text}"
500
+ inputs_raw = self.tokenizer(full_prompt, return_tensors="pt")
501
+
502
+ # Normalize inputs (same logic as main method)
503
+ if isinstance(inputs_raw, (dict, BatchEncoding)):
504
+ try:
505
+ inputs = dict(inputs_raw)
506
+ except Exception:
507
+ inputs = dict(getattr(inputs_raw, "data", {}))
508
+ else:
509
+ inputs = {"input_ids": inputs_raw}
510
+
511
+ if "attention_mask" not in inputs and "input_ids" in inputs:
512
+ if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(inputs["input_ids"], torch.Tensor):
513
+ inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
514
+
515
+ def _to_singleton_batch(d):
516
+ out = {}
517
+ for k, v in d.items():
518
+ if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor):
519
+ if v.dim() == 1:
520
+ out[k] = v.unsqueeze(0)
521
+ elif v.dim() >= 2:
522
+ out[k] = v[:1]
523
+ else:
524
+ out[k] = v
525
+ else:
526
+ out[k] = v
527
+ return out
528
+
529
+ inputs = _to_singleton_batch(inputs)
530
+
531
+ # Validate pad/eos ids
532
+ pad_id = self.tokenizer.pad_token_id
533
+ eos_id = self.tokenizer.eos_token_id
534
+ if pad_id is None and eos_id is not None:
535
+ pad_id = eos_id
536
+ elif pad_id is None and eos_id is None:
537
+ pad_id = 0
538
+
539
+ # Create streamer
540
+ streamer = TextIteratorStreamer(
541
+ self.tokenizer,
542
+ skip_prompt=True,
543
+ skip_special_tokens=True
544
+ )
545
+
546
+ gen_kwargs = {
547
+ **inputs,
548
+ "streamer": streamer,
549
+ "max_new_tokens": max_new_tokens,
550
+ "do_sample": True,
551
+ "temperature": temperature,
552
+ "top_p": top_p,
553
+ "pad_token_id": pad_id,
554
+ "eos_token_id": eos_id,
555
+ "num_return_sequences": 1,
556
+ "num_beams": 1,
557
+ "num_beam_groups": 1,
558
+ "min_new_tokens": max(20, min(50, max_new_tokens // 4)),
559
+ "length_penalty": 1.0,
560
+ "no_repeat_ngram_size": 3,
561
+ "repetition_penalty": 1.05,
562
+ }
563
+
564
+ generation_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
565
+ generation_thread.start()
566
+
567
+ # Stream tokens as they arrive
568
+ token_count = 0
569
+ for text_chunk in streamer:
570
+ if text_chunk:
571
+ yield {
572
+ "content": text_chunk,
573
+ "done": False,
574
+ "tokens_used": token_count,
575
+ }
576
+ token_count += 1
577
+
578
+ # Wait for generation to complete
579
+ generation_thread.join()
580
+
581
+ # Send final "done" chunk
582
+ yield {
583
+ "content": "",
584
+ "done": True,
585
+ "tokens_used": token_count,
586
+ }
587
+
588
+ except Exception:
589
+ logger.exception("❌ Single chunk summarization failed")
590
+ yield {
591
+ "content": "",
592
+ "done": True,
593
+ "error": "Single chunk summarization failed. See server logs for traceback.",
594
+ }
595
+
596
  async def check_health(self) -> bool:
597
  """
598
  Check if the HuggingFace model is properly initialized and ready.
tests/test_hf_streaming_improvements.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for HuggingFace streaming summarizer improvements.
3
+ """
4
+ import pytest
5
+ from unittest.mock import AsyncMock, patch, MagicMock
6
+ from app.services.hf_streaming_summarizer import HFStreamingSummarizer, _split_into_chunks
7
+
8
+
9
+ class TestSplitIntoChunks:
10
+ """Test the text chunking utility function."""
11
+
12
+ def test_split_short_text(self):
13
+ """Test splitting short text that doesn't need chunking."""
14
+ text = "This is a short text."
15
+ chunks = _split_into_chunks(text, chunk_chars=100, overlap=20)
16
+
17
+ assert len(chunks) == 1
18
+ assert chunks[0] == text
19
+
20
+ def test_split_long_text(self):
21
+ """Test splitting long text into multiple chunks."""
22
+ text = "This is a longer text. " * 50 # ~1000 chars
23
+ chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)
24
+
25
+ assert len(chunks) > 1
26
+ # All chunks should be within reasonable size
27
+ for chunk in chunks:
28
+ assert len(chunk) <= 200
29
+ assert len(chunk) > 0
30
+
31
+ def test_chunk_overlap(self):
32
+ """Test that chunks have proper overlap."""
33
+ text = "This is a test text for overlap testing. " * 20 # ~800 chars
34
+ chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)
35
+
36
+ if len(chunks) > 1:
37
+ # Check that consecutive chunks share some content
38
+ for i in range(len(chunks) - 1):
39
+ # There should be some overlap between consecutive chunks
40
+ assert len(chunks[i]) > 0
41
+ assert len(chunks[i+1]) > 0
42
+
43
+ def test_empty_text(self):
44
+ """Test splitting empty text."""
45
+ chunks = _split_into_chunks("", chunk_chars=100, overlap=20)
46
+ assert len(chunks) == 0 # Empty text returns empty list
47
+
48
+
49
+ class TestHFStreamingSummarizerImprovements:
50
+ """Test improvements to HFStreamingSummarizer."""
51
+
52
+ @pytest.fixture
53
+ def mock_summarizer(self):
54
+ """Create a mock HFStreamingSummarizer for testing."""
55
+ summarizer = HFStreamingSummarizer()
56
+ summarizer.model = MagicMock()
57
+ summarizer.tokenizer = MagicMock()
58
+ return summarizer
59
+
60
+ @pytest.mark.asyncio
61
+ async def test_recursive_summarization_long_text(self, mock_summarizer):
62
+ """Test recursive summarization for long text."""
63
+ # Mock the _single_chunk_summarize method
64
+ async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
65
+ yield {"content": f"Summary of: {text[:50]}...", "done": False, "tokens_used": 10}
66
+ yield {"content": "", "done": True, "tokens_used": 10}
67
+
68
+ mock_summarizer._single_chunk_summarize = mock_single_chunk
69
+
70
+ # Long text (>1500 chars)
71
+ long_text = "This is a very long text that should trigger recursive summarization. " * 30 # ~2000+ chars
72
+
73
+ results = []
74
+ async for chunk in mock_summarizer._recursive_summarize(
75
+ long_text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
76
+ ):
77
+ results.append(chunk)
78
+
79
+ # Should have multiple chunks (one for each text chunk + final summary)
80
+ assert len(results) > 2 # At least 2 chunks + final done signal
81
+
82
+ # Check that we get proper streaming format
83
+ content_chunks = [r for r in results if r.get("content") and not r.get("done")]
84
+ assert len(content_chunks) > 0
85
+
86
+ # Should end with done signal
87
+ final_chunk = results[-1]
88
+ assert final_chunk.get("done") is True
89
+
90
+ @pytest.mark.asyncio
91
+ async def test_recursive_summarization_single_chunk(self, mock_summarizer):
92
+ """Test recursive summarization when text fits in single chunk."""
93
+ # Mock the _single_chunk_summarize method
94
+ async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
95
+ yield {"content": "Single chunk summary", "done": False, "tokens_used": 5}
96
+ yield {"content": "", "done": True, "tokens_used": 5}
97
+
98
+ mock_summarizer._single_chunk_summarize = mock_single_chunk
99
+
100
+ # Text that would fit in single chunk after splitting
101
+ text = "This is a medium length text. " * 20 # ~600 chars
102
+
103
+ results = []
104
+ async for chunk in mock_summarizer._recursive_summarize(
105
+ text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
106
+ ):
107
+ results.append(chunk)
108
+
109
+ # Should have at least 2 chunks (content + done)
110
+ assert len(results) >= 2
111
+
112
+ # Should end with done signal
113
+ final_chunk = results[-1]
114
+ assert final_chunk.get("done") is True
115
+
116
+ @pytest.mark.asyncio
117
+ async def test_single_chunk_summarize_parameters(self, mock_summarizer):
118
+ """Test that _single_chunk_summarize uses correct parameters."""
119
+ # Mock the tokenizer and model
120
+ mock_summarizer.tokenizer.model_max_length = 1024
121
+ mock_summarizer.tokenizer.pad_token_id = 0
122
+ mock_summarizer.tokenizer.eos_token_id = 1
123
+
124
+ # Mock the model generation
125
+ mock_streamer = MagicMock()
126
+ mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))
127
+
128
+ with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
129
+ with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
130
+ mock_settings.hf_model_id = "test-model"
131
+
132
+ results = []
133
+ async for chunk in mock_summarizer._single_chunk_summarize(
134
+ "Test text", max_new_tokens=80, temperature=0.3, top_p=0.9, prompt="Test prompt"
135
+ ):
136
+ results.append(chunk)
137
+
138
+ # Should have content chunks + final done
139
+ assert len(results) >= 2
140
+
141
+ # Check that generation was called with correct parameters
142
+ mock_summarizer.model.generate.assert_called_once()
143
+ call_kwargs = mock_summarizer.model.generate.call_args[1]
144
+
145
+ assert call_kwargs["max_new_tokens"] == 80
146
+ assert call_kwargs["temperature"] == 0.3
147
+ assert call_kwargs["top_p"] == 0.9
148
+ assert call_kwargs["length_penalty"] == 1.0 # Should be neutral
149
+ assert call_kwargs["min_new_tokens"] <= 50 # Should be conservative
150
+
151
+ @pytest.mark.asyncio
152
+ async def test_single_chunk_summarize_defaults(self, mock_summarizer):
153
+ """Test that _single_chunk_summarize uses correct defaults."""
154
+ # Mock the tokenizer and model
155
+ mock_summarizer.tokenizer.model_max_length = 1024
156
+ mock_summarizer.tokenizer.pad_token_id = 0
157
+ mock_summarizer.tokenizer.eos_token_id = 1
158
+
159
+ # Mock the model generation
160
+ mock_streamer = MagicMock()
161
+ mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))
162
+
163
+ with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
164
+ with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
165
+ mock_settings.hf_model_id = "test-model"
166
+
167
+ results = []
168
+ async for chunk in mock_summarizer._single_chunk_summarize(
169
+ "Test text", max_new_tokens=None, temperature=None, top_p=None, prompt="Test prompt"
170
+ ):
171
+ results.append(chunk)
172
+
173
+ # Check that generation was called with correct defaults
174
+ mock_summarizer.model.generate.assert_called_once()
175
+ call_kwargs = mock_summarizer.model.generate.call_args[1]
176
+
177
+ assert call_kwargs["max_new_tokens"] == 80 # Default
178
+ assert call_kwargs["temperature"] == 0.3 # Default
179
+ assert call_kwargs["top_p"] == 0.9 # Default
180
+
181
+ @pytest.mark.asyncio
182
+ async def test_recursive_summarization_error_handling(self, mock_summarizer):
183
+ """Test error handling in recursive summarization."""
184
+ # Mock _single_chunk_summarize to raise an exception
185
+ async def mock_single_chunk_error(text, max_tokens, temp, top_p, prompt):
186
+ raise Exception("Test error")
187
+ yield # This line will never be reached, but makes it an async generator
188
+
189
+ mock_summarizer._single_chunk_summarize = mock_single_chunk_error
190
+
191
+ long_text = "This is a long text. " * 30
192
+
193
+ results = []
194
+ async for chunk in mock_summarizer._recursive_summarize(
195
+ long_text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
196
+ ):
197
+ results.append(chunk)
198
+
199
+ # Should have error chunk
200
+ assert len(results) == 1
201
+ error_chunk = results[0]
202
+ assert error_chunk.get("done") is True
203
+ assert "error" in error_chunk
204
+ assert "Test error" in error_chunk["error"]
205
+
206
+ @pytest.mark.asyncio
207
+ async def test_single_chunk_summarize_error_handling(self, mock_summarizer):
208
+ """Test error handling in single chunk summarization."""
209
+ # Mock model to raise exception
210
+ mock_summarizer.model.generate.side_effect = Exception("Generation error")
211
+
212
+ results = []
213
+ async for chunk in mock_summarizer._single_chunk_summarize(
214
+ "Test text", max_new_tokens=80, temperature=0.3, top_p=0.9, prompt="Test prompt"
215
+ ):
216
+ results.append(chunk)
217
+
218
+ # Should have error chunk
219
+ assert len(results) == 1
220
+ error_chunk = results[0]
221
+ assert error_chunk.get("done") is True
222
+ assert "error" in error_chunk
223
+ assert "Generation error" in error_chunk["error"]
224
+
225
+
226
+ class TestHFStreamingSummarizerIntegration:
227
+ """Integration tests for HFStreamingSummarizer improvements."""
228
+
229
+ @pytest.mark.asyncio
230
+ async def test_summarize_text_stream_long_text_detection(self):
231
+ """Test that summarize_text_stream detects long text and uses recursive summarization."""
232
+ summarizer = HFStreamingSummarizer()
233
+
234
+ # Mock the recursive summarization method
235
+ async def mock_recursive(text, max_tokens, temp, top_p, prompt):
236
+ yield {"content": "Recursive summary", "done": False, "tokens_used": 10}
237
+ yield {"content": "", "done": True, "tokens_used": 10}
238
+
239
+ summarizer._recursive_summarize = mock_recursive
240
+
241
+ # Long text (>1500 chars)
242
+ long_text = "This is a very long text. " * 60 # ~1500+ chars
243
+
244
+ results = []
245
+ async for chunk in summarizer.summarize_text_stream(long_text):
246
+ results.append(chunk)
247
+
248
+ # Should have used recursive summarization
249
+ assert len(results) >= 2
250
+ assert results[0]["content"] == "Recursive summary"
251
+ assert results[-1]["done"] is True
252
+
253
+ @pytest.mark.asyncio
254
+ async def test_summarize_text_stream_short_text_normal_flow(self):
255
+ """Test that summarize_text_stream uses normal flow for short text."""
256
+ summarizer = HFStreamingSummarizer()
257
+
258
+ # Mock model and tokenizer
259
+ summarizer.model = MagicMock()
260
+ summarizer.tokenizer = MagicMock()
261
+ summarizer.tokenizer.model_max_length = 1024
262
+ summarizer.tokenizer.pad_token_id = 0
263
+ summarizer.tokenizer.eos_token_id = 1
264
+
265
+ # Mock the streamer
266
+ mock_streamer = MagicMock()
267
+ mock_streamer.__iter__ = MagicMock(return_value=iter(["short", "summary"]))
268
+
269
+ with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
270
+ with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
271
+ mock_settings.hf_model_id = "test-model"
272
+ mock_settings.hf_temperature = 0.3
273
+ mock_settings.hf_top_p = 0.9
274
+
275
+ # Short text (<1500 chars)
276
+ short_text = "This is a short text."
277
+
278
+ results = []
279
+ async for chunk in summarizer.summarize_text_stream(short_text):
280
+ results.append(chunk)
281
+
282
+ # Should have used normal flow (not recursive)
283
+ assert len(results) >= 2
284
+ assert results[0]["content"] == "short"
285
+ assert results[1]["content"] == "summary"
286
+ assert results[-1]["done"] is True
tests/test_schemas.py CHANGED
@@ -15,7 +15,7 @@ class TestSummarizeRequest:
15
 
16
  assert request.text == sample_text.strip()
17
  assert request.max_tokens == 256
18
- assert request.prompt == "Summarize the following text concisely:"
19
 
20
  def test_custom_parameters(self):
21
  """Test request with custom parameters."""
@@ -73,6 +73,57 @@ class TestSummarizeRequest:
73
  long_prompt = "x" * 501
74
  with pytest.raises(ValidationError):
75
  SummarizeRequest(text="test", prompt=long_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  class TestSummarizeResponse:
 
15
 
16
  assert request.text == sample_text.strip()
17
  assert request.max_tokens == 256
18
+ assert request.prompt == "Summarize the key points concisely:"
19
 
20
  def test_custom_parameters(self):
21
  """Test request with custom parameters."""
 
73
  long_prompt = "x" * 501
74
  with pytest.raises(ValidationError):
75
  SummarizeRequest(text="test", prompt=long_prompt)
76
+
77
+ def test_temperature_parameter(self):
78
+ """Test temperature parameter validation."""
79
+ # Valid temperature values
80
+ request = SummarizeRequest(text="test", temperature=0.0)
81
+ assert request.temperature == 0.0
82
+
83
+ request = SummarizeRequest(text="test", temperature=2.0)
84
+ assert request.temperature == 2.0
85
+
86
+ request = SummarizeRequest(text="test", temperature=0.3)
87
+ assert request.temperature == 0.3
88
+
89
+ # Default temperature
90
+ request = SummarizeRequest(text="test")
91
+ assert request.temperature == 0.3
92
+
93
+ # Invalid temperature values
94
+ with pytest.raises(ValidationError):
95
+ SummarizeRequest(text="test", temperature=-0.1)
96
+
97
+ with pytest.raises(ValidationError):
98
+ SummarizeRequest(text="test", temperature=2.1)
99
+
100
+ def test_top_p_parameter(self):
101
+ """Test top_p parameter validation."""
102
+ # Valid top_p values
103
+ request = SummarizeRequest(text="test", top_p=0.0)
104
+ assert request.top_p == 0.0
105
+
106
+ request = SummarizeRequest(text="test", top_p=1.0)
107
+ assert request.top_p == 1.0
108
+
109
+ request = SummarizeRequest(text="test", top_p=0.9)
110
+ assert request.top_p == 0.9
111
+
112
+ # Default top_p
113
+ request = SummarizeRequest(text="test")
114
+ assert request.top_p == 0.9
115
+
116
+ # Invalid top_p values
117
+ with pytest.raises(ValidationError):
118
+ SummarizeRequest(text="test", top_p=-0.1)
119
+
120
+ with pytest.raises(ValidationError):
121
+ SummarizeRequest(text="test", top_p=1.1)
122
+
123
+ def test_updated_default_prompt(self):
124
+ """Test that the default prompt has been updated to be more concise."""
125
+ request = SummarizeRequest(text="test")
126
+ assert request.prompt == "Summarize the key points concisely:"
127
 
128
 
129
  class TestSummarizeResponse:
tests/test_v2_api.py CHANGED
@@ -155,6 +155,144 @@ class TestV2SummarizeStream:
155
  assert call_args[1]['prompt'] == "Custom prompt"
156
  assert call_args[1]['text'] == "Test text"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  class TestV2APICompatibility:
160
  """Test V2 API compatibility with V1."""
 
155
  assert call_args[1]['prompt'] == "Custom prompt"
156
  assert call_args[1]['text'] == "Test text"
157
 
158
+ @pytest.mark.integration
159
+ def test_v2_adaptive_token_logic_short_text(self, client: TestClient):
160
+ """Test adaptive token logic for short texts (<1500 chars)."""
161
+ with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
162
+ async def mock_generator():
163
+ yield {"content": "", "done": True}
164
+
165
+ mock_stream.return_value = mock_generator()
166
+
167
+ # Short text (500 chars)
168
+ short_text = "This is a short text. " * 20 # ~500 chars
169
+
170
+ response = client.post(
171
+ "/api/v2/summarize/stream",
172
+ json={
173
+ "text": short_text,
174
+ # Don't specify max_tokens to test adaptive logic
175
+ }
176
+ )
177
+
178
+ assert response.status_code == 200
179
+
180
+ # Verify service was called with adaptive max_new_tokens
181
+ mock_stream.assert_called_once()
182
+ call_args = mock_stream.call_args
183
+
184
+ # For short text, should use 60-100 tokens
185
+ max_new_tokens = call_args[1]['max_new_tokens']
186
+ assert 60 <= max_new_tokens <= 100
187
+
188
+ @pytest.mark.integration
189
+ def test_v2_adaptive_token_logic_long_text(self, client: TestClient):
190
+ """Test adaptive token logic for long texts (>1500 chars)."""
191
+ with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
192
+ async def mock_generator():
193
+ yield {"content": "", "done": True}
194
+
195
+ mock_stream.return_value = mock_generator()
196
+
197
+ # Long text (2000 chars)
198
+ long_text = "This is a longer text that should trigger adaptive token logic. " * 40 # ~2000 chars
199
+
200
+ response = client.post(
201
+ "/api/v2/summarize/stream",
202
+ json={
203
+ "text": long_text,
204
+ # Don't specify max_tokens to test adaptive logic
205
+ }
206
+ )
207
+
208
+ assert response.status_code == 200
209
+
210
+ # Verify service was called with adaptive max_new_tokens
211
+ mock_stream.assert_called_once()
212
+ call_args = mock_stream.call_args
213
+
214
+ # For long text, should use proportional scaling but capped
215
+ max_new_tokens = call_args[1]['max_new_tokens']
216
+ assert 100 <= max_new_tokens <= 400
217
+
218
+ @pytest.mark.integration
219
+ def test_v2_temperature_and_top_p_parameters(self, client: TestClient):
220
+ """Test that temperature and top_p parameters are passed correctly."""
221
+ with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
222
+ async def mock_generator():
223
+ yield {"content": "", "done": True}
224
+
225
+ mock_stream.return_value = mock_generator()
226
+
227
+ response = client.post(
228
+ "/api/v2/summarize/stream",
229
+ json={
230
+ "text": "Test text",
231
+ "temperature": 0.5,
232
+ "top_p": 0.8
233
+ }
234
+ )
235
+
236
+ assert response.status_code == 200
237
+
238
+ # Verify service was called with correct parameters
239
+ mock_stream.assert_called_once()
240
+ call_args = mock_stream.call_args
241
+
242
+ assert call_args[1]['temperature'] == 0.5
243
+ assert call_args[1]['top_p'] == 0.8
244
+
245
+ @pytest.mark.integration
246
+ def test_v2_default_temperature_and_top_p(self, client: TestClient):
247
+ """Test that default temperature and top_p values are used when not specified."""
248
+ with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
249
+ async def mock_generator():
250
+ yield {"content": "", "done": True}
251
+
252
+ mock_stream.return_value = mock_generator()
253
+
254
+ response = client.post(
255
+ "/api/v2/summarize/stream",
256
+ json={
257
+ "text": "Test text"
258
+ # Don't specify temperature or top_p
259
+ }
260
+ )
261
+
262
+ assert response.status_code == 200
263
+
264
+ # Verify service was called with default parameters
265
+ mock_stream.assert_called_once()
266
+ call_args = mock_stream.call_args
267
+
268
+ assert call_args[1]['temperature'] == 0.3 # Default temperature
269
+ assert call_args[1]['top_p'] == 0.9 # Default top_p
270
+
271
+ @pytest.mark.integration
272
+ def test_v2_recursive_summarization_trigger(self, client: TestClient):
273
+ """Test that recursive summarization is triggered for long texts."""
274
+ with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
275
+ async def mock_generator():
276
+ yield {"content": "", "done": True}
277
+
278
+ mock_stream.return_value = mock_generator()
279
+
280
+ # Very long text (>1500 chars) to trigger recursive summarization
281
+ very_long_text = "This is a very long text that should definitely trigger recursive summarization logic. " * 30 # ~2000+ chars
282
+
283
+ response = client.post(
284
+ "/api/v2/summarize/stream",
285
+ json={
286
+ "text": very_long_text
287
+ }
288
+ )
289
+
290
+ assert response.status_code == 200
291
+
292
+ # The service should be called, and internally it should detect long text
293
+ # and use recursive summarization
294
+ mock_stream.assert_called_once()
295
+
296
 
297
  class TestV2APICompatibility:
298
  """Test V2 API compatibility with V1."""