"""Training KGE models based on the sLCWA."""
import logging
import warnings
from typing import Any, Optional
from class_resolver import HintOrType, OptionalKwargs
from torch.utils.data import DataLoader, Dataset
from .training_loop import TrainingLoop
from ..losses import Loss
from ..models.base import Model
from ..sampling import NegativeSampler
from ..triples import CoreTriplesFactory
from ..triples.instances import BatchedSLCWAInstances, SLCWABatch, SLCWASampleType, SubGraphSLCWAInstances
from ..typing import FloatTensor, InductiveMode
__all__ = [
"SLCWATrainingLoop",
]
logger = logging.getLogger(__name__)
[docs]
class SLCWATrainingLoop(TrainingLoop[SLCWASampleType, SLCWABatch]):
"""A training loop that uses the stochastic local closed world assumption training approach.
[ruffinelli2020]_ call the sLCWA ``NegSamp`` in their work.
"""
def __init__(
self,
negative_sampler: HintOrType[NegativeSampler] = None,
negative_sampler_kwargs: OptionalKwargs = None,
**kwargs,
):
"""Initialize the training loop.
:param negative_sampler: The class, instance, or name of the negative sampler
:param negative_sampler_kwargs: Keyword arguments to pass to the negative sampler class on instantiation
for every positive one
:param kwargs:
Additional keyword-based parameters passed to TrainingLoop.__init__
"""
super().__init__(**kwargs)
self.negative_sampler = negative_sampler
self.negative_sampler_kwargs = negative_sampler_kwargs
# docstr-coverage: inherited
def _create_training_data_loader(
self, triples_factory: CoreTriplesFactory, sampler: Optional[str], batch_size: int, drop_last: bool, **kwargs
) -> DataLoader[SLCWABatch]: # noqa: D102
assert "batch_sampler" not in kwargs
return DataLoader(
dataset=create_slcwa_instances(
triples_factory,
batch_size=batch_size,
shuffle=kwargs.pop("shuffle", True),
drop_last=drop_last,
negative_sampler=self.negative_sampler,
negative_sampler_kwargs=self.negative_sampler_kwargs,
sampler=sampler,
),
# disable automatic batching
batch_size=None,
batch_sampler=None,
**kwargs,
)
@staticmethod
# docstr-coverage: inherited
def _get_batch_size(batch: SLCWABatch) -> int: # noqa: D102
return batch[0].shape[0]
@staticmethod
def _process_batch_static(
model: Model,
loss: Loss,
mode: Optional[InductiveMode],
batch: SLCWABatch,
start: Optional[int],
stop: Optional[int],
label_smoothing: float = 0.0,
slice_size: Optional[int] = None,
) -> FloatTensor:
# Slicing is not possible in sLCWA training loops
if slice_size is not None:
raise AttributeError("Slicing is not possible for sLCWA training loops.")
# split batch
positive_batch, negative_batch, positive_filter = batch
# send to device
positive_batch = positive_batch[start:stop].to(device=model.device)
negative_batch = negative_batch[start:stop]
if positive_filter is not None:
positive_filter = positive_filter[start:stop]
negative_batch = negative_batch[positive_filter]
positive_filter = positive_filter.to(model.device)
# Make it negative batch broadcastable (required for num_negs_per_pos > 1).
negative_score_shape = negative_batch.shape[:-1]
negative_batch = negative_batch.view(-1, 3)
# Ensure they reside on the device (should hold already for most simple negative samplers, e.g.
# BasicNegativeSampler, BernoulliNegativeSampler
negative_batch = negative_batch.to(model.device)
# Compute negative and positive scores
positive_scores = model.score_hrt(positive_batch, mode=mode)
negative_scores = model.score_hrt(negative_batch, mode=mode).view(*negative_score_shape)
return (
loss.process_slcwa_scores(
positive_scores=positive_scores,
negative_scores=negative_scores,
label_smoothing=label_smoothing,
batch_filter=positive_filter,
num_entities=model._get_entity_len(mode=mode),
)
+ model.collect_regularization_term()
)
# docstr-coverage: inherited
def _process_batch(
self,
batch: SLCWABatch,
start: int,
stop: int,
label_smoothing: float = 0.0,
slice_size: Optional[int] = None,
) -> FloatTensor: # noqa: D102
return self._process_batch_static(
model=self.model,
loss=self.loss,
mode=self.mode,
batch=batch,
start=start,
stop=stop,
label_smoothing=label_smoothing,
slice_size=slice_size,
)
# docstr-coverage: inherited
def _slice_size_search(
self,
*,
triples_factory: CoreTriplesFactory,
batch_size: int,
sub_batch_size: int,
supports_sub_batching: bool,
): # noqa: D102
# Slicing is not possible for sLCWA
if supports_sub_batching:
report = "This model supports sub-batching, but it also requires slicing, which is not possible for sLCWA"
else:
report = "This model doesn't support sub-batching and slicing is not possible for sLCWA"
logger.warning(report)
raise MemoryError("The current model can't be trained on this hardware with these parameters.")
def create_slcwa_instances(
triples_factory: CoreTriplesFactory,
*,
sampler: Optional[str] = None,
**kwargs: Any,
) -> Dataset:
"""Create sLCWA instances for this factory's triples."""
cls = BatchedSLCWAInstances if sampler is None else SubGraphSLCWAInstances
if "shuffle" in kwargs:
if kwargs.pop("shuffle"):
warnings.warn("Training instances are always shuffled.", DeprecationWarning, stacklevel=2)
else:
raise AssertionError("If shuffle is provided, it must be True.")
return cls(
mapped_triples=triples_factory._add_inverse_triples_if_necessary(mapped_triples=triples_factory.mapped_triples),
num_entities=triples_factory.num_entities,
num_relations=triples_factory.num_relations,
**kwargs,
)