Source code for pykeen.models.unimodal.ermlp

# -*- coding: utf-8 -*-

"""Implementation of ERMLP."""

from typing import Any, ClassVar, Mapping, Optional

from torch.nn.init import uniform_

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn import ERMLPInteraction
from ...typing import Hint, Initializer

__all__ = [
    "ERMLP",
]


[docs]class ERMLP(ERModel): r"""An implementation of ERMLP from [dong2014]_. ERMLP is a multi-layer perceptron based approach that uses a single hidden layer and represents entities and relations as vectors. In the input-layer, for each triple the embeddings of head, relation, and tail are concatenated and passed to the hidden layer. The output-layer consists of a single neuron that computes the plausibility score of the triple: .. math:: f(h,r,t) = \textbf{w}^{T} g(\textbf{W} [\textbf{h}; \textbf{r}; \textbf{t}]), where $\textbf{W} \in \mathbb{R}^{k \times 3d}$ represents the weight matrix of the hidden layer, $\textbf{w} \in \mathbb{R}^{k}$, the weights of the output layer, and $g$ denotes an activation function such as the hyperbolic tangent. --- name: ER-MLP citation: author: Dong year: 2014 link: https://dl.acm.org/citation.cfm?id=2623623 """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) def __init__( self, *, embedding_dim: int = 64, hidden_dim: Optional[int] = None, entity_initializer: Hint[Initializer] = uniform_, relation_initializer: Hint[Initializer] = uniform_, **kwargs, ) -> None: """Initialize the model.""" # input normalization if hidden_dim is None: hidden_dim = embedding_dim super().__init__( interaction=ERMLPInteraction, interaction_kwargs=dict( embedding_dim=embedding_dim, hidden_dim=hidden_dim, ), entity_representations_kwargs=dict( shape=embedding_dim, initializer=entity_initializer, ), relation_representations_kwargs=dict( shape=embedding_dim, initializer=relation_initializer, ), **kwargs, )