# Source code for pykeen.sampling.bernoulli_negative_sampler

# -*- coding: utf-8 -*-

"""Negative sampling algorithm based on the work of [wang2014]_."""

import torch

from .negative_sampler import NegativeSampler
from ..triples import CoreTriplesFactory

__all__ = [
"BernoulliNegativeSampler",
]

[docs]class BernoulliNegativeSampler(NegativeSampler):
r"""An implementation of the Bernoulli negative sampling approach proposed by [wang2014]_.

The probability of corrupting the head $h$ or tail $t$ in a relation $(h,r,t) \in \mathcal{K}$
is determined by global properties of the relation $r$:

- $r$ is *one-to-many* (e.g. *motherOf*): a higher probability is assigned to replace $h$
- $r$ is *many-to-one* (e.g. *bornIn*): a higher probability is assigned to replace $t$.

More precisely, for each relation $r \in \mathcal{R}$, the average number of tails per head
(tph) and heads per tail (hpt) are first computed.

Then, the head corruption probability $p_r$ is defined as $p_r = \frac{tph}{tph + hpt}$.
The tail corruption probability is defined as $1 - p_r = \frac{hpt}{tph + hpt}$.

For each triple $(h,r,t) \in \mathcal{K}$, the head is corrupted with probability $p_r$ and the tail is
corrupted with probability $1 - p_r$.

If filtered is set to True, all proposed corrupted triples that also exist as
actual positive triples $(h,r,t) \in \mathcal{K}$ will be removed.
"""

def __init__(
self,
*,
triples_factory: CoreTriplesFactory,
**kwargs,
) -> None:
"""Initialize the bernoulli negative sampler with the given entities.

:param triples_factory:
The factory holding the positive training triples
:param kwargs:
Additional keyword based arguments passed to :class:pykeen.sampling.NegativeSampler.
"""
super().__init__(triples_factory=triples_factory, **kwargs)
# Preprocessing: Compute corruption probabilities
triples = triples_factory.mapped_triples
head_rel_uniq, tail_count = torch.unique(triples[:, :2], return_counts=True, dim=0)
rel_tail_uniq, head_count = torch.unique(triples[:, 1:], return_counts=True, dim=0)

triples_factory.num_relations,
device=triples_factory.mapped_triples.device,
)

for r in range(triples_factory.num_relations):
# compute tph, i.e. the average number of tail entities per head

# compute hpt, i.e. the average number of head entities per tail
mask = rel_tail_uniq[:, 0] == r

# Set parameter for Bernoulli distribution
self.corrupt_head_probability[r] = tph / (tph + hpt)

[docs]    def corrupt_batch(self, positive_batch: torch.LongTensor) -> torch.LongTensor:  # noqa: D102
if self.num_negs_per_pos > 1:
positive_batch = positive_batch.repeat_interleave(repeats=self.num_negs_per_pos, dim=0)

# Bind number of negatives to sample
num_negs = positive_batch.shape

# Copy positive batch for corruption.
# Do not detach, as no gradients should flow into the indices.
negative_batch = positive_batch.clone()

device = positive_batch.device
# Decide whether to corrupt head or tail

# Tails are corrupted if heads are not corrupted

# We at least make sure to not replace the triples by the original value
# See below for explanation of why this is on a range of [0, num_entities - 1]
index_max = self.num_entities - 1

# Randomly sample corruption.
negative_entities = torch.randint(
index_max,
size=(num_negs,),
device=positive_batch.device,
)