DataEngEval / test /test_system.py
uparekh01151's picture
Initial commit for DataEngEval
acd8e16
raw
history blame
6.04 kB
"""
Test script to verify the NL→SQL Leaderboard system works correctly.
"""
import os
import sys
import time
# Add src to path for imports
sys.path.append('src')
from evaluator import evaluator, DatasetManager
from models_registry import models_registry
from scoring import scoring_engine
def test_dataset_discovery():
"""Test that datasets are discovered correctly."""
print("Testing dataset discovery...")
dataset_manager = DatasetManager()
datasets = dataset_manager.get_datasets()
print(f"Found datasets: {list(datasets.keys())}")
if "nyc_taxi_small" in datasets:
print("✓ NYC Taxi dataset found")
return True
else:
print("✗ NYC Taxi dataset not found")
return False
def test_models_loading():
"""Test that models are loaded correctly."""
print("\nTesting models loading...")
models = models_registry.get_models()
print(f"Found models: {[model.name for model in models]}")
if len(models) > 0:
print("✓ Models loaded successfully")
return True
else:
print("✗ No models found")
return False
def test_database_creation():
"""Test database creation for NYC Taxi dataset."""
print("\nTesting database creation...")
try:
dataset_manager = DatasetManager()
db_path = dataset_manager.create_database("nyc_taxi_small")
if os.path.exists(db_path):
print("✓ Database created successfully")
# Clean up
os.remove(db_path)
return True
else:
print("✗ Database file not created")
return False
except Exception as e:
print(f"✗ Database creation failed: {e}")
return False
def test_cases_loading():
"""Test loading test cases."""
print("\nTesting cases loading...")
try:
dataset_manager = DatasetManager()
cases = dataset_manager.load_cases("nyc_taxi_small")
print(f"Found {len(cases)} test cases")
if len(cases) > 0:
print("✓ Test cases loaded successfully")
return True
else:
print("✗ No test cases found")
return False
except Exception as e:
print(f"✗ Cases loading failed: {e}")
return False
def test_prompt_templates():
"""Test that prompt templates exist."""
print("\nTesting prompt templates...")
dialects = ["presto", "bigquery", "snowflake"]
all_exist = True
for dialect in dialects:
template_path = f"prompts/template_{dialect}.txt"
if os.path.exists(template_path):
print(f"✓ {dialect} template found")
else:
print(f"✗ {dialect} template not found")
all_exist = False
return all_exist
def test_scoring_engine():
"""Test the scoring engine."""
print("\nTesting scoring engine...")
try:
from scoring import Metrics
# Test with sample metrics
metrics = Metrics(
correctness_exact=1.0,
result_match_f1=0.8,
exec_success=1.0,
latency_ms=100.0,
readability=0.9,
dialect_ok=1.0
)
score = scoring_engine.compute_composite_score(metrics)
print(f"✓ Composite score computed: {score}")
if 0.0 <= score <= 1.0:
print("✓ Score is in valid range")
return True
else:
print("✗ Score is out of valid range")
return False
except Exception as e:
print(f"✗ Scoring engine test failed: {e}")
return False
def test_sql_execution():
"""Test SQL execution with DuckDB."""
print("\nTesting SQL execution...")
try:
import duckdb
# Create a simple test database
conn = duckdb.connect(":memory:")
conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR(10))")
conn.execute("INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob')")
# Test query
result = conn.execute("SELECT COUNT(*) FROM test").fetchdf()
print(f"✓ SQL execution successful: {result.iloc[0, 0]} rows")
conn.close()
return True
except Exception as e:
print(f"✗ SQL execution failed: {e}")
return False
def test_sqlglot_transpilation():
"""Test SQL transpilation with sqlglot."""
print("\nTesting SQL transpilation...")
try:
import sqlglot
# Test simple query
sql = "SELECT COUNT(*) FROM trips"
parsed = sqlglot.parse_one(sql)
# Transpile to different dialects
dialects = ["presto", "bigquery", "snowflake"]
for dialect in dialects:
transpiled = parsed.sql(dialect=dialect)
print(f"✓ {dialect} transpilation: {transpiled}")
return True
except Exception as e:
print(f"✗ SQL transpilation failed: {e}")
return False
def main():
"""Run all tests."""
print("NL→SQL Leaderboard System Test")
print("=" * 40)
tests = [
test_dataset_discovery,
test_models_loading,
test_database_creation,
test_cases_loading,
test_prompt_templates,
test_scoring_engine,
test_sql_execution,
test_sqlglot_transpilation
]
passed = 0
total = len(tests)
for test in tests:
try:
if test():
passed += 1
except Exception as e:
print(f"✗ Test {test.__name__} failed with exception: {e}")
print("\n" + "=" * 40)
print(f"Test Results: {passed}/{total} tests passed")
if passed == total:
print("🎉 All tests passed! The system is ready to use.")
return True
else:
print("❌ Some tests failed. Please check the issues above.")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)