TransformedRepresentation
- class TransformedRepresentation(transformation: Module, max_id: int | None = None, shape: int | Sequence[int] | None = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]
Bases:
Representation
A (learnable) transformation upon base representations.
In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features, and transform them using a 2-layer MLP.
"""Using transformed representations.""" from torch import nn from pykeen.datasets import get_dataset from pykeen.nn import TransformedRepresentation, init dataset = get_dataset(dataset="nations") # Create random walk features # We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's # That is, in the default setup, dim = 33 would return a 32d vector dim = 32 initializer = init.RandomWalkPositionalEncodingInitializer( triples_factory=dataset.training, dim=dim + 1, ) # build an MLP hidden = 64 mlp = nn.Sequential( nn.Linear(in_features=dim, out_features=hidden), nn.ReLU(), nn.Linear(in_features=hidden, out_features=dim), ) r = TransformedRepresentation( transformation=mlp, # note: this will create an Embedding base representation base_kwargs=dict( max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False, ), )
Initialize the representation.
- Parameters:
transformation (nn.Module) – the transformation
max_id (int | None) – the number of representations. If provided, must match the base max id
shape (OneOrSequence[int] | None) – the individual representations’ shape. If provided, must match the output shape of the transformation
base (HintOrType[Representation]) – the base representation, or a hint thereof, cf. representation_resolver
base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base representation
kwargs – additional keyword-based parameters passed to
Representation.__init__()
.
- Raises:
ValueError – if the max_id or shape does not match
Note
The parameter pair
(base, base_kwargs)
is used forpykeen.nn.representation_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.