Source code for pykeen.evaluation.rank_based_evaluator

# -*- coding: utf-8 -*-

"""Implementation of ranked based evaluator."""

import itertools as itt
import logging
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import DefaultDict, Dict, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import torch
from dataclasses_json import dataclass_json
from scipy import stats

from .evaluator import Evaluator, MetricResults
from ..typing import MappedTriples
from ..utils import fix_dataclass_init_docs

__all__ = [
    'compute_rank_from_scores',
    'RankBasedEvaluator',
    'RankBasedMetricResults',
]

logger = logging.getLogger(__name__)

SIDE_HEAD = 'head'
SIDE_TAIL = 'tail'
SIDE_BOTH = 'both'
SIDES = {SIDE_HEAD, SIDE_TAIL, SIDE_BOTH}

RANK_OPTIMISTIC = 'optimistic'
RANK_PESSIMISTIC = 'pessimistic'
RANK_REALISTIC = 'realistic'
RANK_TYPES = {RANK_OPTIMISTIC, RANK_PESSIMISTIC, RANK_REALISTIC}

RANK_EXPECTED_REALISTIC = 'expected_realistic'
EXPECTED_RANKS = {
    RANK_REALISTIC: RANK_EXPECTED_REALISTIC,
    RANK_OPTIMISTIC: None,  # TODO - research problem
    RANK_PESSIMISTIC: None,  # TODO - research problem
}

ARITHMETIC_MEAN_RANK = 'arithmetic_mean_rank'  # also known as mean rank (MR)
GEOMETRIC_MEAN_RANK = 'geometric_mean_rank'
HARMONIC_MEAN_RANK = 'harmonic_mean_rank'
MEDIAN_RANK = 'median_rank'
INVERSE_ARITHMETIC_MEAN_RANK = 'inverse_arithmetic_mean_rank'
INVERSE_GEOMETRIC_MEAN_RANK = 'inverse_geometric_mean_rank'
INVERSE_HARMONIC_MEAN_RANK = 'inverse_harmonic_mean_rank'  # also known as mean reciprocal rank (MRR)
INVERSE_MEDIAN_RANK = 'inverse_median_rank'

RANK_STD = 'rank_std'
RANK_VARIANCE = 'rank_var'
RANK_MAD = 'rank_mad'

all_type_funcs = {
    ARITHMETIC_MEAN_RANK: np.mean,  # This is MR
    HARMONIC_MEAN_RANK: stats.hmean,
    GEOMETRIC_MEAN_RANK: stats.gmean,
    MEDIAN_RANK: np.median,
    INVERSE_ARITHMETIC_MEAN_RANK: lambda x: np.reciprocal(np.mean(x)),
    INVERSE_GEOMETRIC_MEAN_RANK: lambda x: np.reciprocal(stats.gmean(x)),
    INVERSE_HARMONIC_MEAN_RANK: lambda x: np.reciprocal(stats.hmean(x)),  # This is MRR
    INVERSE_MEDIAN_RANK: lambda x: np.reciprocal(np.median(x)),
    # Extra stats stuff
    RANK_STD: np.std,
    RANK_VARIANCE: np.var,
    RANK_MAD: stats.median_abs_deviation,
}

ADJUSTED_ARITHMETIC_MEAN_RANK = 'adjusted_arithmetic_mean_rank'
ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX = 'adjusted_arithmetic_mean_rank_index'
TYPES_REALISTIC_ONLY = {ADJUSTED_ARITHMETIC_MEAN_RANK, ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX}

METRIC_SYNONYMS = {
    'adjusted_mean_rank': ADJUSTED_ARITHMETIC_MEAN_RANK,
    'adjusted_mean_rank_index': ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX,
    'igmr': INVERSE_GEOMETRIC_MEAN_RANK,
    'mr': ARITHMETIC_MEAN_RANK,
    'mean_rank': ARITHMETIC_MEAN_RANK,
    'mrr': INVERSE_HARMONIC_MEAN_RANK,
    'mean_reciprocal_rank': INVERSE_HARMONIC_MEAN_RANK,
}


def compute_rank_from_scores(
    true_score: torch.FloatTensor,
    all_scores: torch.FloatTensor,
) -> Dict[str, torch.FloatTensor]:
    """Compute rank and adjusted rank given scores.

    :param true_score: torch.Tensor, shape: (batch_size, 1)
        The score of the true triple.
    :param all_scores: torch.Tensor, shape: (batch_size, num_entities)
        The scores of all corrupted triples (including the true triple).
    :return: a dictionary
        {
            'optimistic': optimistic_rank,
            'pessimistic': pessimistic_rank,
            'realistic': realistic_rank,
            'expected_realistic': expected_realistic_rank,
        }

        where

        optimistic_rank: shape: (batch_size,)
            The optimistic rank is the rank when assuming all options with an equal score are placed behind the current
            test triple.
        pessimistic_rank:
            The pessimistic rank is the rank when assuming all options with an equal score are placed in front of
            current test triple.
        realistic_rank:
            The realistic rank is the average of the optimistic and pessimistic rank, and hence the expected rank
            over all permutations of the elements with the same score as the currently considered option.
        expected_realistic_rank: shape: (batch_size,)
            The expected rank a random scoring would achieve, which is (#number_of_options + 1)/2
    """
    # The optimistic rank is the rank when assuming all options with an equal score are placed behind the currently
    # considered. Hence, the rank is the number of options with better scores, plus one, as the rank is one-based.
    optimistic_rank = (all_scores > true_score).sum(dim=1) + 1

    # The pessimistic rank is the rank when assuming all options with an equal score are placed in front of the
    # currently considered. Hence, the rank is the number of options which have at least the same score minus one
    # (as the currently considered option in included in all options). As the rank is one-based, we have to add 1,
    # which nullifies the "minus 1" from before.
    pessimistic_rank = (all_scores >= true_score).sum(dim=1)

    # The realistic rank is the average of the optimistic and pessimistic rank, and hence the expected rank over
    # all permutations of the elements with the same score as the currently considered option.
    realistic_rank = (optimistic_rank + pessimistic_rank).float() * 0.5

    # We set values which should be ignored to NaN, hence the number of options which should be considered is given by
    number_of_options = torch.isfinite(all_scores).sum(dim=1).float()

    # The expected rank of a random scoring
    expected_realistic_rank = 0.5 * (number_of_options + 1)

    return {
        RANK_OPTIMISTIC: optimistic_rank,
        RANK_PESSIMISTIC: pessimistic_rank,
        RANK_REALISTIC: realistic_rank,
        RANK_EXPECTED_REALISTIC: expected_realistic_rank,
    }


[docs]@fix_dataclass_init_docs @dataclass_json @dataclass class RankBasedMetricResults(MetricResults): """Results from computing metrics.""" arithmetic_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Mean Rank (MR)", doc='The arithmetic mean over all ranks on, [1, inf). Lower is better.', )) geometric_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Geometric Mean Rank (GMR)", doc='The geometric mean over all ranks, on [1, inf). Lower is better.', )) median_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Median Rank", doc='The median over all ranks, on [1, inf). Lower is better.', )) harmonic_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Harmonic Mean Rank (HMR)", doc='The harmonic mean over all ranks, on [1, inf). Lower is better.', )) inverse_arithmetic_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Inverse Arithmetic Mean Rank (IAMR)", doc='The inverse of the arithmetic mean over all ranks, on (0, 1]. Higher is better.', )) inverse_geometric_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Inverse Geometric Mean Rank (IGMR)", doc='The inverse of the geometric mean over all ranks, on (0, 1]. Higher is better.', )) inverse_harmonic_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Mean Reciprocal Rank (MRR)", doc='The inverse of the harmonic mean over all ranks, on (0, 1]. Higher is better.', )) inverse_median_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name="Inverse Median Rank", doc='The inverse of the median over all ranks, on (0, 1]. Higher is better.', )) rank_std: Dict[str, Dict[str, float]] = field(metadata=dict( name="Rank Standard Deviation", doc='The standard deviation over all ranks on, [0, inf). Lower is better.', )) rank_var: Dict[str, Dict[str, float]] = field(metadata=dict( name="Rank Variance", doc='The variance over all ranks on, [0, inf). Lower is better.', )) rank_mad: Dict[str, Dict[str, float]] = field(metadata=dict( name="Rank Median Absolute Deviation", doc='The median absolute deviation over all ranks on, [0, inf). Lower is better.', )) hits_at_k: Dict[str, Dict[str, Dict[Union[int, float], float]]] = field(metadata=dict( name='Hits @ K', doc='The relative frequency of ranks not larger than a given k, on [0, 1]. Higher is better', )) adjusted_arithmetic_mean_rank: Dict[str, Dict[str, float]] = field(metadata=dict( name='Adjusted Arithmetic Mean Rank (AAMR)', doc='The mean over all chance-adjusted ranks, on (0, 2). Lower is better.', )) adjusted_arithmetic_mean_rank_index: Dict[str, Dict[str, float]] = field(metadata=dict( name='Adjusted Arithmetic Mean Rank Index (AAMRI)', doc='The re-indexed adjusted mean rank (AAMR), on [-1, 1]. Higher is better.', ))
[docs] def get_metric(self, name: str) -> float: """Get the rank-based metric. :param name: The name of the metric, created by concatenating three parts: 1. The side (one of "head", "tail", or "both"). Most publications exclusively report "both". 2. The type (one of "optimistic", "pessimistic", "realistic") 3. The metric name ("adjusted_mean_rank_index", "adjusted_mean_rank", "mean_rank, "mean_reciprocal_rank", "inverse_geometric_mean_rank", or "hits@k" where k defaults to 10 but can be substituted for an integer. By default, 1, 3, 5, and 10 are available. Other K's can be calculated by setting the appropriate variable in the ``evaluation_kwargs`` in the :func:`pykeen.pipeline.pipeline` or setting ``ks`` in the :class:`pykeen.evaluation.RankBasedEvaluator`. In general, all metrics are available for all combinations of sides/types except AMR and AMRI, which are only calculated for the average type. This is because the calculation of the expected MR in the optimistic and pessimistic case scenarios is still an active area of research and therefore has no implementation yet. :return: The value for the metric :raises ValueError: if an invalid name is given. Get the average MR >>> metric_results.get('both.realistic.mean_rank') If you only give a metric name, it assumes that it's for "both" sides and "realistic" type. >>> metric_results.get('adjusted_mean_rank_index') This function will do its best to infer what's going on if you only specify one part. >>> metric_results.get('left.mean_rank') >>> metric_results.get('optimistic.mean_rank') Get the default Hits @ K (where $k=10$) >>> metric_results.get('hits@k') Get a given Hits @ K >>> metric_results.get('hits@5') """ dot_count = name.count('.') if 0 == dot_count: # assume average by default side, rank_type, metric = SIDE_BOTH, RANK_REALISTIC, name elif 1 == dot_count: # Check if it a side or rank type side_or_ranktype, metric = name.split('.') if side_or_ranktype in SIDES: side = side_or_ranktype rank_type = RANK_REALISTIC else: side = SIDE_BOTH rank_type = side_or_ranktype elif 2 == dot_count: side, rank_type, metric = name.split('.') else: raise ValueError(f'Malformed metric name: {name}') # update old names for metrics and handle spaces metric = metric.lower().replace(' ', '_') metric = METRIC_SYNONYMS.get(metric, metric) if side not in SIDES: raise ValueError(f'Invalid side: {side}. Allowed sides: {SIDES}') if rank_type not in RANK_REALISTIC and metric in TYPES_REALISTIC_ONLY: raise ValueError(f'Invalid rank type for {metric}: {rank_type}. Allowed type: {RANK_REALISTIC}') elif rank_type not in RANK_TYPES: raise ValueError(f'Invalid rank type: {rank_type}. Allowed types: {RANK_TYPES}') if not metric.startswith('hits'): return getattr(self, metric)[side][rank_type] # otherwise, assume is hits@k, which is handled differently rank_type_hits_at_k = self.hits_at_k[side][rank_type] for prefix in ('hits_at_', 'hits@'): if not metric.startswith(prefix): continue k = metric[len(prefix):] k_int = 10 if k == 'k' else int(k) return rank_type_hits_at_k[k_int] raise ValueError(f'Invalid metric name: {name}')
[docs] def to_flat_dict(self): # noqa: D102 return { f'{side}.{rank_type}.{metric_name}': value for side, rank_type, metric_name, value in self._iter_rows() }
[docs] def to_df(self) -> pd.DataFrame: """Output the metrics as a pandas dataframe.""" return pd.DataFrame(list(self._iter_rows()), columns=['Side', 'Type', 'Metric', 'Value'])
def _iter_rows(self) -> Iterable[Tuple[str, str, str, float]]: for side, rank_type in itt.product(SIDES, RANK_TYPES): for k, v in self.hits_at_k[side][rank_type].items(): yield side, rank_type, f'hits_at_{k}', v for f in fields(self): if f.name == 'hits_at_k': continue side_data = getattr(self, f.name)[side] if rank_type in side_data: yield side, rank_type, f.name, side_data[rank_type]
[docs]class RankBasedEvaluator(Evaluator): r"""A rank-based evaluator for KGE models. Calculates the following metrics: - Mean Rank (MR) with range $[1, \infty)$ where closer to 0 is better - Adjusted Mean Rank (AMR; [berrendorf2020]_) with range $(0, 2)$ where closer to 0 is better - Adjusted Mean Rank Index (AMRI; [berrendorf2020]_) with range $[-1, 1]$ where closer to 1 is better - Mean Reciprocal Rank (MRR) with range $(0, 1]$ where closer to 1 is better - Hits @ K with range $[0, 1]$ where closer to 1 is better. .. [berrendorf2020] Berrendorf, *et al.* (2020) `Interpretable and Fair Comparison of Link Prediction or Entity Alignment Methods with Adjusted Mean Rank <https://arxiv.org/abs/2002.06914>`_. """ ks: Sequence[Union[int, float]] def __init__( self, ks: Optional[Iterable[Union[int, float]]] = None, filtered: bool = True, **kwargs, ): """Initialize rank-based evaluator. :param ks: The values for which to calculate hits@k. Defaults to {1,3,5,10}. :param filtered: Whether to use the filtered evaluation protocol. If enabled, ranking another true triple higher than the currently considered one will not decrease the score. :param kwargs: Additional keyword arguments that are passed to the base class. """ super().__init__( filtered=filtered, requires_positive_mask=False, **kwargs, ) self.ks = tuple(ks) if ks is not None else (1, 3, 5, 10) for k in self.ks: if isinstance(k, float) and not (0 < k < 1): raise ValueError( 'If k is a float, it should represent a relative rank, i.e. a value between 0 and 1 (excl.)', ) self.ranks: Dict[Tuple[str, str], List[float]] = defaultdict(list) self.num_entities = None def _update_ranks_( self, true_scores: torch.FloatTensor, all_scores: torch.FloatTensor, side: str, ) -> None: """Shared code for updating the stored ranks for head/tail scores. :param true_scores: shape: (batch_size,) :param all_scores: shape: (batch_size, num_entities) """ batch_ranks = compute_rank_from_scores( true_score=true_scores, all_scores=all_scores, ) self.num_entities = all_scores.shape[1] for k, v in batch_ranks.items(): self.ranks[side, k].extend(v.detach().cpu().tolist())
[docs] def process_tail_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 self._update_ranks_(true_scores=true_scores, all_scores=scores, side=SIDE_TAIL)
[docs] def process_head_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 self._update_ranks_(true_scores=true_scores, all_scores=scores, side=SIDE_HEAD)
def _get_ranks(self, side, rank_type) -> np.ndarray: if side == SIDE_BOTH: values: List[float] = sum((self.ranks.get((_side, rank_type), []) for _side in (SIDE_HEAD, SIDE_TAIL)), []) else: values = self.ranks.get((side, rank_type), []) return np.asarray(values, dtype=np.float64)
[docs] def finalize(self) -> RankBasedMetricResults: # noqa: D102 if self.num_entities is None: raise ValueError hits_at_k: DefaultDict[str, Dict[str, Dict[Union[int, float], float]]] = defaultdict(dict) asr: DefaultDict[str, DefaultDict[str, Dict[str, float]]] = defaultdict(lambda: defaultdict(dict)) for side, rank_type in itt.product(SIDES, RANK_TYPES): ranks = self._get_ranks(side=side, rank_type=rank_type) if len(ranks) < 1: continue hits_at_k[side][rank_type] = { k: np.mean(ranks <= (k if isinstance(k, int) else int(self.num_entities * k))).item() for k in self.ks } for metric_name, metric_func in all_type_funcs.items(): asr[metric_name][side][rank_type] = metric_func(ranks).item() expected_rank_type = EXPECTED_RANKS.get(rank_type) if expected_rank_type: expected_ranks = self._get_ranks(side=side, rank_type=expected_rank_type) if 0 < len(expected_ranks): # Adjusted mean rank calculation expected_mean_rank = float(np.mean(expected_ranks)) asr[ADJUSTED_ARITHMETIC_MEAN_RANK][side][rank_type] = ( asr[ARITHMETIC_MEAN_RANK][side][rank_type] / expected_mean_rank ) asr[ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX][side][rank_type] = ( 1.0 - (asr[ARITHMETIC_MEAN_RANK][side][rank_type] - 1) / (expected_mean_rank - 1) ) # Clear buffers self.ranks.clear() return RankBasedMetricResults( arithmetic_mean_rank=dict(asr[ARITHMETIC_MEAN_RANK]), geometric_mean_rank=dict(asr[GEOMETRIC_MEAN_RANK]), harmonic_mean_rank=dict(asr[HARMONIC_MEAN_RANK]), median_rank=dict(asr[MEDIAN_RANK]), inverse_arithmetic_mean_rank=dict(asr[INVERSE_ARITHMETIC_MEAN_RANK]), inverse_geometric_mean_rank=dict(asr[INVERSE_GEOMETRIC_MEAN_RANK]), inverse_harmonic_mean_rank=dict(asr[INVERSE_HARMONIC_MEAN_RANK]), inverse_median_rank=dict(asr[INVERSE_MEDIAN_RANK]), rank_std=dict(asr[RANK_STD]), rank_mad=dict(asr[RANK_MAD]), rank_var=dict(asr[RANK_VARIANCE]), adjusted_arithmetic_mean_rank=dict(asr[ADJUSTED_ARITHMETIC_MEAN_RANK]), adjusted_arithmetic_mean_rank_index=dict(asr[ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX]), hits_at_k=dict(hits_at_k), )