RESCALInteraction

class RESCALInteraction(*args, **kwargs)[source]

Bases: Interaction[Tensor, Tensor, Tensor]

The state-less RESCAL interaction function.

For head and tail entity representations \(\mathbf{h}, \mathbf{t} \in \mathbb{R}^d\) and relation representation \(\mathbf{R} \in \mathbb{R}^{d \times d}\), the interaction function is given as

\[\mathbf{h}^T \textbf{R} \textbf{t} = \sum_{i=1}^{d} \sum_{j=1}^{d} \mathbf{h}_i \mathbf{R}_{i, j} \mathbf{t}_{i}\]

Thus, the relation matrices \(\textbf{R}\) contain weights \(\textbf{R}_{i, j}\) that capture the amount of interaction between the \(i\)-th latent factor of the head representation and the \(j\)-th latent factor.

The computational complexity is given by \(\mathcal{O}(d^2)\).

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Attributes Summary

relation_shape

The symbolic shapes for relation representations

Methods Summary

forward(h, r, t)

Evaluate the interaction function.

Attributes Documentation

relation_shape: Sequence[str] = ('dd',)

The symbolic shapes for relation representations

Methods Documentation

forward(h: Tensor, r: Tensor, t: Tensor) Tensor[source]

Evaluate the interaction function.

See also

Interaction.forward for a detailed description about the generic batched form of the interaction function.

Parameters:
  • h (Tensor) – shape: (*batch_dims, d) The head representations.

  • r (Tensor) – shape: (*batch_dims, d) The relation representations.

  • t (Tensor) – shape: (*batch_dims, d) The tail representations.

Returns:

shape: batch_dims The scores.

Return type:

Tensor