Source code for pykeen.models.unimodal.proj_e

# -*- coding: utf-8 -*-

"""Implementation of ProjE."""

from typing import Optional

import numpy
import torch
import torch.autograd
from torch import nn

from ..base import EntityRelationEmbeddingModel
from ..init import embedding_xavier_uniform_
from ...losses import Loss
from ...regularizers import Regularizer
from ...triples import TriplesFactory

__all__ = [
    'ProjE',
]


[docs]class ProjE(EntityRelationEmbeddingModel): r"""An implementation of ProjE from [shi2017]_. ProjE is a neural network-based approach with a *combination* and a *projection* layer. The interaction model first combines $h$ and $r$ by following combination operator: .. math:: \textbf{h} \otimes \textbf{r} = \textbf{D}_e \textbf{h} + \textbf{D}_r \textbf{r} + \textbf{b}_c where $\textbf{D}_e, \textbf{D}_r \in \mathbb{R}^{k \times k}$ are diagonal matrices which are used as shared parameters among all entities and relations, and $\textbf{b}_c \in \mathbb{R}^{k}$ represents the candidate bias vector shared across all entities. Next, the score for the triple $(h,r,t) \in \mathbb{K}$ is computed: .. math:: f(h, r, t) = g(\textbf{t} \ z(\textbf{h} \otimes \textbf{r}) + \textbf{b}_p) where $g$ and $z$ are activation functions, and $\textbf{b}_p$ represents the shared projection bias vector. .. seealso:: - Official Implementation: https://github.com/nddsg/ProjE """ #: The default strategy for optimizing the model's hyper-parameters hpo_default = dict( embedding_dim=dict(type=int, low=50, high=350, q=25), ) #: The default loss function class loss_default = nn.BCEWithLogitsLoss #: The default parameters for the default loss function class loss_default_kwargs = dict(reduction='mean') def __init__( self, triples_factory: TriplesFactory, embedding_dim: int = 50, automatic_memory_optimization: Optional[bool] = None, loss: Optional[Loss] = None, preferred_device: Optional[str] = None, random_seed: Optional[int] = None, inner_non_linearity: Optional[nn.Module] = None, regularizer: Optional[Regularizer] = None, ) -> None: super().__init__( triples_factory=triples_factory, embedding_dim=embedding_dim, automatic_memory_optimization=automatic_memory_optimization, loss=loss, preferred_device=preferred_device, random_seed=random_seed, regularizer=regularizer, ) # Global entity projection self.d_e = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) # Global relation projection self.d_r = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) # Global combination bias self.b_c = nn.Parameter(torch.empty(self.embedding_dim, device=self.device), requires_grad=True) # Global combination bias self.b_p = nn.Parameter(torch.empty(1, device=self.device), requires_grad=True) if inner_non_linearity is None: inner_non_linearity = nn.Tanh() self.inner_non_linearity = inner_non_linearity # Finalize initialization self.reset_parameters_() def _reset_parameters_(self): # noqa: D102 embedding_xavier_uniform_(self.entity_embeddings) embedding_xavier_uniform_(self.relation_embeddings) bound = numpy.sqrt(6) / self.embedding_dim nn.init.uniform_(self.d_e, a=-bound, b=bound) nn.init.uniform_(self.d_r, a=-bound, b=bound) nn.init.uniform_(self.b_c, a=-bound, b=bound) nn.init.uniform_(self.b_p, a=-bound, b=bound)
[docs] def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(hrt_batch[:, 0]) r = self.relation_embeddings(hrt_batch[:, 1]) t = self.entity_embeddings(hrt_batch[:, 2]) # Compute score hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :]) scores = torch.sum(hidden * t, dim=-1, keepdim=True) + self.b_p return scores
[docs] def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings(hr_batch[:, 0]) r = self.relation_embeddings(hr_batch[:, 1]) t = self.entity_embeddings.weight # Rank against all entities hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :]) scores = torch.sum(hidden[:, None, :] * t[None, :, :], dim=-1) + self.b_p return scores
[docs] def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor: # noqa: D102 # Get embeddings h = self.entity_embeddings.weight r = self.relation_embeddings(rt_batch[:, 0]) t = self.entity_embeddings(rt_batch[:, 1]) # Rank against all entities hidden = self.inner_non_linearity( self.d_e[None, None, :] * h[None, :, :] + (self.d_r[None, None, :] * r[:, None, :] + self.b_c[None, None, :]), ) scores = torch.sum(hidden * t[:, None, :], dim=-1) + self.b_p return scores