NTNInteraction
- class NTNInteraction(activation: str | Module | type[Module] | None = None, activation_kwargs: Mapping[str, Any] | None = None)[source]
Bases:
Interaction
[Tensor
,tuple
[Tensor
,Tensor
,Tensor
,Tensor
,Tensor
],Tensor
]The state-less Neural Tensor Network (NTN) interaction function.
It is given by
\[\mathbf{r}_{u}^{T} \cdot \sigma( \mathbf{h} \mathbf{R}_{3} \mathbf{t} + \mathbf{R}_{2} [\mathbf{h};\mathbf{t}] + \mathbf{r}_1 )\]with \(\mathbf{W}_3 \in \mathbb{R}^{d \times d \times k}\), \(\textbf{R}_2 \in \mathbb{R}^{k \times 2d}\), the bias vector \(\textbf{r}_1\), the final projection \(\textbf{r}_u \in \mathbb{R}^k\), and a non-linear activation function \(\sigma\) (which defaults to
Tanh
).It can be seen as an extension of a two-layer MLP with relation-specific weights and an additional bi-linear tensor in the input layer. A separately parameterized neural network for each relationship makes the model very expressive, but also computationally expensive (\(\mathcal{O}(kd^2)\)).
Note
We split the original \(k \times 2d\)-dimensional \(\mathbf{R}_2\) matrix into two parts of shape \(k \times d\) to support more efficient 1:n scoring, e.g., in the
score_h()
orscore_t()
setting.Initialize NTN with the given non-linear activation function.
- Parameters:
activation (HintOrType[nn.Module]) – A non-linear activation function. Defaults to the hyperbolic tangent
torch.nn.Tanh
ifNone
.activation_kwargs (Mapping[str, Any] | None) – If the
activation
is passed as a class, these keyword arguments are used during its instantiation.
Note
The parameter pair
(activation, activation_kwargs)
is used forclass_resolver.contrib.torch.activation_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Attributes Summary
The symbolic shapes for relation representations
Methods Summary
forward
(h, r, t)Evaluate the interaction function.
Attributes Documentation
- relation_shape: Sequence[str] = ('kdd', 'kd', 'kd', 'k', 'k')
The symbolic shapes for relation representations
Methods Documentation
- forward(h: Tensor, r: tuple[Tensor, Tensor, Tensor, Tensor, Tensor], t: Tensor) Tensor [source]
Evaluate the interaction function.
See also
Interaction.forward
for a detailed description about the generic batched form of the interaction function.- Parameters:
h (Tensor) – shape:
(*batch_dims, d)
The head representations.r (tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) – shape:
(*batch_dims, k, d, d)
,(*batch_dims, k, d)
,(*batch_dims, k, d)
,(*batch_dims, k)
, and(*batch_dims, k)
The relation representations.t (Tensor) – shape:
(*batch_dims, d)
The tail representations.
- Returns:
shape:
batch_dims
The scores.- Return type: