TransformedRepresentation

class TransformedRepresentation(transformation, max_id=None, shape=None, base=None, base_kwargs=None, **kwargs)[source]

Bases: Representation

A (learnable) transformation upon base representations.

In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features. We first load the dataset, here Nations:

>>> from pykeen.datasets import get_dataset
>>> dataset = get_dataset(dataset="nations")

Next, we create a random-walk positional encoding of dimension 32:

>>> from pykeen.nn import init
>>> dim = 32
>>> initializer = init.RandomWalkPositionalEncoding(triples_factory=dataset.training, dim=dim+1)
We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's
That is, in the default setup, dim = 33 would return a 32d vector

For the transformation, we use a simple 2-layer MLP

>>> from torch import nn
>>> hidden = 64
>>> mlp = nn.Sequential(
...     nn.Linear(in_features=dim, out_features=hidden),
...     nn.ReLU(),
...     nn.Linear(in_features=hidden, out_features=dim),
... )

Finally, the transformed representation is given as

>>> from pykeen.nn import TransformedRepresentation
>>> r = TransformedRepresentation(
...     transformation=mlp,
...     base_kwargs=dict(max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False),
... )

Initialize the representation.

Parameters:
  • transformation (Module) – the transformation

  • max_id (Optional[int]) – the number of representations. If provided, must match the base max id

  • shape (Union[int, Sequence[int], None]) – the individual representations’ shape. If provided, must match the output shape of the transformation

  • base (Union[str, Representation, Type[Representation], None]) – the base representation, or a hint thereof, cf. representation_resolver

  • base_kwargs (Optional[Mapping[str, Any]]) – keyword-based parameters used to instantiate the base representation

  • kwargs – additional keyword-based parameters passed to Representation.__init__().

Raises:

ValueError – if the max_id or shape does not match