Spaces:
Sleeping
Sleeping
| """ | |
| Test script to verify the NL→SQL Leaderboard system works correctly. | |
| """ | |
| import os | |
| import sys | |
| import time | |
| # Add src to path for imports | |
| sys.path.append('src') | |
| from evaluator import evaluator, DatasetManager | |
| from models_registry import models_registry | |
| from scoring import scoring_engine | |
| def test_dataset_discovery(): | |
| """Test that datasets are discovered correctly.""" | |
| print("Testing dataset discovery...") | |
| dataset_manager = DatasetManager() | |
| datasets = dataset_manager.get_datasets() | |
| print(f"Found datasets: {list(datasets.keys())}") | |
| if "nyc_taxi_small" in datasets: | |
| print("✓ NYC Taxi dataset found") | |
| return True | |
| else: | |
| print("✗ NYC Taxi dataset not found") | |
| return False | |
| def test_models_loading(): | |
| """Test that models are loaded correctly.""" | |
| print("\nTesting models loading...") | |
| models = models_registry.get_models() | |
| print(f"Found models: {[model.name for model in models]}") | |
| if len(models) > 0: | |
| print("✓ Models loaded successfully") | |
| return True | |
| else: | |
| print("✗ No models found") | |
| return False | |
| def test_database_creation(): | |
| """Test database creation for NYC Taxi dataset.""" | |
| print("\nTesting database creation...") | |
| try: | |
| dataset_manager = DatasetManager() | |
| db_path = dataset_manager.create_database("nyc_taxi_small") | |
| if os.path.exists(db_path): | |
| print("✓ Database created successfully") | |
| # Clean up | |
| os.remove(db_path) | |
| return True | |
| else: | |
| print("✗ Database file not created") | |
| return False | |
| except Exception as e: | |
| print(f"✗ Database creation failed: {e}") | |
| return False | |
| def test_cases_loading(): | |
| """Test loading test cases.""" | |
| print("\nTesting cases loading...") | |
| try: | |
| dataset_manager = DatasetManager() | |
| cases = dataset_manager.load_cases("nyc_taxi_small") | |
| print(f"Found {len(cases)} test cases") | |
| if len(cases) > 0: | |
| print("✓ Test cases loaded successfully") | |
| return True | |
| else: | |
| print("✗ No test cases found") | |
| return False | |
| except Exception as e: | |
| print(f"✗ Cases loading failed: {e}") | |
| return False | |
| def test_prompt_templates(): | |
| """Test that prompt templates exist.""" | |
| print("\nTesting prompt templates...") | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| all_exist = True | |
| for dialect in dialects: | |
| template_path = f"prompts/template_{dialect}.txt" | |
| if os.path.exists(template_path): | |
| print(f"✓ {dialect} template found") | |
| else: | |
| print(f"✗ {dialect} template not found") | |
| all_exist = False | |
| return all_exist | |
| def test_scoring_engine(): | |
| """Test the scoring engine.""" | |
| print("\nTesting scoring engine...") | |
| try: | |
| from scoring import Metrics | |
| # Test with sample metrics | |
| metrics = Metrics( | |
| correctness_exact=1.0, | |
| result_match_f1=0.8, | |
| exec_success=1.0, | |
| latency_ms=100.0, | |
| readability=0.9, | |
| dialect_ok=1.0 | |
| ) | |
| score = scoring_engine.compute_composite_score(metrics) | |
| print(f"✓ Composite score computed: {score}") | |
| if 0.0 <= score <= 1.0: | |
| print("✓ Score is in valid range") | |
| return True | |
| else: | |
| print("✗ Score is out of valid range") | |
| return False | |
| except Exception as e: | |
| print(f"✗ Scoring engine test failed: {e}") | |
| return False | |
| def test_sql_execution(): | |
| """Test SQL execution with DuckDB.""" | |
| print("\nTesting SQL execution...") | |
| try: | |
| import duckdb | |
| # Create a simple test database | |
| conn = duckdb.connect(":memory:") | |
| conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR(10))") | |
| conn.execute("INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob')") | |
| # Test query | |
| result = conn.execute("SELECT COUNT(*) FROM test").fetchdf() | |
| print(f"✓ SQL execution successful: {result.iloc[0, 0]} rows") | |
| conn.close() | |
| return True | |
| except Exception as e: | |
| print(f"✗ SQL execution failed: {e}") | |
| return False | |
| def test_sqlglot_transpilation(): | |
| """Test SQL transpilation with sqlglot.""" | |
| print("\nTesting SQL transpilation...") | |
| try: | |
| import sqlglot | |
| # Test simple query | |
| sql = "SELECT COUNT(*) FROM trips" | |
| parsed = sqlglot.parse_one(sql) | |
| # Transpile to different dialects | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| for dialect in dialects: | |
| transpiled = parsed.sql(dialect=dialect) | |
| print(f"✓ {dialect} transpilation: {transpiled}") | |
| return True | |
| except Exception as e: | |
| print(f"✗ SQL transpilation failed: {e}") | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("NL→SQL Leaderboard System Test") | |
| print("=" * 40) | |
| tests = [ | |
| test_dataset_discovery, | |
| test_models_loading, | |
| test_database_creation, | |
| test_cases_loading, | |
| test_prompt_templates, | |
| test_scoring_engine, | |
| test_sql_execution, | |
| test_sqlglot_transpilation | |
| ] | |
| passed = 0 | |
| total = len(tests) | |
| for test in tests: | |
| try: | |
| if test(): | |
| passed += 1 | |
| except Exception as e: | |
| print(f"✗ Test {test.__name__} failed with exception: {e}") | |
| print("\n" + "=" * 40) | |
| print(f"Test Results: {passed}/{total} tests passed") | |
| if passed == total: | |
| print("🎉 All tests passed! The system is ready to use.") | |
| return True | |
| else: | |
| print("❌ Some tests failed. Please check the issues above.") | |
| return False | |
| if __name__ == "__main__": | |
| success = main() | |
| sys.exit(0 if success else 1) | |