NTNInteraction

class NTNInteraction(activation: str | Module | type[Module] | None = None, activation_kwargs: Mapping[str, Any] | None = None)[source]

Bases: Interaction[Tensor, tuple[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor]

The state-less Neural Tensor Network (NTN) interaction function.

It is given by

\[\mathbf{r}_{u}^{T} \cdot \sigma( \mathbf{h} \mathbf{R}_{3} \mathbf{t} + \mathbf{R}_{2} [\mathbf{h};\mathbf{t}] + \mathbf{r}_1 )\]

with \(\mathbf{W}_3 \in \mathbb{R}^{d \times d \times k}\), \(\textbf{R}_2 \in \mathbb{R}^{k \times 2d}\), the bias vector \(\textbf{r}_1\), the final projection \(\textbf{r}_u \in \mathbb{R}^k\), and a non-linear activation function \(\sigma\) (which defaults to Tanh).

It can be seen as an extension of a two-layer MLP with relation-specific weights and an additional bi-linear tensor in the input layer. A separately parameterized neural network for each relationship makes the model very expressive, but also computationally expensive (\(\mathcal{O}(kd^2)\)).

Note

We split the original \(k \times 2d\)-dimensional \(\mathbf{R}_2\) matrix into two parts of shape \(k \times d\) to support more efficient 1:n scoring, e.g., in the score_h() or score_t() setting.

Initialize NTN with the given non-linear activation function.

Parameters:
  • activation (HintOrType[nn.Module]) – A non-linear activation function. Defaults to the hyperbolic tangent torch.nn.Tanh if None.

  • activation_kwargs (Mapping[str, Any] | None) – If the activation is passed as a class, these keyword arguments are used during its instantiation.

Note

The parameter pair (activation, activation_kwargs) is used for class_resolver.contrib.torch.activation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Attributes Summary

relation_shape

The symbolic shapes for relation representations

Methods Summary

forward(h, r, t)

Evaluate the interaction function.

Attributes Documentation

relation_shape: Sequence[str] = ('kdd', 'kd', 'kd', 'k', 'k')

The symbolic shapes for relation representations

Methods Documentation

forward(h: Tensor, r: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], t: Tensor) Tensor[source]

Evaluate the interaction function.

See also

Interaction.forward for a detailed description about the generic batched form of the interaction function.

Parameters:
  • h (Tensor) – shape: (*batch_dims, d) The head representations.

  • r (tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) – shape: (*batch_dims, k, d, d), (*batch_dims, k, d), (*batch_dims, k, d), (*batch_dims, k), and (*batch_dims, k) The relation representations.

  • t (Tensor) – shape: (*batch_dims, d) The tail representations.

Returns:

shape: batch_dims The scores.

Return type:

Tensor