ComplEx

class ComplEx(*, embedding_dim=200, entity_initializer=<function normal_>, relation_initializer=<function normal_>, **kwargs)[source]

Bases: pykeen.models.base.EntityRelationEmbeddingModel

An implementation of ComplEx [trouillon2016].

ComplEx is an extension of pykeen.models.DistMult that uses complex valued representations for the entities and relations. Entities and relations are represented as vectors \(\textbf{e}_i, \textbf{r}_i \in \mathbb{C}^d\), and the plausibility score is computed using the Hadamard product:

\[f(h,r,t) = Re(\mathbf{e}_h\odot\mathbf{r}_r\odot\bar{\mathbf{e}}_t)\]

Which expands to:

\[f(h,r,t) = \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle + \left\langle Im(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle + \left\langle Re(\mathbf{e}_h),Im(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle - \left\langle Im(\mathbf{e}_h),Im(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle\]

where \(Re(\textbf{x})\) and \(Im(\textbf{x})\) denote the real and imaginary parts of the complex valued vector \(\textbf{x}\). Because the Hadamard product is not commutative in the complex space, ComplEx can model anti-symmetric relations in contrast to DistMult.

See also

Official implementation: https://github.com/ttrouill/complex/

Initialize ComplEx.

Parameters

Attributes Summary

hpo_default

The default strategy for optimizing the model’s hyper-parameters

loss_default_kwargs

The default parameters for the default loss function class

regularizer_default_kwargs

The LP settings used by [trouillon2016] for ComplEx.

Methods Summary

forward(h_indices, r_indices, t_indices)

Unified score function.

interaction_function(h, r, t)

Evaluate the interaction function of ComplEx for given embeddings.

score_h(rt_batch)

Forward pass using left side (head) prediction.

score_hrt(hrt_batch)

Forward pass.

score_r(ht_batch)

Forward pass using middle (relation) prediction.

score_t(hr_batch)

Forward pass using right side (tail) prediction.

Attributes Documentation

hpo_default: ClassVar[Mapping[str, Any]] = {'embedding_dim': {'high': 256, 'low': 16, 'q': 16, 'type': <class 'int'>}}

The default strategy for optimizing the model’s hyper-parameters

loss_default_kwargs: ClassVar[Mapping[str, Any]] = {'reduction': 'mean'}

The default parameters for the default loss function class

regularizer_default_kwargs: ClassVar[Mapping[str, Any]] = {'normalize': True, 'p': 2.0, 'weight': 0.01}

The LP settings used by [trouillon2016] for ComplEx.

Methods Documentation

forward(h_indices, r_indices, t_indices)[source]

Unified score function.

Return type

FloatTensor

static interaction_function(h, r, t)[source]

Evaluate the interaction function of ComplEx for given embeddings.

The embeddings have to be in a broadcastable shape.

Parameters
  • h (FloatTensor) – Head embeddings.

  • r (FloatTensor) – Relation embeddings.

  • t (FloatTensor) – Tail embeddings.

Return type

FloatTensor

Returns

shape: (…) The scores.

score_h(rt_batch)[source]

Forward pass using left side (head) prediction.

This method calculates the score for all possible heads for each (relation, tail) pair.

Parameters

rt_batch (LongTensor) – shape: (batch_size, 2), dtype: long The indices of (relation, tail) pairs.

Return type

FloatTensor

Returns

shape: (batch_size, num_entities), dtype: float For each r-t pair, the scores for all possible heads.

score_hrt(hrt_batch)[source]

Forward pass.

This method takes head, relation and tail of each triple and calculates the corresponding score.

Parameters

hrt_batch (LongTensor) – shape: (batch_size, 3), dtype: long The indices of (head, relation, tail) triples.

Raises

NotImplementedError – If the method was not implemented for this class.

Return type

FloatTensor

Returns

shape: (batch_size, 1), dtype: float The score for each triple.

score_r(ht_batch)[source]

Forward pass using middle (relation) prediction.

This method calculates the score for all possible relations for each (head, tail) pair.

Parameters

ht_batch (LongTensor) – shape: (batch_size, 2), dtype: long The indices of (head, tail) pairs.

Return type

FloatTensor

Returns

shape: (batch_size, num_relations), dtype: float For each h-t pair, the scores for all possible relations.

score_t(hr_batch)[source]

Forward pass using right side (tail) prediction.

This method calculates the score for all possible tails for each (head, relation) pair.

Parameters

hr_batch (LongTensor) – shape: (batch_size, 2), dtype: long The indices of (head, relation) pairs.

Return type

FloatTensor

Returns

shape: (batch_size, num_entities), dtype: float For each h-r pair, the scores for all possible tails.