Source code for pykeen.sampling.basic_negative_sampler

# -*- coding: utf-8 -*-

"""Negative sampling algorithm based on the work of of Bordes *et al.*."""

import torch

from .negative_sampler import NegativeSampler

__all__ = [
    'BasicNegativeSampler',
]


[docs]class BasicNegativeSampler(NegativeSampler): """A basic negative sampler.""" #: The default strategy for optimizing the negative sampler's hyper-parameters hpo_default = dict( num_negs_per_pos=dict(type=int, low=1, high=100, q=10), )
[docs] def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """Generate negative samples from the positive batch.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Equally corrupt head and tail split_idx = num_negs // 2 # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() # Sample random entities as replacement negative_entities = torch.randint(high=self.num_entities - 1, size=(num_negs,), device=positive_batch.device) # Replace heads – To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] filter_same_head = (negative_entities[:split_idx] >= positive_batch[:split_idx, 0]) negative_batch[:split_idx, 0] = negative_entities[:split_idx] + filter_same_head.long() # Corrupt tails filter_same_tail = (negative_entities[split_idx:] >= positive_batch[split_idx:, 2]) negative_batch[split_idx:, 2] = negative_entities[split_idx:] + filter_same_tail.long() return negative_batch