Source code for pykeen.sampling.negative_sampler

# -*- coding: utf-8 -*-

"""Basic structure for a negative sampler."""

from abc import ABC, abstractmethod
from typing import Any, ClassVar, Mapping, Optional, Tuple

import torch
from class_resolver import HintOrType, normalize_string

from .filtering import Filterer, filterer_resolver
from ..triples import CoreTriplesFactory

__all__ = [
    "NegativeSampler",
]


[docs]class NegativeSampler(ABC): """A negative sampler.""" #: The default strategy for optimizing the negative sampler's hyper-parameters hpo_default: ClassVar[Mapping[str, Mapping[str, Any]]] = dict( num_negs_per_pos=dict(type=int, low=1, high=100, log=True), ) #: A filterer for negative batches filterer: Optional[Filterer] num_entities: int num_relations: int num_negs_per_pos: int def __init__( self, triples_factory: CoreTriplesFactory, num_negs_per_pos: Optional[int] = None, filtered: bool = False, filterer: HintOrType[Filterer] = None, filterer_kwargs: Optional[Mapping[str, Any]] = None, ) -> None: """Initialize the negative sampler with the given entities. :param triples_factory: The factory holding the positive training triples :param num_negs_per_pos: Number of negative samples to make per positive triple. Defaults to 1. :param filtered: Whether proposed corrupted triples that are in the training data should be filtered. Defaults to False. See explanation in :func:`filter_negative_triples` for why this is a reasonable default. :param filterer: If filtered is set to True, this can be used to choose which filter module from :mod:`pykeen.sampling.filtering` is used. :param filterer_kwargs: Additional keyword-based arguments passed to the filterer upon construction. """ self.num_entities = triples_factory.num_entities self.num_relations = triples_factory.num_relations self.num_negs_per_pos = num_negs_per_pos if num_negs_per_pos is not None else 1 self.filterer = ( filterer_resolver.make( filterer, pos_kwargs=filterer_kwargs, mapped_triples=triples_factory.mapped_triples, ) if filterer is not None or filtered else None )
[docs] @classmethod def get_normalized_name(cls) -> str: """Get the normalized name of the negative sampler.""" return normalize_string(cls.__name__, suffix=NegativeSampler.__name__)
[docs] def sample(self, positive_batch: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.BoolTensor]]: """ Generate negative samples from the positive batch. :param positive_batch: shape: (batch_size, 3) The positive triples. :return: A pair (negative_batch, filter_mask) where 1. negative_batch: shape: (batch_size, num_negatives, 3) The negative batch. ``negative_batch[i, :, :]`` contains the negative examples generated from ``positive_batch[i, :]``. 2. filter_mask: shape: (batch_size, num_negatives) An optional filter mask. True where negative samples are valid. """ # create unfiltered negative batch by corruption negative_batch = self.corrupt_batch(positive_batch=positive_batch) if self.filterer is None: return negative_batch, None # If filtering is activated, all negative triples that are positive in the training dataset will be removed return negative_batch, self.filterer(negative_batch=negative_batch)
[docs] @abstractmethod def corrupt_batch(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """ Generate negative samples from the positive batch without application of any filter. :param positive_batch: shape: (batch_size, 3) The positive triples. :return: shape: (batch_size, num_negs_per_pos, 3) The negative triples. ``result[i, :, :]`` contains the negative examples generated from ``positive_batch[i, :]``. """ raise NotImplementedError