DistMultInteraction

class DistMultInteraction(*args, **kwargs)[source]

Bases: FunctionalInteraction[Tensor, Tensor, Tensor]

The stateless DistMult interaction function.

This interaction is given by

\[f(\mathbf{h}, \mathbf{r}, \mathbf{t}) = \sum \limits_{i} \mathbf{h}_i \cdot \mathbf{r}_{i} \cdot \mathbf{t}_i\]

where \(\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{R}^{d}\) are the representations for the head entity, the relation, and the tail entity.

For a single triple of \(d\)-dimensional vectors, the computational complexity is given as \(\mathcal{O}(d)\).

The interaction function is symmetric in the entities, i.e.,

\[f(h, r, t) = f(t, r, h)\]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Methods Summary

func(h, r, t)

Evaluate the interaction function.

Methods Documentation

static func(h: Tensor, r: Tensor, t: Tensor) Tensor[source]

Evaluate the interaction function.

See also

Interaction.forward for a detailed description about the generic batched form of the interaction function.

Parameters:
  • h (Tensor) – shape: (*batch_dims, d) The head representations.

  • r (Tensor) – shape: (*batch_dims, d) The relation representations.

  • t (Tensor) – shape: (*batch_dims, d) The tail representations.

Returns:

shape: batch_dims The scores.

Return type:

Tensor