"""Implementation of ProjE."""
from collections.abc import Mapping
from typing import Any, ClassVar
from class_resolver import HintOrType, OptionalKwargs, ResolverKey, update_docstring_with_resolver_keys
from torch import nn
from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import BCEWithLogitsLoss, Loss
from ...nn.init import xavier_uniform_
from ...nn.modules import ProjEInteraction
from ...typing import FloatTensor, Hint, Initializer
__all__ = [
"ProjE",
]
[docs]
class ProjE(ERModel[FloatTensor, FloatTensor, FloatTensor]):
r"""An implementation of ProjE from [shi2017]_.
ProjE represents entities and relations using a $d$-dimensional embedding vector stored in an
:class:`~pykeen.nn.representation.Embedding`. On top of these representations, this model uses the
:class:`~pykeen.nn.modules.ProjEInteraction` to calculate scores.
.. seealso::
- Official Implementation: https://github.com/nddsg/ProjE
---
citation:
author: Shi
year: 2017
link: https://www.aaai.org/ocs/index.php/AAAI/AAAI17/paper/view/14279
github: nddsg/ProjE
"""
#: The default strategy for optimizing the model's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE,
)
#: The default loss function class
loss_default: ClassVar[type[Loss]] = BCEWithLogitsLoss
#: The default parameters for the default loss function class
loss_default_kwargs = dict(reduction="mean")
@update_docstring_with_resolver_keys(
ResolverKey(name="inner_non_linearity", resolver="class_resolver.contrib.torch.activation_resolver")
)
def __init__(
self,
*,
embedding_dim: int = 50,
inner_non_linearity: HintOrType[nn.Module] = None,
inner_non_linearity_kwargs: OptionalKwargs = None,
entity_initializer: Hint[Initializer] = xavier_uniform_,
relation_initializer: Hint[Initializer] = xavier_uniform_,
**kwargs,
) -> None:
"""
Initialize the model.
:param embedding_dim:
the embedding dimension
:param inner_non_linearity:
the inner non-linearity, of a hint thereof. cf. :class:`pykeen.nn.modules.ProjEInteraction`
:param inner_non_linearity_kwargs:
additional keyword-based parameters used to instantiate the non-linearity.
:param entity_initializer:
the entity representation initializer, defaults to :func:`~pykeen.nn.init.xavier_uniform_`.
:param relation_initializer:
the relation representation initializer, defaults to :func:`~pykeen.nn.init.xavier_uniform_`.
:param kwargs:
additional keyword-based parameters passed to :class:`~pykeen.models.ERModel`
"""
super().__init__(
interaction=ProjEInteraction,
interaction_kwargs=dict(
embedding_dim=embedding_dim,
inner_activation=inner_non_linearity,
inner_activation_kwargs=inner_non_linearity_kwargs,
),
entity_representations_kwargs=dict(
shape=embedding_dim,
initializer=entity_initializer,
),
relation_representations_kwargs=dict(
shape=embedding_dim,
initializer=relation_initializer,
),
**kwargs,
)