FeaturizedMessagePassingRepresentation

class FeaturizedMessagePassingRepresentation(triples_factory, relation_representation=None, relation_representation_kwargs=None, relation_transformation=None, **kwargs)[source]

Bases: TypedMessagePassingRepresentation

A representation with message passing with uses edge features obtained from relation representations.

It (re-)uses a representation layer for relations to obtain edge features, which are then utilized by appropriate message passing layers, e.g., torch_geometric.nn.conv.GMMConv, or torch_geometric.nn.conv.GATConv. We further allow a (shared) transformation of edge features between layers.

The following example creates a two-layer GAT on top of the base representations:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = FeaturizedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    relation_representation_kwargs=dict(
        shape=embedding_dim,
    ),
    layers="gat",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        edge_dim=embedding_dim,  # should match relation dim
    ),
)

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing

  • relation_representation (Union[str, Representation, Type[Representation], None]) – the base representations for relations, or a hint thereof

  • relation_representation_kwargs (Optional[Mapping[str, Any]]) – additional keyword-based parameters passed to the base representations upon instantiation

  • relation_transformation (Optional[Module]) – an optional transformation to apply to the relation representations after each message passing step. If None, do not modify the representations.

  • kwargs – additional keyword-based parameters passed to TypedMessagePassingRepresentation.__init__(), except the triples_factory

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x, edge_index, edge_mask=None)[source]

Perform the message passing steps.

Parameters:
  • x (FloatTensor) – shape: (n, d_in) the base entity representations

  • edge_index (LongTensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Optional[BoolTensor]) – shape: (num_edges,) an edge mask if message passing is restricted

Return type:

FloatTensor

Returns:

shape: (n, d_out) the enriched entity representations