Source code for pykeen.evaluation.classification_evaluator

# -*- coding: utf-8 -*-

"""Implementation of wrapper around sklearn metrics."""

from dataclasses import dataclass, field, fields, make_dataclass
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
from dataclasses_json import dataclass_json

from .evaluator import Evaluator, MetricResults
from .rexmex_compat import classifier_annotator
from ..typing import MappedTriples
from ..utils import fix_dataclass_init_docs

__all__ = [
    "ClassificationEvaluator",
    "ClassificationMetricResults",
]

_fields = [
    (
        metadata.func.__name__,
        float,
        field(
            metadata=dict(
                name=metadata.name,
                doc=metadata.description or "",
                link=metadata.link,
                range=metadata.interval(),
                increasing=metadata.higher_is_better,
                f=metadata.func,
            )
        ),
    )
    for metadata in classifier_annotator.metrics.values()
]

ClassificationMetricResultsBase = make_dataclass(
    "ClassificationMetricResultsBase",
    _fields,
    bases=(MetricResults,),
)


[docs]@fix_dataclass_init_docs @dataclass_json @dataclass class ClassificationMetricResults(ClassificationMetricResultsBase): # type: ignore """Results from computing metrics."""
[docs] @classmethod def from_scores(cls, y_true, y_score): """Return an instance of these metrics from a given set of true and scores.""" return ClassificationMetricResults(**{f.name: f.metadata["f"](y_true, y_score) for f in fields(cls)})
[docs] def get_metric(self, name: str) -> float: # noqa: D102 return getattr(self, name)
[docs]class ClassificationEvaluator(Evaluator): """An evaluator that uses a classification metrics.""" all_scores: Dict[Tuple[Any, ...], np.ndarray] all_positives: Dict[Tuple[Any, ...], np.ndarray] def __init__(self, **kwargs): super().__init__( filtered=False, requires_positive_mask=True, **kwargs, ) self.all_scores = {} self.all_positives = {} def _process_scores( self, keys: torch.LongTensor, scores: torch.FloatTensor, positive_mask: torch.FloatTensor, head_side: bool, ) -> None: # Transfer to cpu and convert to numpy scores = scores.detach().cpu().numpy() positive_mask = positive_mask.detach().cpu().numpy() keys = keys.detach().cpu().numpy() # Ensure that each key gets counted only once for i in range(keys.shape[0]): # include head_side flag into key to differentiate between (h, r) and (r, t) key = (head_side,) + tuple(map(int, keys[i])) self.all_scores[key] = scores[i] self.all_positives[key] = positive_mask[i]
[docs] def process_tail_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 if dense_positive_mask is None: raise KeyError("Sklearn evaluators need the positive mask!") self._process_scores(keys=hrt_batch[:, :2], scores=scores, positive_mask=dense_positive_mask, head_side=False)
[docs] def process_head_scores_( self, hrt_batch: MappedTriples, true_scores: torch.FloatTensor, scores: torch.FloatTensor, dense_positive_mask: Optional[torch.FloatTensor] = None, ) -> None: # noqa: D102 if dense_positive_mask is None: raise KeyError("Sklearn evaluators need the positive mask!") self._process_scores(keys=hrt_batch[:, 1:], scores=scores, positive_mask=dense_positive_mask, head_side=True)
[docs] def finalize(self) -> ClassificationMetricResults: # noqa: D102 # Important: The order of the values of an dictionary is not guaranteed. Hence, we need to retrieve scores and # masks using the exact same key order. all_keys = list(self.all_scores.keys()) # TODO how to define a cutoff on y_scores to make binary? # see: https://github.com/xptree/NetMF/blob/77286b826c4af149055237cef65e2a500e15631a/predict.py#L25-L33 y_score = np.concatenate([self.all_scores[k] for k in all_keys], axis=0).flatten() y_true = np.concatenate([self.all_positives[k] for k in all_keys], axis=0).flatten() # Clear buffers self.all_positives.clear() self.all_scores.clear() return ClassificationMetricResults.from_scores(y_true, y_score)