LossWeighter

class LossWeighter[source]

Bases: ABC

Determine loss weights for triples.

Methods Summary

__call__(h, r, t)

Calculate the sample weights for the given triples.

weight_triples(mapped_triples)

Calculate the sample weights for the given batch of triples.

Methods Documentation

abstractmethod __call__(h: Tensor | None, r: Tensor | None, t: Tensor | None) Tensor[source]

Calculate the sample weights for the given triples.

Does support broadcasting semantics.

Parameters:
  • h (Tensor | None) – The head indices, or None to denote all of them.

  • r (Tensor | None) – The relation indices, or None to denote all of them.

  • t (Tensor | None) – The tail indices, or None to denote all of them.

Returns:

The sample weights.

Return type:

Tensor

weight_triples(mapped_triples: Tensor) Tensor[source]

Calculate the sample weights for the given batch of triples.

Parameters:

mapped_triples (Tensor) – shape: (…, 3) The ID-based triples.

Returns:

The sample weights.

Return type:

Tensor