# -*- coding: utf-8 -*-
"""Result trackers in PyKEEN."""
import os
from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Type, Union
from .utils import flatten_dictionary, get_cls, normalize_string
__all__ = [
'get_result_tracker_cls',
'ResultTracker',
'MLFlowResultTracker',
'WANDBResultTracker',
]
if TYPE_CHECKING:
import wandb.wandb_run
[docs]class ResultTracker:
"""A class that tracks the results from a pipeline run."""
[docs] def start_run(self, run_name: Optional[str] = None) -> None:
"""Start a run with an optional name."""
[docs] def log_params(self, params: Dict[str, Any], prefix: Optional[str] = None) -> None:
"""Log parameters to result store."""
[docs] def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None, prefix: Optional[str] = None) -> None:
"""Log metrics to result store.
:param metrics: The metrics to log.
:param step: An optional step to attach the metrics to (e.g. the epoch).
:param prefix: An optional prefix to prepend to every key in metrics.
"""
[docs] def end_run(self) -> None:
"""End a run.
HAS to be called after the experiment is finished.
"""
[docs]class MLFlowResultTracker(ResultTracker):
"""A tracker for MLFlow."""
def __init__(
self,
tracking_uri: Optional[str] = None,
experiment_id: Optional[int] = None,
experiment_name: Optional[str] = None,
):
"""
Initialize result tracking via MLFlow.
:param tracking_uri:
The tracking uri.
:param experiment_id:
The experiment ID. If given, this has to be the ID of an existing experiment in MFLow. Has priority over
experiment_name.
:param experiment_name:
The experiment name. If this experiment name exists, add the current run to this experiment. Otherwise
create an experiment of the given name.
"""
import mlflow as _mlflow
self.mlflow = _mlflow
self.mlflow.set_tracking_uri(tracking_uri)
if experiment_id is not None:
experiment = self.mlflow.get_experiment(experiment_id=experiment_id)
experiment_name = experiment.name
if experiment_name is not None:
self.mlflow.set_experiment(experiment_name)
[docs] def start_run(self, run_name: Optional[str] = None) -> None: # noqa: D102
self.mlflow.start_run(run_name=run_name)
[docs] def log_metrics(
self,
metrics: Dict[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None: # noqa: D102
metrics = flatten_dictionary(dictionary=metrics, prefix=prefix)
self.mlflow.log_metrics(metrics=metrics, step=step)
[docs] def log_params(self, params: Dict[str, Any], prefix: Optional[str] = None) -> None: # noqa: D102
params = flatten_dictionary(dictionary=params, prefix=prefix)
self.mlflow.log_params(params=params)
[docs] def end_run(self) -> None: # noqa: D102
self.mlflow.end_run()
[docs]class WANDBResultTracker(ResultTracker):
"""A tracker for Weights and Biases.
Note that you have to perform wandb login beforehand.
"""
#: The WANDB run
run: 'wandb.wandb_run.Run'
def __init__(
self,
project: str,
experiment: Optional[str] = None,
offline: bool = False,
**kwargs,
):
"""Initialize result tracking via WANDB.
:param project:
project name your WANDB login has access to.
:param experiment:
The experiment name to appear on the website. If not given, WANDB will generate a random name.
"""
import wandb as _wandb
self.wandb = _wandb
if project is None:
raise ValueError('Weights & Biases requires a project name.')
self.project = project
if offline:
os.environ[self.wandb.env.MODE] = 'dryrun'
self.run = self.wandb.init(project=self.project, name=experiment, **kwargs)
[docs] def log_metrics(
self,
metrics: Dict[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None: # noqa: D102
metrics = flatten_dictionary(dictionary=metrics, prefix=prefix)
self.wandb.log(metrics, step=step)
[docs] def log_params(self, params: Dict[str, Any], prefix: Optional[str] = None) -> None: # noqa: D102
params = flatten_dictionary(dictionary=params, prefix=prefix)
self.wandb.config.update(params)
#: A mapping of trackers' names to their implementations
trackers: Mapping[str, Type[ResultTracker]] = {
normalize_string(tracker.__name__, suffix='ResultTracker'): tracker
for tracker in ResultTracker.__subclasses__()
}
[docs]def get_result_tracker_cls(query: Union[None, str, Type[ResultTracker]]) -> Type[ResultTracker]:
"""Get the tracker class."""
return get_cls(
query,
base=ResultTracker,
lookup_dict=trackers,
default=ResultTracker,
)