# Source code for pykeen.sampling.filtering

# -*- coding: utf-8 -*-

r"""Consider the following properties of relation $r$. Because the corruption operations (see Corruption_)
are applied independently of triples, the resulting candidate corrupt triples could overlap with known positive
triples in $\mathcal{K}$.

=====================  ============================================  ==============================================================
Property of :math:r  Example pair of triples                       Implications
=====================  ============================================  ==============================================================
one-to-many            :math:(h,r,t_1), (h,r,t_2) \in \mathcal{K}  :math:(h,r,t_2) \in T(h,r,t_1) \cup (h,r,t_1) \in T(h,r,t_2)
multiple               :math:(h,r_1,t), (h,r_2,t) \in \mathcal{K}  :math:(h,r_2,t) \in R(h,r_1,t) \cup (h,r_1,t) \in R(h,r_2,t)
many-to-one            :math:(h_1,r,t), (h_2,r,t) \in \mathcal{K}  :math:(h_2,r,t) \in H(h_1,r,t) \cup (h_1,r,t) \in H(h_2,r,t)
=====================  ============================================  ==============================================================

If no relations in $\mathcal{K}$ satisfy any of the relevant properties for the corruption schema chosen in negative
sampling, then there is guaranteed to be no overlap between $\mathcal{N}$ and $\mathcal{K}$ such that
$\mathcal{N} \cap \mathcal{K} \neq \emptyset$. However, this scenario is very unlikely for real-world knowledge graphs.

The known positive triples that appear in $\mathcal{N}$ are known false negatives. Hence, we know that these are
incorrect (negative) training examples, and might want to exclude them to reduce the training noise.

.. warning::

It should be taken into account that also a corrupted triple that is *not part*
of the knowledge graph can represent a true fact. These "unknown" false negatives can
not be removed *a priori* in the filtered setting. The philosophy of the methodology again relies
on the low number of unknown false negatives such that learning can take place.

However, in practice, $|\mathcal{N}| \gg |\mathcal{K}|$, so the likelihood of generating a false negative is rather low.
Therefore, the additional filter step is often omitted to lower computational cost. This general observation might not
hold for all entities; e.g., for a hub entity which is connected to many other entities, there may be a considerable
number of false negatives without filtering.

Identifying False Negatives During Training
-------------------------------------------
By default, PyKEEN does *not* filter false negatives from $\mathcal{N}$ during training. To enable filtering of
negative examples during training, the filtered keyword can be given to negative_sampler_kwargs like in:

.. code-block:: python

results = pipeline(
dataset='YAGO3-10',
model='PairRE',
training_loop='sLCWA',
negative_sampler='basic',
negative_sampler_kwargs=dict(
filtered=True,
),
)

PyKEEN implements several algorithms for filtering with different properties that can be chosen using the
filterer keyword argument in negative_sampler_kwargs. By default, an fast and approximate algorithm is used in
:class:pykeen.sampling.filtering.BloomFilterer, which is based on
bloom filters <https://en.wikipedia.org/wiki/Bloom_filter>_. The bloom filterer also has a configurable desired error
rate, which can be further lowered at the cost of increase in memory and computation costs.

.. code-block:: python

from pykeen.pipeline import pipeline

results = pipeline(
dataset='YAGO3-10',
model='PairRE',
training_loop='sLCWA',
negative_sampler='basic',
negative_sampler_kwargs=dict(
filtered=True,
filterer='bloom',
filterer_kwargs=dict(
error_rate=0.0001,
),
),
)

If you want to have a guarantee that all known false negatives are filtered, you can use a slower implementation based
on Python's built-in sets, the :class:pykeen.sampling.filtering.PythonSetFilterer. It can be activated with:

.. code-block:: python

from pykeen.pipeline import pipeline

results = pipeline(
dataset='YAGO3-10',
model='PairRE',
training_loop='sLCWA',
negative_sampler='basic',
negative_sampler_kwargs=dict(
filtered=True,
filterer='python-set',
),
)

Identifying False Negatives During Evaluation
---------------------------------------------
In contrast to training, PyKEEN **does** filter false negatives from $\mathcal{N}$ during evaluation by default.
To disable the "filtered setting" during evaluation, the filtered keyword can be given to evaluator_kwargs
like in:

.. code-block:: python

from pykeen.pipeline import pipeline

results = pipeline(
dataset='YAGO3-10',
model='PairRE',
evaluator_kwargs=dict(
filtered=False,
),
)

Filtering during evaluation is implemented differently than in negative sampling:

First, there are no choices between an exact or approximate algorithm via a
:class:pykeen.sampling.filtering.Filterer. Instead, the evaluation filtering can modify the
scores in-place and does so instead of selecting only the non-filtered entries. The reason is
mainly that evaluation always is done in 1:n scoring, and thus, we gain some efficiently here
by keeping the tensor in "dense" shape (batch_size, num_entities).

Second, filtering during evaluation has to be correct, and is crucial for reproducing results
from the filtered setting. For evaluation it makes sense to use all information we have to get
as solid evaluation results as possible.
"""  # noqa

import math
from abc import abstractmethod
from typing import Iterable

import torch
from class_resolver import Resolver
from torch import nn

from ..typing import MappedTriples

__all__ = [
"filterer_resolver",
"Filterer",
"BloomFilterer",
"PythonSetFilterer",
]

[docs]class Filterer(nn.Module):
"""An interface for filtering methods for negative triples."""

[docs]    def forward(
self,
negative_batch: MappedTriples,
) -> torch.BoolTensor:
"""Filter all proposed negative samples that are positive in the training dataset.

Normally there is a low probability that proposed negative samples are positive in the training datasets and
thus act as false negatives. This is expected to act as a kind of regularization, since it adds noise signal to
the training data. However, the degree of regularization is hard to control since the added noise signal depends
on the ratio of true triples for a given entity relation or entity entity pair. Therefore, the effects are hard
to control and a researcher might want to exclude the possibility of having false negatives in the proposed
negative triples.

.. note ::
Filtering is a very expensive task, since every proposed negative sample has to be checked against the
entire training dataset.

:param negative_batch: shape: (batch_size, num_negatives, 3)
The batch of negative triples.

:return: shape: (batch_size, num_negatives)
A mask, where True indicates that the negative sample is valid.
"""
return ~self.contains(batch=negative_batch)

[docs]    @abstractmethod
def contains(self, batch: MappedTriples) -> torch.BoolTensor:
"""
Check whether a triple is contained.

Supports batching.

:param batch: shape (batch_size, 3)
The batch of triples.

:return: shape: (batch_size,)
Whether the triples are contained in the training triples.
"""
raise NotImplementedError

[docs]class PythonSetFilterer(Filterer):
"""A filterer using Python sets for filtering.

This filterer is expected to be rather slow due to the conversion from torch long tensors to Python tuples. It can
still serve as a baseline for performance comparison.
"""

def __init__(self, mapped_triples: MappedTriples):
"""Initialize the filterer.

:param mapped_triples:
The ID-based triples.
"""
super().__init__()
# store set of triples
self.triples = set(map(tuple, mapped_triples.tolist()))

[docs]    def contains(self, batch: MappedTriples) -> torch.BoolTensor:  # noqa: D102
data=[
tuple(triple) in self.triples
for triple in batch.view(-1, 3).tolist()
],
dtype=torch.bool,
device=batch.device,
).view(*batch.shape[:-1])

[docs]class BloomFilterer(Filterer):
"""
A filterer for negative triples based on the Bloom filter.

Pure PyTorch, a proper module which can be moved to GPU, and support batch-wise computation.

.. seealso ::
* https://github.com/hiway/python-bloom-filter/ - for calculation of sizes, and rough structure of code
* https://github.com/skeeto/hash-prospector#two-round-functions - for parts of the hash function
"""

#: some prime numbers for tuple hashing
mersenne: torch.LongTensor

#: The bit-array for the Bloom filter data structure
bit_array: torch.BoolTensor

def __init__(self, mapped_triples: MappedTriples, error_rate: float = 0.001):
"""
Initialize the Bloom filter based filterer.

:param mapped_triples:
The ID-based triples.
:param error_rate:
The desired error rate.
"""
super().__init__()

# Allocate bit array
self.ideal_num_elements = mapped_triples.shape
size = self.num_bits(num=self.ideal_num_elements, error_rate=error_rate)
self.register_buffer(name="bit_array", tensor=torch.zeros(size, dtype=torch.bool))
self.register_buffer(
name="mersenne",
tensor=torch.as_tensor(
data=[2 ** x - 1 for x in [17, 19, 31]],
dtype=torch.long,
).unsqueeze(dim=0),
)

# calculate number of hashing rounds
self.rounds = self.num_probes(num_elements=self.ideal_num_elements, num_bits=size)

# index triples

# Store some meta-data
self.error_rate = error_rate

def __repr__(self):  # noqa:D105
return (
f"{self.__class__.__name__}("
f"error_rate={self.error_rate}, "
f"size={self.bit_array.shape}, "
f"rounds={self.rounds}, "
f"ideal_num_elements={self.ideal_num_elements}, "
f")"
)

[docs]    @staticmethod
def num_bits(num: int, error_rate: float = 0.01) -> int:
"""
Determine the required number of bits.

:param num:
The number of elements the Bloom filter shall store.
:param error_rate:
The desired error rate.

:return:
The required number of bits.
"""
numerator = -1 * num * math.log(error_rate)
denominator = math.log(2) ** 2
real_num_bits_m = numerator / denominator
return int(math.ceil(real_num_bits_m))

[docs]    @staticmethod
def num_probes(num_elements: int, num_bits: int):
"""
Determine the number of probes / hashing rounds.

:param num_elements:
The number of elements.
:param num_bits:
The number of bits, i.e., the size of the Bloom filter.

:return:
The number of hashing rounds.
"""
num_bits = num_bits
real_num_probes_k = (num_bits / num_elements) * math.log(2)
return int(math.ceil(real_num_probes_k))

[docs]    def probe(
self,
batch: MappedTriples,
) -> Iterable[torch.LongTensor]:
"""
Iterate over indices from the probes.

:param batch: shape: (batch_size, 3)
A batch of elements.

:yields: Indices of the k-th round, shape: (batch_size,).
"""
# pre-hash
x = (self.mersenne * batch).sum(dim=-1)
for _ in range(self.rounds):
# cf. https://github.com/skeeto/hash-prospector#two-round-functions
x = x ^ (x >> 16)
x = x * 0x7feb352d
x = x ^ (x >> 15)
x = x * 0x846ca68b
x = x ^ (x >> 16)
yield x % self.bit_array.shape

[docs]    def add(self, triples: MappedTriples) -> None:
"""Add triples to the Bloom filter."""
for i in self.probe(batch=triples):
self.bit_array[i] = True

[docs]    def contains(self, batch: MappedTriples) -> torch.BoolTensor:
"""
Check whether a triple is contained.

:param batch: shape (batch_size, 3)
The batch of triples.

:return: shape: (batch_size,)
The result. False guarantees that the element was not contained in the indexed triples. True can be
erroneous.
"""
result = batch.new_ones(batch.shape[:-1], dtype=torch.bool)
for i in self.probe(batch):
result &= self.bit_array[i]
return result

filterer_resolver = Resolver.from_subclasses(
base=Filterer,
default=BloomFilterer,
)