Source code for pykeen.models.multimodal.distmult_literal

"""Implementation of the DistMultLiteral model."""

from collections.abc import Mapping
from typing import Any, ClassVar

import torch.nn as nn

from .base import LiteralModel
from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn import ConcatProjectionCombination
from ...nn.modules import DistMultInteraction, Interaction
from ...triples import TriplesNumericLiteralsFactory

__all__ = [
    "DistMultLiteral",
]


[docs] class DistMultLiteral(LiteralModel): """An implementation of the LiteralE model with the DistMult interaction from [kristiadi2018]_. --- name: DistMult Literal citation: author: Kristiadi year: 2018 link: https://arxiv.org/abs/1802.00934 """ #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = { "embedding_dim": DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, "input_dropout": DEFAULT_DROPOUT_HPO_RANGE, } #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = {"margin": 0.0} interaction_cls: ClassVar[type[Interaction]] = DistMultInteraction def __init__( self, triples_factory: TriplesNumericLiteralsFactory, embedding_dim: int = 50, input_dropout: float = 0.0, **kwargs, ) -> None: """ Initialize the model. :param triples_factory: the (training) triples factory :param embedding_dim: the embedding dimension :param input_dropout: the input dropout, cf. :meth:`DistMultCombination.__init__` :param kwargs: additional keyword-based parameters passed to :meth:`LiteralModel.__init__` """ super().__init__( triples_factory=triples_factory, interaction=self.interaction_cls, combination=ConcatProjectionCombination, combination_kwargs={ "input_dims": [embedding_dim, triples_factory.literal_shape[0]], "output_dim": embedding_dim, "bias": True, "dropout": input_dropout, # no activation "activation": nn.Identity, "activation_kwargs": None, }, entity_representations_kwargs=[ { "shape": embedding_dim, "initializer": nn.init.xavier_normal_, }, ], relation_representations_kwargs=[ { "shape": embedding_dim, "initializer": nn.init.xavier_normal_, }, ], **kwargs, )