Source code for pykeen.models.unimodal.quate

"""Implementation of the QuatE model."""

from collections.abc import Mapping
from typing import Any, ClassVar, Optional

import torch

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import BCEWithLogitsLoss, Loss
from ...nn import quaternion
from ...nn.init import init_quaternions
from ...nn.modules import QuatEInteraction
from ...regularizers import LpRegularizer, Regularizer
from ...typing import Constrainer, Hint, Initializer
from ...utils import get_expected_norm

__all__ = [
    "QuatE",
]


[docs] class QuatE(ERModel): r"""An implementation of QuatE from [zhang2019]_. QuatE uses hypercomplex valued representations for the entities and relations. Entities and relations are represented as vectors $\textbf{e}_i, \textbf{r}_i \in \mathbb{H}^d$, and the plausibility score is computed using the quaternion inner product. The representations are stored in an :class:`~pykeen.nn.representation.Embedding`. Scores are calculated with :class:`~pykeen.nn.modules.QuatEInteraction`. .. seealso :: Official implementation: https://github.com/cheungdaven/QuatE/blob/master/models/QuatE.py --- citation: author: Zhang year: 2019 arxiv: 1904.10281 link: https://arxiv.org/abs/1904.10281 github: cheungdaven/quate """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) #: The default loss function class loss_default: ClassVar[type[Loss]] = BCEWithLogitsLoss #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = dict(reduction="mean") #: The LP settings used by [zhang2019]_ for QuatE. regularizer_default_kwargs: ClassVar[Mapping[str, Any]] = dict( weight=0.3 / get_expected_norm(p=2, d=100), p=2.0, normalize=True, ) def __init__( self, *, embedding_dim: int = 100, entity_initializer: Hint[Initializer] = init_quaternions, entity_regularizer: Hint[Regularizer] = LpRegularizer, entity_regularizer_kwargs: Optional[Mapping[str, Any]] = None, relation_initializer: Hint[Initializer] = init_quaternions, relation_regularizer: Hint[Regularizer] = LpRegularizer, relation_regularizer_kwargs: Optional[Mapping[str, Any]] = None, relation_normalizer: Hint[Constrainer] = quaternion.normalize, **kwargs, ) -> None: """Initialize QuatE. .. note :: The default parameters correspond to the first setting for FB15k-237 described from [zhang2019]_. :param embedding_dim: The embedding dimensionality of the entity embeddings. .. note :: The number of parameter per entity is `4 * embedding_dim`, since quaternion are used. :param entity_initializer: The initializer to use for the entity embeddings. :param entity_regularizer: The regularizer to use for the entity embeddings. :param entity_regularizer_kwargs: The keyword arguments passed to the entity regularizer. Defaults to :data:`QuatE.regularizer_default_kwargs` if not specified. :param relation_initializer: The initializer to use for the relation embeddings. :param relation_regularizer: The regularizer to use for the relation embeddings. :param relation_regularizer_kwargs: The keyword arguments passed to the relation regularizer. Defaults to :data:`QuatE.regularizer_default_kwargs` if not specified. :param relation_normalizer: The normalizer to use for the relation embeddings. :param kwargs: Additional keyword based arguments passed to :class:`pykeen.models.ERModel`. Must not contain "interaction", "entity_representations", or "relation_representations". """ super().__init__( interaction=QuatEInteraction, entity_representations_kwargs=dict( shape=(embedding_dim, 4), # quaternions initializer=entity_initializer, dtype=torch.float, regularizer=entity_regularizer, regularizer_kwargs=entity_regularizer_kwargs or self.regularizer_default_kwargs, ), relation_representations_kwargs=dict( shape=(embedding_dim, 4), # quaternions initializer=relation_initializer, normalizer=relation_normalizer, dtype=torch.float, regularizer=relation_regularizer, regularizer_kwargs=relation_regularizer_kwargs or self.regularizer_default_kwargs, ), **kwargs, )