"""
Checkpoint cleanup methods.
The cleanup methods determine, for any given set of existing checkpoints, which of them can be pruned.
We provide a set of basic rules that can be easily combined into more complex logic.
"""
import abc
import dataclasses
from collections.abc import Collection, Iterator, Sequence
from class_resolver import ClassResolver, OneOrManyHintOrType, OneOrManyOptionalKwargs
from .utils import MetricSelection, ResultListenerAdapter
from ..trackers.base import ResultTracker
__all__ = [
"CheckpointKeeper",
"keeper_resolver",
"LastCheckpointKeeper",
"ModuloCheckpointKeeper",
"ExplicitCheckpointKeeper",
"BestCheckpointKeeper",
"UnionCheckpointKeeper",
]
[docs]
class CheckpointKeeper(abc.ABC):
"""A checkpoint cleanup interface."""
[docs]
@abc.abstractmethod
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
"""Iterate over the steps for which checkpoints should be kept.
:param steps:
the sorted list of steps at which checkpoints were written.
:yields:
the steps for which checkpoints should be kept
"""
[docs]
@dataclasses.dataclass
class LastCheckpointKeeper(CheckpointKeeper):
"""Keep the last $n$ checkpoints."""
#: the number of checkpoints to keep
keep: int = 1
[docs]
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
yield from steps[-self.keep :]
[docs]
@dataclasses.dataclass
class ModuloCheckpointKeeper(CheckpointKeeper):
"""Keep checkpoints if the step is divisible by a number."""
divisor: int = 10
[docs]
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
for step in steps:
if step % self.divisor == 0:
yield step
[docs]
@dataclasses.dataclass
class ExplicitCheckpointKeeper(CheckpointKeeper):
"""Keep checkpoints at explicit steps."""
keep: Collection[int]
def __post_init__(self):
# convert to set for better lookup speed
self.keep = set(self.keep)
[docs]
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
# the set operation should be a nop of sets
yield from set(self.keep).intersection(steps)
[docs]
@dataclasses.dataclass
class BestCheckpointKeeper(CheckpointKeeper):
"""Keep checkpoints for steps that achieved the best value for a metric."""
#: the result tracker which receives updates on metrics
#: since the same tracker instance needs to receive results from the training loop, we do require a pre-instantiated
#: one rather than offering to provide hints, too
result_tracker: ResultTracker
#: the metric selection
metric_selection: MetricSelection
# note: internal detail
_adapter: ResultListenerAdapter = dataclasses.field(init=False)
def __post_init__(self):
self._adapter = ResultListenerAdapter(self.result_tracker, metric_selection=self.metric_selection)
[docs]
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
return filter(self._adapter.is_best, steps)
[docs]
@dataclasses.dataclass
class UnionCheckpointKeeper(CheckpointKeeper):
"""Keep a checkpoint where one of the criteria is met."""
bases: OneOrManyHintOrType[CheckpointKeeper]
bases_kwargs: OneOrManyOptionalKwargs = None
_bases: Sequence[CheckpointKeeper] = dataclasses.field(init=False)
def __post_init__(self):
self._bases = keeper_resolver.make_many(self.bases, self.bases_kwargs)
[docs]
def __call__(self, steps: Sequence[int]) -> Iterator[int]:
result: set[int] = set()
for base in self._bases:
result.update(base(steps))
yield from result
#: a resolver for checkpoint keepers
keeper_resolver: ClassResolver[CheckpointKeeper] = ClassResolver.from_subclasses(
CheckpointKeeper, default=CheckpointKeeper
)