fairchem_leaderboard / evaluator.py
mshuaibi's picture
better error handling
7c3b81b
raw
history blame
6.8 kB
import logging
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
import json
from fairchem.data.omol.modules.evaluator import (
ligand_pocket,
ligand_strain,
geom_conformers,
protonation_energies,
unoptimized_ie_ea,
distance_scaling,
unoptimized_spin_gap,
)
class SubmissionLoadError(Exception):
"""Raised if unable to load the submission file."""
OMOL_EVAL_FUNCTIONS = {
"Ligand pocket": ligand_pocket,
"Ligand strain": ligand_strain,
"Conformers": geom_conformers,
"Protonation": protonation_energies,
"IE_EA": unoptimized_ie_ea,
"Distance scaling": distance_scaling,
"Spin gap": unoptimized_spin_gap,
}
OMOL_DATA_ID_MAPPING = {
"metal_complexes": ["metal_complexes"],
"electrolytes": ["elytes"],
"biomolecules": ["biomolecules"],
"neutral_organics": ["ani2x", "orbnet_denali", "geom_orca6", "trans1x", "rgd"],
}
def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray:
"""
Get the ordering so that `to_reorder[ordering]` == ref.
eg:
ref = [c, a, b]
to_reorder = [b, a, c]
order = reorder(ref, to_reorder) # [2, 1, 0]
assert ref == to_reorder[order]
Parameters
----------
ref : np.ndarray
Reference array. Must not contains duplicates.
to_reorder : np.ndarray
Array to re-order. Must not contains duplicates.
Items must be the same as in `ref`.
Returns
-------
np.ndarray
the ordering to apply on `to_reorder`
"""
assert len(ref) == len(set(ref))
assert len(to_reorder) == len(set(to_reorder))
assert set(ref) == set(to_reorder)
item_to_idx = {item: idx for idx, item in enumerate(to_reorder)}
return np.array([item_to_idx[item] for item in ref])
def get_order(path_submission: Path, path_annotations: Path):
try:
with np.load(path_submission) as data:
submission_ids = data["ids"]
except Exception as e:
raise SubmissionLoadError(
f"Error loading submission file. 'ids' must not be object types."
) from e
with np.load(path_annotations, allow_pickle=True) as data:
annotations_ids = data["ids"]
# Use sets for faster comparison
submission_set = set(submission_ids)
annotations_set = set(annotations_ids)
if submission_set != annotations_set:
missing_ids = annotations_set - submission_set
unexpected_ids = submission_set - annotations_set
details = (
f"{len(missing_ids)} missing IDs: ({list(missing_ids)[:3]}, ...)\n"
f"{len(unexpected_ids)} unexpected IDs: ({list(unexpected_ids)[:3]}, ...)"
)
raise Exception(f"IDs don't match.\n{details}")
assert len(submission_ids) == len(
submission_set
), "Duplicate IDs found in submission."
return reorder(annotations_ids, submission_ids)
def s2ef_metrics(
annotations_path: Path,
submission_filename: Path,
subsets: list = ["all"],
) -> Dict[str, float]:
order = get_order(submission_filename, annotations_path)
try:
with np.load(submission_filename) as data:
forces = data["forces"]
energy = data["energy"][order]
forces = np.array(
np.split(forces, np.cumsum(data["natoms"])[:-1]), dtype=object
)[order]
except Exception as e:
raise SubmissionLoadError(
f"Error loading submission data. Make sure you concatenated your forces and there are no object types."
) from e
if len(set(np.where(np.isinf(energy))[0])) != 0:
inf_energy_ids = list(set(np.where(np.isinf(energy))[0]))
raise Exception(
f"Inf values found in `energy` for IDs: ({inf_energy_ids[:3]}, ...)"
)
with np.load(annotations_path, allow_pickle=True) as data:
target_forces = data["forces"]
target_energy = data["energy"]
target_data_ids = data["data_ids"]
metrics = {}
for subset in subsets:
if subset == "all":
subset_mask = np.ones(len(target_data_ids), dtype=bool)
else:
allowed_ids = set(OMOL_DATA_ID_MAPPING.get(subset, []))
subset_mask = np.array(
[data_id in allowed_ids for data_id in target_data_ids]
)
sub_energy = energy[subset_mask]
sub_target_energy = target_energy[subset_mask]
energy_mae = np.mean(np.abs(sub_target_energy - sub_energy))
metrics[f"{subset}_energy_mae"] = energy_mae
forces_mae = 0
natoms = 0
for sub_forces, sub_target_forces in zip(
forces[subset_mask], target_forces[subset_mask]
):
forces_mae += np.sum(np.abs(sub_target_forces - sub_forces))
natoms += sub_forces.shape[0]
forces_mae /= 3 * natoms
metrics[f"{subset}_forces_mae"] = forces_mae
return metrics
def omol_evaluations(
annotations_path: Path,
submission_filename: Path,
eval_type: str,
) -> Dict[str, float]:
try:
with open(submission_filename) as f:
submission_data = json.load(f)
except Exception as e:
raise SubmissionLoadError(f"Error loading submission file") from e
with open(annotations_path) as f:
annotations_data = json.load(f)
submission_entries = set(submission_data.keys())
annotation_entries = set(annotations_data.keys())
if submission_entries != annotation_entries:
missing = annotation_entries - submission_entries
unexpected = submission_entries - annotation_entries
raise ValueError(
f"Submission and annotations entries do not match.\n"
f"Missing entries in submission: {missing}\n"
f"Unexpected entries in submission: {unexpected}"
)
assert len(submission_entries) == len(
submission_data
), "Duplicate entries found in submission."
eval_fn = OMOL_EVAL_FUNCTIONS.get(eval_type)
metrics = eval_fn(annotations_data, submission_data)
return metrics
def evaluate(
annotations_path: Path,
submission_filename: Path,
eval_type: str,
):
if eval_type in ["Validation", "Test"]:
metrics = s2ef_metrics(
annotations_path,
submission_filename,
subsets=[
"all",
"metal_complexes",
"electrolytes",
"biomolecules",
"neutral_organics",
],
)
elif eval_type in OMOL_EVAL_FUNCTIONS:
metrics = omol_evaluations(
annotations_path,
submission_filename,
eval_type,
)
else:
raise ValueError(f"Unknown eval_type: {eval_type}")
return metrics