Source code for pykeen.models.multimodal.base

# -*- coding: utf-8 -*-

"""Base classes for multi-modal models."""

from class_resolver.utils import OneOrManyHintOrType, OneOrManyOptionalKwargs

from ..nbase import ERModel
from ...nn.init import PretrainedInitializer
from ...nn.modules import LiteralInteraction
from ...nn.representation import Embedding, Representation
from ...triples import TriplesNumericLiteralsFactory
from ...typing import HeadRepresentation, RelationRepresentation, TailRepresentation
from ...utils import upgrade_to_sequence

__all__ = [
    "LiteralModel",
]


[docs]class LiteralModel(ERModel[HeadRepresentation, RelationRepresentation, TailRepresentation], autoreset=False): """Base class for models with entity literals that uses combinations from :class:`pykeen.nn.combinations`.""" def __init__( self, triples_factory: TriplesNumericLiteralsFactory, interaction: LiteralInteraction, entity_representations: OneOrManyHintOrType[Representation] = None, entity_representations_kwargs: OneOrManyOptionalKwargs = None, relation_representations: OneOrManyHintOrType[Representation] = None, relation_representations_kwargs: OneOrManyOptionalKwargs = None, **kwargs, ): literals = triples_factory.get_numeric_literals_tensor() max_id, *shape = literals.shape entity_representations = tuple(upgrade_to_sequence(entity_representations)) + (Embedding,) entity_representations_kwargs = tuple(upgrade_to_sequence(entity_representations_kwargs)) + ( dict( # max_id=max_id, # will be added by ERModel shape=shape, initializer=PretrainedInitializer(tensor=literals), trainable=False, ), ) super().__init__( triples_factory=triples_factory, interaction=interaction, entity_representations=entity_representations, entity_representations_kwargs=entity_representations_kwargs, relation_representations=relation_representations, relation_representations_kwargs=relation_representations_kwargs, **kwargs, )