Spaces:
Running
Running
| """ | |
| Script for an iterative scheme. | |
| Assumptions: | |
| - complete pariwise comparisons available, i.e. evaluations are cheap | |
| - | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| from .metrics import mapk, rank_biased_overlap | |
| from .plots import plot_ranks | |
| import logging | |
| from typing import List, Callable, Optional | |
| logger = logging.getLogger(__name__) | |
| tol = 0.001 | |
| class SelfRank: | |
| def __init__(self, MODELS: List, evaluator: Callable, true_ranking: Optional[List]=None, show_progress: Optional[bool]=False): | |
| self.MODELS = MODELS | |
| self.N = len(MODELS) | |
| self.evaluate = evaluator | |
| self.true_ranking = true_ranking | |
| self.show_progress = show_progress | |
| def fit(self, df: pd.DataFrame): | |
| """ | |
| df: Dataframe where each row is a benchmark instance, | |
| and there is a column with the output for each Model | |
| """ | |
| assert set(self.MODELS) == set(df.columns), "Benchmark data models inconsistent with models to be ranked." | |
| # Build a pairwise preference matrix | |
| if self.show_progress: | |
| pbar = tqdm(total=self.N**3, position=0, leave=False, desc="Evaluations") | |
| y = np.empty((self.N, self.N, self.N)) | |
| for i, a in enumerate(self.MODELS): | |
| for j, b in enumerate(self.MODELS): | |
| for k, c in enumerate(self.MODELS): # Judge | |
| # Some checks to limit evaluations | |
| if a == b: | |
| y[i, j, k] = 0.5 | |
| y[j, i, k] = 0.5 | |
| elif a == c: | |
| y[i, j, k] = 1 | |
| y[j, i, k] = 0 | |
| elif b == c: | |
| y[i, j, k] = 0 | |
| y[j, i, k] = 1 | |
| elif j > i: | |
| y[i, j, k] = self.evaluate(a=a, b=b, c=c, df=df) | |
| y[j, i, k] = 1 - y[i, j, k] # complement in the other direction | |
| if self.show_progress: pbar.update(1) | |
| # Estimate the ranks | |
| r = np.ones((self.N, )) | |
| iter = 0 | |
| while True: | |
| # weighted mean over k | |
| m = np.einsum('ijk,i->ij', y, r) / self.N | |
| # Aggregate preferences using majority voting | |
| y_p = np.zeros_like(m) | |
| for i in np.arange(self.N): | |
| for j in np.arange(self.N): | |
| if j > i: | |
| if m[i, j] >= m[j, i]: | |
| y_p[i,j] = 1. | |
| y_p[j,i] = 0. | |
| else: | |
| y_p[i,j] = 0. | |
| y_p[j,i] = 1. | |
| # update reputation score by wins | |
| r_k = y_p.sum(axis=1)/max(y_p.sum(axis=1)) | |
| # termination if reputation score converges | |
| delta = np.sum(np.abs(r - r_k)) | |
| logging.info(f"Iteration {iter}:{delta}") | |
| logging.info(f"Reputation score: {r}") | |
| if delta<= tol: | |
| break | |
| else: | |
| iter += 1 | |
| r = r_k | |
| # Get ranked list from the reputation score | |
| idx = np.argsort(r_k)[::-1] | |
| self.ranking = np.array(self.MODELS)[idx].tolist() | |
| logger.info(f"Estimated ranks (best to worst): {self.ranking}") | |
| if self.true_ranking is not None: | |
| logger.info(f"True ranking: {self.true_ranking}") | |
| logger.info(f"RBO measure: {self.measure()}") | |
| return self.ranking # Best to worst | |
| def measure(self, metric='rbo', k=5, p=0.95) -> float: | |
| """ | |
| Report metric related to self-rank | |
| """ | |
| if metric not in ['rbo', 'mapk']: | |
| raise ValueError(f"Metric {metric} not supported (use 'rbo'/'mapk').") | |
| if hasattr(self, 'ranking'): | |
| if self.true_ranking is not None: | |
| if metric == 'mapk': | |
| if k > len(self.true_ranking): | |
| logger.warning(f"MAPk metric is for k={len(self.true_ranking)}, and not k={k}.") | |
| actual = [self.true_ranking[:k]] | |
| pred = [self.ranking[:k]] | |
| return mapk(actual, pred, k=k) | |
| elif metric == 'rbo': | |
| return rank_biased_overlap(self.true_ranking, self.ranking, p=p) | |
| else: | |
| raise ValueError(f"Metric {metric} not understood.") | |
| else: | |
| raise ValueError("True ranking not available for metric calculation.") | |
| else: | |
| raise ValueError("Ranking not estimated. Run 'fit' first.") | |
| def plot(self, caselabel="output"): | |
| if hasattr(self, 'ranking') & (self.true_ranking is not None): | |
| return plot_ranks(self.true_ranking, self.ranking, "actual", "estimated", caselabel) | |