Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Test script to verify the fixes work correctly.""" | |
| import sys | |
| import os | |
| import json | |
| from pathlib import Path | |
| # Add the src directory to the path | |
| sys.path.insert(0, str(Path(__file__).parent / "src")) | |
| def test_dataclass_creation(): | |
| """Test that the AutoEvalColumn dataclass can be created successfully.""" | |
| print("Testing AutoEvalColumn dataclass creation...") | |
| try: | |
| from src.display.utils import AutoEvalColumn, fields | |
| # Test that we can access the fields | |
| all_fields = fields(AutoEvalColumn) | |
| print(f"β Successfully created AutoEvalColumn with {len(all_fields)} fields") | |
| # Test that the average field exists | |
| assert hasattr(AutoEvalColumn, 'average'), "Missing 'average' field" | |
| print("β 'average' field exists") | |
| # Test that we can access field names | |
| field_names = [c.name for c in all_fields] | |
| assert 'average' in field_names, "Average field not in field names" | |
| print("β Average field accessible in field names") | |
| return True | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return False | |
| def test_precision_from_str(): | |
| """Test that the Precision.from_str method works correctly.""" | |
| print("Testing Precision.from_str method...") | |
| try: | |
| from src.display.utils import Precision | |
| # Test different precision values | |
| result1 = Precision.from_str("torch.float16") | |
| assert result1 == Precision.float16, f"Expected float16, got {result1}" | |
| print("β torch.float16 correctly parsed") | |
| result2 = Precision.from_str("float16") | |
| assert result2 == Precision.float16, f"Expected float16, got {result2}" | |
| print("β float16 correctly parsed") | |
| result3 = Precision.from_str("torch.bfloat16") | |
| assert result3 == Precision.bfloat16, f"Expected bfloat16, got {result3}" | |
| print("β torch.bfloat16 correctly parsed") | |
| result4 = Precision.from_str("unknown") | |
| assert result4 == Precision.Unknown, f"Expected Unknown, got {result4}" | |
| print("β Unknown precision correctly parsed") | |
| return True | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return False | |
| def test_eval_result_parsing(): | |
| """Test that the EvalResult can parse JSON files correctly.""" | |
| print("Testing EvalResult JSON parsing...") | |
| try: | |
| from src.leaderboard.read_evals import EvalResult | |
| from src.about import Tasks | |
| # Create a sample result file | |
| sample_result = { | |
| "config": { | |
| "model_name": "test/model", | |
| "model_dtype": "torch.float16", | |
| "model_sha": "abc123" | |
| }, | |
| "results": { | |
| "emea_ner": {"f1": 0.85}, | |
| "medline_ner": {"f1": 0.82} | |
| } | |
| } | |
| # Write to temp file | |
| temp_file = "/tmp/test_result.json" | |
| with open(temp_file, 'w') as f: | |
| json.dump(sample_result, f) | |
| # Test parsing | |
| result = EvalResult.init_from_json_file(temp_file) | |
| assert result.full_model == "test/model", f"Expected test/model, got {result.full_model}" | |
| assert result.org == "test", f"Expected test, got {result.org}" | |
| assert result.model == "model", f"Expected model, got {result.model}" | |
| assert result.revision == "abc123", f"Expected abc123, got {result.revision}" | |
| print("β JSON parsing works correctly") | |
| # Test with missing fields | |
| sample_result_minimal = { | |
| "config": { | |
| "model": "test/model2" | |
| }, | |
| "results": { | |
| "emea_ner": {"f1": 0.75} | |
| } | |
| } | |
| temp_file_minimal = "/tmp/test_result_minimal.json" | |
| with open(temp_file_minimal, 'w') as f: | |
| json.dump(sample_result_minimal, f) | |
| result_minimal = EvalResult.init_from_json_file(temp_file_minimal) | |
| assert result_minimal.full_model == "test/model2", f"Expected test/model2, got {result_minimal.full_model}" | |
| print("β Minimal JSON parsing works correctly") | |
| # Clean up | |
| os.remove(temp_file) | |
| os.remove(temp_file_minimal) | |
| return True | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return False | |
| def test_to_dict(): | |
| """Test that EvalResult.to_dict works correctly.""" | |
| print("Testing EvalResult.to_dict method...") | |
| try: | |
| from src.leaderboard.read_evals import EvalResult | |
| from src.display.utils import Precision, ModelType, WeightType | |
| # Create a test EvalResult | |
| eval_result = EvalResult( | |
| eval_name="test_model_float16", | |
| full_model="test/model", | |
| org="test", | |
| model="model", | |
| revision="abc123", | |
| results={"emea_ner": 85.0, "medline_ner": 82.0}, | |
| precision=Precision.float16, | |
| model_type=ModelType.FT, | |
| weight_type=WeightType.Original, | |
| architecture="BertForTokenClassification", | |
| license="MIT", | |
| likes=10, | |
| num_params=110, | |
| date="2023-01-01", | |
| still_on_hub=True | |
| ) | |
| # Test to_dict conversion | |
| result_dict = eval_result.to_dict() | |
| # Check that all required fields are present | |
| assert "average" in result_dict, "Missing average field in dict" | |
| assert result_dict["average"] == 83.5, f"Expected average 83.5, got {result_dict['average']}" | |
| print("β to_dict method works correctly") | |
| print(f" - Average: {result_dict['average']}") | |
| return True | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return False | |
| def main(): | |
| """Run all tests.""" | |
| print("Running bug fix tests...\n") | |
| tests = [ | |
| test_dataclass_creation, | |
| test_precision_from_str, | |
| test_eval_result_parsing, | |
| test_to_dict, | |
| ] | |
| results = [] | |
| for test in tests: | |
| print(f"\n{'='*50}") | |
| try: | |
| result = test() | |
| results.append(result) | |
| except Exception as e: | |
| print(f"β Test {test.__name__} failed with exception: {e}") | |
| results.append(False) | |
| print(f"\n{'='*50}") | |
| print(f"Test Results: {sum(results)}/{len(results)} tests passed") | |
| if all(results): | |
| print("π All tests passed! The fixes are working correctly.") | |
| return 0 | |
| else: | |
| print("β Some tests failed. Please check the output above.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) |