Source code for pykeen.models.unimodal.node_piece

# -*- coding: utf-8 -*-

"""A wrapper which combines an interaction function with NodePiece entity representations."""

import logging
from typing import Any, Callable, ClassVar, Mapping, Optional, Sequence, Union

import torch
from class_resolver import Hint, HintOrType
from torch import nn

from ..nbase import ERModel
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
from ...nn.emb import EmbeddingSpecification, NodePieceRepresentation, SubsetRepresentationModule
from ...nn.modules import DistMultInteraction, Interaction
from ...triples.triples_factory import CoreTriplesFactory

__all__ = [
    "NodePiece",
]

logger = logging.getLogger(__name__)


class _ConcatMLP(nn.Sequential):
    """A 2-layer MLP with ReLU activation and dropout applied to the concatenation of token representations.

    This is for conveniently choosing a configuration similar to the paper. For more complex aggregation mechanisms,
    pass an arbitrary callable instead.

    .. seealso:: https://github.com/migalkin/NodePiece/blob/d731c9990/lp_rp/pykeen105/nodepiece_rotate.py#L57-L65
    """

    def __init__(
        self,
        num_tokens: int,
        embedding_dim: int,
        dropout: float = 0.1,
        ratio: Union[int, float] = 2,
    ):
        """
        Initialize the module.

        :param num_tokens:
            the number of tokens
        :param embedding_dim:
            the embedding dimension for a single token
        :param dropout:
            the dropout value on the hidden layer
        :param ratio:
            the ratio of the embedding dimension to the hidden layer size.
        """
        hidden_dim = int(ratio * embedding_dim)
        super().__init__(
            nn.Linear(num_tokens * embedding_dim, hidden_dim),
            nn.Dropout(dropout),
            nn.ReLU(),
            nn.Linear(hidden_dim, embedding_dim),
        )

    def forward(self, xs: torch.FloatTensor, dim: int) -> torch.FloatTensor:  # noqa: D102
        # dim is only a parameter to match the signature of torch.mean / torch.sum
        # this class is not thought to be usable from outside
        assert dim == -2
        return super().forward(xs.view(*xs.shape[:-2], -1))


[docs]class NodePiece(ERModel): """A wrapper which combines an interaction function with NodePiece entity representations from [galkin2021]_. This model uses the :class:`pykeen.nn.emb.NodePieceRepresentation` instead of a typical :class:`pykeen.nn.emb.Embedding` to more efficiently store representations. --- citation: author: Galkin year: 2021 link: https://arxiv.org/abs/2106.12144 github: https://github.com/migalkin/NodePiece """ hpo_default: ClassVar[Mapping[str, Any]] = dict( embedding_dim=DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE, ) def __init__( self, *, triples_factory: CoreTriplesFactory, num_tokens: int = 2, embedding_dim: int = 64, embedding_specification: Optional[EmbeddingSpecification] = None, interaction: HintOrType[Interaction] = DistMultInteraction, aggregation: Hint[Callable[[torch.Tensor, int], torch.Tensor]] = None, shape: Optional[Sequence[int]] = None, **kwargs, ) -> None: """ Initialize the model. :param triples_factory: the triples factory. Must have create_inverse_triples set to True. :param num_tokens: the number of relations to use to represent each entity, cf. :class:`pykeen.nn.emb.NodePieceRepresentation`. :param embedding_dim: the embedding dimension. Only used if embedding_specification is not given. :param embedding_specification: the embedding specification. :param interaction: the interaction module, or a hint for it. :param aggregation: aggregation of multiple token representations to a single entity representation. By default, this uses :func:`torch.mean`. If a string is provided, the module assumes that this refers to a top-level torch function, e.g. "mean" for :func:`torch.mean`, or "sum" for func:`torch.sum`. An aggregation can also have trainable parameters, .e.g., ``MLP(mean(MLP(tokens)))`` (cf. DeepSets from [zaheer2017]_). In this case, the module has to be created outside of this component. Moreover, we support providing "mlp" as a shortcut to use the MLP aggregation version from [galkin2021]_. We could also have aggregations which result in differently shapes output, e.g. a concatenation of all token embeddings resulting in shape ``(num_tokens * d,)``. In this case, `shape` must be provided. The aggregation takes two arguments: the (batched) tensor of token representations, in shape ``(*, num_tokens, *dt)``, and the index along which to aggregate. :param shape: the shape of an individual representation. Only necessary, if aggregation results in a change of dimensions. this will only be necessary if the aggregation is an *ad hoc* function. :param kwargs: additional keyword-based arguments passed to :meth:`ERModel.__init__` :raises ValueError: if the triples factory does not create inverse triples """ if not triples_factory.create_inverse_triples: raise ValueError( "The provided triples factory does not create inverse triples. However, for the node piece " "representations inverse relation representations are required.", ) embedding_specification = embedding_specification or EmbeddingSpecification( shape=(embedding_dim,), ) # Create an MLP for string aggregation if aggregation == "mlp": aggregation = _ConcatMLP( num_tokens=num_tokens, embedding_dim=embedding_dim, ) # always create representations for normal and inverse relations and padding relation_representations = embedding_specification.make( num_embeddings=2 * triples_factory.real_num_relations + 1, ) entity_representations = NodePieceRepresentation( triples_factory=triples_factory, token_representation=relation_representations, aggregation=aggregation, shape=shape, num_tokens=num_tokens, ) super().__init__( triples_factory=triples_factory, interaction=interaction, entity_representations=entity_representations, relation_representations=SubsetRepresentationModule( # hide padding relation relation_representations, max_id=triples_factory.num_relations, ), **kwargs, )