Source code for pykeen.models.unimodal.rotate

# -*- coding: utf-8 -*-

"""Implementation of the RotatE model."""

from typing import Any, ClassVar, Mapping

import torch
import torch.autograd
from torch import linalg

from ..base import EntityRelationEmbeddingModel
from ...nn.emb import EmbeddingSpecification
from ...nn.init import init_phases, xavier_uniform_
from ...typing import Constrainer, Hint, Initializer
from ...utils import complex_normalize

__all__ = [
    "RotatE",
]


[docs]class RotatE(EntityRelationEmbeddingModel): r"""An implementation of RotatE from [sun2019]_. RotatE models relations as rotations from head to tail entities in complex space: .. math:: \textbf{e}_t= \textbf{e}_h \odot \textbf{r}_r where $\textbf{e}, \textbf{r} \in \mathbb{C}^{d}$ and the complex elements of $\textbf{r}_r$ are restricted to have a modulus of one ($\|\textbf{r}_r\| = 1$). The interaction model is then defined as: .. math:: f(h,r,t) = -\|\textbf{e}_h \odot \textbf{r}_r - \textbf{e}_t\| which allows to model symmetry, antisymmetry, inversion, and composition. .. seealso:: - Authors' `implementation of RotatE <https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/blob/master/codes/model.py#L200-L228>`_ --- citation: author: Sun year: 2019 link: https://arxiv.org/abs/1902.10197v1 github: DeepGraphLearning/KnowledgeGraphEmbedding """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=dict(type=int, low=32, high=1024, q=16), ) def __init__( self, *, embedding_dim: int = 200, entity_initializer: Hint[Initializer] = xavier_uniform_, relation_initializer: Hint[Initializer] = init_phases, relation_constrainer: Hint[Constrainer] = complex_normalize, **kwargs, ) -> None: super().__init__( entity_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=entity_initializer, dtype=torch.cfloat, ), relation_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=relation_initializer, constrainer=relation_constrainer, dtype=torch.cfloat, ), **kwargs, ) self.real_embedding_dim = embedding_dim
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. WARNING: No forward constraints are applied. :param h: shape: (..., e, 2) Head embeddings. Last dimension corresponds to (real, imag). :param r: shape: (..., e, 2) Relation embeddings. Last dimension corresponds to (real, imag). :param t: shape: (..., e, 2) Tail embeddings. Last dimension corresponds to (real, imag). :return: shape: (...) The scores. """ # Decompose into real and imaginary part h_re = h[..., 0] h_im = h[..., 1] r_re = r[..., 0] r_im = r[..., 1] # Rotate (=Hadamard product in complex space). rot_h = torch.stack( [ h_re * r_re - h_im * r_im, h_re * r_im + h_im * r_re, ], dim=-1, ) # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed diff = rot_h - t scores = -linalg.vector_norm(diff, dim=(-2, -1)) return scores
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # Rank against all entities t = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # r expresses a rotation in complex plane. # The inverse rotation is expressed by the complex conjugate of r. # The score is computed as the distance of the relation-rotated head to the tail. # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. # |h * r - t| = |h - conj(r) * t| r_inv = torch.stack([r[:, :, :, 0], -r[:, :, :, 1]], dim=-1) # Rank against all entities h = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=t, r=r_inv, t=h) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores