Source code for pykeen.models.unimodal.rotate

# -*- coding: utf-8 -*-

"""Implementation of the RotatE model."""

from typing import Any, ClassVar, Mapping, Optional

import torch
import torch.autograd

from ..base import EntityRelationEmbeddingModel
from ...losses import Loss
from ...nn import EmbeddingSpecification
from ...nn.init import init_phases, xavier_uniform_
from ...nn.norm import complex_normalize
from ...regularizers import Regularizer
from ...triples import TriplesFactory
from ...typing import Constrainer, DeviceHint, Hint, Initializer

__all__ = [

[docs]class RotatE(EntityRelationEmbeddingModel): r"""An implementation of RotatE from [sun2019]_. RotatE models relations as rotations from head to tail entities in complex space: .. math:: \textbf{e}_t= \textbf{e}_h \odot \textbf{r}_r where $\textbf{e}, \textbf{r} \in \mathbb{C}^{d}$ and the complex elements of $\textbf{r}_r$ are restricted to have a modulus of one ($\|\textbf{r}_r\| = 1$). The interaction model is then defined as: .. math:: f(h,r,t) = -\|\textbf{e}_h \odot \textbf{r}_r - \textbf{e}_t\| which allows to model symmetry, antisymmetry, inversion, and composition. .. seealso:: - Authors' `implementation of RotatE <>`_ """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=dict(type=int, low=32, high=1024, q=16), ) def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 200, loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, regularizer: Optional[Regularizer] = None, entity_initializer: Hint[Initializer] = xavier_uniform_, relation_initializer: Hint[Initializer] = init_phases, relation_constrainer: Hint[Constrainer] = complex_normalize, ) -> None: super().__init__( triples_factory=triples_factory, loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, entity_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=entity_initializer, dtype=torch.cfloat, ), relation_representations=EmbeddingSpecification( embedding_dim=embedding_dim, initializer=relation_initializer, constrainer=relation_constrainer, dtype=torch.cfloat, ), ) self.real_embedding_dim = embedding_dim
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function of ComplEx for given embeddings. The embeddings have to be in a broadcastable shape. WARNING: No forward constraints are applied. :param h: shape: (..., e, 2) Head embeddings. Last dimension corresponds to (real, imag). :param r: shape: (..., e, 2) Relation embeddings. Last dimension corresponds to (real, imag). :param t: shape: (..., e, 2) Tail embeddings. Last dimension corresponds to (real, imag). :return: shape: (...) The scores. """ # Decompose into real and imaginary part h_re = h[..., 0] h_im = h[..., 1] r_re = r[..., 0] r_im = r[..., 1] # Rotate (=Hadamard product in complex space). rot_h = torch.stack( [ h_re * r_re - h_im * r_im, h_re * r_im + h_im * r_re, ], dim=-1, ) # Workaround until is fixed diff = rot_h - t scores = -torch.norm(diff.view(diff.shape[:-2] + (-1,)), dim=-1) return scores
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]).view(-1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hrt_batch[:, 1]).view(-1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=hrt_batch[:, 2]).view(-1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) r = self.relation_embeddings(indices=hr_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # Rank against all entities t = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=h, r=r, t=t) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings r = self.relation_embeddings(indices=rt_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2) t = self.entity_embeddings(indices=rt_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2) # r expresses a rotation in complex plane. # The inverse rotation is expressed by the complex conjugate of r. # The score is computed as the distance of the relation-rotated head to the tail. # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. # |h * r - t| = |h - conj(r) * t| r_inv = torch.stack([r[:, :, :, 0], -r[:, :, :, 1]], dim=-1) # Rank against all entities h = self.entity_embeddings(indices=None).view(1, -1, self.real_embedding_dim, 2) # Compute scores scores = self.interaction_function(h=t, r=r_inv, t=h) # Embedding Regularization self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim)) return scores