ComplExInteraction

class ComplExInteraction(*args, **kwargs)[source]

Bases: FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]

The ComplEx interaction proposed by [trouillon2016].

ComplEx operates on complex-valued entity and relation representations, i.e., \(\textbf{e}_i, \textbf{r}_i \in \mathbb{C}^d\) and calculates the plausibility score via the Hadamard product:

\[f(h,r,t) = Re(\mathbf{e}_h\odot\mathbf{r}_r\odot\bar{\mathbf{e}}_t)\]

Which expands to:

\[f(h,r,t) = \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle + \left\langle Im(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle + \left\langle Re(\mathbf{e}_h),Im(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle - \left\langle Im(\mathbf{e}_h),Im(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle\]

where \(Re(\textbf{x})\) and \(Im(\textbf{x})\) denote the real and imaginary parts of the complex valued vector \(\textbf{x}\). Because the Hadamard product is not commutative in the complex space, ComplEx can model anti-symmetric relations in contrast to DistMult.

See also

Official implementation: https://github.com/ttrouill/complex/

Note

this method generally expects all tensors to be of complex datatype, i.e., torch.is_complex(x) to evaluate to True. However, for backwards compatibility and convenience in use, you can also pass real tensors whose shape is compliant with torch.view_as_complex(), cf. pykeen.utils.ensure_complex().

Initialize internal Module state, shared by both nn.Module and ScriptModule.

Attributes Summary

is_complex

whether the interaction is defined on complex input

Methods Summary

func(h, r, t)

Evaluate the interaction function.

Attributes Documentation

is_complex: ClassVar[bool] = True

whether the interaction is defined on complex input

Methods Documentation

static func(h, r, t)[source]

Evaluate the interaction function.

Parameters:
  • h (FloatTensor) – shape: (*batch_dims, dim) The complex head representations.

  • r (FloatTensor) – shape: (*batch_dims, dim) The complex relation representations.

  • t (FloatTensor) – shape: (*batch_dims, dim) The complex tail representations.

Return type:

FloatTensor

Returns:

shape: batch_dims The scores.