# Source code for pykeen.models.unimodal.trans_e

# -*- coding: utf-8 -*-

"""TransE."""

from typing import Optional

import torch
from torch.nn import functional

from ..base import EntityRelationEmbeddingModel
from ..init import embedding_xavier_uniform_
from ...losses import Loss
from ...regularizers import Regularizer
from ...triples import TriplesFactory

__all__ = [
'TransE',
]

[docs]class TransE(EntityRelationEmbeddingModel):
r"""TransE models relations as a translation from head to tail entities in :math:\textbf{e} [bordes2013]_.

.. math::

\textbf{e}_h + \textbf{e}_r \approx \textbf{e}_t

This equation is rearranged and the :math:l_p norm is applied to create the TransE interaction function.

.. math::

f(h, r, t) = - \|\textbf{e}_h + \textbf{e}_r - \textbf{e}_t\|_{p}

While this formulation is computationally efficient, it inherently cannot model one-to-many, many-to-one, and
many-to-many relationships. For triples :math:(h,r,t_1), (h,r,t_2) \in \mathcal{K} where :math:t_1 \neq t_2,
the model adapts the embeddings in order to ensure :math:\textbf{e}_h + \textbf{e}_r \approx \textbf{e}_{t_1}
and :math:\textbf{e}_h + \textbf{e}_r \approx \textbf{e}_{t_2} which results in
:math:\textbf{e}_{t_1} \approx \textbf{e}_{t_2}.
"""

#: The default strategy for optimizing the model's hyper-parameters
hpo_default = dict(
embedding_dim=dict(type=int, low=50, high=300, q=50),
scoring_fct_norm=dict(type=int, low=1, high=2),
)

def __init__(
self,
triples_factory: TriplesFactory,
embedding_dim: int = 50,
automatic_memory_optimization: Optional[bool] = None,
scoring_fct_norm: int = 1,
loss: Optional[Loss] = None,
preferred_device: Optional[str] = None,
random_seed: Optional[int] = None,
regularizer: Optional[Regularizer] = None,
) -> None:
r"""Initialize TransE.

:param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$.
:param scoring_fct_norm: The :math:l_p norm applied in the interaction function. Is usually 1 or 2..

.. seealso::

- OpenKE implementation of TransE <https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/models/TransE.py>_
"""
super().__init__(
triples_factory=triples_factory,
embedding_dim=embedding_dim,
automatic_memory_optimization=automatic_memory_optimization,
loss=loss,
preferred_device=preferred_device,
random_seed=random_seed,
regularizer=regularizer,
)
self.scoring_fct_norm = scoring_fct_norm

# Finalize initialization
self.reset_parameters_()

def _reset_parameters_(self):  # noqa: D102
embedding_xavier_uniform_(self.entity_embeddings)
embedding_xavier_uniform_(self.relation_embeddings)
# Initialise relation embeddings to unit length
functional.normalize(self.relation_embeddings.weight.data, out=self.relation_embeddings.weight.data)

[docs]    def post_parameter_update(self) -> None:  # noqa: D102
# Make sure to call super first
super().post_parameter_update()

# Normalize entity embeddings
functional.normalize(self.entity_embeddings.weight.data, out=self.entity_embeddings.weight.data)

[docs]    def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings(hrt_batch[:, 0])
r = self.relation_embeddings(hrt_batch[:, 1])
t = self.entity_embeddings(hrt_batch[:, 2])

return -torch.norm(h + r - t, dim=-1, p=self.scoring_fct_norm, keepdim=True)

[docs]    def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings(hr_batch[:, 0])
r = self.relation_embeddings(hr_batch[:, 1])
t = self.entity_embeddings.weight

return -torch.norm(h[:, None, :] + r[:, None, :] - t[None, :, :], dim=-1, p=self.scoring_fct_norm)

[docs]    def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings.weight
r = self.relation_embeddings(rt_batch[:, 0])
t = self.entity_embeddings(rt_batch[:, 1])

return -torch.norm(h[None, :, :] + r[:, None, :] - t[:, None, :], dim=-1, p=self.scoring_fct_norm)