Source code for pykeen.trackers

# -*- coding: utf-8 -*-

"""Result trackers in PyKEEN."""

from typing import Mapping, Type, Union

from .base import ResultTracker
from .mlflow import MLFlowResultTracker
from .neptune import NeptuneResultTracker
from .wandb import WANDBResultTracker
from ..utils import get_cls, normalize_string

__all__ = [
    'ResultTracker',
    'get_result_tracker_cls',
    'MLFlowResultTracker',
    'NeptuneResultTracker',
    'WANDBResultTracker',
]

_RESULT_TRACKER_SUFFIX = 'ResultTracker'

#: A mapping of trackers' names to their implementations
trackers: Mapping[str, Type[ResultTracker]] = {
    normalize_string(tracker.__name__, suffix=_RESULT_TRACKER_SUFFIX): tracker
    for tracker in ResultTracker.__subclasses__()
}


[docs]def get_result_tracker_cls(query: Union[None, str, Type[ResultTracker]]) -> Type[ResultTracker]: """Get the tracker class.""" return get_cls( query, base=ResultTracker, lookup_dict=trackers, default=ResultTracker, suffix=_RESULT_TRACKER_SUFFIX, )