Source code for pykeen.models.unimodal.proj_e

"""Implementation of ProjE."""

from collections.abc import Mapping
from typing import Any, ClassVar

from class_resolver import HintOrType, OptionalKwargs, ResolverKey, update_docstring_with_resolver_keys
from torch import nn

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import BCEWithLogitsLoss, Loss
from ...nn.init import xavier_uniform_
from ...nn.modules import ProjEInteraction
from ...typing import FloatTensor, Hint, Initializer

__all__ = [
    "ProjE",
]


[docs] class ProjE(ERModel[FloatTensor, FloatTensor, FloatTensor]): r"""An implementation of ProjE from [shi2017]_. ProjE represents entities and relations using a $d$-dimensional embedding vector stored in an :class:`~pykeen.nn.representation.Embedding`. On top of these representations, this model uses the :class:`~pykeen.nn.modules.ProjEInteraction` to calculate scores. .. seealso:: - Official Implementation: https://github.com/nddsg/ProjE --- citation: author: Shi year: 2017 link: https://www.aaai.org/ocs/index.php/AAAI/AAAI17/paper/view/14279 github: nddsg/ProjE """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = { "embedding_dim": DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, } #: The default loss function class loss_default: ClassVar[type[Loss]] = BCEWithLogitsLoss #: The default parameters for the default loss function class loss_default_kwargs = {"reduction": "mean"} @update_docstring_with_resolver_keys( ResolverKey(name="inner_non_linearity", resolver="class_resolver.contrib.torch.activation_resolver") ) def __init__( self, *, embedding_dim: int = 50, inner_non_linearity: HintOrType[nn.Module] = None, inner_non_linearity_kwargs: OptionalKwargs = None, entity_initializer: Hint[Initializer] = xavier_uniform_, relation_initializer: Hint[Initializer] = xavier_uniform_, **kwargs, ) -> None: """ Initialize the model. :param embedding_dim: the embedding dimension :param inner_non_linearity: the inner non-linearity, of a hint thereof. cf. :class:`pykeen.nn.modules.ProjEInteraction` :param inner_non_linearity_kwargs: additional keyword-based parameters used to instantiate the non-linearity. :param entity_initializer: the entity representation initializer, defaults to :func:`~pykeen.nn.init.xavier_uniform_`. :param relation_initializer: the relation representation initializer, defaults to :func:`~pykeen.nn.init.xavier_uniform_`. :param kwargs: additional keyword-based parameters passed to :class:`~pykeen.models.ERModel` """ super().__init__( interaction=ProjEInteraction, interaction_kwargs={ "embedding_dim": embedding_dim, "inner_activation": inner_non_linearity, "inner_activation_kwargs": inner_non_linearity_kwargs, }, entity_representations_kwargs={ "shape": embedding_dim, "initializer": entity_initializer, }, relation_representations_kwargs={ "shape": embedding_dim, "initializer": relation_initializer, }, **kwargs, )