TransFInteraction

class TransFInteraction(*args, **kwargs)[source]

Bases: Interaction[Tensor, Tensor, Tensor]

The state-less norm-based TransF interaction function.

It is given by

\[f(\mathbf{h}, \mathbf{r}, \mathbf{t}) = (\mathbf{h} + \mathbf{r})^T \mathbf{t} + \mathbf{h}^T (\mathbf{r} - \mathbf{t})\]

for head entity, relation, and tail entity representations \(\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{R}^d\). The interaction function can be simplified as

\[\begin{split}f(\mathbf{h}, \mathbf{r}, \mathbf{t}) &=& (\mathbf{h} + \mathbf{r})^T \mathbf{t} + \mathbf{h}^T (\mathbf{t} - \mathbf{r}) \\ &=& \langle \mathbf{h}, \mathbf{t}\rangle + \langle \mathbf{r}, \mathbf{t}\rangle + \langle \mathbf{h}, \mathbf{t}\rangle - \langle \mathbf{h}, \mathbf{r}\rangle \\ &=& 2 \cdot \langle \mathbf{h}, \mathbf{t}\rangle + \langle \mathbf{r}, \mathbf{t}\rangle - \langle \mathbf{h}, \mathbf{r}\rangle\end{split}\]

Note

This is the balanced variant from the paper.

Todo

Implement the unbalanced version, too: \(f(\mathbf{h}, \mathbf{r}, \mathbf{t}) = (\mathbf{h} + \mathbf{r})^T \mathbf{t}\)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods Summary

forward(h, r, t)

Evaluate the interaction function.

Methods Documentation

forward(h: Tensor, r: Tensor, t: Tensor) Tensor[source]

Evaluate the interaction function.

See also

Interaction.forward for a detailed description about the generic batched form of the interaction function.

Parameters:
  • h (Tensor) – shape: (*batch_dims, d) The head representations.

  • r (Tensor) – shape: (*batch_dims, d) The relation representations.

  • t (Tensor) – shape: (*batch_dims, d) The tail representations.

Returns:

shape: batch_dims The scores.

Return type:

Tensor