# Source code for pykeen.models.unimodal.proj_e

# -*- coding: utf-8 -*-

"""Implementation of ProjE."""

from typing import Any, ClassVar, Mapping, Optional, Type

import numpy
import torch
from torch import nn

from ..base import EntityRelationEmbeddingModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import BCEWithLogitsLoss, Loss
from ...nn.emb import EmbeddingSpecification
from ...nn.init import xavier_uniform_
from ...typing import Hint, Initializer

__all__ = [
'ProjE',
]

[docs]class ProjE(EntityRelationEmbeddingModel):
r"""An implementation of ProjE from [shi2017]_.

ProjE is a neural network-based approach with a *combination* and a *projection* layer. The interaction model
first combines $h$ and $r$ by following combination operator:

.. math::

\textbf{h} \otimes \textbf{r} = \textbf{D}_e \textbf{h} + \textbf{D}_r \textbf{r} + \textbf{b}_c

where $\textbf{D}_e, \textbf{D}_r \in \mathbb{R}^{k \times k}$ are diagonal matrices which are used as shared
parameters among all entities and relations, and $\textbf{b}_c \in \mathbb{R}^{k}$ represents the candidate bias
vector shared across all entities. Next, the score for the triple $(h,r,t) \in \mathbb{K}$ is computed:

.. math::

f(h, r, t) = g(\textbf{t} \ z(\textbf{h} \otimes \textbf{r}) + \textbf{b}_p)

where $g$ and $z$ are activation functions, and $\textbf{b}_p$ represents the shared projection bias vector.

.. seealso::

- Official Implementation: https://github.com/nddsg/ProjE
---
citation:
author: Shi
year: 2017
github: nddsg/ProjE
"""

#: The default strategy for optimizing the model's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE,
)
#: The default loss function class
loss_default: ClassVar[Type[Loss]] = BCEWithLogitsLoss
#: The default parameters for the default loss function class
loss_default_kwargs = dict(reduction='mean')

def __init__(
self,
*,
embedding_dim: int = 50,
inner_non_linearity: Optional[nn.Module] = None,
entity_initializer: Hint[Initializer] = xavier_uniform_,
relation_initializer: Hint[Initializer] = xavier_uniform_,
**kwargs,
) -> None:
super().__init__(
entity_representations=EmbeddingSpecification(
embedding_dim=embedding_dim,
initializer=entity_initializer,
),
relation_representations=EmbeddingSpecification(
embedding_dim=embedding_dim,
initializer=relation_initializer,
),
**kwargs,
)

# Global entity projection

# Global relation projection

# Global combination bias

# Global combination bias

if inner_non_linearity is None:
inner_non_linearity = nn.Tanh()
self.inner_non_linearity = inner_non_linearity

def _reset_parameters_(self):  # noqa: D102
super()._reset_parameters_()
bound = numpy.sqrt(6) / self.embedding_dim
nn.init.uniform_(self.d_e, a=-bound, b=bound)
nn.init.uniform_(self.d_r, a=-bound, b=bound)
nn.init.uniform_(self.b_c, a=-bound, b=bound)
nn.init.uniform_(self.b_p, a=-bound, b=bound)

[docs]    def score_hrt(self, hrt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings(indices=hrt_batch[:, 0])
r = self.relation_embeddings(indices=hrt_batch[:, 1])
t = self.entity_embeddings(indices=hrt_batch[:, 2])

# Compute score
hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :])
scores = torch.sum(hidden * t, dim=-1, keepdim=True) + self.b_p

return scores

[docs]    def score_t(self, hr_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings(indices=hr_batch[:, 0])
r = self.relation_embeddings(indices=hr_batch[:, 1])
t = self.entity_embeddings(indices=None)

# Rank against all entities
hidden = self.inner_non_linearity(self.d_e[None, :] * h + self.d_r[None, :] * r + self.b_c[None, :])
scores = torch.sum(hidden[:, None, :] * t[None, :, :], dim=-1) + self.b_p

return scores

[docs]    def score_h(self, rt_batch: torch.LongTensor) -> torch.FloatTensor:  # noqa: D102
# Get embeddings
h = self.entity_embeddings(indices=None)
r = self.relation_embeddings(indices=rt_batch[:, 0])
t = self.entity_embeddings(indices=rt_batch[:, 1])

# Rank against all entities
hidden = self.inner_non_linearity(
self.d_e[None, None, :] * h[None, :, :]
+ (self.d_r[None, None, :] * r[:, None, :] + self.b_c[None, None, :]),
)
scores = torch.sum(hidden * t[:, None, :], dim=-1) + self.b_p

return scores