Source code for pykeen.models.inductive.inductive_nodepiece

# -*- coding: utf-8 -*-

"""A wrapper which combines an interaction function with NodePiece entity representations."""

import logging
from typing import Any, Callable, ClassVar, Mapping, Optional

import more_itertools
import torch
from class_resolver import Hint, HintOrType, OptionalKwargs

from .base import InductiveERModel
from ...nn import (
from ...nn.node_piece import RelationTokenizer
from ...triples.triples_factory import CoreTriplesFactory

__all__ = [

logger = logging.getLogger(__name__)

[docs]class InductiveNodePiece(InductiveERModel): """A wrapper which combines an interaction function with NodePiece entity representations from [galkin2021]_. This model uses the :class:`pykeen.nn.NodePieceRepresentation` instead of a typical :class:`pykeen.nn.Embedding` to more efficiently store representations. --- citation: author: Galkin year: 2021 link: github: """ hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) def __init__( self, *, triples_factory: CoreTriplesFactory, inference_factory: CoreTriplesFactory, num_tokens: int = 2, embedding_dim: int = 64, relation_representations_kwargs: OptionalKwargs = None, interaction: HintOrType[Interaction] = DistMultInteraction, aggregation: Hint[Callable[[torch.Tensor, int], torch.Tensor]] = None, validation_factory: Optional[CoreTriplesFactory] = None, test_factory: Optional[CoreTriplesFactory] = None, **kwargs, ) -> None: """ Initialize the model. :param triples_factory: the triples factory of training triples. Must have create_inverse_triples set to True. :param inference_factory: the triples factory of inference triples. Must have create_inverse_triples set to True. :param validation_factory: the triples factory of validation triples. Must have create_inverse_triples set to True. :param test_factory: the triples factory of testing triples. Must have create_inverse_triples set to True. :param num_tokens: the number of relations to use to represent each entity, cf. :class:`pykeen.nn.NodePieceRepresentation`. :param embedding_dim: the embedding dimension. Only used if embedding_specification is not given. :param relation_representations_kwargs: the relation representation parameters :param interaction: the interaction module, or a hint for it. :param aggregation: aggregation of multiple token representations to a single entity representation. By default, this uses :func:`torch.mean`. If a string is provided, the module assumes that this refers to a top-level torch function, e.g. "mean" for :func:`torch.mean`, or "sum" for func:`torch.sum`. An aggregation can also have trainable parameters, .e.g., ``MLP(mean(MLP(tokens)))`` (cf. DeepSets from [zaheer2017]_). In this case, the module has to be created outside of this component. Moreover, we support providing "mlp" as a shortcut to use the MLP aggregation version from [galkin2021]_. We could also have aggregations which result in differently shapes output, e.g. a concatenation of all token embeddings resulting in shape ``(num_tokens * d,)``. In this case, `shape` must be provided. The aggregation takes two arguments: the (batched) tensor of token representations, in shape ``(*, num_tokens, *dt)``, and the index along which to aggregate. :param kwargs: additional keyword-based arguments passed to :meth:`ERModel.__init__` :raises ValueError: if the triples factory does not create inverse triples """ if not triples_factory.create_inverse_triples: raise ValueError( "The provided triples factory does not create inverse triples. However, for the node piece " "representations inverse relation representations are required.", ) # always create representations for normal and inverse relations and padding relation_representations = representation_resolver.make( query=None, pos_kwargs=relation_representations_kwargs, max_id=2 * triples_factory.real_num_relations + 1, shape=embedding_dim, ) if validation_factory is None: validation_factory = inference_factory super().__init__( triples_factory=triples_factory, interaction=interaction, entity_representations=NodePieceRepresentation, entity_representations_kwargs=dict( triples_factory=triples_factory, tokenizers=RelationTokenizer, token_representations=relation_representations, aggregation=aggregation, num_tokens=num_tokens, ), relation_representations=SubsetRepresentation( # hide padding relation max_id=triples_factory.num_relations, base=relation_representations, ), validation_factory=validation_factory, testing_factory=test_factory, **kwargs, ) # note: we need to share the aggregation across representations, since the aggregation may have # trainable parameters np: NodePieceRepresentation = self.entity_representations[0] for representations in self._mode_to_representations.values(): assert len(representations) == 1 np2 = representations[0] assert isinstance(np2, NodePieceRepresentation) np2.combination = np.combination
[docs] def create_entity_representation_for_new_triples( self, triples_factory: CoreTriplesFactory ) -> NodePieceRepresentation: """ Create NodePiece representations for a new triples factory. The representations are initialized such that the same relation representations are used, and the aggregation is shared. :param triples_factory: the triples factory used for relation tokenization; must share the same relation to ID mapping. :return: a new NodePiece entity representation with shared relation tokenization and aggregation. :raises ValueError: if the triples factory does not request inverse triples, or the number of relations differs. """ if not triples_factory.create_inverse_triples: raise ValueError("Must create a triples factory with inverse triples") if triples_factory.num_relations != self.num_relations: raise ValueError(f"{self.num_relations=} != {triples_factory.num_relations=} !") # note: we cannot ensure the mapping also matches... # get relation representations relation_repr = assert isinstance(relation_repr, SubsetRepresentation) relation_repr = relation_repr.base # get combination np = assert isinstance(np, NodePieceRepresentation) combination = np.combination assert isinstance(combination, ConcatAggregationCombination) # get token representations tr = assert isinstance(tr, TokenizationRepresentation) num_tokens = tr.num_tokens # relation representations are shared new = NodePieceRepresentation( triples_factory=triples_factory, tokenizers=RelationTokenizer, token_representations=relation_repr, aggregation=combination.aggregation, num_tokens=num_tokens, ) # share combination weights new.combination = np.combination return new