# Source code for pykeen.models.unimodal.distmult

# -*- coding: utf-8 -*-

"""Implementation of DistMult."""

from typing import Any, ClassVar, Mapping, Type

from class_resolver import HintOrType, OptionalKwargs
from torch.nn import functional

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn.init import xavier_normal_norm_, xavier_uniform_
from ...nn.modules import DistMultInteraction
from ...regularizers import LpRegularizer, Regularizer
from ...typing import Constrainer, Hint, Initializer

__all__ = [
"DistMult",
]

[docs]class DistMult(ERModel):
r"""An implementation of DistMult from [yang2014]_.

This model simplifies RESCAL by restricting matrices representing relations as diagonal matrices.

DistMult is a simplification of :class:pykeen.models.RESCAL where the relation matrices
$\textbf{W}_{r} \in \mathbb{R}^{d \times d}$ are restricted to diagonal matrices:

.. math::

f(h,r,t) = \textbf{e}_h^{T} \textbf{W}_r \textbf{e}_t = \sum_{i=1}^{d}(\textbf{e}_h)_i \cdot
diag(\textbf{W}_r)_i \cdot (\textbf{e}_t)_i

Because of its restriction to diagonal matrices, DistMult is more computationally than RESCAL, but at the same
time it is less expressive. For instance, it is not able to model anti-symmetric relations,
since $f(h,r, t) = f(t,r,h)$. This can alternatively be formulated with relation vectors
$\textbf{r}_r \in \mathbb{R}^d$ and the Hadamard operator and the $l_1$ norm.

.. note::

DistMult uses a hard constraint on the embedding norm, but applies a (soft) regularization term on the
relation vector norms

.. math::

f(h,r,t) = \|\textbf{e}_h \odot \textbf{r}_r \odot \textbf{e}_t\|_1

Note:
- For FB15k, Yang *et al.* report 2 negatives per each positive.

.. seealso::

- OpenKE implementation of DistMult <https://github.com/thunlp/OpenKE/blob/master/models/DistMult.py>_
---
citation:
author: Yang
year: 2014
"""

#: The default strategy for optimizing the model's hyper-parameters
hpo_default: ClassVar[Mapping[str, Any]] = dict(
embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE,
)
#: The regularizer used by [yang2014]_ for DistMult
#: In the paper, they use weight of 0.0001, mini-batch-size of 10, and dimensionality of vector 100
#: Thus, when we use normalized regularization weight, the normalization factor is 10*sqrt(100) = 100, which is
#: why the weight has to be increased by a factor of 100 to have the same configuration as in the paper.
regularizer_default: ClassVar[Type[Regularizer]] = LpRegularizer
#: The LP settings used by [yang2014]_ for DistMult
regularizer_default_kwargs: ClassVar[Mapping[str, Any]] = dict(
weight=0.1,
p=2.0,
normalize=True,
)

def __init__(
self,
*,
embedding_dim: int = 50,
entity_initializer: Hint[Initializer] = xavier_uniform_,
entity_constrainer: Hint[Constrainer] = functional.normalize,
relation_initializer: Hint[Initializer] = xavier_normal_norm_,
regularizer: HintOrType[Regularizer] = LpRegularizer,
regularizer_kwargs: OptionalKwargs = None,
**kwargs,
) -> None:
r"""Initialize DistMult.

:param embedding_dim: The entity embedding dimension $d$. Is usually $d \in [50, 300]$.
:param entity_initializer: Default: xavier uniform, c.f.
:param entity_constrainer: Default: constrain entity embeddings to unit length
:param relation_initializer: Default: relations are initialized to unit length (but not constrained)
:param regularizer:
the *relation* representation regularizer
:param regularizer_kwargs:
additional keyword-based parameters. defaults to :attr:DistMult.regularizer_default_kwargs for the
default regularizer
:param kwargs:
Remaining keyword arguments to forward to :class:pykeen.models.ERModel
"""
if regularizer is LpRegularizer and regularizer_kwargs is None:
regularizer_kwargs = DistMult.regularizer_default_kwargs
super().__init__(
interaction=DistMultInteraction,
entity_representations_kwargs=dict(
shape=embedding_dim,
initializer=entity_initializer,
constrainer=entity_constrainer,
# note: DistMult only regularizes the relation embeddings;
#       entity embeddings are hard constrained instead
),
relation_representations_kwargs=dict(
shape=embedding_dim,
initializer=relation_initializer,
regularizer=regularizer,
regularizer_kwargs=regularizer_kwargs,
),
**kwargs,
)