Source code for pykeen.models.unimodal.quate

# -*- coding: utf-8 -*-

"""Implementation of the QuatE model."""

from typing import Any, ClassVar, Mapping, Optional, Type

import torch
from torch.nn import functional

from ..nbase import ERModel
from ...losses import BCEWithLogitsLoss, Loss
from ...nn.init import init_quaternions
from ...nn.modules import QuatEInteraction
from ...regularizers import LpRegularizer, Regularizer
from ...typing import Constrainer, Hint, Initializer
from ...utils import get_expected_norm

__all__ = [

def quaternion_normalizer(x: torch.FloatTensor) -> torch.FloatTensor:
    Normalize the length of relation vectors, if the forward constraint has not been applied yet.

    Absolute value of a quaternion

    .. math::

        |a + bi + cj + dk| = \sqrt{a^2 + b^2 + c^2 + d^2}

    L2 norm of quaternion vector:

    .. math::
        \|x\|^2 = \sum_{i=1}^d |x_i|^2
                 = \sum_{i=1}^d (^2 + x_i.im_1^2 + x_i.im_2^2 + x_i.im_3^2)
    :param x:
        The vector.

        The normalized vector.
    # Normalize relation embeddings
    shape = x.shape
    x = x.view(*shape[:-1], -1, 4)
    x = functional.normalize(x, p=2, dim=-1)
    return x.view(*shape)

[docs]class QuatE(ERModel): r"""An implementation of QuatE from [zhang2019]_. QuatE uses hypercomplex valued representations for the entities and relations. Entities and relations are represented as vectors $\textbf{e}_i, \textbf{r}_i \in \mathbb{H}^d$, and the plausibility score is computed using the quaternion inner product. .. seealso :: Official implementation: --- citation: author: Zhang year: 2019 link: github: cheungdaven/quate """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) #: The default loss function class loss_default: ClassVar[Type[Loss]] = BCEWithLogitsLoss #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = dict(reduction="mean") #: The LP settings used by [zhang2019]_ for QuatE. regularizer_default_kwargs: ClassVar[Mapping[str, Any]] = dict( weight=0.3 / get_expected_norm(p=2, d=100), p=2.0, normalize=True, ) def __init__( self, *, embedding_dim: int = 100, entity_initializer: Hint[Initializer] = init_quaternions, entity_regularizer: Hint[Regularizer] = LpRegularizer, entity_regularizer_kwargs: Optional[Mapping[str, Any]] = None, relation_initializer: Hint[Initializer] = init_quaternions, relation_regularizer: Hint[Regularizer] = LpRegularizer, relation_regularizer_kwargs: Optional[Mapping[str, Any]] = None, relation_normalizer: Hint[Constrainer] = quaternion_normalizer, **kwargs, ) -> None: """Initialize QuatE. .. note :: The default parameters correspond to the first setting for FB15k-237 described from [zhang2019]_. :param embedding_dim: The embedding dimensionality of the entity embeddings. .. note :: The number of parameter per entity is `4 * embedding_dim`, since quaternion are used. :param entity_initializer: The initializer to use for the entity embeddings. :param entity_regularizer: The regularizer to use for the entity embeddings. :param entity_regularizer_kwargs: The keyword arguments passed to the entity regularizer. Defaults to :data:`QuatE.regularizer_default_kwargs` if not specified. :param relation_initializer: The initializer to use for the relation embeddings. :param relation_regularizer: The regularizer to use for the relation embeddings. :param relation_regularizer_kwargs: The keyword arguments passed to the relation regularizer. Defaults to :data:`QuatE.regularizer_default_kwargs` if not specified. :param relation_normalizer: The normalizer to use for the relation embeddings. :param kwargs: Additional keyword based arguments passed to :class:`pykeen.models.ERModel`. Must not contain "interaction", "entity_representations", or "relation_representations". """ super().__init__( interaction=QuatEInteraction, entity_representations_kwargs=dict( shape=(embedding_dim, 4), # quaternions initializer=entity_initializer, dtype=torch.float, regularizer=entity_regularizer, regularizer_kwargs=entity_regularizer_kwargs or self.regularizer_default_kwargs, ), relation_representations_kwargs=dict( shape=(embedding_dim, 4), # quaternions initializer=relation_initializer, normalizer=relation_normalizer, dtype=torch.float, regularizer=relation_regularizer, regularizer_kwargs=relation_regularizer_kwargs or self.regularizer_default_kwargs, ), **kwargs, )