"""Training callbacks.
Training callbacks allow for arbitrary extension of the functionality of the :class:`pykeen.training.TrainingLoop`
without subclassing it. Each callback instance has a ``loop`` attribute that allows access to the parent training
loop and all of its attributes, including the model. The interaction points are similar to those of
`Keras <https://keras.io/guides/writing_your_own_callbacks/#an-overview-of-callback-methods>`_.
Examples
--------
The following are vignettes showing how PyKEEN's training loop can be arbitrarily extended
using callbacks. If you find that none of the hooks in the :class:`TrainingCallback`
help do what you want, feel free to open an issue.
Reporting Batch Loss
~~~~~~~~~~~~~~~~~~~~
It was suggested in `Issue #333 <https://github.com/pykeen/pykeen/issues/333>`_ that it might
be useful to log all batch losses. This could be accomplished with the following:
.. code-block:: python
from pykeen.training import TrainingCallback
class BatchLossReportCallback(TrainingCallback):
def on_batch(self, epoch: int, batch, batch_loss: float):
print(epoch, batch_loss)
Implementing Gradient Clipping
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
`Gradient
clipping <https://neptune.ai/blog/understanding-gradient-clipping-and-how-it-can-fix-exploding-gradients-problem>`_
is one technique used to avoid the exploding gradient problem. Despite it being a very simple, it has several
`theoretical implications <https://openreview.net/forum?id=BJgnXpVYwS>`_.
In order to reproduce the reference experiments on R-GCN performed by [schlichtkrull2018]_,
gradient clipping must be used before each step of the optimizer. The following example shows how
to implement a gradient clipping callback:
.. code-block:: python
from pykeen.training import TrainingCallback
from pykeen.nn.utils import clip_grad_value_
class GradientClippingCallback(TrainingCallback):
def __init__(self, clip_value: float = 1.0):
super().__init__()
self.clip_value = clip_value
def pre_step(self, **kwargs: Any):
clip_grad_value_(self.model.parameters(), clip_value=self.clip_value)
"""
from __future__ import annotations
import logging
import pathlib
import uuid
from collections.abc import Mapping, Sequence
from typing import Any
import torch
from class_resolver import ClassResolver, HintOrType, OptionalKwargs
from torch import optim
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch_max_mem import maximize_memory_utilization
from .. import training # required for type annotations
from ..checkpoints import CheckpointKeeper, CheckpointSchedule, keeper_resolver, save_model, schedule_resolver
from ..constants import PYKEEN_CHECKPOINTS
from ..evaluation import Evaluator, evaluator_resolver
from ..evaluation.evaluation_loop import AdditionalFilterTriplesHint, LCWAEvaluationLoop
from ..losses import Loss
from ..models import Model
from ..stoppers import Stopper
from ..trackers import ResultTracker
from ..triples import CoreTriplesFactory
from ..typing import MappedTriples, OneOrSequence
from ..utils import determine_maximum_batch_size
logger = logging.getLogger(__name__)
__all__ = [
"callback_resolver",
"TrainingCallbackHint",
"TrainingCallback",
"StopperTrainingCallback",
"TrackerTrainingCallback",
"EvaluationLoopTrainingCallback",
"EvaluationTrainingCallback",
"CheckpointTrainingCallback",
"MultiTrainingCallback",
"GradientNormClippingTrainingCallback",
"GradientAbsClippingTrainingCallback",
]
[docs]
class TrainingCallback:
"""An interface for training callbacks."""
def __init__(self):
"""Initialize the callback."""
self._training_loop = None
@property
def training_loop(self) -> training.TrainingLoop: # noqa:D401
"""The training loop."""
if self._training_loop is None:
raise ValueError("Callback was never initialized")
return self._training_loop
@property
def model(self) -> Model: # noqa:D401
"""The model, accessed via the training loop."""
return self.training_loop.model
@property
def loss(self) -> Loss: # noqa: D401
"""The loss, accessed via the training loop."""
return self.training_loop.loss
@property
def optimizer(self) -> optim.Optimizer: # noqa:D401
"""The optimizer, accessed via the training loop."""
return self.training_loop.optimizer
@property
def result_tracker(self) -> ResultTracker: # noqa: D401
"""The result tracker, accessed via the training loop."""
assert self.training_loop.result_tracker is not None
return self.training_loop.result_tracker
[docs]
def register_training_loop(self, training_loop: training.TrainingLoop) -> None:
"""Register the training loop."""
self._training_loop = training_loop
[docs]
def pre_batch(self, **kwargs: Any) -> None:
"""Call before training batch."""
[docs]
def on_batch(self, epoch: int, batch, batch_loss: float, **kwargs: Any) -> None:
"""Call for training batches."""
[docs]
def pre_step(self, **kwargs: Any) -> None:
"""Call before the optimizer's step."""
[docs]
def post_batch(self, epoch: int, batch, **kwargs: Any) -> None:
"""Call for training batches."""
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None:
"""Call after epoch."""
[docs]
def post_train(self, losses: list[float], **kwargs: Any) -> None:
"""Call after training."""
[docs]
class TrackerTrainingCallback(TrainingCallback):
"""
An adapter for the :class:`pykeen.trackers.ResultTracker`.
It logs the loss after each epoch to the given result tracker,
"""
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
self.result_tracker.log_metrics({"loss": epoch_loss}, step=epoch)
[docs]
class GradientNormClippingTrainingCallback(TrainingCallback):
"""A callback for gradient clipping before stepping the optimizer with :func:`torch.nn.utils.clip_grad_norm_`."""
def __init__(self, max_norm: float, norm_type: float | None = None):
"""
Initialize the callback.
:param max_norm:
The maximum gradient norm for use with gradient clipping.
:param norm_type:
The gradient norm type to use for maximum gradient norm, cf. :func:`torch.nn.utils.clip_grad_norm_`
"""
super().__init__()
self.max_norm = max_norm
self.norm_type = norm_type or 2.0
# docstr-coverage: inherited
[docs]
def pre_step(self, **kwargs: Any) -> None: # noqa: D102
clip_grad_norm_(
parameters=self.model.get_grad_params(),
max_norm=self.max_norm,
norm_type=self.norm_type,
error_if_nonfinite=True, # this will become default in future releases of pytorch
)
[docs]
class GradientAbsClippingTrainingCallback(TrainingCallback):
"""A callback for gradient clipping before stepping the optimizer with :func:`torch.nn.utils.clip_grad_value_`."""
def __init__(self, clip_value: float):
"""
Initialize the callback.
:param clip_value:
The maximum absolute value in gradients, cf. :func:`torch.nn.utils.clip_grad_value_`. If None, no
gradient clipping will be used.
"""
super().__init__()
self.clip_value = clip_value
# docstr-coverage: inherited
[docs]
def pre_step(self, **kwargs: Any) -> None: # noqa: D102
clip_grad_value_(self.model.get_grad_params(), clip_value=self.clip_value)
[docs]
class EvaluationTrainingCallback(TrainingCallback):
"""
A callback for regular evaluation.
Example: evaluate training performance
.. code-block:: python
from pykeen.datasets import get_dataset
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations")
result = pipeline(
dataset=dataset,
model="mure",
training_loop_kwargs=dict(
result_tracker="console",
),
training_kwargs=dict(
num_epochs=100,
callbacks="evaluation",
callback_kwargs=dict(
evaluation_triples=dataset.training.mapped_triples,
prefix="training",
),
),
)
"""
def __init__(
self,
*,
evaluation_triples: MappedTriples,
frequency: int = 1,
evaluator: HintOrType[Evaluator] = None,
evaluator_kwargs: OptionalKwargs = None,
prefix: str | None = None,
**kwargs,
):
"""
Initialize the callback.
:param evaluation_triples:
the triples on which to evaluate
:param frequency:
the evaluation frequency in epochs
:param evaluator:
the evaluator to use for evaluation, cf. `evaluator_resolver`
:param evaluator_kwargs:
additional keyword-based parameters for the evaluator
:param prefix:
the prefix to use for logging the metrics
:param kwargs:
additional keyword-based parameters passed to `evaluate`
"""
super().__init__()
self.frequency = frequency
self.evaluation_triples = evaluation_triples
self.evaluator = evaluator_resolver.make(evaluator, evaluator_kwargs)
self.prefix = prefix
self.kwargs = kwargs
self.batch_size = self.kwargs.pop("batch_size", None)
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
if epoch % self.frequency:
return
result = self.evaluator.evaluate(
model=self.model,
mapped_triples=self.evaluation_triples,
device=self.training_loop.device,
batch_size=self.evaluator.batch_size or self.batch_size,
**self.kwargs,
)
self.result_tracker.log_metrics(metrics=result.to_flat_dict(), step=epoch, prefix=self.prefix)
[docs]
class EvaluationLoopTrainingCallback(TrainingCallback):
"""A callback for regular evaluation using new-style evaluation loops."""
def __init__(
self,
factory: CoreTriplesFactory,
frequency: int = 1,
prefix: str | None = None,
evaluator: HintOrType[Evaluator] = None,
evaluator_kwargs: OptionalKwargs = None,
additional_filter_triples: AdditionalFilterTriplesHint = None,
**kwargs,
):
"""
Initialize the callback.
:param factory:
the triples factory comprising the evaluation triples
:param frequency:
the evaluation frequency
:param prefix:
a prefix to use for logging (e.g., to distinguish between different splits)
:param evaluator:
the evaluator, or a hint thereof
:param evaluator_kwargs:
additional keyword-based parameters used for the evaluation instantiation
:param additional_filter_triples:
additional filter triples to use for creating the filter
:param kwargs:
additional keyword-based parameters passed to :meth:`EvaluationLoop.evaluate`
"""
super().__init__()
self.frequency = frequency
self.prefix = prefix
self.factory = factory
self.evaluator = evaluator_resolver.make(evaluator, evaluator_kwargs)
# lazy init
self._evaluation_loop = None
self.kwargs = kwargs
self.additional_filter_triples = additional_filter_triples
@property
def evaluation_loop(self):
"""Return the evaluation loop instance (lazy-initialization)."""
if self._evaluation_loop is None:
self._evaluation_loop = LCWAEvaluationLoop(
triples_factory=self.factory,
evaluator=self.evaluator,
model=self.model,
additional_filter_triples=self.additional_filter_triples,
)
return self._evaluation_loop
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
if epoch % self.frequency:
return
result = self.evaluation_loop.evaluate(**self.kwargs)
self.result_tracker.log_metrics(metrics=result.to_flat_dict(), step=epoch, prefix=self.prefix)
[docs]
class StopperTrainingCallback(TrainingCallback):
"""An adapter for the :class:`pykeen.stopper.Stopper`."""
def __init__(
self,
stopper: Stopper,
*,
triples_factory: CoreTriplesFactory,
last_best_epoch: int | None = None,
best_epoch_model_file_path: pathlib.Path | None,
):
"""
Initialize the callback.
:param stopper:
the stopper
:param triples_factory:
the triples factory used for saving the state
:param last_best_epoch:
the last best epoch
:param best_epoch_model_file_path:
the path under which to store the best model checkpoint
"""
super().__init__()
self.stopper = stopper
self.triples_factory = triples_factory
self.last_best_epoch = last_best_epoch
self.best_epoch_model_file_path = best_epoch_model_file_path
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
if self.stopper.should_evaluate(epoch):
# TODO how to pass inductive mode
if self.stopper.should_stop(epoch):
self.training_loop._should_stop = True
# Since the model is also used within the stopper, its graph and cache have to be cleared
self.model._free_graph_and_cache()
# When the stopper obtained a new best epoch, this model has to be saved for reconstruction
if self.stopper.best_epoch != self.last_best_epoch and self.best_epoch_model_file_path is not None:
self.training_loop._save_state(path=self.best_epoch_model_file_path, triples_factory=self.triples_factory)
self.last_best_epoch = epoch
class OptimizerTrainingCallback(TrainingCallback):
"""Use optimizer to update parameters."""
# TODO: we may want to separate TrainingCallback from pre-step callbacks in the future
def __init__(self, only_size_probing: bool = False, pre_step_callbacks: Sequence[TrainingCallback] | None = None):
"""Initialize the callback.
:param only_size_probing:
whether this is during size probing, where we do not want to apply weight changes
:param pre_step_callbacks:
callbacks to apply before making the step, e.g., for gradient clipping.
"""
super().__init__()
self.only_size_probing = only_size_probing
self.pre_step_callbacks = tuple(pre_step_callbacks or [])
# docstr-coverage: inherited
def pre_batch(self, **kwargs: Any) -> None: # noqa: D102
# Recall that torch *accumulates* gradients. Before passing in a
# new instance, you need to zero out the gradients from the old instance
# note: we want to run this step during size probing to cleanup any remaining grads
self.optimizer.zero_grad(set_to_none=True)
# docstr-coverage: inherited
def post_batch(self, epoch: int, batch, **kwargs: Any) -> None: # noqa: D102
# pre-step callbacks
for cb in self.pre_step_callbacks:
cb.pre_step(epoch=epoch, **kwargs)
# when called by batch_size_search(), the parameter update should not be applied.
if not self.only_size_probing:
# update parameters according to optimizer
self.optimizer.step()
# After changing applying the gradients to the embeddings, the model is notified that the forward
# constraints are no longer applied
# note: we want to apply this during size probing to properly account for the memory necessary for e.g.,
# regularization
self.model.post_parameter_update()
class LearningRateSchedulerTrainingCallback(TrainingCallback):
"""Update learning rate scheduler."""
# docstr-coverage: inherited
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
if self.training_loop.lr_scheduler is None:
raise ValueError(f"{self} can only be called when a learning rate schedule is used.")
self.training_loop.lr_scheduler.step(epoch=epoch)
def _hasher(kwargs: Mapping[str, Any]) -> int:
# do not share optimal parameters across different training loops
return id(kwargs["training_loop"])
@maximize_memory_utilization(parameter_name=("batch_size", "slice_size"), hasher=_hasher)
@torch.inference_mode()
def _validation_loss_amo_wrapper(
training_loop: training.TrainingLoop,
triples_factory: CoreTriplesFactory,
batch_size: int,
slice_size: int,
label_smoothing: float,
epoch: int,
callback: MultiTrainingCallback,
**kwargs,
) -> float:
"""Calculate validation loss with automatic batch size optimization."""
return training_loop._train_epoch(
# todo: create dataset only once
batches=training_loop._create_training_data_loader(
triples_factory=triples_factory, batch_size=batch_size, drop_last=False, **kwargs
),
label_smoothing=label_smoothing,
callbacks=callback,
epoch=epoch,
# no sub-batching (for evaluation, we can just reduce batch size without any effect)
sub_batch_size=None,
slice_size=slice_size if training_loop.supports_slicing else None,
# this is handled by the AMO wrapper
only_size_probing=False,
# no backward passes
backward=False,
)
class EvaluationLossTrainingCallback(TrainingCallback):
"""
Calculate loss on an evaluation set.
.. code-block ::
from pykeen.datasets import get_dataset
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations")
pipeline(
dataset=dataset,
model="mure",
training_kwargs=dict(
callbacks="evaluation-loss",
callback_kwargs=dict(triples_factory=dataset.validation),
prefix="validation",
),
result_tracker="console",
)
"""
def __init__(
self,
triples_factory: CoreTriplesFactory,
callbacks: TrainingCallbackHint = None,
callbacks_kwargs: TrainingCallbackKwargsHint = None,
maximum_batch_size: int | None = None,
label_smoothing: float = 0.0,
data_loader_kwargs: Mapping[str, Any] | None = None,
prefix: str = "validation",
):
"""
Initialize the callback.
:param triples_factory:
the evaluation triples factory
:param callbacks:
callbacks for the validation loss loop
:param callbacks_kwargs:
keyword-based parameters for the callbacks of the validation loss loop
:param maximum_batch_size:
the maximum batch size
:param label_smoothing:
the label smoothing to use; usually this should be matched with the training settings
:param data_loader_kwargs:
the keyword based parameters for the data loader
:param prefix:
the prefix to use for logging
"""
super().__init__()
self.triples_factory = triples_factory
self.prefix = prefix
self.label_smoothing = label_smoothing
if data_loader_kwargs is None:
data_loader_kwargs = dict(sampler=None)
self.data_loader_kwargs = data_loader_kwargs
self.maximum_batch_size = maximum_batch_size
self.callback = MultiTrainingCallback(callbacks=callbacks, callbacks_kwargs=callbacks_kwargs)
# docstr-coverage: inherited
def register_training_loop(self, training_loop: training.TrainingLoop) -> None: # noqa: D102
super().register_training_loop(training_loop)
self.callback.register_training_loop(training_loop=training_loop)
# docstr-coverage: inherited
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
from .lcwa import LCWATrainingLoop
# set to evaluation mode
self.model.eval()
# determine maximum batch size
maximum_batch_size = determine_maximum_batch_size(
batch_size=self.maximum_batch_size,
device=self.model.device,
# TODO: this should be num_instances rather than num_triples
maximum_batch_size=self.triples_factory.num_triples,
)
loss = _validation_loss_amo_wrapper(
training_loop=self.training_loop,
triples_factory=self.triples_factory,
batch_size=maximum_batch_size,
# note: slicing is only effective for LCWA training
slice_size=self.training_loop.num_targets if isinstance(self.training_loop, LCWATrainingLoop) else 1,
label_smoothing=self.label_smoothing,
callback=self.callback,
epoch=epoch,
**self.data_loader_kwargs,
)
self.result_tracker.log_metrics(metrics=dict(loss=loss), step=epoch, prefix=self.prefix)
#: A hint for constructing a :class:`MultiTrainingCallback`
TrainingCallbackHint = OneOrSequence[HintOrType[TrainingCallback]]
TrainingCallbackKwargsHint = OneOrSequence[OptionalKwargs]
[docs]
class MultiTrainingCallback(TrainingCallback):
"""A wrapper for calling multiple training callbacks together."""
#: A collection of callbacks
callbacks: list[TrainingCallback]
def __init__(
self,
callbacks: TrainingCallbackHint = None,
callbacks_kwargs: TrainingCallbackKwargsHint = None,
) -> None:
"""
Initialize the callback.
.. note ::
the constructor allows "broadcasting" of callbacks, i.e., proving a single callback,
but a list of callback kwargs. In this case, for each element of this list the given
callback is instantiated.
:param callbacks:
the callbacks
:param callbacks_kwargs:
additional keyword-based parameters for instantiating the callbacks
"""
super().__init__()
self.callbacks = callback_resolver.make_many(callbacks, callbacks_kwargs) if callbacks else []
# docstr-coverage: inherited
[docs]
def register_training_loop(self, training_loop: training.TrainingLoop) -> None: # noqa: D102
super().register_training_loop(training_loop=training_loop)
for callback in self.callbacks:
callback.register_training_loop(training_loop=training_loop)
[docs]
def register_callback(self, callback: TrainingCallback) -> None:
"""Register a callback."""
self.callbacks.append(callback)
if self._training_loop is not None:
callback.register_training_loop(self._training_loop)
# docstr-coverage: inherited
[docs]
def pre_batch(self, **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.pre_batch(**kwargs)
# docstr-coverage: inherited
[docs]
def on_batch(self, epoch: int, batch, batch_loss: float, **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.on_batch(epoch=epoch, batch=batch, batch_loss=batch_loss, **kwargs)
# docstr-coverage: inherited
[docs]
def post_batch(self, epoch: int, batch, **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.post_batch(epoch=epoch, batch=batch, **kwargs)
# docstr-coverage: inherited
[docs]
def pre_step(self, **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.pre_step(**kwargs)
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.post_epoch(epoch=epoch, epoch_loss=epoch_loss, **kwargs)
# docstr-coverage: inherited
[docs]
def post_train(self, losses: list[float], **kwargs: Any) -> None: # noqa: D102
for callback in self.callbacks:
callback.post_train(losses=losses, **kwargs)
[docs]
class CheckpointTrainingCallback(TrainingCallback):
"""Save checkpoints at user-specific epochs."""
def __init__(
self,
schedule: HintOrType[CheckpointSchedule] = None,
schedule_kwargs: OptionalKwargs = None,
keeper: HintOrType[CheckpointKeeper] = None,
keeper_kwargs: OptionalKwargs = None,
root: pathlib.Path | str | None = None,
name_template: str = "checkpoint_{epoch:07d}.pt",
):
"""
Create callback.
:param schedule:
a selection of the checkpoint schedule, cf. :const:`pykeen.checkpoints.scheduler_resolver`
:param schedule_kwargs:
keyword-based parameters to instantiate the checkpoint schedule, if necessary,
cf. :const:`pykeen.checkpoints.scheduler_resolver`
:param keeper:
a selection of the checkpoint retention logic, cf. :const:`pykeen.checkpoints.keeper_resolver`.
`None` corresponds to keeping all checkpoints (which were created).
:param keeper_kwargs:
keyword-based parameters to instantiate the retention policy, if necessary,
cf. :const:`pykeen.checkpoints.keeper_resolver`
:param root:
the checkpoint root directory. Defaults to a fresh sub-directory of
:const:`pykeen.constants.PYKEEN_CHECKPOINTS`
:param name_template:
a name template for the checkpoint file. Can contain a format key `{epoch}` which is replaced by the actual
epoch. This callback does not take care of overwriting existing files, i.e., if you want to keep multiple
checkpoints make sure to choose unique filenames.
"""
super().__init__()
self.schedule = schedule_resolver.make(schedule, schedule_kwargs)
self.keeper = keeper_resolver.make_safe(keeper, keeper_kwargs)
self.checkpoint_store: dict[int, pathlib.Path] = dict()
if root is None:
while (path := PYKEEN_CHECKPOINTS.joinpath(str(uuid.uuid4()))).exists():
continue
root = path
logger.info(f"Inferred checkpoint {path= !s}")
self.root = pathlib.Path(root)
self.name_template = name_template
self.root.mkdir(parents=True, exist_ok=True)
# docstr-coverage: inherited
[docs]
def post_epoch(self, epoch: int, epoch_loss: float, **kwargs: Any) -> None:
# use 1-based epochs
epoch += 1
if not self.schedule(epoch):
return
# save checkpoint
path = self.root.joinpath(self.name_template.format(epoch=epoch))
save_model(self.training_loop.model, path)
logger.info(f"Saved checkpoint for {epoch=:_} to {path= !s}")
# None corresponds to no clean-up
if self.keeper is None:
return
# add newly saved checkpoint to the store
if epoch in self.checkpoint_store:
raise ValueError(
f"Cannot add multiple checkpoints for a single {epoch=}, "
f"but got {path= !s} and had already {self.checkpoint_store[epoch]= !s}."
)
self.checkpoint_store[epoch] = path
# delete checkpoints which we do not want to keep
for step in set(self.checkpoint_store).difference(self.keeper(sorted(self.checkpoint_store))):
path = self.checkpoint_store.pop(step)
path.unlink()
logger.info(f"Deleted checkpoint for {step=:_} at {path= !s}")
#: A resolver for training callbacks
callback_resolver: ClassResolver[TrainingCallback] = ClassResolver.from_subclasses(
base=TrainingCallback,
skip={MultiTrainingCallback},
)