Source code for pykeen.models.unimodal.complex

# -*- coding: utf-8 -*-

"""Implementation of the ComplEx model."""

from typing import Optional

import torch
import torch.nn as nn

from ..base import EntityRelationEmbeddingModel
from ...losses import Loss, SoftplusLoss
from ...regularizers import LpRegularizer, Regularizer
from ...triples import TriplesFactory
from ...utils import get_embedding_in_canonical_shape, split_complex

__all__ = [
    'ComplEx',
]


[docs]class ComplEx(EntityRelationEmbeddingModel): r"""An implementation of ComplEx [trouillon2016]_. ComplEx is an extension of :class:`pykeen.models.DistMult` that uses complex valued representations for the entities and relations. Entities and relations are represented as vectors $\textbf{e}_i, \textbf{r}_i \in \mathbb{C}^d$, and the plausibility score is computed using the Hadamard product: .. math:: f(h,r,t) = Re(\mathbf{e}_h\odot\mathbf{r}_r\odot\mathbf{e}_t) Which expands to: .. math:: f(h,r,t) = \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Re(\mathbf{e}_t)\right\rangle + \left\langle Im(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle + \left\langle Re(\mathbf{e}_h),Re(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle - \left\langle Im(\mathbf{e}_h),Im(\mathbf{r}_r),Im(\mathbf{e}_t)\right\rangle where $Re(\textbf{x})$ and $Im(\textbf{x})$ denote the real and imaginary parts of the complex valued vector $\textbf{x}$. Because the Hadamard product is not commutative in the complex space, ComplEx can model anti-symmetric relations in contrast to DistMult. .. seealso :: Official implementation: https://github.com/ttrouill/complex/ """ #: The default strategy for optimizing the model's hyper-parameters hpo_default = dict( embedding_dim=dict(type=int, low=50, high=300, q=50), ) #: The default loss function class loss_default = SoftplusLoss #: The default parameters for the default loss function class loss_default_kwargs = dict(reduction='mean') #: The regularizer used by [trouillon2016]_ for ComplEx. regularizer_default = LpRegularizer #: The LP settings used by [trouillon2016]_ for ComplEx. regularizer_default_kwargs = dict( weight=0.01, p=2.0, normalize=True, ) def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 200, automatic_memory_optimization: Optional[bool] = None, loss: Optional[Loss] = None, preferred_device: Optional[str] = None, random_seed: Optional[int] = None, regularizer: Optional[Regularizer] = None, ) -> None: """Initialize ComplEx. :param triples_factory: TriplesFactory The triple factory connected to the model. :param embedding_dim: The embedding dimensionality of the entity embeddings. :param automatic_memory_optimization: bool Whether to automatically optimize the sub-batch size during training and batch size during evaluation with regards to the hardware at hand. :param loss: OptionalLoss (optional) The loss to use. Defaults to SoftplusLoss. :param preferred_device: str (optional) The default device where to model is located. :param random_seed: int (optional) An optional random seed to set before the initialization of weights. :param regularizer: BaseRegularizer The regularizer to use. """ super().__init__( triples_factory=triples_factory, embedding_dim=2 * embedding_dim, # complex embeddings automatic_memory_optimization=automatic_memory_optimization, loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, ) # Finalize initialization self.reset_parameters_() def _reset_parameters_(self): # noqa: D102 # initialize with entity and relation embeddings with standard normal distribution, cf. # https://github.com/ttrouill/complex/blob/dc4eb93408d9a5288c986695b58488ac80b1cc17/efe/models.py#L481-L487 nn.init.normal_(tensor=self.entity_embeddings.weight, mean=0., std=1.) nn.init.normal_(tensor=self.relation_embeddings.weight, mean=0., std=1.)
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. :param h: Head embeddings. :param r: Relation embeddings. :param t: Tail embeddings. :return: shape: (...) The scores. """ # split into real and imaginary part (h_re, h_im), (r_re, r_im), (t_re, t_im) = [split_complex(x=x) for x in (h, r, t)] # ComplEx space bilinear product # *: Elementwise multiplication return sum( (hh * rr * tt).sum(dim=-1) for hh, rr, tt in [ (h_re, r_re, t_re), (h_re, r_im, t_im), (h_im, r_re, t_im), (h_im, r_im, t_re), ] )
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # get embeddings h, r, t = [ get_embedding_in_canonical_shape(embedding=e, ind=ind) for e, ind in [ (self.entity_embeddings, hrt_batch[:, 0]), (self.relation_embeddings, hrt_batch[:, 1]), (self.entity_embeddings, hrt_batch[:, 2]), ] ] # Compute scores scores = self.interaction_function(h=h, r=r, t=t) # Regularization self.regularize_if_necessary(h, r, t) return scores