Source code for pykeen.models.unimodal.trans_r

# -*- coding: utf-8 -*-

"""Implementation of TransR."""

from functools import partial
from typing import Any, ClassVar, Mapping, Optional

import torch
import torch.autograd
import torch.nn.init
from torch.nn import functional

from ..base import EntityRelationEmbeddingModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import Loss
from ...nn import Embedding
from ...nn.init import xavier_uniform_
from ...regularizers import Regularizer
from ...triples import TriplesFactory
from ...typing import DeviceHint
from ...utils import clamp_norm, compose

__all__ = [
    'TransR',
]


def _projection_initializer(
    x: torch.FloatTensor,
    num_relations: int,
    embedding_dim: int,
    relation_dim: int,
) -> torch.FloatTensor:
    """Initialize by Glorot."""
    return torch.nn.init.xavier_uniform_(x.view(num_relations, embedding_dim, relation_dim)).view(x.shape)


[docs]class TransR(EntityRelationEmbeddingModel): r"""An implementation of TransR from [lin2015]_. TransR is an extension of :class:`pykeen.models.TransH` that explicitly considers entities and relations as different objects and therefore represents them in different vector spaces. For a triple $(h,r,t) \in \mathbb{K}$, the entity embeddings, $\textbf{e}_h, \textbf{e}_t \in \mathbb{R}^d$, are first projected into the relation space by means of a relation-specific projection matrix $\textbf{M}_{r} \in \mathbb{R}^{k \times d}$. With relation embedding $\textbf{r}_r \in \mathbb{R}^k$, the interaction model is defined similarly to TransE with: .. math:: f(h,r,t) = -\|\textbf{M}_{r}\textbf{e}_h + \textbf{r}_r - \textbf{M}_{r}\textbf{e}_t\|_{p}^2 The following constraints are applied: * $\|\textbf{e}_h\|_2 \leq 1$ * $\|\textbf{r}_r\|_2 \leq 1$ * $\|\textbf{e}_t\|_2 \leq 1$ * $\|\textbf{M}_{r}\textbf{e}_h\|_2 \leq 1$ * $\|\textbf{M}_{r}\textbf{e}_t\|_2 \leq 1$ .. seealso:: - OpenKE `TensorFlow implementation of TransR <https://github.com/thunlp/OpenKE/blob/master/models/TransR.py>`_ - OpenKE `PyTorch implementation of TransR <https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/models/TransR.py>`_ """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, relation_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, scoring_fct_norm=dict(type=int, low=1, high=2), ) def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 50, relation_dim: int = 30, scoring_fct_norm: int = 1, loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, regularizer: Optional[Regularizer] = None, ) -> None: """Initialize the model.""" super().__init__( triples_factory=triples_factory, embedding_dim=embedding_dim, relation_dim=relation_dim, loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, entity_initializer=xavier_uniform_, entity_constrainer=clamp_norm, entity_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), relation_initializer=compose( xavier_uniform_, functional.normalize, ), relation_constrainer=clamp_norm, relation_constrainer_kwargs=dict(maxnorm=1., p=2, dim=-1), ) self.scoring_fct_norm = scoring_fct_norm # TODO: Initialize from TransE # embeddings self.relation_projections = Embedding.init_with_device( num_embeddings=triples_factory.num_relations, embedding_dim=relation_dim * embedding_dim, device=self.device, initializer=partial( _projection_initializer, num_relations=self.num_relations, embedding_dim=self.embedding_dim, relation_dim=self.relation_dim, ), ) def _reset_parameters_(self): # noqa: D102 super()._reset_parameters_() self.relation_projections.reset_parameters()
[docs] @staticmethod def interaction_function( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, m_r: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the interaction function for given embeddings. The embeddings have to be in a broadcastable shape. :param h: shape: (batch_size, num_entities, d_e) Head embeddings. :param r: shape: (batch_size, num_entities, d_r) Relation embeddings. :param t: shape: (batch_size, num_entities, d_e) Tail embeddings. :param m_r: shape: (batch_size, num_entities, d_e, d_r) The relation specific linear transformations. :return: shape: (batch_size, num_entities) The scores. """ # project to relation specific subspace, shape: (b, e, d_r) h_bot = h @ m_r t_bot = t @ m_r # ensure constraints h_bot = clamp_norm(h_bot, p=2, dim=-1, maxnorm=1.) t_bot = clamp_norm(t_bot, p=2, dim=-1, maxnorm=1.) # evaluate score function, shape: (b, e) return -torch.norm(h_bot + r - t_bot, dim=-1) ** 2
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hrt_batch[:, 0]).unsqueeze(dim=1) r = self.relation_embeddings(indices=hrt_batch[:, 1]).unsqueeze(dim=1) t = self.entity_embeddings(indices=hrt_batch[:, 2]).unsqueeze(dim=1) m_r = self.relation_projections(indices=hrt_batch[:, 1]).view(-1, self.embedding_dim, self.relation_dim) return self.interaction_function(h=h, r=r, t=t, m_r=m_r).view(-1, 1)
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=hr_batch[:, 0]).unsqueeze(dim=1) r = self.relation_embeddings(indices=hr_batch[:, 1]).unsqueeze(dim=1) t = self.entity_embeddings(indices=None).unsqueeze(dim=0) m_r = self.relation_projections(indices=hr_batch[:, 1]).view(-1, self.embedding_dim, self.relation_dim) return self.interaction_function(h=h, r=r, t=t, m_r=m_r)
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(indices=None).unsqueeze(dim=0) r = self.relation_embeddings(indices=rt_batch[:, 0]).unsqueeze(dim=1) t = self.entity_embeddings(indices=rt_batch[:, 1]).unsqueeze(dim=1) m_r = self.relation_projections(indices=rt_batch[:, 0]).view(-1, self.embedding_dim, self.relation_dim) return self.interaction_function(h=h, r=r, t=t, m_r=m_r)