TransformedRepresentation
- class TransformedRepresentation(transformation, max_id=None, shape=None, base=None, base_kwargs=None, **kwargs)[source]
Bases:
RepresentationA (learnable) transformation upon base representations.
In the following example, we create representations which are obtained from a trainable transformation of fixed random walk encoding features. We first load the dataset, here Nations:
>>> from pykeen.datasets import get_dataset >>> dataset = get_dataset(dataset="nations")
Next, we create a random-walk positional encoding of dimension 32:
>>> from pykeen.nn import init >>> dim = 32 >>> initializer = init.RandomWalkPositionalEncoding(triples_factory=dataset.training, dim=dim+1) We used dim+1 for the RWPE initializion as by default it doesn't return the first dimension of 0's That is, in the default setup, dim = 33 would return a 32d vector
For the transformation, we use a simple 2-layer MLP
>>> from torch import nn >>> hidden = 64 >>> mlp = nn.Sequential( ... nn.Linear(in_features=dim, out_features=hidden), ... nn.ReLU(), ... nn.Linear(in_features=hidden, out_features=dim), ... )
Finally, the transformed representation is given as
>>> from pykeen.nn import TransformedRepresentation >>> r = TransformedRepresentation( ... transformation=mlp, ... base_kwargs=dict(max_id=dataset.num_entities, shape=(dim,), initializer=initializer, trainable=False), ... )
Initialize the representation.
- Parameters:
transformation (
Module) – the transformationmax_id (
Optional[int]) – the number of representations. If provided, must match the base max idshape (
Union[int,Sequence[int],None]) – the individual representations’ shape. If provided, must match the output shape of the transformationbase (
Union[str,Representation,Type[Representation],None]) – the base representation, or a hint thereof, cf. representation_resolverbase_kwargs (
Optional[Mapping[str,Any]]) – keyword-based parameters used to instantiate the base representationkwargs – additional keyword-based parameters passed to
Representation.__init__().
- Raises:
ValueError – if the max_id or shape does not match