Source code for pykeen.models.multimodal.distmult_literal

# -*- coding: utf-8 -*-

"""Implementation of the DistMultLiteral model."""

from typing import Any, ClassVar, Mapping, Optional

import torch
import torch.nn as nn

from ..base import MultimodalModel
from ..unimodal.distmult import DistMult
from ...constants import DEFAULT_DROPOUT_HPO_RANGE, DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...losses import Loss
from ...nn import Embedding
from ...triples import TriplesNumericLiteralsFactory
from ...typing import DeviceHint


[docs]class DistMultLiteral(DistMult, MultimodalModel): """An implementation of DistMultLiteral from [agustinus2018]_.""" #: The default strategy for optimizing the model's hyper-parameters hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, input_dropout=DEFAULT_DROPOUT_HPO_RANGE, ) #: The default parameters for the default loss function class loss_default_kwargs: ClassVar[Mapping[str, Any]] = dict(margin=0.0) def __init__( self, triples_factory: TriplesNumericLiteralsFactory, embedding_dim: int = 50, input_dropout: float = 0.0, loss: Optional[Loss] = None, preferred_device: DeviceHint = None, random_seed: Optional[int] = None, ) -> None: super().__init__( triples_factory=triples_factory, embedding_dim=embedding_dim, loss=loss, preferred_device=preferred_device, random_seed=random_seed, ) # Literal # num_ent x num_lit self.numeric_literals = Embedding( num_embeddings=triples_factory.num_entities, embedding_dim=triples_factory.numeric_literals.shape[-1], initializer=lambda x: triples_factory.numeric_literals, ) # Number of columns corresponds to number of literals self.num_of_literals = self.numeric_literals.embedding_dim self.linear_transformation = nn.Linear(self.embedding_dim + self.num_of_literals, self.embedding_dim) self.inp_drop = torch.nn.Dropout(input_dropout) def _get_entity_representations( self, idx: torch.LongTensor, ) -> torch.FloatTensor: emb = self.entity_embeddings.get_in_canonical_shape(indices=idx) lit = self.numeric_literals.get_in_canonical_shape(indices=idx) x = self.linear_transformation(torch.cat([emb, lit], dim=-1)) return self.inp_drop(x)
[docs] def forward( self, h_indices: Optional[torch.LongTensor], r_indices: Optional[torch.LongTensor], t_indices: Optional[torch.LongTensor], ) -> torch.FloatTensor: # noqa: D102 # TODO: this is very similar to ComplExLiteral, except a few dropout differences h = self._get_entity_representations(idx=h_indices) r = self.relation_embeddings.get_in_canonical_shape(indices=r_indices) t = self._get_entity_representations(idx=t_indices) return self.interaction_function(h=h, r=r, t=t)