Source code for pykeen.models.unimodal.rotate

# -*- coding: utf-8 -*-

"""Implementation of the RotatE model."""

from typing import Any, ClassVar, Mapping

import torch
import torch.autograd

from ..base import EntityRelationEmbeddingModel
from ...nn.emb import EmbeddingSpecification
from ...nn.init import init_phases, xavier_uniform_
from ...typing import Constrainer, Hint, Initializer
from ...utils import complex_normalize

__all__ = [
    'RotatE',
]


[docs]class RotatE(EntityRelationEmbeddingModel): r"""An implementation of RotatE from [sun2019]_. RotatE models relations as rotations from head to tail entities in complex space: .. math:: \textbf{e}_t= \textbf{e}_h \odot \textbf{r}_r where $\textbf{e}, \textbf{r} \in \mathbb{C}^{d}$ and the complex elements of $\textbf{r}_r$ are restricted to have a modulus of one ($\|\textbf{r}_r\| = 1$). The interaction model is then defined as: .. math:: f(h,r,t) = -\|\textbf{e}_h \odot \textbf{r}_r - \textbf{e}_t\| which allows to model symmetry, antisymmetry, inversion, and composition. .. seealso:: - Authors' `implementation of RotatE <https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/blob/master/codes/model.py#L200-L228>`_ --- citation: author: Sun year: 2019 link: https://arxiv.org/abs/1902.10197v1 github: DeepGraphLearning/KnowledgeGraphEmbedding """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=dict(type=int, low=32, high=1024, q=16), ) def __init__( self, *, embedding_dim: int = 200, entity_initializer: Hint[Initializer] = xavier_uniform_, relation_initializer: Hint[Initializer] = init_phases, relation_constrainer: Hint[Constrainer] = complex_normalize, **kwargs, ) -> None: super().__init__( entity_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=entity_initializer, dtype=torch.cfloat, ), relation_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=relation_initializer, constrainer=relation_constrainer, dtype=torch.cfloat, ), **kwargs, ) self.real_embedding_dim = embedding_dim
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. WARNING: No forward constraints are applied. :param h: shape: (..., e, 2) Head embeddings. Last dimension corresponds to (real, imag). :param r: shape: (..., e, 2) Relation embeddings. Last dimension corresponds to (real, imag). :param t: shape: (..., e, 2) Tail embeddings. Last dimension corresponds to (real, imag). :return: shape: (...) The scores. """ # Decompose into real and imaginary part h_re = h[..., 0] h_im = h[..., 1] r_re = r[..., 0] r_im = r[..., 1] # Rotate (=Hadamard product in complex space). rot_h = torch.stack( [ h_re * r_re - h_im * r_im, h_re * r_im + h_im * r_re, ], dim=-1, ) # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed diff = rot_h - t scores = -torch.norm(diff.view(diff.shape[:-2] + (-1,)), dim=-1) return scores
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # Rank against all entities t = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # r expresses a rotation in complex plane. # The inverse rotation is expressed by the complex conjugate of r. # The score is computed as the distance of the relation-rotated head to the tail. # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. # |h * r - t| = |h - conj(r) * t| r_inv = torch.stack([r[:, :, :, 0], -r[:, :, :, 1]], dim=-1) # Rank against all entities h = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=t, r=r_inv, t=h) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores