Source code for pykeen.models.unimodal.simple

# -*- coding: utf-8 -*-

"""Implementation of SimplE."""

from typing import Optional, Tuple, Union

import torch.autograd

from ..base import EntityRelationEmbeddingModel
from ...losses import Loss, SoftplusLoss
from ...regularizers import PowerSumRegularizer, Regularizer
from ...triples import TriplesFactory
from ...utils import get_embedding, get_embedding_in_canonical_shape

__all__ = [
    'SimplE',
]


[docs]class SimplE(EntityRelationEmbeddingModel): r"""An implementation of SimplE [kazemi2018]_. SimplE is an extension of canonical polyadic (CP), an early tensor factorization approach in which each entity $e \in \mathcal{E}$ is represented by two vectors $\textbf{h}_e, \textbf{t}_e \in \mathbb{R}^d$ and each relation by a single vector $\textbf{r}_r \in \mathbb{R}^d$. Depending whether an entity participates in a triple as the head or tail entity, either $\textbf{h}$ or $\textbf{t}$ is used. Both entity representations are learned independently, i.e. observing a triple $(h,r,t)$, the method only updates $\textbf{h}_h$ and $\textbf{t}_t$. In contrast to CP, SimplE introduces for each relation $\textbf{r}_r$ the inverse relation $\textbf{r'}_r$, and formulates its the interaction model based on both: .. math:: f(h,r,t) = \frac{1}{2}\left(\left\langle\textbf{h}_h, \textbf{r}_r, \textbf{t}_t\right\rangle + \left\langle\textbf{h}_t, \textbf{r'}_r, \textbf{t}_h\right\rangle\right) Therefore, for each triple $(h,r,t) \in \mathbb{K}$, both $\textbf{h}_h$ and $\textbf{h}_t$ as well as $\textbf{t}_h$ and $\textbf{t}_t$ are updated. .. seealso:: - Official implementation: https://github.com/Mehran-k/SimplE - Improved implementation in pytorch: https://github.com/baharefatemi/SimplE """ #: The default strategy for optimizing the model's hyper-parameters hpo_default = dict( embedding_dim=dict(type=int, low=50, high=350, q=25), ) #: The default loss function class loss_default = SoftplusLoss #: The default parameters for the default loss function class loss_default_kwargs = {} #: The regularizer used by [trouillon2016]_ for SimplE #: In the paper, they use weight of 0.1, and do not normalize the #: regularization term by the number of elements, which is 200. regularizer_default = PowerSumRegularizer #: The power sum settings used by [trouillon2016]_ for SimplE regularizer_default_kwargs = dict( weight=20, p=2.0, normalize=True, ) def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 200, automatic_memory_optimization: Optional[bool] = None, loss: Optional[Loss] = None, preferred_device: Optional[str] = None, random_seed: Optional[int] = None, regularizer: Optional[Regularizer] = None, clamp_score: Optional[Union[float, Tuple[float, float]]] = None, ) -> None: super().__init__( triples_factory=triples_factory, embedding_dim=embedding_dim, automatic_memory_optimization=automatic_memory_optimization, loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, ) # extra embeddings self.tail_entity_embeddings = get_embedding( num_embeddings=triples_factory.num_entities, embedding_dim=embedding_dim, device=self.device, ) self.inverse_relation_embeddings = get_embedding( num_embeddings=triples_factory.num_relations, embedding_dim=embedding_dim, device=self.device, ) if isinstance(clamp_score, float): clamp_score = (-clamp_score, clamp_score) self.clamp = clamp_score # Finalize initialization self.reset_parameters_() def _reset_parameters_(self): # noqa: D102 for emb in [ self.entity_embeddings, self.tail_entity_embeddings, self.relation_embeddings, self.inverse_relation_embeddings, ]: emb.reset_parameters() def _score(self, h_ind: torch.LongTensor, r_ind: torch.LongTensor, t_ind: torch.LongTensor) -> torch.FloatTensor: # forward model h = get_embedding_in_canonical_shape(embedding=self.entity_embeddings, ind=h_ind) r = get_embedding_in_canonical_shape(embedding=self.relation_embeddings, ind=r_ind) t = get_embedding_in_canonical_shape(embedding=self.tail_entity_embeddings, ind=t_ind) scores = (h * r * t).sum(dim=-1) # Regularization self.regularize_if_necessary(h, r, t) # backward model h = get_embedding_in_canonical_shape(embedding=self.entity_embeddings, ind=t_ind) r = get_embedding_in_canonical_shape(embedding=self.inverse_relation_embeddings, ind=r_ind) t = get_embedding_in_canonical_shape(embedding=self.tail_entity_embeddings, ind=h_ind) scores = 0.5 * (scores + (h * r * t).sum(dim=-1)) # Regularization self.regularize_if_necessary(h, r, t) # Note: In the code in their repository, the score is clamped to [-20, 20]. # That is not mentioned in the paper, so it is omitted here. if self.clamp is not None: min_, max_ = self.clamp scores = scores.clamp(min=min_, max=max_) return scores
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self._score(h_ind=hrt_batch[:, 0], r_ind=hrt_batch[:, 1], t_ind=hrt_batch[:, 2]).view(-1, 1)
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self._score(h_ind=hr_batch[:, 0], r_ind=hr_batch[:, 1], t_ind=None)
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 return self._score(h_ind=None, r_ind=rt_batch[:, 0], t_ind=rt_batch[:, 1])