Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		ming
		
	commited on
		
		
					Commit 
							
							·
						
						9884884
	
1
								Parent(s):
							
							698636a
								
Improve V2 summarization: adaptive tokens, recursive summarization, better defaults
Browse files- Add temperature and top_p parameters to SummarizeRequest schema (defaults: 0.3, 0.9)
- Implement adaptive max_new_tokens logic: 60-100 for short texts, proportional for long texts
- Add recursive summarization for texts >1500 chars with chunking and summary-of-summaries
- Fix generation parameters to prevent rambling (reduce min_new_tokens, neutral length_penalty)
- Update default prompt to be more concise
- Add comprehensive unit tests for all improvements (40 tests)
- Fix V2 API to generate concise summaries instead of rambling output
- app/api/v1/schemas.py +3 -1
- app/api/v2/summarize.py +20 -4
- app/services/hf_streaming_summarizer.py +227 -10
- tests/test_hf_streaming_improvements.py +286 -0
- tests/test_schemas.py +52 -1
- tests/test_v2_api.py +138 -0
    	
        app/api/v1/schemas.py
    CHANGED
    
    | @@ -10,8 +10,10 @@ class SummarizeRequest(BaseModel): | |
| 10 |  | 
| 11 | 
             
                text: str = Field(..., min_length=1, max_length=32000, description="Text to summarize")
         | 
| 12 | 
             
                max_tokens: Optional[int] = Field(default=256, ge=1, le=2048, description="Maximum tokens for summary")
         | 
|  | |
|  | |
| 13 | 
             
                prompt: Optional[str] = Field(
         | 
| 14 | 
            -
                    default=" | 
| 15 | 
             
                    max_length=500,
         | 
| 16 | 
             
                    description="Custom prompt for summarization"
         | 
| 17 | 
             
                )
         | 
|  | |
| 10 |  | 
| 11 | 
             
                text: str = Field(..., min_length=1, max_length=32000, description="Text to summarize")
         | 
| 12 | 
             
                max_tokens: Optional[int] = Field(default=256, ge=1, le=2048, description="Maximum tokens for summary")
         | 
| 13 | 
            +
                temperature: Optional[float] = Field(default=0.3, ge=0.0, le=2.0, description="Sampling temperature for generation")
         | 
| 14 | 
            +
                top_p: Optional[float] = Field(default=0.9, ge=0.0, le=1.0, description="Nucleus sampling parameter")
         | 
| 15 | 
             
                prompt: Optional[str] = Field(
         | 
| 16 | 
            +
                    default="Summarize the key points concisely:",
         | 
| 17 | 
             
                    max_length=500,
         | 
| 18 | 
             
                    description="Custom prompt for summarization"
         | 
| 19 | 
             
                )
         | 
    	
        app/api/v2/summarize.py
    CHANGED
    
    | @@ -27,12 +27,28 @@ async def summarize_stream(payload: SummarizeRequest): | |
| 27 | 
             
            async def _stream_generator(payload: SummarizeRequest):
         | 
| 28 | 
             
                """Generator function for streaming SSE responses using HuggingFace."""
         | 
| 29 | 
             
                try:
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 30 | 
             
                    async for chunk in hf_streaming_service.summarize_text_stream(
         | 
| 31 | 
             
                        text=payload.text,
         | 
| 32 | 
            -
                        max_new_tokens= | 
| 33 | 
            -
                        temperature= | 
| 34 | 
            -
                        top_p= | 
| 35 | 
            -
                        prompt=payload.prompt | 
| 36 | 
             
                    ):
         | 
| 37 | 
             
                        # Format as SSE event (same format as V1)
         | 
| 38 | 
             
                        sse_data = json.dumps(chunk)
         | 
|  | |
| 27 | 
             
            async def _stream_generator(payload: SummarizeRequest):
         | 
| 28 | 
             
                """Generator function for streaming SSE responses using HuggingFace."""
         | 
| 29 | 
             
                try:
         | 
| 30 | 
            +
                    # Calculate adaptive max_new_tokens based on text length
         | 
| 31 | 
            +
                    text_length = len(payload.text)
         | 
| 32 | 
            +
                    if text_length < 1500:
         | 
| 33 | 
            +
                        # Short texts: use 60-100 tokens
         | 
| 34 | 
            +
                        adaptive_max_tokens = min(100, max(60, text_length // 15))
         | 
| 35 | 
            +
                    else:
         | 
| 36 | 
            +
                        # Longer texts: scale proportionally but cap appropriately
         | 
| 37 | 
            +
                        adaptive_max_tokens = min(400, max(100, text_length // 20))
         | 
| 38 | 
            +
                    
         | 
| 39 | 
            +
                    # Use adaptive calculation by default, but allow user override
         | 
| 40 | 
            +
                    # Check if max_tokens was explicitly provided (not just the default 256)
         | 
| 41 | 
            +
                    if hasattr(payload, 'model_fields_set') and 'max_tokens' in payload.model_fields_set:
         | 
| 42 | 
            +
                        max_new_tokens = payload.max_tokens
         | 
| 43 | 
            +
                    else:
         | 
| 44 | 
            +
                        max_new_tokens = adaptive_max_tokens
         | 
| 45 | 
            +
                    
         | 
| 46 | 
             
                    async for chunk in hf_streaming_service.summarize_text_stream(
         | 
| 47 | 
             
                        text=payload.text,
         | 
| 48 | 
            +
                        max_new_tokens=max_new_tokens,
         | 
| 49 | 
            +
                        temperature=payload.temperature,  # Use user-provided temperature
         | 
| 50 | 
            +
                        top_p=payload.top_p,  # Use user-provided top_p
         | 
| 51 | 
            +
                        prompt=payload.prompt,
         | 
| 52 | 
             
                    ):
         | 
| 53 | 
             
                        # Format as SSE event (same format as V1)
         | 
| 54 | 
             
                        sse_data = json.dumps(chunk)
         | 
    	
        app/services/hf_streaming_summarizer.py
    CHANGED
    
    | @@ -164,7 +164,7 @@ class HFStreamingSummarizer: | |
| 164 | 
             
                    max_new_tokens: int = None,
         | 
| 165 | 
             
                    temperature: float = None,
         | 
| 166 | 
             
                    top_p: float = None,
         | 
| 167 | 
            -
                    prompt: str = " | 
| 168 | 
             
                ) -> AsyncGenerator[Dict[str, Any], None]:
         | 
| 169 | 
             
                    """
         | 
| 170 | 
             
                    Stream text summarization using HuggingFace's TextIteratorStreamer.
         | 
| @@ -194,13 +194,19 @@ class HFStreamingSummarizer: | |
| 194 |  | 
| 195 | 
             
                    logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
         | 
| 196 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 197 | 
             
                    try:
         | 
| 198 | 
             
                        # Use provided parameters or sensible defaults
         | 
| 199 | 
            -
                        #  | 
| 200 | 
            -
                         | 
| 201 | 
            -
                         | 
| 202 | 
            -
                         | 
| 203 | 
            -
                        top_p = top_p or settings.hf_top_p
         | 
| 204 |  | 
| 205 | 
             
                        # Determine a generous encoder max length (respect tokenizer.model_max_length)
         | 
| 206 | 
             
                        model_max = getattr(self.tokenizer, "model_max_length", 1024)
         | 
| @@ -319,10 +325,10 @@ class HFStreamingSummarizer: | |
| 319 | 
             
                        gen_kwargs["num_return_sequences"] = 1
         | 
| 320 | 
             
                        gen_kwargs["num_beams"] = 1
         | 
| 321 | 
             
                        gen_kwargs["num_beam_groups"] = 1
         | 
| 322 | 
            -
                        #  | 
| 323 | 
            -
                        gen_kwargs["min_new_tokens"] = max( | 
| 324 | 
            -
                        # length_penalty  | 
| 325 | 
            -
                        gen_kwargs["length_penalty"] = 1. | 
| 326 | 
             
                        # Reduce premature EOS in some checkpoints (optional)
         | 
| 327 | 
             
                        gen_kwargs["no_repeat_ngram_size"] = 3
         | 
| 328 | 
             
                        gen_kwargs["repetition_penalty"] = 1.05
         | 
| @@ -376,6 +382,217 @@ class HFStreamingSummarizer: | |
| 376 | 
             
                            "error": "HF summarization failed. See server logs for traceback.",
         | 
| 377 | 
             
                        }
         | 
| 378 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 379 | 
             
                async def check_health(self) -> bool:
         | 
| 380 | 
             
                    """
         | 
| 381 | 
             
                    Check if the HuggingFace model is properly initialized and ready.
         | 
|  | |
| 164 | 
             
                    max_new_tokens: int = None,
         | 
| 165 | 
             
                    temperature: float = None,
         | 
| 166 | 
             
                    top_p: float = None,
         | 
| 167 | 
            +
                    prompt: str = "Summarize the key points concisely:",
         | 
| 168 | 
             
                ) -> AsyncGenerator[Dict[str, Any], None]:
         | 
| 169 | 
             
                    """
         | 
| 170 | 
             
                    Stream text summarization using HuggingFace's TextIteratorStreamer.
         | 
|  | |
| 194 |  | 
| 195 | 
             
                    logger.info(f"Processing text of {text_length} chars with HuggingFace model: {settings.hf_model_id}")
         | 
| 196 |  | 
| 197 | 
            +
                    # Check if text is long enough to require recursive summarization
         | 
| 198 | 
            +
                    if text_length > 1500:
         | 
| 199 | 
            +
                        logger.info(f"Text is long ({text_length} chars), using recursive summarization")
         | 
| 200 | 
            +
                        async for chunk in self._recursive_summarize(text, max_new_tokens, temperature, top_p, prompt):
         | 
| 201 | 
            +
                            yield chunk
         | 
| 202 | 
            +
                        return
         | 
| 203 | 
            +
                    
         | 
| 204 | 
             
                    try:
         | 
| 205 | 
             
                        # Use provided parameters or sensible defaults
         | 
| 206 | 
            +
                        # For short texts, aim for concise summaries (60-100 tokens)
         | 
| 207 | 
            +
                        max_new_tokens = max_new_tokens or max(getattr(settings, "hf_max_new_tokens", 0) or 0, 80)
         | 
| 208 | 
            +
                        temperature = temperature or getattr(settings, "hf_temperature", 0.3)
         | 
| 209 | 
            +
                        top_p = top_p or getattr(settings, "hf_top_p", 0.9)
         | 
|  | |
| 210 |  | 
| 211 | 
             
                        # Determine a generous encoder max length (respect tokenizer.model_max_length)
         | 
| 212 | 
             
                        model_max = getattr(self.tokenizer, "model_max_length", 1024)
         | 
|  | |
| 325 | 
             
                        gen_kwargs["num_return_sequences"] = 1
         | 
| 326 | 
             
                        gen_kwargs["num_beams"] = 1
         | 
| 327 | 
             
                        gen_kwargs["num_beam_groups"] = 1
         | 
| 328 | 
            +
                        # Set conservative min_new_tokens to prevent rambling
         | 
| 329 | 
            +
                        gen_kwargs["min_new_tokens"] = max(20, min(50, max_new_tokens // 4))  # floor ~20-50
         | 
| 330 | 
            +
                        # Use neutral length_penalty to avoid encouraging longer outputs
         | 
| 331 | 
            +
                        gen_kwargs["length_penalty"] = 1.0
         | 
| 332 | 
             
                        # Reduce premature EOS in some checkpoints (optional)
         | 
| 333 | 
             
                        gen_kwargs["no_repeat_ngram_size"] = 3
         | 
| 334 | 
             
                        gen_kwargs["repetition_penalty"] = 1.05
         | 
|  | |
| 382 | 
             
                            "error": "HF summarization failed. See server logs for traceback.",
         | 
| 383 | 
             
                        }
         | 
| 384 |  | 
| 385 | 
            +
                async def _recursive_summarize(
         | 
| 386 | 
            +
                    self,
         | 
| 387 | 
            +
                    text: str,
         | 
| 388 | 
            +
                    max_new_tokens: int,
         | 
| 389 | 
            +
                    temperature: float,
         | 
| 390 | 
            +
                    top_p: float,
         | 
| 391 | 
            +
                    prompt: str,
         | 
| 392 | 
            +
                ) -> AsyncGenerator[Dict[str, Any], None]:
         | 
| 393 | 
            +
                    """
         | 
| 394 | 
            +
                    Recursively summarize long text by chunking and summarizing each chunk,
         | 
| 395 | 
            +
                    then summarizing the summaries if there are multiple chunks.
         | 
| 396 | 
            +
                    """
         | 
| 397 | 
            +
                    try:
         | 
| 398 | 
            +
                        # Split text into chunks of ~800-1000 tokens
         | 
| 399 | 
            +
                        chunks = _split_into_chunks(text, chunk_chars=4000, overlap=400)
         | 
| 400 | 
            +
                        logger.info(f"Split long text into {len(chunks)} chunks for recursive summarization")
         | 
| 401 | 
            +
                        
         | 
| 402 | 
            +
                        chunk_summaries = []
         | 
| 403 | 
            +
                        
         | 
| 404 | 
            +
                        # Summarize each chunk
         | 
| 405 | 
            +
                        for i, chunk in enumerate(chunks):
         | 
| 406 | 
            +
                            logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
         | 
| 407 | 
            +
                            
         | 
| 408 | 
            +
                            # Use smaller max_new_tokens for individual chunks
         | 
| 409 | 
            +
                            chunk_max_tokens = min(max_new_tokens, 80)
         | 
| 410 | 
            +
                            
         | 
| 411 | 
            +
                            chunk_summary = ""
         | 
| 412 | 
            +
                            async for chunk_result in self._single_chunk_summarize(
         | 
| 413 | 
            +
                                chunk, chunk_max_tokens, temperature, top_p, prompt
         | 
| 414 | 
            +
                            ):
         | 
| 415 | 
            +
                                if chunk_result.get("content"):
         | 
| 416 | 
            +
                                    chunk_summary += chunk_result["content"]
         | 
| 417 | 
            +
                                yield chunk_result  # Stream each chunk's summary
         | 
| 418 | 
            +
                            
         | 
| 419 | 
            +
                            chunk_summaries.append(chunk_summary.strip())
         | 
| 420 | 
            +
                        
         | 
| 421 | 
            +
                        # If we have multiple chunks, create a final summary of summaries
         | 
| 422 | 
            +
                        if len(chunk_summaries) > 1:
         | 
| 423 | 
            +
                            logger.info("Creating final summary of summaries")
         | 
| 424 | 
            +
                            combined_summaries = "\n\n".join(chunk_summaries)
         | 
| 425 | 
            +
                            
         | 
| 426 | 
            +
                            # Use original max_new_tokens for final summary
         | 
| 427 | 
            +
                            async for final_result in self._single_chunk_summarize(
         | 
| 428 | 
            +
                                combined_summaries, max_new_tokens, temperature, top_p, 
         | 
| 429 | 
            +
                                "Summarize the key points from these summaries:"
         | 
| 430 | 
            +
                            ):
         | 
| 431 | 
            +
                                yield final_result
         | 
| 432 | 
            +
                        else:
         | 
| 433 | 
            +
                            # Single chunk, just yield the done signal
         | 
| 434 | 
            +
                            yield {
         | 
| 435 | 
            +
                                "content": "",
         | 
| 436 | 
            +
                                "done": True,
         | 
| 437 | 
            +
                                "tokens_used": 0,
         | 
| 438 | 
            +
                            }
         | 
| 439 | 
            +
                            
         | 
| 440 | 
            +
                    except Exception as e:
         | 
| 441 | 
            +
                        logger.exception("❌ Recursive summarization failed")
         | 
| 442 | 
            +
                        yield {
         | 
| 443 | 
            +
                            "content": "",
         | 
| 444 | 
            +
                            "done": True,
         | 
| 445 | 
            +
                            "error": f"Recursive summarization failed: {str(e)}",
         | 
| 446 | 
            +
                        }
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                async def _single_chunk_summarize(
         | 
| 449 | 
            +
                    self,
         | 
| 450 | 
            +
                    text: str,
         | 
| 451 | 
            +
                    max_new_tokens: int,
         | 
| 452 | 
            +
                    temperature: float,
         | 
| 453 | 
            +
                    top_p: float,
         | 
| 454 | 
            +
                    prompt: str,
         | 
| 455 | 
            +
                ) -> AsyncGenerator[Dict[str, Any], None]:
         | 
| 456 | 
            +
                    """
         | 
| 457 | 
            +
                    Summarize a single chunk of text using the same logic as the main method
         | 
| 458 | 
            +
                    but without the recursive check.
         | 
| 459 | 
            +
                    """
         | 
| 460 | 
            +
                    if not self.model or not self.tokenizer:
         | 
| 461 | 
            +
                        error_msg = "HuggingFace model not available. Please check model initialization."
         | 
| 462 | 
            +
                        logger.error(f"❌ {error_msg}")
         | 
| 463 | 
            +
                        yield {
         | 
| 464 | 
            +
                            "content": "",
         | 
| 465 | 
            +
                            "done": True,
         | 
| 466 | 
            +
                            "error": error_msg,
         | 
| 467 | 
            +
                        }
         | 
| 468 | 
            +
                        return
         | 
| 469 | 
            +
                        
         | 
| 470 | 
            +
                    try:
         | 
| 471 | 
            +
                        # Use provided parameters or sensible defaults
         | 
| 472 | 
            +
                        max_new_tokens = max_new_tokens or 80
         | 
| 473 | 
            +
                        temperature = temperature or 0.3
         | 
| 474 | 
            +
                        top_p = top_p or 0.9
         | 
| 475 | 
            +
                        
         | 
| 476 | 
            +
                        # Determine encoder max length
         | 
| 477 | 
            +
                        model_max = getattr(self.tokenizer, "model_max_length", 1024)
         | 
| 478 | 
            +
                        if not isinstance(model_max, int) or model_max <= 0:
         | 
| 479 | 
            +
                            model_max = 1024
         | 
| 480 | 
            +
                        enc_max_len = min(model_max, 2048)
         | 
| 481 | 
            +
                        
         | 
| 482 | 
            +
                        # Build tokenized inputs
         | 
| 483 | 
            +
                        if "t5" in settings.hf_model_id.lower():
         | 
| 484 | 
            +
                            full_prompt = f"summarize: {text}"
         | 
| 485 | 
            +
                            inputs_raw = self.tokenizer(full_prompt, return_tensors="pt", max_length=enc_max_len, truncation=True)
         | 
| 486 | 
            +
                        elif "bart" in settings.hf_model_id.lower():
         | 
| 487 | 
            +
                            inputs_raw = self.tokenizer(text, return_tensors="pt", max_length=enc_max_len, truncation=True)
         | 
| 488 | 
            +
                        else:
         | 
| 489 | 
            +
                            messages = [
         | 
| 490 | 
            +
                                {"role": "system", "content": prompt},
         | 
| 491 | 
            +
                                {"role": "user", "content": text}
         | 
| 492 | 
            +
                            ]
         | 
| 493 | 
            +
                            
         | 
| 494 | 
            +
                            if hasattr(self.tokenizer, "apply_chat_template") and self.tokenizer.chat_template:
         | 
| 495 | 
            +
                                inputs_raw = self.tokenizer.apply_chat_template(
         | 
| 496 | 
            +
                                    messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
         | 
| 497 | 
            +
                                )
         | 
| 498 | 
            +
                            else:
         | 
| 499 | 
            +
                                full_prompt = f"{prompt}\n\n{text}"
         | 
| 500 | 
            +
                                inputs_raw = self.tokenizer(full_prompt, return_tensors="pt")
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                        # Normalize inputs (same logic as main method)
         | 
| 503 | 
            +
                        if isinstance(inputs_raw, (dict, BatchEncoding)):
         | 
| 504 | 
            +
                            try:
         | 
| 505 | 
            +
                                inputs = dict(inputs_raw)
         | 
| 506 | 
            +
                            except Exception:
         | 
| 507 | 
            +
                                inputs = dict(getattr(inputs_raw, "data", {}))
         | 
| 508 | 
            +
                        else:
         | 
| 509 | 
            +
                            inputs = {"input_ids": inputs_raw}
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                        if "attention_mask" not in inputs and "input_ids" in inputs:
         | 
| 512 | 
            +
                            if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(inputs["input_ids"], torch.Tensor):
         | 
| 513 | 
            +
                                inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                        def _to_singleton_batch(d):
         | 
| 516 | 
            +
                            out = {}
         | 
| 517 | 
            +
                            for k, v in d.items():
         | 
| 518 | 
            +
                                if TRANSFORMERS_AVAILABLE and 'torch' in globals() and isinstance(v, torch.Tensor):
         | 
| 519 | 
            +
                                    if v.dim() == 1:
         | 
| 520 | 
            +
                                        out[k] = v.unsqueeze(0)
         | 
| 521 | 
            +
                                    elif v.dim() >= 2:
         | 
| 522 | 
            +
                                        out[k] = v[:1]
         | 
| 523 | 
            +
                                    else:
         | 
| 524 | 
            +
                                        out[k] = v
         | 
| 525 | 
            +
                                else:
         | 
| 526 | 
            +
                                    out[k] = v
         | 
| 527 | 
            +
                            return out
         | 
| 528 | 
            +
             | 
| 529 | 
            +
                        inputs = _to_singleton_batch(inputs)
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                        # Validate pad/eos ids
         | 
| 532 | 
            +
                        pad_id = self.tokenizer.pad_token_id
         | 
| 533 | 
            +
                        eos_id = self.tokenizer.eos_token_id
         | 
| 534 | 
            +
                        if pad_id is None and eos_id is not None:
         | 
| 535 | 
            +
                            pad_id = eos_id
         | 
| 536 | 
            +
                        elif pad_id is None and eos_id is None:
         | 
| 537 | 
            +
                            pad_id = 0
         | 
| 538 | 
            +
                        
         | 
| 539 | 
            +
                        # Create streamer
         | 
| 540 | 
            +
                        streamer = TextIteratorStreamer(
         | 
| 541 | 
            +
                            self.tokenizer, 
         | 
| 542 | 
            +
                            skip_prompt=True, 
         | 
| 543 | 
            +
                            skip_special_tokens=True
         | 
| 544 | 
            +
                        )
         | 
| 545 | 
            +
                        
         | 
| 546 | 
            +
                        gen_kwargs = {
         | 
| 547 | 
            +
                            **inputs,
         | 
| 548 | 
            +
                            "streamer": streamer,
         | 
| 549 | 
            +
                            "max_new_tokens": max_new_tokens,
         | 
| 550 | 
            +
                            "do_sample": True,
         | 
| 551 | 
            +
                            "temperature": temperature,
         | 
| 552 | 
            +
                            "top_p": top_p,
         | 
| 553 | 
            +
                            "pad_token_id": pad_id,
         | 
| 554 | 
            +
                            "eos_token_id": eos_id,
         | 
| 555 | 
            +
                            "num_return_sequences": 1,
         | 
| 556 | 
            +
                            "num_beams": 1,
         | 
| 557 | 
            +
                            "num_beam_groups": 1,
         | 
| 558 | 
            +
                            "min_new_tokens": max(20, min(50, max_new_tokens // 4)),
         | 
| 559 | 
            +
                            "length_penalty": 1.0,
         | 
| 560 | 
            +
                            "no_repeat_ngram_size": 3,
         | 
| 561 | 
            +
                            "repetition_penalty": 1.05,
         | 
| 562 | 
            +
                        }
         | 
| 563 | 
            +
                        
         | 
| 564 | 
            +
                        generation_thread = threading.Thread(target=self.model.generate, kwargs=gen_kwargs, daemon=True)
         | 
| 565 | 
            +
                        generation_thread.start()
         | 
| 566 | 
            +
                        
         | 
| 567 | 
            +
                        # Stream tokens as they arrive
         | 
| 568 | 
            +
                        token_count = 0
         | 
| 569 | 
            +
                        for text_chunk in streamer:
         | 
| 570 | 
            +
                            if text_chunk:
         | 
| 571 | 
            +
                                yield {
         | 
| 572 | 
            +
                                    "content": text_chunk,
         | 
| 573 | 
            +
                                    "done": False,
         | 
| 574 | 
            +
                                    "tokens_used": token_count,
         | 
| 575 | 
            +
                                }
         | 
| 576 | 
            +
                                token_count += 1
         | 
| 577 | 
            +
                        
         | 
| 578 | 
            +
                        # Wait for generation to complete
         | 
| 579 | 
            +
                        generation_thread.join()
         | 
| 580 | 
            +
                        
         | 
| 581 | 
            +
                        # Send final "done" chunk
         | 
| 582 | 
            +
                        yield {
         | 
| 583 | 
            +
                            "content": "",
         | 
| 584 | 
            +
                            "done": True,
         | 
| 585 | 
            +
                            "tokens_used": token_count,
         | 
| 586 | 
            +
                        }
         | 
| 587 | 
            +
                        
         | 
| 588 | 
            +
                    except Exception:
         | 
| 589 | 
            +
                        logger.exception("❌ Single chunk summarization failed")
         | 
| 590 | 
            +
                        yield {
         | 
| 591 | 
            +
                            "content": "",
         | 
| 592 | 
            +
                            "done": True,
         | 
| 593 | 
            +
                            "error": "Single chunk summarization failed. See server logs for traceback.",
         | 
| 594 | 
            +
                        }
         | 
| 595 | 
            +
             | 
| 596 | 
             
                async def check_health(self) -> bool:
         | 
| 597 | 
             
                    """
         | 
| 598 | 
             
                    Check if the HuggingFace model is properly initialized and ready.
         | 
    	
        tests/test_hf_streaming_improvements.py
    ADDED
    
    | @@ -0,0 +1,286 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Tests for HuggingFace streaming summarizer improvements.
         | 
| 3 | 
            +
            """
         | 
| 4 | 
            +
            import pytest
         | 
| 5 | 
            +
            from unittest.mock import AsyncMock, patch, MagicMock
         | 
| 6 | 
            +
            from app.services.hf_streaming_summarizer import HFStreamingSummarizer, _split_into_chunks
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class TestSplitIntoChunks:
         | 
| 10 | 
            +
                """Test the text chunking utility function."""
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                def test_split_short_text(self):
         | 
| 13 | 
            +
                    """Test splitting short text that doesn't need chunking."""
         | 
| 14 | 
            +
                    text = "This is a short text."
         | 
| 15 | 
            +
                    chunks = _split_into_chunks(text, chunk_chars=100, overlap=20)
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                    assert len(chunks) == 1
         | 
| 18 | 
            +
                    assert chunks[0] == text
         | 
| 19 | 
            +
                
         | 
| 20 | 
            +
                def test_split_long_text(self):
         | 
| 21 | 
            +
                    """Test splitting long text into multiple chunks."""
         | 
| 22 | 
            +
                    text = "This is a longer text. " * 50  # ~1000 chars
         | 
| 23 | 
            +
                    chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                    assert len(chunks) > 1
         | 
| 26 | 
            +
                    # All chunks should be within reasonable size
         | 
| 27 | 
            +
                    for chunk in chunks:
         | 
| 28 | 
            +
                        assert len(chunk) <= 200
         | 
| 29 | 
            +
                        assert len(chunk) > 0
         | 
| 30 | 
            +
                
         | 
| 31 | 
            +
                def test_chunk_overlap(self):
         | 
| 32 | 
            +
                    """Test that chunks have proper overlap."""
         | 
| 33 | 
            +
                    text = "This is a test text for overlap testing. " * 20  # ~800 chars
         | 
| 34 | 
            +
                    chunks = _split_into_chunks(text, chunk_chars=200, overlap=50)
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    if len(chunks) > 1:
         | 
| 37 | 
            +
                        # Check that consecutive chunks share some content
         | 
| 38 | 
            +
                        for i in range(len(chunks) - 1):
         | 
| 39 | 
            +
                            # There should be some overlap between consecutive chunks
         | 
| 40 | 
            +
                            assert len(chunks[i]) > 0
         | 
| 41 | 
            +
                            assert len(chunks[i+1]) > 0
         | 
| 42 | 
            +
                
         | 
| 43 | 
            +
                def test_empty_text(self):
         | 
| 44 | 
            +
                    """Test splitting empty text."""
         | 
| 45 | 
            +
                    chunks = _split_into_chunks("", chunk_chars=100, overlap=20)
         | 
| 46 | 
            +
                    assert len(chunks) == 0  # Empty text returns empty list
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class TestHFStreamingSummarizerImprovements:
         | 
| 50 | 
            +
                """Test improvements to HFStreamingSummarizer."""
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                @pytest.fixture
         | 
| 53 | 
            +
                def mock_summarizer(self):
         | 
| 54 | 
            +
                    """Create a mock HFStreamingSummarizer for testing."""
         | 
| 55 | 
            +
                    summarizer = HFStreamingSummarizer()
         | 
| 56 | 
            +
                    summarizer.model = MagicMock()
         | 
| 57 | 
            +
                    summarizer.tokenizer = MagicMock()
         | 
| 58 | 
            +
                    return summarizer
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                @pytest.mark.asyncio
         | 
| 61 | 
            +
                async def test_recursive_summarization_long_text(self, mock_summarizer):
         | 
| 62 | 
            +
                    """Test recursive summarization for long text."""
         | 
| 63 | 
            +
                    # Mock the _single_chunk_summarize method
         | 
| 64 | 
            +
                    async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
         | 
| 65 | 
            +
                        yield {"content": f"Summary of: {text[:50]}...", "done": False, "tokens_used": 10}
         | 
| 66 | 
            +
                        yield {"content": "", "done": True, "tokens_used": 10}
         | 
| 67 | 
            +
                    
         | 
| 68 | 
            +
                    mock_summarizer._single_chunk_summarize = mock_single_chunk
         | 
| 69 | 
            +
                    
         | 
| 70 | 
            +
                    # Long text (>1500 chars)
         | 
| 71 | 
            +
                    long_text = "This is a very long text that should trigger recursive summarization. " * 30  # ~2000+ chars
         | 
| 72 | 
            +
                    
         | 
| 73 | 
            +
                    results = []
         | 
| 74 | 
            +
                    async for chunk in mock_summarizer._recursive_summarize(
         | 
| 75 | 
            +
                        long_text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
         | 
| 76 | 
            +
                    ):
         | 
| 77 | 
            +
                        results.append(chunk)
         | 
| 78 | 
            +
                    
         | 
| 79 | 
            +
                    # Should have multiple chunks (one for each text chunk + final summary)
         | 
| 80 | 
            +
                    assert len(results) > 2  # At least 2 chunks + final done signal
         | 
| 81 | 
            +
                    
         | 
| 82 | 
            +
                    # Check that we get proper streaming format
         | 
| 83 | 
            +
                    content_chunks = [r for r in results if r.get("content") and not r.get("done")]
         | 
| 84 | 
            +
                    assert len(content_chunks) > 0
         | 
| 85 | 
            +
                    
         | 
| 86 | 
            +
                    # Should end with done signal
         | 
| 87 | 
            +
                    final_chunk = results[-1]
         | 
| 88 | 
            +
                    assert final_chunk.get("done") is True
         | 
| 89 | 
            +
                
         | 
| 90 | 
            +
                @pytest.mark.asyncio
         | 
| 91 | 
            +
                async def test_recursive_summarization_single_chunk(self, mock_summarizer):
         | 
| 92 | 
            +
                    """Test recursive summarization when text fits in single chunk."""
         | 
| 93 | 
            +
                    # Mock the _single_chunk_summarize method
         | 
| 94 | 
            +
                    async def mock_single_chunk(text, max_tokens, temp, top_p, prompt):
         | 
| 95 | 
            +
                        yield {"content": "Single chunk summary", "done": False, "tokens_used": 5}
         | 
| 96 | 
            +
                        yield {"content": "", "done": True, "tokens_used": 5}
         | 
| 97 | 
            +
                    
         | 
| 98 | 
            +
                    mock_summarizer._single_chunk_summarize = mock_single_chunk
         | 
| 99 | 
            +
                    
         | 
| 100 | 
            +
                    # Text that would fit in single chunk after splitting
         | 
| 101 | 
            +
                    text = "This is a medium length text. " * 20  # ~600 chars
         | 
| 102 | 
            +
                    
         | 
| 103 | 
            +
                    results = []
         | 
| 104 | 
            +
                    async for chunk in mock_summarizer._recursive_summarize(
         | 
| 105 | 
            +
                        text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
         | 
| 106 | 
            +
                    ):
         | 
| 107 | 
            +
                        results.append(chunk)
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    # Should have at least 2 chunks (content + done)
         | 
| 110 | 
            +
                    assert len(results) >= 2
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    # Should end with done signal
         | 
| 113 | 
            +
                    final_chunk = results[-1]
         | 
| 114 | 
            +
                    assert final_chunk.get("done") is True
         | 
| 115 | 
            +
                
         | 
| 116 | 
            +
                @pytest.mark.asyncio
         | 
| 117 | 
            +
                async def test_single_chunk_summarize_parameters(self, mock_summarizer):
         | 
| 118 | 
            +
                    """Test that _single_chunk_summarize uses correct parameters."""
         | 
| 119 | 
            +
                    # Mock the tokenizer and model
         | 
| 120 | 
            +
                    mock_summarizer.tokenizer.model_max_length = 1024
         | 
| 121 | 
            +
                    mock_summarizer.tokenizer.pad_token_id = 0
         | 
| 122 | 
            +
                    mock_summarizer.tokenizer.eos_token_id = 1
         | 
| 123 | 
            +
                    
         | 
| 124 | 
            +
                    # Mock the model generation
         | 
| 125 | 
            +
                    mock_streamer = MagicMock()
         | 
| 126 | 
            +
                    mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))
         | 
| 127 | 
            +
                    
         | 
| 128 | 
            +
                    with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
         | 
| 129 | 
            +
                        with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
         | 
| 130 | 
            +
                            mock_settings.hf_model_id = "test-model"
         | 
| 131 | 
            +
                            
         | 
| 132 | 
            +
                            results = []
         | 
| 133 | 
            +
                            async for chunk in mock_summarizer._single_chunk_summarize(
         | 
| 134 | 
            +
                                "Test text", max_new_tokens=80, temperature=0.3, top_p=0.9, prompt="Test prompt"
         | 
| 135 | 
            +
                            ):
         | 
| 136 | 
            +
                                results.append(chunk)
         | 
| 137 | 
            +
                            
         | 
| 138 | 
            +
                            # Should have content chunks + final done
         | 
| 139 | 
            +
                            assert len(results) >= 2
         | 
| 140 | 
            +
                            
         | 
| 141 | 
            +
                            # Check that generation was called with correct parameters
         | 
| 142 | 
            +
                            mock_summarizer.model.generate.assert_called_once()
         | 
| 143 | 
            +
                            call_kwargs = mock_summarizer.model.generate.call_args[1]
         | 
| 144 | 
            +
                            
         | 
| 145 | 
            +
                            assert call_kwargs["max_new_tokens"] == 80
         | 
| 146 | 
            +
                            assert call_kwargs["temperature"] == 0.3
         | 
| 147 | 
            +
                            assert call_kwargs["top_p"] == 0.9
         | 
| 148 | 
            +
                            assert call_kwargs["length_penalty"] == 1.0  # Should be neutral
         | 
| 149 | 
            +
                            assert call_kwargs["min_new_tokens"] <= 50  # Should be conservative
         | 
| 150 | 
            +
                
         | 
| 151 | 
            +
                @pytest.mark.asyncio
         | 
| 152 | 
            +
                async def test_single_chunk_summarize_defaults(self, mock_summarizer):
         | 
| 153 | 
            +
                    """Test that _single_chunk_summarize uses correct defaults."""
         | 
| 154 | 
            +
                    # Mock the tokenizer and model
         | 
| 155 | 
            +
                    mock_summarizer.tokenizer.model_max_length = 1024
         | 
| 156 | 
            +
                    mock_summarizer.tokenizer.pad_token_id = 0
         | 
| 157 | 
            +
                    mock_summarizer.tokenizer.eos_token_id = 1
         | 
| 158 | 
            +
                    
         | 
| 159 | 
            +
                    # Mock the model generation
         | 
| 160 | 
            +
                    mock_streamer = MagicMock()
         | 
| 161 | 
            +
                    mock_streamer.__iter__ = MagicMock(return_value=iter(["test", "summary"]))
         | 
| 162 | 
            +
                    
         | 
| 163 | 
            +
                    with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
         | 
| 164 | 
            +
                        with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
         | 
| 165 | 
            +
                            mock_settings.hf_model_id = "test-model"
         | 
| 166 | 
            +
                            
         | 
| 167 | 
            +
                            results = []
         | 
| 168 | 
            +
                            async for chunk in mock_summarizer._single_chunk_summarize(
         | 
| 169 | 
            +
                                "Test text", max_new_tokens=None, temperature=None, top_p=None, prompt="Test prompt"
         | 
| 170 | 
            +
                            ):
         | 
| 171 | 
            +
                                results.append(chunk)
         | 
| 172 | 
            +
                            
         | 
| 173 | 
            +
                            # Check that generation was called with correct defaults
         | 
| 174 | 
            +
                            mock_summarizer.model.generate.assert_called_once()
         | 
| 175 | 
            +
                            call_kwargs = mock_summarizer.model.generate.call_args[1]
         | 
| 176 | 
            +
                            
         | 
| 177 | 
            +
                            assert call_kwargs["max_new_tokens"] == 80  # Default
         | 
| 178 | 
            +
                            assert call_kwargs["temperature"] == 0.3  # Default
         | 
| 179 | 
            +
                            assert call_kwargs["top_p"] == 0.9  # Default
         | 
| 180 | 
            +
                
         | 
| 181 | 
            +
                @pytest.mark.asyncio
         | 
| 182 | 
            +
                async def test_recursive_summarization_error_handling(self, mock_summarizer):
         | 
| 183 | 
            +
                    """Test error handling in recursive summarization."""
         | 
| 184 | 
            +
                    # Mock _single_chunk_summarize to raise an exception
         | 
| 185 | 
            +
                    async def mock_single_chunk_error(text, max_tokens, temp, top_p, prompt):
         | 
| 186 | 
            +
                        raise Exception("Test error")
         | 
| 187 | 
            +
                        yield  # This line will never be reached, but makes it an async generator
         | 
| 188 | 
            +
                    
         | 
| 189 | 
            +
                    mock_summarizer._single_chunk_summarize = mock_single_chunk_error
         | 
| 190 | 
            +
                    
         | 
| 191 | 
            +
                    long_text = "This is a long text. " * 30
         | 
| 192 | 
            +
                    
         | 
| 193 | 
            +
                    results = []
         | 
| 194 | 
            +
                    async for chunk in mock_summarizer._recursive_summarize(
         | 
| 195 | 
            +
                        long_text, max_new_tokens=100, temperature=0.3, top_p=0.9, prompt="Test prompt"
         | 
| 196 | 
            +
                    ):
         | 
| 197 | 
            +
                        results.append(chunk)
         | 
| 198 | 
            +
                    
         | 
| 199 | 
            +
                    # Should have error chunk
         | 
| 200 | 
            +
                    assert len(results) == 1
         | 
| 201 | 
            +
                    error_chunk = results[0]
         | 
| 202 | 
            +
                    assert error_chunk.get("done") is True
         | 
| 203 | 
            +
                    assert "error" in error_chunk
         | 
| 204 | 
            +
                    assert "Test error" in error_chunk["error"]
         | 
| 205 | 
            +
                
         | 
| 206 | 
            +
                @pytest.mark.asyncio
         | 
| 207 | 
            +
                async def test_single_chunk_summarize_error_handling(self, mock_summarizer):
         | 
| 208 | 
            +
                    """Test error handling in single chunk summarization."""
         | 
| 209 | 
            +
                    # Mock model to raise exception
         | 
| 210 | 
            +
                    mock_summarizer.model.generate.side_effect = Exception("Generation error")
         | 
| 211 | 
            +
                    
         | 
| 212 | 
            +
                    results = []
         | 
| 213 | 
            +
                    async for chunk in mock_summarizer._single_chunk_summarize(
         | 
| 214 | 
            +
                        "Test text", max_new_tokens=80, temperature=0.3, top_p=0.9, prompt="Test prompt"
         | 
| 215 | 
            +
                    ):
         | 
| 216 | 
            +
                        results.append(chunk)
         | 
| 217 | 
            +
                    
         | 
| 218 | 
            +
                    # Should have error chunk
         | 
| 219 | 
            +
                    assert len(results) == 1
         | 
| 220 | 
            +
                    error_chunk = results[0]
         | 
| 221 | 
            +
                    assert error_chunk.get("done") is True
         | 
| 222 | 
            +
                    assert "error" in error_chunk
         | 
| 223 | 
            +
                    assert "Generation error" in error_chunk["error"]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            class TestHFStreamingSummarizerIntegration:
         | 
| 227 | 
            +
                """Integration tests for HFStreamingSummarizer improvements."""
         | 
| 228 | 
            +
                
         | 
| 229 | 
            +
                @pytest.mark.asyncio
         | 
| 230 | 
            +
                async def test_summarize_text_stream_long_text_detection(self):
         | 
| 231 | 
            +
                    """Test that summarize_text_stream detects long text and uses recursive summarization."""
         | 
| 232 | 
            +
                    summarizer = HFStreamingSummarizer()
         | 
| 233 | 
            +
                    
         | 
| 234 | 
            +
                    # Mock the recursive summarization method
         | 
| 235 | 
            +
                    async def mock_recursive(text, max_tokens, temp, top_p, prompt):
         | 
| 236 | 
            +
                        yield {"content": "Recursive summary", "done": False, "tokens_used": 10}
         | 
| 237 | 
            +
                        yield {"content": "", "done": True, "tokens_used": 10}
         | 
| 238 | 
            +
                    
         | 
| 239 | 
            +
                    summarizer._recursive_summarize = mock_recursive
         | 
| 240 | 
            +
                    
         | 
| 241 | 
            +
                    # Long text (>1500 chars)
         | 
| 242 | 
            +
                    long_text = "This is a very long text. " * 60  # ~1500+ chars
         | 
| 243 | 
            +
                    
         | 
| 244 | 
            +
                    results = []
         | 
| 245 | 
            +
                    async for chunk in summarizer.summarize_text_stream(long_text):
         | 
| 246 | 
            +
                        results.append(chunk)
         | 
| 247 | 
            +
                    
         | 
| 248 | 
            +
                    # Should have used recursive summarization
         | 
| 249 | 
            +
                    assert len(results) >= 2
         | 
| 250 | 
            +
                    assert results[0]["content"] == "Recursive summary"
         | 
| 251 | 
            +
                    assert results[-1]["done"] is True
         | 
| 252 | 
            +
                
         | 
| 253 | 
            +
                @pytest.mark.asyncio
         | 
| 254 | 
            +
                async def test_summarize_text_stream_short_text_normal_flow(self):
         | 
| 255 | 
            +
                    """Test that summarize_text_stream uses normal flow for short text."""
         | 
| 256 | 
            +
                    summarizer = HFStreamingSummarizer()
         | 
| 257 | 
            +
                    
         | 
| 258 | 
            +
                    # Mock model and tokenizer
         | 
| 259 | 
            +
                    summarizer.model = MagicMock()
         | 
| 260 | 
            +
                    summarizer.tokenizer = MagicMock()
         | 
| 261 | 
            +
                    summarizer.tokenizer.model_max_length = 1024
         | 
| 262 | 
            +
                    summarizer.tokenizer.pad_token_id = 0
         | 
| 263 | 
            +
                    summarizer.tokenizer.eos_token_id = 1
         | 
| 264 | 
            +
                    
         | 
| 265 | 
            +
                    # Mock the streamer
         | 
| 266 | 
            +
                    mock_streamer = MagicMock()
         | 
| 267 | 
            +
                    mock_streamer.__iter__ = MagicMock(return_value=iter(["short", "summary"]))
         | 
| 268 | 
            +
                    
         | 
| 269 | 
            +
                    with patch('app.services.hf_streaming_summarizer.TextIteratorStreamer', return_value=mock_streamer):
         | 
| 270 | 
            +
                        with patch('app.services.hf_streaming_summarizer.settings') as mock_settings:
         | 
| 271 | 
            +
                            mock_settings.hf_model_id = "test-model"
         | 
| 272 | 
            +
                            mock_settings.hf_temperature = 0.3
         | 
| 273 | 
            +
                            mock_settings.hf_top_p = 0.9
         | 
| 274 | 
            +
                            
         | 
| 275 | 
            +
                            # Short text (<1500 chars)
         | 
| 276 | 
            +
                            short_text = "This is a short text."
         | 
| 277 | 
            +
                            
         | 
| 278 | 
            +
                            results = []
         | 
| 279 | 
            +
                            async for chunk in summarizer.summarize_text_stream(short_text):
         | 
| 280 | 
            +
                                results.append(chunk)
         | 
| 281 | 
            +
                            
         | 
| 282 | 
            +
                            # Should have used normal flow (not recursive)
         | 
| 283 | 
            +
                            assert len(results) >= 2
         | 
| 284 | 
            +
                            assert results[0]["content"] == "short"
         | 
| 285 | 
            +
                            assert results[1]["content"] == "summary"
         | 
| 286 | 
            +
                            assert results[-1]["done"] is True
         | 
    	
        tests/test_schemas.py
    CHANGED
    
    | @@ -15,7 +15,7 @@ class TestSummarizeRequest: | |
| 15 |  | 
| 16 | 
             
                    assert request.text == sample_text.strip()
         | 
| 17 | 
             
                    assert request.max_tokens == 256
         | 
| 18 | 
            -
                    assert request.prompt == "Summarize the  | 
| 19 |  | 
| 20 | 
             
                def test_custom_parameters(self):
         | 
| 21 | 
             
                    """Test request with custom parameters."""
         | 
| @@ -73,6 +73,57 @@ class TestSummarizeRequest: | |
| 73 | 
             
                    long_prompt = "x" * 501
         | 
| 74 | 
             
                    with pytest.raises(ValidationError):
         | 
| 75 | 
             
                        SummarizeRequest(text="test", prompt=long_prompt)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 76 |  | 
| 77 |  | 
| 78 | 
             
            class TestSummarizeResponse:
         | 
|  | |
| 15 |  | 
| 16 | 
             
                    assert request.text == sample_text.strip()
         | 
| 17 | 
             
                    assert request.max_tokens == 256
         | 
| 18 | 
            +
                    assert request.prompt == "Summarize the key points concisely:"
         | 
| 19 |  | 
| 20 | 
             
                def test_custom_parameters(self):
         | 
| 21 | 
             
                    """Test request with custom parameters."""
         | 
|  | |
| 73 | 
             
                    long_prompt = "x" * 501
         | 
| 74 | 
             
                    with pytest.raises(ValidationError):
         | 
| 75 | 
             
                        SummarizeRequest(text="test", prompt=long_prompt)
         | 
| 76 | 
            +
                
         | 
| 77 | 
            +
                def test_temperature_parameter(self):
         | 
| 78 | 
            +
                    """Test temperature parameter validation."""
         | 
| 79 | 
            +
                    # Valid temperature values
         | 
| 80 | 
            +
                    request = SummarizeRequest(text="test", temperature=0.0)
         | 
| 81 | 
            +
                    assert request.temperature == 0.0
         | 
| 82 | 
            +
                    
         | 
| 83 | 
            +
                    request = SummarizeRequest(text="test", temperature=2.0)
         | 
| 84 | 
            +
                    assert request.temperature == 2.0
         | 
| 85 | 
            +
                    
         | 
| 86 | 
            +
                    request = SummarizeRequest(text="test", temperature=0.3)
         | 
| 87 | 
            +
                    assert request.temperature == 0.3
         | 
| 88 | 
            +
                    
         | 
| 89 | 
            +
                    # Default temperature
         | 
| 90 | 
            +
                    request = SummarizeRequest(text="test")
         | 
| 91 | 
            +
                    assert request.temperature == 0.3
         | 
| 92 | 
            +
                    
         | 
| 93 | 
            +
                    # Invalid temperature values
         | 
| 94 | 
            +
                    with pytest.raises(ValidationError):
         | 
| 95 | 
            +
                        SummarizeRequest(text="test", temperature=-0.1)
         | 
| 96 | 
            +
                    
         | 
| 97 | 
            +
                    with pytest.raises(ValidationError):
         | 
| 98 | 
            +
                        SummarizeRequest(text="test", temperature=2.1)
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                def test_top_p_parameter(self):
         | 
| 101 | 
            +
                    """Test top_p parameter validation."""
         | 
| 102 | 
            +
                    # Valid top_p values
         | 
| 103 | 
            +
                    request = SummarizeRequest(text="test", top_p=0.0)
         | 
| 104 | 
            +
                    assert request.top_p == 0.0
         | 
| 105 | 
            +
                    
         | 
| 106 | 
            +
                    request = SummarizeRequest(text="test", top_p=1.0)
         | 
| 107 | 
            +
                    assert request.top_p == 1.0
         | 
| 108 | 
            +
                    
         | 
| 109 | 
            +
                    request = SummarizeRequest(text="test", top_p=0.9)
         | 
| 110 | 
            +
                    assert request.top_p == 0.9
         | 
| 111 | 
            +
                    
         | 
| 112 | 
            +
                    # Default top_p
         | 
| 113 | 
            +
                    request = SummarizeRequest(text="test")
         | 
| 114 | 
            +
                    assert request.top_p == 0.9
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    # Invalid top_p values
         | 
| 117 | 
            +
                    with pytest.raises(ValidationError):
         | 
| 118 | 
            +
                        SummarizeRequest(text="test", top_p=-0.1)
         | 
| 119 | 
            +
                    
         | 
| 120 | 
            +
                    with pytest.raises(ValidationError):
         | 
| 121 | 
            +
                        SummarizeRequest(text="test", top_p=1.1)
         | 
| 122 | 
            +
                
         | 
| 123 | 
            +
                def test_updated_default_prompt(self):
         | 
| 124 | 
            +
                    """Test that the default prompt has been updated to be more concise."""
         | 
| 125 | 
            +
                    request = SummarizeRequest(text="test")
         | 
| 126 | 
            +
                    assert request.prompt == "Summarize the key points concisely:"
         | 
| 127 |  | 
| 128 |  | 
| 129 | 
             
            class TestSummarizeResponse:
         | 
    	
        tests/test_v2_api.py
    CHANGED
    
    | @@ -155,6 +155,144 @@ class TestV2SummarizeStream: | |
| 155 | 
             
                        assert call_args[1]['prompt'] == "Custom prompt"
         | 
| 156 | 
             
                        assert call_args[1]['text'] == "Test text"
         | 
| 157 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 158 |  | 
| 159 | 
             
            class TestV2APICompatibility:
         | 
| 160 | 
             
                """Test V2 API compatibility with V1."""
         | 
|  | |
| 155 | 
             
                        assert call_args[1]['prompt'] == "Custom prompt"
         | 
| 156 | 
             
                        assert call_args[1]['text'] == "Test text"
         | 
| 157 |  | 
| 158 | 
            +
                @pytest.mark.integration
         | 
| 159 | 
            +
                def test_v2_adaptive_token_logic_short_text(self, client: TestClient):
         | 
| 160 | 
            +
                    """Test adaptive token logic for short texts (<1500 chars)."""
         | 
| 161 | 
            +
                    with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
         | 
| 162 | 
            +
                        async def mock_generator():
         | 
| 163 | 
            +
                            yield {"content": "", "done": True}
         | 
| 164 | 
            +
                        
         | 
| 165 | 
            +
                        mock_stream.return_value = mock_generator()
         | 
| 166 | 
            +
                        
         | 
| 167 | 
            +
                        # Short text (500 chars)
         | 
| 168 | 
            +
                        short_text = "This is a short text. " * 20  # ~500 chars
         | 
| 169 | 
            +
                        
         | 
| 170 | 
            +
                        response = client.post(
         | 
| 171 | 
            +
                            "/api/v2/summarize/stream",
         | 
| 172 | 
            +
                            json={
         | 
| 173 | 
            +
                                "text": short_text,
         | 
| 174 | 
            +
                                # Don't specify max_tokens to test adaptive logic
         | 
| 175 | 
            +
                            }
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
                        
         | 
| 178 | 
            +
                        assert response.status_code == 200
         | 
| 179 | 
            +
                        
         | 
| 180 | 
            +
                        # Verify service was called with adaptive max_new_tokens
         | 
| 181 | 
            +
                        mock_stream.assert_called_once()
         | 
| 182 | 
            +
                        call_args = mock_stream.call_args
         | 
| 183 | 
            +
                        
         | 
| 184 | 
            +
                        # For short text, should use 60-100 tokens
         | 
| 185 | 
            +
                        max_new_tokens = call_args[1]['max_new_tokens']
         | 
| 186 | 
            +
                        assert 60 <= max_new_tokens <= 100
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                @pytest.mark.integration
         | 
| 189 | 
            +
                def test_v2_adaptive_token_logic_long_text(self, client: TestClient):
         | 
| 190 | 
            +
                    """Test adaptive token logic for long texts (>1500 chars)."""
         | 
| 191 | 
            +
                    with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
         | 
| 192 | 
            +
                        async def mock_generator():
         | 
| 193 | 
            +
                            yield {"content": "", "done": True}
         | 
| 194 | 
            +
                        
         | 
| 195 | 
            +
                        mock_stream.return_value = mock_generator()
         | 
| 196 | 
            +
                        
         | 
| 197 | 
            +
                        # Long text (2000 chars)
         | 
| 198 | 
            +
                        long_text = "This is a longer text that should trigger adaptive token logic. " * 40  # ~2000 chars
         | 
| 199 | 
            +
                        
         | 
| 200 | 
            +
                        response = client.post(
         | 
| 201 | 
            +
                            "/api/v2/summarize/stream",
         | 
| 202 | 
            +
                            json={
         | 
| 203 | 
            +
                                "text": long_text,
         | 
| 204 | 
            +
                                # Don't specify max_tokens to test adaptive logic
         | 
| 205 | 
            +
                            }
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                        
         | 
| 208 | 
            +
                        assert response.status_code == 200
         | 
| 209 | 
            +
                        
         | 
| 210 | 
            +
                        # Verify service was called with adaptive max_new_tokens
         | 
| 211 | 
            +
                        mock_stream.assert_called_once()
         | 
| 212 | 
            +
                        call_args = mock_stream.call_args
         | 
| 213 | 
            +
                        
         | 
| 214 | 
            +
                        # For long text, should use proportional scaling but capped
         | 
| 215 | 
            +
                        max_new_tokens = call_args[1]['max_new_tokens']
         | 
| 216 | 
            +
                        assert 100 <= max_new_tokens <= 400
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                @pytest.mark.integration
         | 
| 219 | 
            +
                def test_v2_temperature_and_top_p_parameters(self, client: TestClient):
         | 
| 220 | 
            +
                    """Test that temperature and top_p parameters are passed correctly."""
         | 
| 221 | 
            +
                    with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
         | 
| 222 | 
            +
                        async def mock_generator():
         | 
| 223 | 
            +
                            yield {"content": "", "done": True}
         | 
| 224 | 
            +
                        
         | 
| 225 | 
            +
                        mock_stream.return_value = mock_generator()
         | 
| 226 | 
            +
                        
         | 
| 227 | 
            +
                        response = client.post(
         | 
| 228 | 
            +
                            "/api/v2/summarize/stream",
         | 
| 229 | 
            +
                            json={
         | 
| 230 | 
            +
                                "text": "Test text",
         | 
| 231 | 
            +
                                "temperature": 0.5,
         | 
| 232 | 
            +
                                "top_p": 0.8
         | 
| 233 | 
            +
                            }
         | 
| 234 | 
            +
                        )
         | 
| 235 | 
            +
                        
         | 
| 236 | 
            +
                        assert response.status_code == 200
         | 
| 237 | 
            +
                        
         | 
| 238 | 
            +
                        # Verify service was called with correct parameters
         | 
| 239 | 
            +
                        mock_stream.assert_called_once()
         | 
| 240 | 
            +
                        call_args = mock_stream.call_args
         | 
| 241 | 
            +
                        
         | 
| 242 | 
            +
                        assert call_args[1]['temperature'] == 0.5
         | 
| 243 | 
            +
                        assert call_args[1]['top_p'] == 0.8
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                @pytest.mark.integration
         | 
| 246 | 
            +
                def test_v2_default_temperature_and_top_p(self, client: TestClient):
         | 
| 247 | 
            +
                    """Test that default temperature and top_p values are used when not specified."""
         | 
| 248 | 
            +
                    with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
         | 
| 249 | 
            +
                        async def mock_generator():
         | 
| 250 | 
            +
                            yield {"content": "", "done": True}
         | 
| 251 | 
            +
                        
         | 
| 252 | 
            +
                        mock_stream.return_value = mock_generator()
         | 
| 253 | 
            +
                        
         | 
| 254 | 
            +
                        response = client.post(
         | 
| 255 | 
            +
                            "/api/v2/summarize/stream",
         | 
| 256 | 
            +
                            json={
         | 
| 257 | 
            +
                                "text": "Test text"
         | 
| 258 | 
            +
                                # Don't specify temperature or top_p
         | 
| 259 | 
            +
                            }
         | 
| 260 | 
            +
                        )
         | 
| 261 | 
            +
                        
         | 
| 262 | 
            +
                        assert response.status_code == 200
         | 
| 263 | 
            +
                        
         | 
| 264 | 
            +
                        # Verify service was called with default parameters
         | 
| 265 | 
            +
                        mock_stream.assert_called_once()
         | 
| 266 | 
            +
                        call_args = mock_stream.call_args
         | 
| 267 | 
            +
                        
         | 
| 268 | 
            +
                        assert call_args[1]['temperature'] == 0.3  # Default temperature
         | 
| 269 | 
            +
                        assert call_args[1]['top_p'] == 0.9  # Default top_p
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                @pytest.mark.integration
         | 
| 272 | 
            +
                def test_v2_recursive_summarization_trigger(self, client: TestClient):
         | 
| 273 | 
            +
                    """Test that recursive summarization is triggered for long texts."""
         | 
| 274 | 
            +
                    with patch('app.services.hf_streaming_summarizer.hf_streaming_service.summarize_text_stream') as mock_stream:
         | 
| 275 | 
            +
                        async def mock_generator():
         | 
| 276 | 
            +
                            yield {"content": "", "done": True}
         | 
| 277 | 
            +
                        
         | 
| 278 | 
            +
                        mock_stream.return_value = mock_generator()
         | 
| 279 | 
            +
                        
         | 
| 280 | 
            +
                        # Very long text (>1500 chars) to trigger recursive summarization
         | 
| 281 | 
            +
                        very_long_text = "This is a very long text that should definitely trigger recursive summarization logic. " * 30  # ~2000+ chars
         | 
| 282 | 
            +
                        
         | 
| 283 | 
            +
                        response = client.post(
         | 
| 284 | 
            +
                            "/api/v2/summarize/stream",
         | 
| 285 | 
            +
                            json={
         | 
| 286 | 
            +
                                "text": very_long_text
         | 
| 287 | 
            +
                            }
         | 
| 288 | 
            +
                        )
         | 
| 289 | 
            +
                        
         | 
| 290 | 
            +
                        assert response.status_code == 200
         | 
| 291 | 
            +
                        
         | 
| 292 | 
            +
                        # The service should be called, and internally it should detect long text
         | 
| 293 | 
            +
                        # and use recursive summarization
         | 
| 294 | 
            +
                        mock_stream.assert_called_once()
         | 
| 295 | 
            +
             | 
| 296 |  | 
| 297 | 
             
            class TestV2APICompatibility:
         | 
| 298 | 
             
                """Test V2 API compatibility with V1."""
         |