BaseBatchedSLCWAInstances

class BaseBatchedSLCWAInstances(mapped_triples: Tensor, batch_size: int = 1, drop_last: bool = True, num_entities: int | None = None, num_relations: int | None = None, negative_sampler: str | NegativeSampler | type[NegativeSampler] | None = None, negative_sampler_kwargs: Mapping[str, Any] | None = None, loss_weighter: str | LossWeighter | type[LossWeighter] | None = None, loss_weighter_kwargs: Mapping[str, Any] | None = None)[source]

Bases: Instances[SLCWABatch], IterableDataset[SLCWABatch]

Pre-batched training instances for the sLCWA training loop.

Note

this class is intended to be used with automatic batching disabled, i.e., both parameters batch_size and batch_sampler of torch.utils.data.DataLoader` are set to None.

Initialize the dataset.

Parameters:
  • mapped_triples (MappedTriples) – shape: (num_triples, 3) the mapped triples

  • batch_size (int) – the batch size

  • drop_last (bool) – whether to drop the last (incomplete) batch

  • num_entities (int | None) – >0 the number of entities, passed to the negative sampler

  • num_relations (int | None) – >0 the number of relations, passed to the negative sampler

  • negative_sampler (HintOrType[NegativeSampler]) – the negative sampler, or a hint thereof

  • negative_sampler_kwargs (OptionalKwargs) – additional keyword-based parameters used to instantiate the negative sampler

  • loss_weighter (HintOrType[LossWeighter]) – The method to determine sample weights.

  • loss_weighter_kwargs (OptionalKwargs) – Parameters for the method to determine sample weights.

Note

2 resolvers are used in this function.

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

from_triples_factory(tf, **kwargs)

Create sLCWA instances for triples factory.

iter_triple_ids()

Iterate over batches of IDs of positive triples.

Methods Documentation

classmethod from_triples_factory(tf: CoreTriplesFactory, **kwargs) Self[source]

Create sLCWA instances for triples factory.

Parameters:

tf (CoreTriplesFactory)

Return type:

Self

abstractmethod iter_triple_ids() Iterable[list[int]][source]

Iterate over batches of IDs of positive triples.

Return type:

Iterable[list[int]]