Flexible Weight Checkpoints

This module contains methods for deciding when to write and clear checkpoints.

Warning

While this module provides a flexible and modular way to describe a desired checkpoint behavior, it currently only stores the model’s weights (more precisely, its torch.nn.Module.state_dict()). Thus, it does not yet replace the full training loop checkpointing mechanism described in Regular Checkpoints.

It consists of two main components: checkpoint schedules decide whether to write a checkpoint at a given epoch. If we have multiple checkpoints, we can use multiple keep strategies to decide which checkpoints to keep and which to discard. For both, we provide a set of basic rules, as well as a way to combine them via union. Those should be sufficient to easily model most of the desired checkpointing behaviours.

Examples

Below you can find a few examples of how to use them inside the training pipeline. If you want to check before an actual training how (static) checkpoint schedules behave, you can take a look at pykeen.checkpoints.final_checkpoints() and pykeen.checkpoints.simulate_checkpoints().

To reduce the number of necessary imports, the examples all use dictionaries/strings to specify components instead of passing classes or actual instances. You can find more information about resolution in general at Using Resolvers. The resolver for the schedule component is pykeen.checkpoints.schedule.schedule_resolver, and for the keeper component it is pykeen.checkpoints.keeper_resolver.

Example 1

"""Write a checkpoint every 10 steps and keep them all."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=100,
        callbacks="checkpoint",
        # create one checkpoint every 10 epochs
        callbacks_kwargs=dict(
            schedule="every",
            schedule_kwargs=dict(
                frequency=10,
            ),
        ),
    ),
)

Example 2

"""Write a checkpoint at epoch 1, 7, and 10 and keep them all."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        # create checkpoints at epoch 1, 7, and 10
        callbacks_kwargs=dict(
            schedule="explicit",
            schedule_kwargs=dict(steps=(1, 7, 10)),
        ),
    ),
)

Example 3

"""Write a checkpoint avery 5 epochs, but also at epoch 7."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        callbacks_kwargs=dict(
            schedule="union",
            # create checkpoints every 5 epochs, and at epoch 7
            schedule_kwargs=dict(bases=["every", "explicit"], bases_kwargs=[dict(frequency=5), dict(steps=[7])]),
        ),
    ),
)

Example 4

"""Write a checkpoint whenever a metric improves (here, just the training loss)."""

from pykeen.checkpoints import MetricSelection
from pykeen.pipeline import pipeline
from pykeen.trackers import tracker_resolver

# create a default result tracker (or use a proper one)
result_tracker = tracker_resolver.make(None)
result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=10,
        callbacks="checkpoint",
        callbacks_kwargs=dict(
            schedule="best",
            schedule_kwargs=dict(
                result_tracker=result_tracker,
                # in this example, we just use the training loss
                metric_selection=MetricSelection(
                    metric="loss",
                    maximize=False,
                ),
            ),
        ),
    ),
    # Important: use the same result tracker instance as in the checkpoint callback
    result_tracker=result_tracker,
)

Example 5

"""Write a checkpoint every 10 steps, but keep only the last one and one every 50 steps."""

from pykeen.pipeline import pipeline

result = pipeline(
    dataset="nations",
    model="mure",
    training_kwargs=dict(
        num_epochs=100,
        callbacks="checkpoint",
        # create one checkpoint every 10 epochs
        callbacks_kwargs=dict(
            schedule="every",
            schedule_kwargs=dict(
                frequency=10,
            ),
            keeper="union",
            keeper_kwargs=dict(
                bases=["modulo", "last"],
                bases_kwargs=[dict(divisor=50), None],
            ),
        ),
    ),
)

Functions

save_model(model, file)

Save a model to a file.

simulate_checkpoints([num_epochs, schedule, ...])

Simulate a checkpoint schedule and print information about checkpointing.

final_checkpoints([num_epochs, schedule, ...])

Simulate a checkpoint schedule and return the set of epochs for which a checkpoint remains.

Classes

CheckpointSchedule()

Interface for checkpoint schedules.

EveryCheckpointSchedule([frequency])

Create a checkpoint every \(n\) steps.

ExplicitCheckpointSchedule(steps)

Create a checkpoint for explicitly chosen steps.

BestCheckpointSchedule(result_tracker, ...)

Create a checkpoint whenever a metric improves.

UnionCheckpointSchedule(bases[, bases_kwargs])

Create a checkpoint whenever one of the base schedules requires it.

CheckpointKeeper()

A checkpoint cleanup interface.

LastCheckpointKeeper([keep])

Keep the last \(n\) checkpoints.

ModuloCheckpointKeeper([divisor])

Keep checkpoints if the step is divisible by a number.

ExplicitCheckpointKeeper(keep)

Keep checkpoints at explicit steps.

BestCheckpointKeeper(result_tracker, ...)

Keep checkpoints for steps that achieved the best value for a metric.

UnionCheckpointKeeper(bases[, bases_kwargs])

Keep a checkpoint where one of the criteria is met.

MetricSelection(metric[, prefix, maximize])

The selection of the metric to monitor.

Class Inheritance Diagram

Inheritance diagram of pykeen.checkpoints.schedule.CheckpointSchedule, pykeen.checkpoints.schedule.EveryCheckpointSchedule, pykeen.checkpoints.schedule.ExplicitCheckpointSchedule, pykeen.checkpoints.schedule.BestCheckpointSchedule, pykeen.checkpoints.schedule.UnionCheckpointSchedule, pykeen.checkpoints.keeper.CheckpointKeeper, pykeen.checkpoints.keeper.LastCheckpointKeeper, pykeen.checkpoints.keeper.ModuloCheckpointKeeper, pykeen.checkpoints.keeper.ExplicitCheckpointKeeper, pykeen.checkpoints.keeper.BestCheckpointKeeper, pykeen.checkpoints.keeper.UnionCheckpointKeeper, pykeen.checkpoints.utils.MetricSelection