Source code for pykeen.evaluation.rank_based_evaluator

# -*- coding: utf-8 -*-

"""Implementation of ranked based evaluator."""

import itertools as itt
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import torch
from dataclasses_json import dataclass_json
from scipy import stats

from .evaluator import Evaluator, MetricResults
from ..typing import MappedTriples
from ..utils import fix_dataclass_init_docs

__all__ = [
    "compute_rank_from_scores",
    "RankBasedEvaluator",
    "RankBasedMetricResults",
    "MetricKey",
    "resolve_metric_name",
]

logger = logging.getLogger(__name__)

SIDE_HEAD = "head"
SIDE_TAIL = "tail"
SIDE_BOTH = "both"
SIDES = {SIDE_HEAD, SIDE_TAIL, SIDE_BOTH}

RANK_OPTIMISTIC = "optimistic"
RANK_PESSIMISTIC = "pessimistic"
RANK_REALISTIC = "realistic"
RANK_TYPES = {RANK_OPTIMISTIC, RANK_PESSIMISTIC, RANK_REALISTIC}

RANK_EXPECTED_REALISTIC = "expected_realistic"
EXPECTED_RANKS = {
    RANK_REALISTIC: RANK_EXPECTED_REALISTIC,
    RANK_OPTIMISTIC: None,  # TODO - research problem
    RANK_PESSIMISTIC: None,  # TODO - research problem
}

ARITHMETIC_MEAN_RANK = "arithmetic_mean_rank"  # also known as mean rank (MR)
GEOMETRIC_MEAN_RANK = "geometric_mean_rank"
HARMONIC_MEAN_RANK = "harmonic_mean_rank"
MEDIAN_RANK = "median_rank"
INVERSE_ARITHMETIC_MEAN_RANK = "inverse_arithmetic_mean_rank"
INVERSE_GEOMETRIC_MEAN_RANK = "inverse_geometric_mean_rank"
INVERSE_HARMONIC_MEAN_RANK = "inverse_harmonic_mean_rank"  # also known as mean reciprocal rank (MRR)
INVERSE_MEDIAN_RANK = "inverse_median_rank"

RANK_STD = "rank_std"
RANK_VARIANCE = "rank_var"
RANK_MAD = "rank_mad"
RANK_COUNT = "rank_count"

all_type_funcs = {
    ARITHMETIC_MEAN_RANK: np.mean,  # This is MR
    HARMONIC_MEAN_RANK: stats.hmean,
    GEOMETRIC_MEAN_RANK: stats.gmean,
    MEDIAN_RANK: np.median,
    INVERSE_ARITHMETIC_MEAN_RANK: lambda x: np.reciprocal(np.mean(x)),
    INVERSE_GEOMETRIC_MEAN_RANK: lambda x: np.reciprocal(stats.gmean(x)),
    INVERSE_HARMONIC_MEAN_RANK: lambda x: np.reciprocal(stats.hmean(x)),  # This is MRR
    INVERSE_MEDIAN_RANK: lambda x: np.reciprocal(np.median(x)),
    # Extra stats stuff
    RANK_STD: np.std,
    RANK_VARIANCE: np.var,
    RANK_MAD: stats.median_abs_deviation,
    RANK_COUNT: lambda x: np.asarray(x.size),
}

ADJUSTED_ARITHMETIC_MEAN_RANK = "adjusted_arithmetic_mean_rank"
ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX = "adjusted_arithmetic_mean_rank_index"
TYPES_REALISTIC_ONLY = {ADJUSTED_ARITHMETIC_MEAN_RANK, ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX}

METRIC_SYNONYMS = {
    "adjusted_mean_rank": ADJUSTED_ARITHMETIC_MEAN_RANK,
    "adjusted_mean_rank_index": ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX,
    "amr": ADJUSTED_ARITHMETIC_MEAN_RANK,
    "aamr": ADJUSTED_ARITHMETIC_MEAN_RANK,
    "amri": ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX,
    "aamri": ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX,
    "igmr": INVERSE_GEOMETRIC_MEAN_RANK,
    "iamr": INVERSE_ARITHMETIC_MEAN_RANK,
    "mr": ARITHMETIC_MEAN_RANK,
    "mean_rank": ARITHMETIC_MEAN_RANK,
    "mrr": INVERSE_HARMONIC_MEAN_RANK,
    "mean_reciprocal_rank": INVERSE_HARMONIC_MEAN_RANK,
}


class MetricKey(NamedTuple):
    """A key for the kind of metric to resolve."""

    name: str
    side: str
    rank_type: str
    k: Optional[int]

    def __str__(self) -> str:  # noqa: D105
        components = [self.name, self.side, self.rank_type]
        if self.k:
            components.append(str(self.k))
        return ".".join(components)


def compute_rank_from_scores(
    true_score: torch.FloatTensor,
    all_scores: torch.FloatTensor,
) -> Dict[str, torch.FloatTensor]:
    """Compute rank and adjusted rank given scores.

    :param true_score: torch.Tensor, shape: (batch_size, 1)
        The score of the true triple.
    :param all_scores: torch.Tensor, shape: (batch_size, num_entities)
        The scores of all corrupted triples (including the true triple).
    :return: a dictionary
        {
            'optimistic': optimistic_rank,
            'pessimistic': pessimistic_rank,
            'realistic': realistic_rank,
            'expected_realistic': expected_realistic_rank,
        }

        where

        optimistic_rank: shape: (batch_size,)
            The optimistic rank is the rank when assuming all options with an equal score are placed behind the current
            test triple.
        pessimistic_rank:
            The pessimistic rank is the rank when assuming all options with an equal score are placed in front of
            current test triple.
        realistic_rank:
            The realistic rank is the average of the optimistic and pessimistic rank, and hence the expected rank
            over all permutations of the elements with the same score as the currently considered option.
        expected_realistic_rank: shape: (batch_size,)
            The expected rank a random scoring would achieve, which is (#number_of_options + 1)/2
    """
    # The optimistic rank is the rank when assuming all options with an equal score are placed behind the currently
    # considered. Hence, the rank is the number of options with better scores, plus one, as the rank is one-based.
    optimistic_rank = (all_scores > true_score).sum(dim=1) + 1

    # The pessimistic rank is the rank when assuming all options with an equal score are placed in front of the
    # currently considered. Hence, the rank is the number of options which have at least the same score minus one
    # (as the currently considered option in included in all options). As the rank is one-based, we have to add 1,
    # which nullifies the "minus 1" from before.
    pessimistic_rank = (all_scores >= true_score).sum(dim=1)

    # The realistic rank is the average of the optimistic and pessimistic rank, and hence the expected rank over
    # all permutations of the elements with the same score as the currently considered option.
    realistic_rank = (optimistic_rank + pessimistic_rank).float() * 0.5

    # We set values which should be ignored to NaN, hence the number of options which should be considered is given by
    number_of_options = torch.isfinite(all_scores).sum(dim=1).float()

    # The expected rank of a random scoring
    expected_realistic_rank = 0.5 * (number_of_options + 1)

    return {
        RANK_OPTIMISTIC: optimistic_rank,
        RANK_PESSIMISTIC: pessimistic_rank,
        RANK_REALISTIC: realistic_rank,
        RANK_EXPECTED_REALISTIC: expected_realistic_rank,
    }


RANK_TYPE_SYNONYMS = {
    "best": RANK_OPTIMISTIC,
    "worst": RANK_PESSIMISTIC,
    "avg": RANK_REALISTIC,
    "average": RANK_REALISTIC,
}

_SIDE_PATTERN = "|".join(SIDES)
_TYPE_PATTERN = "|".join(itt.chain(RANK_TYPES, RANK_TYPE_SYNONYMS.keys()))
METRIC_PATTERN = re.compile(
    rf"(?P<name>[\w@]+)(\.(?P<side>{_SIDE_PATTERN}))?(\.(?P<type>{_TYPE_PATTERN}))?(\.(?P<k>\d+))?",
)
HITS_PATTERN = re.compile(r"(hits_at_|hits@|h@)(?P<k>\d+)")


def resolve_metric_name(name: str) -> MetricKey:
    """Functional metric name normalization."""
    match = METRIC_PATTERN.match(name)
    if not match:
        raise ValueError(f"Invalid metric name: {name}")
    k: Union[None, str, int]
    name, side, rank_type, k = [match.group(key) for key in ("name", "side", "type", "k")]

    # normalize metric name
    if not name:
        raise ValueError("A metric name must be provided.")
    # handle spaces and case
    name = name.lower().replace(" ", "_")

    # special case for hits_at_k
    match = HITS_PATTERN.match(name)
    if match:
        name = "hits_at_k"
        k = match.group("k")
    if name == "hits_at_k":
        if k is None:
            k = 10
        # TODO: Fractional?
        try:
            k = int(k)
        except ValueError as error:
            raise ValueError(f"Invalid k={k} for hits_at_k") from error
        if k < 0:
            raise ValueError(f"For hits_at_k, you must provide a positive value of k, but found {k}.")
    assert k is None or isinstance(k, int)

    # synonym normalization
    name = METRIC_SYNONYMS.get(name, name)

    # normalize side
    side = side or SIDE_BOTH
    side = side.lower()
    if side not in SIDES:
        raise ValueError(f"Invalid side: {side}. Allowed are {SIDES}.")

    # normalize rank type
    rank_type = rank_type or RANK_REALISTIC
    rank_type = rank_type.lower()
    rank_type = RANK_TYPE_SYNONYMS.get(rank_type, rank_type)
    if rank_type not in RANK_TYPES:
        raise ValueError(f"Invalid rank type: {rank_type}. Allowed are {RANK_TYPES}.")
    elif rank_type != RANK_REALISTIC and name in TYPES_REALISTIC_ONLY:
        raise ValueError(f"Invalid rank type for {name}: {rank_type}. Allowed type: {RANK_REALISTIC}")

    return MetricKey(name, side, rank_type, k)


[docs]@fix_dataclass_init_docs @dataclass_json @dataclass class RankBasedMetricResults(MetricResults): """Results from computing metrics.""" arithmetic_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Mean Rank (MR)", increasing=False, range="[1, inf)", doc="The arithmetic mean over all ranks.", link="https://pykeen.readthedocs.io/en/stable/tutorial/understanding_evaluation.html#mean-rank", ) ) geometric_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Geometric Mean Rank (GMR)", increasing=False, range="[1, inf)", doc="The geometric mean over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) median_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Median Rank", increasing=False, range="[1, inf)", doc="The median over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) harmonic_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Harmonic Mean Rank (HMR)", increasing=False, range="[1, inf)", doc="The harmonic mean over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) inverse_arithmetic_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Inverse Arithmetic Mean Rank (IAMR)", increasing=True, range="(0, 1]", doc="The inverse of the arithmetic mean over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) inverse_geometric_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Inverse Geometric Mean Rank (IGMR)", increasing=True, range="(0, 1]", doc="The inverse of the geometric mean over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) inverse_harmonic_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Mean Reciprocal Rank (MRR)", increasing=True, range="(0, 1]", doc="The inverse of the harmonic mean over all ranks.", link="https://en.wikipedia.org/wiki/Mean_reciprocal_rank", ) ) inverse_median_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Inverse Median Rank", increasing=True, range="(0, 1]", doc="The inverse of the median over all ranks.", link="https://cthoyt.com/2021/04/19/pythagorean-mean-ranks.html", ) ) rank_count: Dict[str, int] = field( metadata=dict( name="Rank Count", doc="The number of considered ranks, a non-negative number. Low numbers may indicate unreliable results.", ) ) rank_std: Dict[str, Dict[str, float]] = field( metadata=dict( name="Rank Standard Deviation", range="[0, inf)", increasing=False, doc="The standard deviation over all ranks.", ) ) rank_var: Dict[str, Dict[str, float]] = field( metadata=dict( name="Rank Variance", range="[0, inf)", increasing=False, doc="The variance over all ranks.", ) ) rank_mad: Dict[str, Dict[str, float]] = field( metadata=dict( name="Rank Median Absolute Deviation", range="[0, inf)", doc="The median absolute deviation over all ranks.", ) ) hits_at_k: Dict[str, Dict[str, Dict[Union[int, float], float]]] = field( metadata=dict( name="Hits @ K", range="[0, 1]", increasing=True, doc="The relative frequency of ranks not larger than a given k.", link="https://pykeen.readthedocs.io/en/stable/tutorial/understanding_evaluation.html#hits-k", ) ) adjusted_arithmetic_mean_rank: Dict[str, Dict[str, float]] = field( metadata=dict( name="Adjusted Arithmetic Mean Rank (AAMR)", increasing=False, range="(0, 2)", doc="The mean over all chance-adjusted ranks.", link="https://arxiv.org/abs/2002.06914", ) ) adjusted_arithmetic_mean_rank_index: Dict[str, Dict[str, float]] = field( metadata=dict( name="Adjusted Arithmetic Mean Rank Index (AAMRI)", increasing=True, range="[-1, 1]", doc="The re-indexed adjusted mean rank (AAMR)", link="https://arxiv.org/abs/2002.06914", ) )
[docs] def get_metric(self, name: str) -> float: """Get the rank-based metric. :param name: The name of the metric, created by concatenating three parts: 1. The side (one of "head", "tail", or "both"). Most publications exclusively report "both". 2. The type (one of "optimistic", "pessimistic", "realistic") 3. The metric name ("adjusted_mean_rank_index", "adjusted_mean_rank", "mean_rank, "mean_reciprocal_rank", "inverse_geometric_mean_rank", or "hits@k" where k defaults to 10 but can be substituted for an integer. By default, 1, 3, 5, and 10 are available. Other K's can be calculated by setting the appropriate variable in the ``evaluation_kwargs`` in the :func:`pykeen.pipeline.pipeline` or setting ``ks`` in the :class:`pykeen.evaluation.RankBasedEvaluator`. In general, all metrics are available for all combinations of sides/types except AMR and AMRI, which are only calculated for the average type. This is because the calculation of the expected MR in the optimistic and pessimistic case scenarios is still an active area of research and therefore has no implementation yet. :return: The value for the metric :raises ValueError: if an invalid name is given. Get the average MR >>> metric_results.get('both.realistic.mean_rank') If you only give a metric name, it assumes that it's for "both" sides and "realistic" type. >>> metric_results.get('adjusted_mean_rank_index') This function will do its best to infer what's going on if you only specify one part. >>> metric_results.get('left.mean_rank') >>> metric_results.get('optimistic.mean_rank') Get the default Hits @ K (where $k=10$) >>> metric_results.get('hits@k') Get a given Hits @ K >>> metric_results.get('hits@5') """ metric, side, rank_type, k = resolve_metric_name(name) if not metric.startswith("hits"): return getattr(self, metric)[side][rank_type] assert k is not None return self.hits_at_k[side][rank_type][k]
[docs] def to_flat_dict(self): # noqa: D102 return {f"{side}.{rank_type}.{metric_name}": value for side, rank_type, metric_name, value in self._iter_rows()}
[docs] def to_df(self) -> pd.DataFrame: """Output the metrics as a pandas dataframe.""" return pd.DataFrame(list(self._iter_rows()), columns=["Side", "Type", "Metric", "Value"])
def _iter_rows(self) -> Iterable[Tuple[str, str, str, float]]: for side, rank_type in itt.product(SIDES, RANK_TYPES): for k, v in self.hits_at_k[side][rank_type].items(): yield side, rank_type, f"hits_at_{k}", v for f in fields(self): if f.name == "hits_at_k": continue side_data = getattr(self, f.name)[side] if rank_type in side_data: yield side, rank_type, f.name, side_data[rank_type]
[docs]class RankBasedEvaluator(Evaluator): r"""A rank-based evaluator for KGE models. Calculates the following metrics: - Mean Rank (MR) with range $[1, \infty)$ where closer to 0 is better - Adjusted Mean Rank (AMR; [berrendorf2020]_) with range $(0, 2)$ where closer to 0 is better - Adjusted Mean Rank Index (AMRI; [berrendorf2020]_) with range $[-1, 1]$ where closer to 1 is better - Mean Reciprocal Rank (MRR) with range $(0, 1]$ where closer to 1 is better - Hits @ K with range $[0, 1]$ where closer to 1 is better. .. [berrendorf2020] Berrendorf, *et al.* (2020) `Interpretable and Fair Comparison of Link Prediction or Entity Alignment Methods with Adjusted Mean Rank <https://arxiv.org/abs/2002.06914>`_. """ ks: Sequence[Union[int, float]] def __init__( self, ks: Optional[Iterable[Union[int, float]]] = None, filtered: bool = True, **kwargs, ): """Initialize rank-based evaluator. :param ks: The values for which to calculate hits@k. Defaults to {1,3,5,10}. :param filtered: Whether to use the filtered evaluation protocol. If enabled, ranking another true triple higher than the currently considered one will not decrease the score. :param kwargs: Additional keyword arguments that are passed to the base class. """ super().__init__( filtered=filtered, requires_positive_mask=False, **kwargs, ) self.ks = tuple(ks) if ks is not None else (1, 3, 5, 10) for k in self.ks: if isinstance(k, float) and not (0 < k < 1): raise ValueError( "If k is a float, it should represent a relative rank, i.e. a value between 0 and 1 (excl.)", ) self.ranks: Dict[Tuple[str, str], List[float]] = defaultdict(list) self.num_entities = None def _update_ranks_( self, true_scores: torch.FloatTensor, all_scores: torch.FloatTensor, side: str, ) -> None: """Shared code for updating the stored ranks for head/tail scores. :param true_scores: shape: (batch_size,) :param all_scores: shape: (batch_size, num_entities) """ batch_ranks = compute_rank_from_scores( true_score=true_scores, all_scores=all_scores, ) self.num_entities = all_scores.shape[1] for k, v in batch_ranks.items(): self.ranks[side, k].extend(v.detach().cpu().tolist())
[docs] def process_tail_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 self._update_ranks_(true_scores=true_scores, all_scores=scores, side=SIDE_TAIL)
[docs] def process_head_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 self._update_ranks_(true_scores=true_scores, all_scores=scores, side=SIDE_HEAD)
def _get_ranks(self, side, rank_type) -> np.ndarray: if side == SIDE_BOTH: values: List[float] = sum((self.ranks.get((_side, rank_type), []) for _side in (SIDE_HEAD, SIDE_TAIL)), []) else: values = self.ranks.get((side, rank_type), []) return np.asarray(values, dtype=np.float64)
[docs] def finalize(self) -> RankBasedMetricResults: # noqa: D102 if self.num_entities is None: raise ValueError hits_at_k: DefaultDict[str, Dict[str, Dict[Union[int, float], float]]] = defaultdict(dict) asr: DefaultDict[str, DefaultDict[str, Dict[str, float]]] = defaultdict(lambda: defaultdict(dict)) for side, rank_type in itt.product(SIDES, RANK_TYPES): ranks = self._get_ranks(side=side, rank_type=rank_type) if len(ranks) < 1: continue hits_at_k[side][rank_type] = { k: np.mean(ranks <= (k if isinstance(k, int) else int(self.num_entities * k))).item() for k in self.ks } for metric_name, metric_func in all_type_funcs.items(): asr[metric_name][side][rank_type] = metric_func(ranks).item() expected_rank_type = EXPECTED_RANKS.get(rank_type) if expected_rank_type: expected_ranks = self._get_ranks(side=side, rank_type=expected_rank_type) if 0 < len(expected_ranks): # Adjusted mean rank calculation expected_mean_rank = float(np.mean(expected_ranks)) asr[ADJUSTED_ARITHMETIC_MEAN_RANK][side][rank_type] = ( asr[ARITHMETIC_MEAN_RANK][side][rank_type] / expected_mean_rank ) asr[ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX][side][rank_type] = 1.0 - ( asr[ARITHMETIC_MEAN_RANK][side][rank_type] - 1 ) / (expected_mean_rank - 1) # Clear buffers self.ranks.clear() return RankBasedMetricResults( arithmetic_mean_rank=dict(asr[ARITHMETIC_MEAN_RANK]), geometric_mean_rank=dict(asr[GEOMETRIC_MEAN_RANK]), harmonic_mean_rank=dict(asr[HARMONIC_MEAN_RANK]), median_rank=dict(asr[MEDIAN_RANK]), inverse_arithmetic_mean_rank=dict(asr[INVERSE_ARITHMETIC_MEAN_RANK]), inverse_geometric_mean_rank=dict(asr[INVERSE_GEOMETRIC_MEAN_RANK]), inverse_harmonic_mean_rank=dict(asr[INVERSE_HARMONIC_MEAN_RANK]), inverse_median_rank=dict(asr[INVERSE_MEDIAN_RANK]), rank_count=dict(asr[RANK_COUNT]), rank_std=dict(asr[RANK_STD]), rank_mad=dict(asr[RANK_MAD]), rank_var=dict(asr[RANK_VARIANCE]), adjusted_arithmetic_mean_rank=dict(asr[ADJUSTED_ARITHMETIC_MEAN_RANK]), adjusted_arithmetic_mean_rank_index=dict(asr[ADJUSTED_ARITHMETIC_MEAN_RANK_INDEX]), hits_at_k=dict(hits_at_k), )