Source code for pykeen.models.unimodal.complex

# -*- coding: utf-8 -*-

"""Implementation of the ComplEx model."""

from typing import Any, ClassVar, Mapping, Optional, Type

import torch
import torch.nn as nn

from ..base import EntityRelationEmbeddingModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import Loss, SoftplusLoss
from ...regularizers import LpRegularizer, Regularizer
from ...triples import TriplesFactory
from ...typing import DeviceHint
from ...utils import split_complex

__all__ = [
    'ComplEx',
]


[docs]class ComplEx(EntityRelationEmbeddingModel): r"""An implementation of ComplEx [trouillon2016]_. ComplEx is an extension of :class:`pykeen.models.DistMult` that uses complex valued representations for the entities and relations. Entities and relations are represented as vectors $\textbf{e}_i, \textbf{r}_i \in \mathbb{C}^d$, and the plausibility score is computed using the Hadamard product: .. math:: f(h,r,t) = Re(\mathbf{e}_h\odot\mathbf{r}_r\odot\mathbf{e}_t) Which expands to: .. math:: f(h,r,t) = \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle + \left\langle Im(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle + \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle - \left\langle Im(\mathbf{e}_h),Im(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle where $Re(\textbf{x})$ and $Im(\textbf{x})$ denote the real and imaginary parts of the complex valued vector $\textbf{x}$. Because the Hadamard product is not commutative in the complex space, ComplEx can model anti-symmetric relations in contrast to DistMult. .. seealso :: Official implementation: https://github.com/ttrouill/complex/ """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) #: The default loss function class loss_default: ClassVar[Type[Loss]] = SoftplusLoss #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = dict(reduction='mean') #: The regularizer used by [trouillon2016]_ for ComplEx. regularizer_default: ClassVar[Type[Regularizer]] = LpRegularizer #: The LP settings used by [trouillon2016]_ for ComplEx. regularizer_default_kwargs: ClassVar[Mapping[str, Any]] = dict( weight=0.01, p=2.0, normalize=True, ) def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 200, loss: Optional[Loss] = None, regularizer: Optional[Regularizer] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, entity_initializer=nn.init.normal_, relation_initializer=nn.init.normal_, ) -> None: """Initialize ComplEx. :param triples_factory: The triple factory connected to the model. :param embedding_dim: The embedding dimensionality of the entity embeddings. :param loss: The loss to use. Defaults to SoftplusLoss. :param regularizer: The regularizer to use. :param preferred_device: The default device where to model is located. :param random_seed: An optional random seed to set before the initialization of weights. """ super().__init__( triples_factory=triples_factory, embedding_dim=2 * embedding_dim, # complex embeddings loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, # initialize with entity and relation embeddings with standard normal distribution, cf. # https://github.com/ttrouill/complex/blob/dc4eb93408d9a5288c986695b58488ac80b1cc17/efe/models.py#L481-L487 entity_initializer=entity_initializer, relation_initializer=relation_initializer, )
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. :param h: Head embeddings. :param r: Relation embeddings. :param t: Tail embeddings. :return: shape: (...) The scores. """ # split into real and imaginary part (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] # ComplEx space bilinear product # *: Elementwise multiplication return sum( (hh * rr * tt).sum(dim=-1) for hh, rr, tt in [ (h_re, r_re, t_re), (h_re, r_im, t_im), (h_im, r_re, t_im), (h_im, r_im, t_re), ] )
[docs] def forward( self, h_indices: Optional[torch.LongTensor], r_indices: Optional[torch.LongTensor], t_indices: Optional[torch.LongTensor], ) -> torch.FloatTensor: """Unified score function.""" # get embeddings h = self.entity_embeddings.get_in_canonical_shape(indices=h_indices) r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) t = self.entity_embeddings.get_in_canonical_shape(indices=t_indices) # Regularization self.regularize_if_necessary(h, r, t) # Compute scores return self.interaction_function(h=h, r=r, t=t)
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self(h_indices=hrt_batch[:, 0], r_indices=hrt_batch[:, 1], t_indices=hrt_batch[:, 2]).view(-1, 1)
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self(h_indices=hr_batch[:, 0], r_indices=hr_batch[:, 1], t_indices=None)
[docs] def score_r(self, ht_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self(h_indices=ht_batch[:, 0], r_indices=None, t_indices=ht_batch[:, 1])
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self(h_indices=None, r_indices=rt_batch[:, 0], t_indices=rt_batch[:, 1])