# -*- coding: utf-8 -*-
"""Utilities and base classes for PyKEEN tracker adapters."""
import logging
import re
from typing import Any, Mapping, Optional, Pattern, Union
from tqdm.auto import tqdm
from ..utils import flatten_dictionary
__all__ = [
'ResultTracker',
'ConsoleResultTracker',
]
[docs]class ResultTracker:
"""A class that tracks the results from a pipeline run."""
[docs] def start_run(self, run_name: Optional[str] = None) -> None:
"""Start a run with an optional name."""
[docs] def log_params(self, params: Mapping[str, Any], prefix: Optional[str] = None) -> None:
"""Log parameters to result store."""
[docs] def log_metrics(
self,
metrics: Mapping[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None:
"""Log metrics to result store.
:param metrics: The metrics to log.
:param step: An optional step to attach the metrics to (e.g. the epoch).
:param prefix: An optional prefix to prepend to every key in metrics.
"""
[docs] def end_run(self) -> None:
"""End a run.
HAS to be called after the experiment is finished.
"""
[docs]class ConsoleResultTracker(ResultTracker):
"""A class that directly prints to console."""
def __init__(
self,
*,
track_parameters: bool = True,
parameter_filter: Union[None, str, Pattern[str]] = None,
track_metrics: bool = True,
metric_filter: Union[None, str, Pattern[str]] = None,
start_end_run: bool = False,
writer: str = 'tqdm',
):
"""
Initialize the tracker.
:param track_parameters:
Whether to print parameters.
:param parameter_filter:
A regular expression to filter parameters. If None, print all parameters.
:param track_metrics:
Whether to print metrics.
:param metric_filter:
A regular expression to filter metrics. If None, print all parameters.
:param start_end_run:
Whether to print start/end run messages.
:param writer:
The writer to use - one of "tqdm", "builtin", or "logger".
"""
self.start_end_run = start_end_run
self.track_parameters = track_parameters
if isinstance(parameter_filter, str):
parameter_filter = re.compile(parameter_filter)
self.parameter_filter = parameter_filter
self.track_metrics = track_metrics
if isinstance(metric_filter, str):
metric_filter = re.compile(metric_filter)
self.metric_filter = metric_filter
if writer == 'tqdm':
self.write = tqdm.write
elif writer == 'builtin':
self.write = print
elif writer == 'logging':
self.write = logging.getLogger('pykeen').info
[docs] def start_run(self, run_name: Optional[str] = None) -> None: # noqa: D102
if run_name is not None and self.start_end_run:
self.write(f"Starting run: {run_name}")
[docs] def log_params(self, params: Mapping[str, Any], prefix: Optional[str] = None) -> None: # noqa: D102
if not self.track_parameters:
return
for key, value in flatten_dictionary(dictionary=params).items():
if not self.parameter_filter or self.parameter_filter.match(key):
self.write(f"Parameter: {key} = {value}")
[docs] def log_metrics(
self,
metrics: Mapping[str, float],
step: Optional[int] = None,
prefix: Optional[str] = None,
) -> None: # noqa: D102
if not self.track_metrics:
return
self.write(f"Step: {step}")
for key, value in flatten_dictionary(dictionary=metrics, prefix=prefix).items():
if not self.metric_filter or self.metric_filter.match(key):
self.write(f"Parameter: {key} = {value}")
[docs] def end_run(self) -> None: # noqa: D102
if self.start_end_run:
self.write("Finished run.")