TransformedRepresentation
- class TransformedRepresentation(transformation: Module, max_id: int | None = None, shape: int | Sequence[int] | None = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]
Bases:
RepresentationA (learnable) transformation upon base representations.
In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features. We first load the dataset, here Nations:
>>> from pykeen.datasets import get_dataset >>> dataset = get_dataset(dataset="nations")
Next, we create a random-walk positional encoding of dimension 32:
>>> from pykeen.nn import init >>> dim = 32 >>> initializer = init.RandomWalkPositionalEncoding(triples_factory=dataset.training, dim=dim+1) We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's That is, in the default setup, dim = 33 would return a 32d vector
For the transformation, we use a simple 2-layer MLP
>>> from torch import nn >>> hidden = 64 >>> mlp = nn.Sequential( ... nn.Linear(in_features=dim, out_features=hidden), ... nn.ReLU(), ... nn.Linear(in_features=hidden, out_features=dim), ... )
Finally, the transformed representation is given as
>>> from pykeen.nn import TransformedRepresentation >>> r = TransformedRepresentation( ... transformation=mlp, ... base_kwargs=dict(max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False), ... )
Initialize the representation.
- Parameters:
transformation (nn.Module) – the transformation
max_id (int | None) – the number of representations. If provided, must match the base max id
shape (OneOrSequence[int] | None) – the individual representations’ shape. If provided, must match the output shape of the transformation
base (HintOrType[Representation]) – the base representation, or a hint thereof, cf. representation_resolver
base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base representation
kwargs – additional keyword-based parameters passed to
Representation.__init__().
- Raises:
ValueError – if the max_id or shape does not match