TransformedRepresentation

class TransformedRepresentation(transformation: Module, max_id: int | None = None, shape: int | Sequence[int] | None = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]

Bases: Representation

A (learnable) transformation upon base representations.

In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features. We first load the dataset, here Nations:

>>> from pykeen.datasets import get_dataset
>>> dataset = get_dataset(dataset="nations")

Next, we create a random-walk positional encoding of dimension 32:

>>> from pykeen.nn import init
>>> dim = 32
>>> initializer = init.RandomWalkPositionalEncoding(triples_factory=dataset.training, dim=dim+1)
We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's
That is, in the default setup, dim = 33 would return a 32d vector

For the transformation, we use a simple 2-layer MLP

>>> from torch import nn
>>> hidden = 64
>>> mlp = nn.Sequential(
...     nn.Linear(in_features=dim, out_features=hidden),
...     nn.ReLU(),
...     nn.Linear(in_features=hidden, out_features=dim),
... )

Finally, the transformed representation is given as

>>> from pykeen.nn import TransformedRepresentation
>>> r = TransformedRepresentation(
...     transformation=mlp,
...     base_kwargs=dict(max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False),
... )

Initialize the representation.

Parameters:
  • transformation (nn.Module) – the transformation

  • max_id (int | None) – the number of representations. If provided, must match the base max id

  • shape (OneOrSequence[int] | None) – the individual representations’ shape. If provided, must match the output shape of the transformation

  • base (HintOrType[Representation]) – the base representation, or a hint thereof, cf. representation_resolver

  • base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base representation

  • kwargs – additional keyword-based parameters passed to Representation.__init__().

Raises:

ValueError – if the max_id or shape does not match