Source code for pykeen.sampling.bernoulli_negative_sampler

# -*- coding: utf-8 -*-

"""Negative sampling algorithm based on the work of [wang2014]_."""

from typing import Optional

import torch

from .negative_sampler import NegativeSampler
from ..triples import TriplesFactory

__all__ = [
    'BernoulliNegativeSampler',
]


[docs]class BernoulliNegativeSampler(NegativeSampler): """An implementation of the bernoulli negative sampling approach proposed by [wang2014]_.""" #: The default strategy for optimizing the negative sampler's hyper-parameters hpo_default = dict( num_negs_per_pos=dict(type=int, low=1, high=100, q=10), ) def __init__( self, triples_factory: TriplesFactory, num_negs_per_pos: Optional[int] = None, ) -> None: super().__init__( triples_factory=triples_factory, num_negs_per_pos=num_negs_per_pos, ) # Preprocessing: Compute corruption probabilities triples = self.triples_factory.mapped_triples head_rel_uniq, tail_count = torch.unique(triples[:, :2], return_counts=True, dim=0) rel_tail_uniq, head_count = torch.unique(triples[:, 1:], return_counts=True, dim=0) self.corrupt_head_probability = torch.empty( self.triples_factory.num_relations, device=triples_factory.mapped_triples.device, ) for r in range(self.triples_factory.num_relations): # compute tph, i.e. the average number of tail entities per head mask = (head_rel_uniq[:, 1] == r) tph = tail_count[mask].float().mean() # compute hpt, i.e. the average number of head entities per tail mask = (rel_tail_uniq[:, 0] == r) hpt = head_count[mask].float().mean() # Set parameter for Bernoulli distribution self.corrupt_head_probability[r] = tph / (tph + hpt)
[docs] def sample(self, positive_batch: torch.LongTensor) -> torch.LongTensor: """Sample a negative batched based on the bern approach.""" if self.num_negs_per_pos > 1: positive_batch = positive_batch.repeat(self.num_negs_per_pos, 1) # Bind number of negatives to sample num_negs = positive_batch.shape[0] # Copy positive batch for corruption. # Do not detach, as no gradients should flow into the indices. negative_batch = positive_batch.clone() device = positive_batch.device # Decide whether to corrupt head or tail head_corruption_probability = self.corrupt_head_probability[positive_batch[:, 1]] head_mask = torch.rand(num_negs, device=device) < head_corruption_probability.to(device=device) # Tails are corrupted if heads are not corrupted tail_mask = ~head_mask # Randomly sample corruption negative_entities = torch.randint( self.triples_factory.num_entities - 1, size=(num_negs,), device=positive_batch.device, ) # Replace heads – To make sure we don't replace the head by the original value # we shift all values greater or equal than the original value by one up # for that reason we choose the random value from [0, num_entities -1] filter_same_head = (negative_entities[head_mask] >= positive_batch[:, 0][head_mask]) negative_batch[:, 0][head_mask] = negative_entities[head_mask] + filter_same_head.long() # Replace tails filter_same_tail = (negative_entities[tail_mask] >= positive_batch[:, 2][tail_mask]) negative_batch[:, 2][tail_mask] = negative_entities[tail_mask] + filter_same_tail.long() return negative_batch