TransformedRepresentation
- class TransformedRepresentation(transformation, max_id=None, shape=None, base=None, base_kwargs=None, **kwargs)[source]
Bases:
Representation
A (learnable) transformation upon base representations.
In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features. We first load the dataset, here Nations:
>>> from pykeen.datasets import get_dataset >>> dataset = get_dataset(dataset="nations")
Next, we create a random-walk positional encoding of dimension 32:
>>> from pykeen.nn import init >>> dim = 32 >>> initializer = init.RandomWalkPositionalEncoding(triples_factory=dataset.training, dim=dim+1) We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's That is, in the default setup, dim = 33 would return a 32d vector
For the transformation, we use a simple 2-layer MLP
>>> from torch import nn >>> hidden = 64 >>> mlp = nn.Sequential( ... nn.Linear(in_features=dim, out_features=hidden), ... nn.ReLU(), ... nn.Linear(in_features=hidden, out_features=dim), ... )
Finally, the transformed representation is given as
>>> from pykeen.nn import TransformedRepresentation >>> r = TransformedRepresentation( ... transformation=mlp, ... base_kwargs=dict(max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False), ... )
Initialize the representation.
- Parameters:
transformation (
Module
) – the transformationmax_id (
Optional
[int
]) – the number of representations. If provided, must match the base max idshape (
Union
[int
,Sequence
[int
],None
]) – the individual representations’ shape. If provided, must match the output shape of the transformationbase (
Union
[str
,Representation
,Type
[Representation
],None
]) – the base representation, or a hint thereof, cf. representation_resolverbase_kwargs (
Optional
[Mapping
[str
,Any
]]) – keyword-based parameters used to instantiate the base representationkwargs – additional keyword-based parameters passed to
Representation.__init__()
.
- Raises:
ValueError – if the max_id or shape does not match