Spaces:
Sleeping
Sleeping
| """ | |
| Test script for Advanced RAG features | |
| Demonstrates new capabilities: multiple texts/images indexing and advanced RAG chat | |
| """ | |
| import requests | |
| import json | |
| from typing import List, Optional | |
| class AdvancedRAGTester: | |
| """Test client for Advanced RAG API""" | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self.base_url = base_url | |
| def test_multiple_index(self, doc_id: str, texts: List[str], image_paths: Optional[List[str]] = None): | |
| """ | |
| Test indexing with multiple texts and images | |
| Args: | |
| doc_id: Document ID | |
| texts: List of texts (max 10) | |
| image_paths: List of image file paths (max 10) | |
| """ | |
| print(f"\n{'='*60}") | |
| print(f"TEST: Indexing document '{doc_id}' with multiple texts/images") | |
| print(f"{'='*60}") | |
| # Prepare form data | |
| data = {'id': doc_id} | |
| # Add texts | |
| if texts: | |
| if len(texts) > 10: | |
| print("WARNING: Maximum 10 texts allowed. Taking first 10.") | |
| texts = texts[:10] | |
| data['texts'] = texts | |
| print(f"✓ Texts: {len(texts)} items") | |
| # Prepare files | |
| files = [] | |
| if image_paths: | |
| if len(image_paths) > 10: | |
| print("WARNING: Maximum 10 images allowed. Taking first 10.") | |
| image_paths = image_paths[:10] | |
| for img_path in image_paths: | |
| try: | |
| files.append(('images', open(img_path, 'rb'))) | |
| except FileNotFoundError: | |
| print(f"WARNING: Image not found: {img_path}") | |
| print(f"✓ Images: {len(files)} files") | |
| # Make request | |
| try: | |
| response = requests.post(f"{self.base_url}/index", data=data, files=files) | |
| response.raise_for_status() | |
| result = response.json() | |
| print(f"\n✓ SUCCESS") | |
| print(f" - Document ID: {result['id']}") | |
| print(f" - Message: {result['message']}") | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| print(f"\n✗ ERROR: {e}") | |
| if hasattr(e.response, 'text'): | |
| print(f" Response: {e.response.text}") | |
| return None | |
| finally: | |
| # Close file handles | |
| for _, file_obj in files: | |
| file_obj.close() | |
| def test_advanced_rag_chat( | |
| self, | |
| message: str, | |
| hf_token: Optional[str] = None, | |
| use_advanced_rag: bool = True, | |
| use_reranking: bool = True, | |
| use_compression: bool = True, | |
| top_k: int = 3, | |
| score_threshold: float = 0.5 | |
| ): | |
| """ | |
| Test advanced RAG chat | |
| Args: | |
| message: User question | |
| hf_token: Hugging Face token (optional) | |
| use_advanced_rag: Use advanced RAG pipeline | |
| use_reranking: Enable reranking | |
| use_compression: Enable context compression | |
| top_k: Number of documents to retrieve | |
| score_threshold: Minimum relevance score | |
| """ | |
| print(f"\n{'='*60}") | |
| print(f"TEST: Advanced RAG Chat") | |
| print(f"{'='*60}") | |
| print(f"Question: {message}") | |
| print(f"Advanced RAG: {use_advanced_rag}") | |
| print(f"Reranking: {use_reranking}") | |
| print(f"Compression: {use_compression}") | |
| payload = { | |
| 'message': message, | |
| 'use_rag': True, | |
| 'use_advanced_rag': use_advanced_rag, | |
| 'use_reranking': use_reranking, | |
| 'use_compression': use_compression, | |
| 'top_k': top_k, | |
| 'score_threshold': score_threshold, | |
| } | |
| if hf_token: | |
| payload['hf_token'] = hf_token | |
| try: | |
| response = requests.post(f"{self.base_url}/chat", json=payload) | |
| response.raise_for_status() | |
| result = response.json() | |
| print(f"\n✓ SUCCESS") | |
| print(f"\n--- Answer ---") | |
| print(result['response']) | |
| print(f"\n--- Retrieved Context ({len(result['context_used'])} documents) ---") | |
| for i, ctx in enumerate(result['context_used'], 1): | |
| print(f"{i}. [{ctx['id']}] Confidence: {ctx['confidence']:.2%}") | |
| text_preview = ctx['metadata'].get('text', '')[:100] | |
| print(f" Text: {text_preview}...") | |
| if result.get('rag_stats'): | |
| print(f"\n--- RAG Pipeline Statistics ---") | |
| stats = result['rag_stats'] | |
| print(f" Original query: {stats.get('original_query')}") | |
| print(f" Expanded queries: {stats.get('expanded_queries')}") | |
| print(f" Initial results: {stats.get('initial_results')}") | |
| print(f" After reranking: {stats.get('after_rerank')}") | |
| print(f" After compression: {stats.get('after_compression')}") | |
| return result | |
| except requests.exceptions.RequestException as e: | |
| print(f"\n✗ ERROR: {e}") | |
| if hasattr(e.response, 'text'): | |
| print(f" Response: {e.response.text}") | |
| return None | |
| def compare_basic_vs_advanced_rag(self, message: str, hf_token: Optional[str] = None): | |
| """Compare basic RAG vs advanced RAG side by side""" | |
| print(f"\n{'='*60}") | |
| print(f"COMPARISON: Basic RAG vs Advanced RAG") | |
| print(f"{'='*60}") | |
| print(f"Question: {message}\n") | |
| # Test Basic RAG | |
| print("\n--- BASIC RAG ---") | |
| basic_result = self.test_advanced_rag_chat( | |
| message=message, | |
| hf_token=hf_token, | |
| use_advanced_rag=False | |
| ) | |
| # Test Advanced RAG | |
| print("\n--- ADVANCED RAG ---") | |
| advanced_result = self.test_advanced_rag_chat( | |
| message=message, | |
| hf_token=hf_token, | |
| use_advanced_rag=True | |
| ) | |
| # Compare | |
| print(f"\n{'='*60}") | |
| print("COMPARISON SUMMARY") | |
| print(f"{'='*60}") | |
| if basic_result and advanced_result: | |
| print(f"Basic RAG:") | |
| print(f" - Retrieved docs: {len(basic_result['context_used'])}") | |
| print(f"\nAdvanced RAG:") | |
| print(f" - Retrieved docs: {len(advanced_result['context_used'])}") | |
| if advanced_result.get('rag_stats'): | |
| stats = advanced_result['rag_stats'] | |
| print(f" - Query expansion: {len(stats.get('expanded_queries', []))} variants") | |
| print(f" - Initial retrieval: {stats.get('initial_results', 0)} docs") | |
| print(f" - After reranking: {stats.get('after_rerank', 0)} docs") | |
| def main(): | |
| """Run tests""" | |
| tester = AdvancedRAGTester() | |
| print("="*60) | |
| print("ADVANCED RAG FEATURE TESTS") | |
| print("="*60) | |
| # Test 1: Index with multiple texts (no images for demo) | |
| print("\n\n### TEST 1: Index Multiple Texts ###") | |
| tester.test_multiple_index( | |
| doc_id="event_music_festival_2025", | |
| texts=[ | |
| "Festival âm nhạc quốc tế Hà Nội 2025", | |
| "Thời gian: 15-17 tháng 11 năm 2025", | |
| "Địa điểm: Công viên Thống Nhất, Hà Nội", | |
| "Line-up: Sơn Tùng MTP, Đen Vâu, Hoàng Thùy Linh, Mỹ Tâm", | |
| "Giá vé: Early bird 500.000đ, VIP 2.000.000đ", | |
| "Dự kiến 50.000 khán giả tham dự", | |
| "3 sân khấu chính, 5 food court, khu vực cắm trại" | |
| ] | |
| ) | |
| # Test 2: Index another document | |
| print("\n\n### TEST 2: Index Another Document ###") | |
| tester.test_multiple_index( | |
| doc_id="safety_guidelines", | |
| texts=[ | |
| "Vũ khí và đồ vật nguy hiểm bị cấm mang vào sự kiện", | |
| "Dao, kiếm, súng và các loại vũ khí nguy hiểm nghiêm cấm", | |
| "An ninh sẽ kiểm tra tất cả túi xách và đồ mang theo", | |
| "Vi phạm sẽ bị tịch thu và có thể bị trục xuất khỏi sự kiện" | |
| ] | |
| ) | |
| # Test 3: Basic chat (without HF token - will show placeholder) | |
| print("\n\n### TEST 3: Basic RAG Chat (No LLM) ###") | |
| tester.test_advanced_rag_chat( | |
| message="Festival Hà Nội diễn ra khi nào?", | |
| use_advanced_rag=False | |
| ) | |
| # Test 4: Advanced RAG chat | |
| print("\n\n### TEST 4: Advanced RAG Chat (No LLM) ###") | |
| tester.test_advanced_rag_chat( | |
| message="Festival Hà Nội diễn ra khi nào và có những nghệ sĩ nào?", | |
| use_advanced_rag=True, | |
| use_reranking=True, | |
| use_compression=True | |
| ) | |
| # Test 5: Compare basic vs advanced | |
| print("\n\n### TEST 5: Comparison Test ###") | |
| tester.compare_basic_vs_advanced_rag( | |
| message="Dao có được mang vào sự kiện không?" | |
| ) | |
| print("\n\n" + "="*60) | |
| print("ALL TESTS COMPLETED") | |
| print("="*60) | |
| print("\nNOTE: To test with actual LLM responses, add your Hugging Face token:") | |
| print(" tester.test_advanced_rag_chat(message='...', hf_token='hf_xxxxx')") | |
| if __name__ == "__main__": | |
| main() | |