Source code for pykeen.training.lcwa

# -*- coding: utf-8 -*-

"""Training KGE models based on the LCWA."""

import logging
from math import ceil
from typing import Optional

import torch

from .training_loop import TrainingLoop
from ..triples import CoreTriplesFactory, Instances
from ..triples.instances import LCWABatchType, LCWASampleType

__all__ = [
    'LCWATrainingLoop',
]

logger = logging.getLogger(__name__)


[docs]class LCWATrainingLoop(TrainingLoop[LCWASampleType, LCWABatchType]): """A training loop that uses the local closed world assumption training approach.""" def _create_instances(self, triples_factory: CoreTriplesFactory) -> Instances: # noqa: D102 return triples_factory.create_lcwa_instances() @staticmethod def _get_batch_size(batch: LCWABatchType) -> int: # noqa: D102 return batch[0].shape[0] def _process_batch( self, batch: LCWABatchType, start: int, stop: int, label_smoothing: float = 0.0, slice_size: Optional[int] = None, ) -> torch.FloatTensor: # noqa: D102 # Split batch components batch_pairs, batch_labels_full = batch # Send batch to device batch_pairs = batch_pairs[start:stop].to(device=self.device) batch_labels_full = batch_labels_full[start:stop].to(device=self.device) if slice_size is None: predictions = self.model.score_t(hr_batch=batch_pairs) else: predictions = self.model.score_t(hr_batch=batch_pairs, slice_size=slice_size) # type: ignore return self.loss.process_lcwa_scores( predictions=predictions, labels=batch_labels_full, label_smoothing=label_smoothing, num_entities=self.model.num_entities, ) + self.model.collect_regularization_term() def _slice_size_search( self, *, triples_factory: CoreTriplesFactory, training_instances: Instances, batch_size: int, sub_batch_size: int, supports_sub_batching: bool, ) -> int: # noqa: D102 self._check_slicing_availability(supports_sub_batching) reached_max = False evaluated_once = False logger.info("Trying slicing now.") # Since the batch_size search with size 1, i.e. one tuple ((h, r) or (r, t)) scored on all entities, # must have failed to start slice_size search, we start with trying half the entities. slice_size = ceil(self.model.num_entities / 2) while True: try: logger.debug(f'Trying slice size {slice_size} now.') self._train( triples_factory=triples_factory, training_instances=training_instances, num_epochs=1, batch_size=batch_size, sub_batch_size=sub_batch_size, slice_size=slice_size, only_size_probing=True, ) except RuntimeError as e: self._free_graph_and_cache() if 'CUDA out of memory.' not in e.args[0]: raise e if evaluated_once: slice_size //= 2 logger.info(f'Concluded search with slice_size {slice_size}.') break if slice_size == 1: raise MemoryError( f"Even slice_size={slice_size} doesn't fit into your memory with these" f" parameters.", ) from e logger.debug( f'The slice_size {slice_size} was too big, trying less now.', ) slice_size //= 2 reached_max = True else: self._free_graph_and_cache() if reached_max: logger.info(f'Concluded search with slice_size {slice_size}.') break slice_size *= 2 evaluated_once = True return slice_size def _check_slicing_availability(self, supports_sub_batching: bool): if self.model.can_slice_t: return elif supports_sub_batching: report = ( "This model supports sub-batching, but it also requires slicing," " which is not implemented for this model yet." ) else: report = ( "This model doesn't support sub-batching and slicing is not" " implemented for this model yet." ) logger.warning(report) raise MemoryError("The current model can't be trained on this hardware with these parameters.")