TransformedRepresentation

class TransformedRepresentation(transformation: Module, max_id: int | None = None, shape: int | Sequence[int] | None = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]

Bases: Representation

A (learnable) transformation upon base representations.

In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features, and transform them using a 2-layer MLP.

"""Using transformed representations."""

from torch import nn

from pykeen.datasets import get_dataset
from pykeen.nn import TransformedRepresentation, init

dataset = get_dataset(dataset="nations")

# Create random walk features
# We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's
# That is, in the default setup, dim = 33 would return a 32d vector
dim = 32
initializer = init.RandomWalkPositionalEncodingInitializer(
    triples_factory=dataset.training,
    dim=dim + 1,
)

# build an MLP
hidden = 64
mlp = nn.Sequential(
    nn.Linear(in_features=dim, out_features=hidden),
    nn.ReLU(),
    nn.Linear(in_features=hidden, out_features=dim),
)
r = TransformedRepresentation(
    transformation=mlp,
    # note: this will create an Embedding base representation
    base_kwargs=dict(
        max_id=dataset.num_entities,
        shape=(dim,),
        initializer=initializer,
        trainable=False,
    ),
)

Initialize the representation.

Parameters:
  • transformation (nn.Module) – the transformation

  • max_id (int | None) – the number of representations. If provided, must match the base max id

  • shape (OneOrSequence[int] | None) – the individual representations’ shape. If provided, must match the output shape of the transformation

  • base (HintOrType[Representation]) – the base representation, or a hint thereof, cf. representation_resolver

  • base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base representation

  • kwargs – additional keyword-based parameters passed to Representation.__init__().

Raises:

ValueError – if the max_id or shape does not match

Note

The parameter pair (base, base_kwargs) is used for pykeen.nn.representation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.