Source code for

# -*- coding: utf-8 -*-

"""Training KGE models based on the LCWA."""

import logging
from math import ceil
from typing import Optional, Tuple

import torch

from .training_loop import TrainingLoop
from .utils import apply_label_smoothing
from ..triples import LCWAInstances
from ..typing import MappedTriples

__all__ = [

logger = logging.getLogger(__name__)

[docs]class LCWATrainingLoop(TrainingLoop): """A training loop that uses the local closed world assumption training approach.""" def _create_instances(self, use_tqdm: Optional[bool] = None) -> LCWAInstances: # noqa: D102 return self.triples_factory.create_lcwa_instances(use_tqdm=use_tqdm) @staticmethod def _get_batch_size(batch: Tuple[MappedTriples, torch.FloatTensor]) -> int: # noqa: D102 return batch[0].shape[0] def _process_batch( self, batch: Tuple[MappedTriples, torch.FloatTensor], start: int, stop: int, label_smoothing: float = 0.0, slice_size: Optional[int] = None, ) -> torch.FloatTensor: # noqa: D102 # Split batch components batch_pairs, batch_labels_full = batch # Send batch to device batch_pairs = batch_pairs[start:stop].to(device=self.device) batch_labels_full = batch_labels_full[start:stop].to(device=self.device) if slice_size is None: predictions = self.model.score_t(hr_batch=batch_pairs) else: predictions = self.model.score_t(hr_batch=batch_pairs, slice_size=slice_size) loss = self._loss_helper( predictions, batch_labels_full, label_smoothing, ) return loss def _label_loss_helper( self, predictions: torch.FloatTensor, labels: torch.FloatTensor, label_smoothing: float, ) -> torch.FloatTensor: # Apply label smoothing if label_smoothing > 0.: labels = apply_label_smoothing( labels=labels, epsilon=label_smoothing, num_classes=self.model.num_entities, ) return self.model._compute_loss( tensor_1=predictions, tensor_2=labels, ) def _mr_loss_helper( self, predictions: torch.FloatTensor, labels: torch.FloatTensor, _label_smoothing=None, ) -> torch.FloatTensor: # This shows how often one row has to be repeated repeat_rows = (labels == 1).nonzero(as_tuple=False)[:, 0] # Create boolean indices for negative labels in the repeated rows labels_negative = labels[repeat_rows] == 0 # Repeat the predictions and filter for negative labels negative_scores = predictions[repeat_rows][labels_negative] # This tells us how often each true label should be repeated repeat_true_labels = (labels[repeat_rows] == 0).nonzero(as_tuple=False)[:, 0] # First filter the predictions for true labels and then repeat them based on the repeat vector positive_scores = predictions[labels == 1][repeat_true_labels] return self.model.compute_mr_loss( positive_scores=positive_scores, negative_scores=negative_scores, ) def _self_adversarial_negative_sampling_loss_helper( self, predictions: torch.FloatTensor, labels: torch.FloatTensor, _label_smoothing=None, ) -> torch.FloatTensor: """Compute self adversarial negative sampling loss.""" # Split positive and negative scores positive_scores = predictions[labels == 1] negative_scores = predictions[labels == 0] return self.model.compute_self_adversarial_negative_sampling_loss( positive_scores=positive_scores, negative_scores=negative_scores, ) def _slice_size_search( self, batch_size: int, sub_batch_size: int, supports_sub_batching: bool, ) -> int: # noqa: D102 self._check_slicing_availability(supports_sub_batching) reached_max = False evaluated_once = False"Trying slicing now.") # Since the batch_size search with size 1, i.e. one tuple ((h, r) or (r, t)) scored on all entities, # must have failed to start slice_size search, we start with trying half the entities. slice_size = ceil(self.model.num_entities / 2) while True: try: logger.debug(f'Trying slice size {slice_size} now.') self._train( num_epochs=1, batch_size=batch_size, sub_batch_size=sub_batch_size, slice_size=slice_size, only_size_probing=True, ) except RuntimeError as e: self._free_graph_and_cache() if 'CUDA out of memory.' not in e.args[0]: raise e if evaluated_once: slice_size //= 2'Concluded search with slice_size {slice_size}.') break if slice_size == 1: raise MemoryError( f"Even slice_size={slice_size} doesn't fit into your memory with these" f" parameters.", ) from e logger.debug( f'The slice_size {slice_size} was too big, trying less now.', ) slice_size //= 2 reached_max = True else: self._free_graph_and_cache() if reached_max:'Concluded search with slice_size {slice_size}.') break slice_size *= 2 evaluated_once = True return slice_size def _check_slicing_availability(self, supports_sub_batching: bool): if self.model.can_slice_t: return elif supports_sub_batching: report = ( "This model supports sub-batching, but it also requires slicing," " which is not implemented for this model yet." ) else: report = ( "This model doesn't support sub-batching and slicing is not" " implemented for this model yet." ) logger.warning(report) raise MemoryError("The current model can't be trained on this hardware with these parameters.")