# -*- coding: utf-8 -*-
"""Implementation of the RotatE model."""
from typing import Optional
import numpy as np
import torch
import torch.autograd
from torch.nn import functional
from ..base import EntityRelationEmbeddingModel
from ..init import embedding_xavier_uniform_
from ...losses import Loss
from ...regularizers import Regularizer
from ...triples import TriplesFactory
__all__ = [
'RotatE',
]
[docs]class RotatE(EntityRelationEmbeddingModel):
r"""An implementation of RotatE from [sun2019]_.
RotatE models relations as rotations from head to tail entities in complex space:
.. math::
\textbf{e}_t= \textbf{e}_h \odot \textbf{r}_r
where $\textbf{e}, \textbf{r} \in \mathbb{C}^{d}$ and the complex elements of
$\textbf{r}_r$ are restricted to have a modulus of one ($\|\textbf{r}_r\| = 1$). The
interaction model is then defined as:
.. math::
f(h,r,t) = -\|\textbf{e}_h \odot \textbf{r}_r - \textbf{e}_t\|
which allows to model symmetry, antisymmetry, inversion, and composition.
.. seealso::
- Authors' `implementation of RotatE
<https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/blob/master/codes/model.py#L200-L228>`_
"""
#: The default strategy for optimizing the model's hyper-parameters
hpo_default = dict(
embedding_dim=dict(type=int, low=125, high=1000, q=100),
)
def __init__(
self,
triples_factory: TriplesFactory,
embedding_dim: int = 200,
automatic_memory_optimization: Optional[bool] = None,
loss: Optional[Loss] = None,
preferred_device: Optional[str] = None,
random_seed: Optional[int] = None,
regularizer: Optional[Regularizer] = None,
) -> None:
super().__init__(
triples_factory=triples_factory,
embedding_dim=2 * embedding_dim,
loss=loss,
automatic_memory_optimization=automatic_memory_optimization,
preferred_device=preferred_device,
random_seed=random_seed,
regularizer=regularizer,
)
self.real_embedding_dim = embedding_dim
# Finalize initialization
self.reset_parameters_()
def _reset_parameters_(self): # noqa: D102
embedding_xavier_uniform_(self.entity_embeddings)
# phases randomly between 0 and 2 pi
phases = 2 * np.pi * torch.rand(self.num_relations, self.real_embedding_dim, device=self.device)
relations = torch.stack([torch.cos(phases), torch.sin(phases)], dim=-1).detach()
assert torch.allclose(torch.norm(relations, p=2, dim=-1), phases.new_ones(size=(1, 1)))
self.relation_embeddings.weight.data = relations.view(self.num_relations, self.embedding_dim)
[docs] def post_parameter_update(self): # noqa: D102
r"""Normalize the length of relation vectors, if the forward constraint has not been applied yet.
The `modulus of complex number <https://en.wikipedia.org/wiki/Absolute_value#Complex_numbers>`_ is given as:
.. math::
|a + ib| = \sqrt{a^2 + b^2}
$l_2$ norm of complex vector $x \in \mathbb{C}^d$:
.. math::
\|x\|^2 = \sum_{i=1}^d |x_i|^2
= \sum_{i=1}^d \left(\operatorname{Re}(x_i)^2 + \operatorname{Im}(x_i)^2\right)
= \left(\sum_{i=1}^d \operatorname{Re}(x_i)^2) + (\sum_{i=1}^d \operatorname{Im}(x_i)^2\right)
= \|\operatorname{Re}(x)\|^2 + \|\operatorname{Im}(x)\|^2
= \| [\operatorname{Re}(x); \operatorname{Im}(x)] \|^2
"""
# Make sure to call super first
super().post_parameter_update()
# Normalize relation embeddings
rel = self.relation_embeddings.weight.data.view(self.num_relations, self.real_embedding_dim, 2)
rel = functional.normalize(rel, p=2, dim=-1)
self.relation_embeddings.weight.data = rel.view(self.num_relations, self.embedding_dim)
[docs] @staticmethod
def interaction_function(
h: torch.FloatTensor,
r: torch.FloatTensor,
t: torch.FloatTensor,
) -> torch.FloatTensor:
"""Evaluate the interaction function of ComplEx for given embeddings.
The embeddings have to be in a broadcastable shape.
WARNING: No forward constraints are applied.
:param h: shape: (..., e, 2)
Head embeddings. Last dimension corresponds to (real, imag).
:param r: shape: (..., e, 2)
Relation embeddings. Last dimension corresponds to (real, imag).
:param t: shape: (..., e, 2)
Tail embeddings. Last dimension corresponds to (real, imag).
:return: shape: (...)
The scores.
"""
# Decompose into real and imaginary part
h_re = h[..., 0]
h_im = h[..., 1]
r_re = r[..., 0]
r_im = r[..., 1]
# Rotate (=Hadamard product in complex space).
rot_h = torch.stack(
[
h_re * r_re - h_im * r_im,
h_re * r_im + h_im * r_re,
],
dim=-1,
)
# Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed
diff = rot_h - t
scores = -torch.norm(diff.view(diff.shape[:-2] + (-1,)), dim=-1)
return scores
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
# Get embeddings
h = self.entity_embeddings(hrt_batch[:, 0]).view(-1, self.real_embedding_dim, 2)
r = self.relation_embeddings(hrt_batch[:, 1]).view(-1, self.real_embedding_dim, 2)
t = self.entity_embeddings(hrt_batch[:, 2]).view(-1, self.real_embedding_dim, 2)
# Compute scores
scores = self.interaction_function(h=h, r=r, t=t).view(-1, 1)
# Embedding Regularization
self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim))
return scores
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
# Get embeddings
h = self.entity_embeddings(hr_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2)
r = self.relation_embeddings(hr_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2)
# Rank against all entities
t = self.entity_embeddings.weight.view(1, -1, self.real_embedding_dim, 2)
# Compute scores
scores = self.interaction_function(h=h, r=r, t=t)
# Embedding Regularization
self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim))
return scores
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102
# Get embeddings
r = self.relation_embeddings(rt_batch[:, 0]).view(-1, 1, self.real_embedding_dim, 2)
t = self.entity_embeddings(rt_batch[:, 1]).view(-1, 1, self.real_embedding_dim, 2)
# r expresses a rotation in complex plane.
# The inverse rotation is expressed by the complex conjugate of r.
# The score is computed as the distance of the relation-rotated head to the tail.
# Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e.
# |h * r - t| = |h - conj(r) * t|
r_inv = torch.stack([r[:, :, :, 0], -r[:, :, :, 1]], dim=-1)
# Rank against all entities
h = self.entity_embeddings.weight.view(1, -1, self.real_embedding_dim, 2)
# Compute scores
scores = self.interaction_function(h=t, r=r_inv, t=h)
# Embedding Regularization
self.regularize_if_necessary(h.view(-1, self.embedding_dim), t.view(-1, self.embedding_dim))
return scores