MLPTransformedRepresentation

class MLPTransformedRepresentation(*, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, output_dim: int | None = None, mlp_dropout: float = 0.1, ratio: int | float = 2, **kwargs)[source]

Bases: TransformedRepresentation

A representation that transforms a representation with a learnable two-layer MLP.

In the following example, we show how to construct a feature-enriched embedding.

"""Demonstrate applying a learnable MLP transformation on top of a representation."""

import torch

from pykeen.models import ERModel
from pykeen.nn import Embedding, MLPTransformedRepresentation
from pykeen.triples.generation import generate_triples_factory
from pykeen.typing import FloatTensor

n_entities = 15
n_relations = 3
n_triples = 100
features_dim = 256
target_dim = 32

# mock some triples
triples_factory = generate_triples_factory(n_entities, n_relations, n_triples)

# mock some feature tensor
features = torch.rand(n_entities, features_dim)

base_representation = Embedding.from_pretrained(features)

# this embedding is learned on top of the base representation
entity_representation = MLPTransformedRepresentation(base=base_representation, output_dim=target_dim)

# we're going to use DistMult as the interaction, so
# we need a relation representation of the same size
relation_representation = Embedding(max_id=n_relations, embedding_dim=features_dim)

model = ERModel[FloatTensor, FloatTensor, FloatTensor](
    triples_factory=triples_factory,
    interaction="DistMult",
    entity_representations=entity_representation,
    relation_representations=relation_representation,
)

Initialize the representation.

Parameters:
  • base (HintOrType[Representation]) – the base representation, or a hint thereof, cf. representation_resolver

  • base_kwargs (OptionalKwargs) – keyword-based parameters used to instantiate the base representation

  • output_dim (int | None) – the output dimension. defaults to input dim

  • mlp_dropout (float) –

    the dropout value on the hidden layer.

    Warning

    don’t confuse with the optional keyword argument for the representation’s dropout

  • ratio (int | float) – the ratio of the output dimension to the hidden layer size.

  • kwargs – keyword arguments forwarded to the parent’s constructor