Source code for pykeen.models.unimodal.ermlp

"""Implementation of ERMLP."""

from collections.abc import Mapping
from typing import Any, ClassVar

from class_resolver import HintOrType, OptionalKwargs, ResolverKey, update_docstring_with_resolver_keys
from torch import nn

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn import ERMLPInteraction
from ...typing import FloatTensor, Hint, Initializer

__all__ = [
    "ERMLP",
]


[docs] class ERMLP(ERModel[FloatTensor, FloatTensor, FloatTensor]): r"""An implementation of ERMLP from [dong2014]_. This model represents both entities and relations as $d$-dimensional vectors stored in an :class:`~pykeen.nn.representation.Embedding` matrix. The representations are then passed to the :class:`~pykeen.nn.modules.ERMLPInteraction` function to obtain scores. --- name: ER-MLP citation: author: Dong year: 2014 link: https://dl.acm.org/citation.cfm?id=2623623 """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) @update_docstring_with_resolver_keys( ResolverKey(name="activation", resolver="class_resolver.contrib.torch.activation_resolver") ) def __init__( self, *, embedding_dim: int = 64, hidden_dim: int | None = None, activation: HintOrType[nn.Module] = nn.ReLU, activation_kwargs: OptionalKwargs = None, entity_initializer: Hint[Initializer] = nn.init.uniform_, relation_initializer: Hint[Initializer] = nn.init.uniform_, **kwargs, ) -> None: """ Initialize the model. :param embedding_dim: The embedding vector dimension for entities and relations. :param hidden_dim: The hidden dimension of the MLP. Defaults to `embedding_dim`. :param activation: The activation function or a hint thereof. :param activation_kwargs: Additional keyword-based parameters passed to the activation's constructor, if the activation is not pre-instantiated. :param entity_initializer: the method to initialize the entity embeddings :param relation_initializer: the method to initialize the entity embeddings :param kwargs: additional keyword-based parameters passed to :class:`pykeen.models.ERModel` """ super().__init__( interaction=ERMLPInteraction, interaction_kwargs=dict( embedding_dim=embedding_dim, hidden_dim=hidden_dim, activation=activation, activation_kwargs=activation_kwargs, ), entity_representations_kwargs=dict( shape=embedding_dim, initializer=entity_initializer, ), relation_representations_kwargs=dict( shape=embedding_dim, initializer=relation_initializer, ), **kwargs, )