FeaturizedMessagePassingRepresentation

class FeaturizedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, relation_representation: str | Representation | type[Representation] | None = None, relation_representation_kwargs: Mapping[str, Any] | None = None, relation_transformation: Module | None = None, **kwargs)[source]

Bases: TypedMessagePassingRepresentation

A representation with message passing with uses edge features obtained from relation representations.

It (re-)uses a representation layer for relations to obtain edge features, which are then utilized by appropriate message passing layers, e.g., torch_geometric.nn.conv.GMMConv, or torch_geometric.nn.conv.GATConv. We further allow a (shared) transformation of edge features between layers.

The following example creates a two-layer GAT on top of the base representations:

"""Message passing using relation features."""

from pykeen.datasets import get_dataset
from pykeen.models.nbase import ERModel
from pykeen.nn.pyg import FeaturizedMessagePassingRepresentation
from pykeen.nn.representation import Embedding
from pykeen.pipeline import pipeline

embedding_dim = 64
dataset = get_dataset(dataset="nations")
# create embedding matrix for relation representations
relations = Embedding(max_id=dataset.num_relations, embedding_dim=embedding_dim)
entities = FeaturizedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    relation_representation=relations,  # re-use relation representation here
    layers="gat",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        edge_dim=embedding_dim,  # should match relation dim
    ),
)
result = pipeline(
    dataset=dataset,
    # compose a model with distmult interaction function
    model=ERModel(
        triples_factory=dataset.training,
        entity_representations=entities,
        relation_representations=relations,
        interaction="DistMult",
    ),
)

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – The factory comprising the training triples used for message passing.

  • relation_representation (Representation) – The base representations for relations, or a hint thereof.

  • relation_representation_kwargs (OptionalKwargs) – Additional keyword-based parameters passed to the base representations upon instantiation.

  • relation_transformation (nn.Module | None) – An optional transformation to apply to the relation representations after each message passing step. If None, do not modify the representations.

  • kwargs – Additional keyword-based parameters passed to pykeen.nn.pyg.TypedMessagePassingRepresentation, except the triples_factory.

Note

The parameter pair (relation_representation, relation_representation_kwargs) is used for pykeen.nn.representation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor