# -*- coding: utf-8 -*-
"""Basic stoppers."""
import logging
import pathlib
from abc import ABC, abstractmethod
from typing import Any, List, Mapping, Union
import torch
__all__ = [
'Stopper',
'NopStopper',
]
logger = logging.getLogger(__name__)
[docs]class Stopper(ABC):
"""A harness for stopping training."""
def __init__(self, *args, **kwargs):
# To make MyPy happy
self.best_epoch = None
[docs] def should_evaluate(self, epoch: int) -> bool:
"""Check if the stopper should be evaluated on the given epoch."""
raise NotImplementedError
[docs] @abstractmethod
def should_stop(self, epoch: int) -> bool:
"""Validate on validation set and check for termination condition."""
raise NotImplementedError
[docs] @abstractmethod
def get_summary_dict(self) -> Mapping[str, Any]:
"""Get a summary dict."""
raise NotImplementedError
def _write_from_summary_dict(
self,
frequency: int,
patience: int,
relative_delta: float,
metric: str,
larger_is_better: bool,
results: List[float],
stopped: bool,
best_epoch: int,
best_metric: float,
):
pass
[docs] @staticmethod
def load_summary_dict_from_training_loop_checkpoint(path: Union[str, pathlib.Path]) -> Mapping[str, Any]:
"""Load the summary dict from a training loop checkpoint.
:param path:
Path of the file where to store the state in.
:return:
The summary dict of the stopper at the time of saving the checkpoint.
"""
logger.info(f"=> loading stopper summary dict from training loop checkpoint in '{path}'")
checkpoint = torch.load(path)
logger.info(f"=> loaded stopper summary dictionary from checkpoint in '{path}'")
return checkpoint['stopper_dict']
[docs]class NopStopper(Stopper):
"""A stopper that does nothing."""
[docs] def should_evaluate(self, epoch: int) -> bool:
"""Return false; should never evaluate."""
return False
[docs] def should_stop(self, epoch: int) -> bool:
"""Return false; should never stop."""
return False
[docs] def get_summary_dict(self) -> Mapping[str, Any]:
"""Return empty mapping, doesn't have any attributes."""
return dict()