Source code for pykeen.trackers.mlflow

"""An adapter for MLflow."""

from collections.abc import Mapping
from typing import Any

from .base import ResultTracker
from ..utils import flatten_dictionary

__all__ = [
    "MLFlowResultTracker",
]


[docs] class MLFlowResultTracker(ResultTracker): """A tracker for MLflow.""" def __init__( self, tracking_uri: str | None = None, experiment_id: int | None = None, experiment_name: str | None = None, tags: dict[str, Any] | None = None, ): """Initialize result tracking via MLFlow. :param tracking_uri: The tracking uri. :param experiment_id: The experiment ID. If given, this has to be the ID of an existing experiment in MFLow. Has priority over experiment_name. :param experiment_name: The experiment name. If this experiment name exists, add the current run to this experiment. Otherwise create an experiment of the given name. :param tags: The additional run details which are presented as tags to be logged """ import mlflow as _mlflow self.mlflow = _mlflow self.tags = tags self.mlflow.set_tracking_uri(tracking_uri) if experiment_id is not None: experiment = self.mlflow.get_experiment(experiment_id=experiment_id) experiment_name = experiment.name if experiment_name is not None: self.mlflow.set_experiment(experiment_name) # docstr-coverage: inherited
[docs] def start_run(self, run_name: str | None = None) -> None: # noqa: D102 self.mlflow.start_run(run_name=run_name) if self.tags is not None: self.mlflow.set_tags(tags=self.tags)
# docstr-coverage: inherited
[docs] def log_metrics( self, metrics: Mapping[str, float], step: int | None = None, prefix: str | None = None, ) -> None: # noqa: D102 metrics = flatten_dictionary(dictionary=metrics, prefix=prefix) self.mlflow.log_metrics(metrics=metrics, step=step)
# docstr-coverage: inherited
[docs] def log_params(self, params: Mapping[str, Any], prefix: str | None = None) -> None: # noqa: D102 params = flatten_dictionary(dictionary=params, prefix=prefix) self.mlflow.log_params(params=params)
# docstr-coverage: inherited
[docs] def end_run(self, success: bool = True) -> None: # noqa: D102 status = self.mlflow.entities.RunStatus.FINISHED if success else self.mlflow.entities.RunStatus.FAILED self.mlflow.end_run(status=self.mlflow.entities.RunStatus.to_string(status))