DirectionAverageInteraction

class DirectionAverageInteraction(base: str | Interaction[Tensor, Tensor, Tensor] | type[Interaction[Tensor, Tensor, Tensor]], base_kwargs: Mapping[str, Any] | None = None)[source]

Bases: Interaction[tuple[Tensor, Tensor], tuple[Tensor, Tensor], tuple[Tensor, Tensor]]

The directional average interaction module.

This can be considered as a generalization of the SimplE interaction module that can be parametrized with any other interaction module, rather than just pykeen.nn.modules.DistMultInteraction.

A separate representation is learned for each entity \(e \in \mathcal{E}\) for when it appears as the subject of a triple \(\mathbf{e}_h \in \mathbb{R}^d\) and as the object of a triple \(\mathbf{e}_t \in \mathbb{R}^d\). Similarly, two representations are learned for each relationship for a forward \(\textbf{r}_{\rightarrow}\) and backward triple \(\textbf{r}_{\leftarrow}\).

The score is then obtained by averaging the forward and the backward interaction function value:

\[\frac{ f(\textbf{h}_{h}, \textbf{r}_{\rightarrow}, \textbf{t}_{t}) + f(\textbf{t}_{h}, \textbf{r}_{\leftarrow}, \textbf{h}_{t}) }{2}\]

Where f is the interaction model used. If pykeen.nn.modules.DistMultInteraction is used, then this becomes pykeen.nn.modules.SimplEInteraction.

Todo

can we generalize the type annotations for this from FloatTensor to HeadRepresentation, etc.?

Initialize the interaction module.

Parameters:
  • base (LookupOrType[Interaction[FloatTensor, FloatTensor, FloatTensor]]) – the base interaction.

  • base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base interaction

Note

The parameter pair (base, base_kwargs) is used for interaction_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

forward(h, r, t)

Evaluate the interaction function.

Methods Documentation

forward(h: tuple[Tensor, Tensor], r: tuple[Tensor, Tensor], t: tuple[Tensor, Tensor]) Tensor[source]

Evaluate the interaction function.

See also

Interaction.forward for a detailed description about the generic batched form of the interaction function.

Parameters:
  • h (tuple[Tensor, Tensor]) – shape: (*batch_dims, d) and (*batch_dims, d) The head representations.

  • r (tuple[Tensor, Tensor]) – shape: (*batch_dims, d) and (*batch_dims, d) The relation representations.

  • t (tuple[Tensor, Tensor]) – shape: (*batch_dims, d) and (*batch_dims, d) The tail representations.

Returns:

shape: batch_dims The scores.

Return type:

Tensor