"""Basic structure for a negative sampler."""
from abc import abstractmethod
from collections.abc import Mapping
from typing import Any, ClassVar
from class_resolver import HintOrType, normalize_string
from torch import nn
from .filtering import Filterer, filterer_resolver
from ..typing import BoolTensor, LongTensor, MappedTriples
__all__ = [
"NegativeSampler",
]
[docs]
class NegativeSampler(nn.Module):
"""A negative sampler."""
#: The default strategy for optimizing the negative sampler's hyper-parameters
hpo_default: ClassVar[Mapping[str, Mapping[str, Any]]] = {
"num_negs_per_pos": {"type": int, "low": 1, "high": 100, "log": True},
}
#: A filterer for negative batches
filterer: Filterer | None
num_entities: int
num_relations: int
num_negs_per_pos: int
def __init__(
self,
*,
mapped_triples: MappedTriples,
num_entities: int | None = None,
num_relations: int | None = None,
num_negs_per_pos: int | None = None,
filtered: bool = False,
filterer: HintOrType[Filterer] = None,
filterer_kwargs: Mapping[str, Any] | None = None,
) -> None:
"""Initialize the negative sampler with the given entities.
:param mapped_triples: the positive training triples
:param num_entities: the number of entities. If None, will be inferred from the triples.
:param num_relations: the number of relations. If None, will be inferred from the triples.
:param num_negs_per_pos: number of negative samples to make per positive triple. Defaults to 1.
:param filtered: Whether proposed corrupted triples that are in the training data should be filtered. Defaults
to False. See explanation in :func:`filter_negative_triples` for why this is a reasonable default.
:param filterer: If filtered is set to True, this can be used to choose which filter module from
:mod:`pykeen.sampling.filtering` is used.
:param filterer_kwargs: Additional keyword-based arguments passed to the filterer upon construction.
"""
super().__init__()
self.num_entities = num_entities or mapped_triples[:, [0, 2]].max().item() + 1
self.num_relations = num_relations or mapped_triples[:, 1].max().item() + 1
self.num_negs_per_pos = num_negs_per_pos if num_negs_per_pos is not None else 1
self.filterer = (
filterer_resolver.make(
filterer,
pos_kwargs=filterer_kwargs,
mapped_triples=mapped_triples,
)
if filterer is not None or filtered
else None
)
[docs]
@classmethod
def get_normalized_name(cls) -> str:
"""Get the normalized name of the negative sampler."""
return normalize_string(cls.__name__, suffix=NegativeSampler.__name__)
[docs]
def sample(self, positive_batch: LongTensor) -> tuple[LongTensor, BoolTensor | None]:
"""Generate negative samples from the positive batch.
:param positive_batch: shape: (batch_size, 3) The positive triples.
:returns: A pair `(negative_batch, filter_mask)` where
1. `negative_batch`: shape: (batch_size, num_negatives, 3) The negative batch. ``negative_batch[i, :, :]``
contains the negative examples generated from ``positive_batch[i, :]``.
2. filter_mask: shape: (batch_size, num_negatives) An optional filter mask. True where negative samples are
valid.
"""
# create unfiltered negative batch by corruption
negative_batch = self.corrupt_batch(positive_batch=positive_batch)
if self.filterer is None:
return negative_batch, None
# If filtering is activated, all negative triples that are positive in the training dataset will be removed
return negative_batch, self.filterer(negative_batch=negative_batch)
[docs]
@abstractmethod
def corrupt_batch(self, positive_batch: LongTensor) -> LongTensor:
"""Generate negative samples from the positive batch without application of any filter.
:param positive_batch: shape: `(*batch_dims, 3)` The positive triples.
:returns: shape: `(*batch_dims, num_negs_per_pos, 3)` The negative triples. ``result[*bi, :, :]`` contains the
negative examples generated from ``positive_batch[*bi, :]``.
"""
raise NotImplementedError