SimpleMessagePassingRepresentation

class SimpleMessagePassingRepresentation(triples_factory, layers, layers_kwargs=None, base=None, base_kwargs=None, max_id=None, shape=None, activations=None, activations_kwargs=None, restrict_k_hop=False, **kwargs)[source]

Bases: MessagePassingRepresentation

A representation with message passing not making use of the relation type.

By only using the connectivity information, but not the relation type information, this module can utilize message passing layers defined on uni-relational graphs, which are the majority of available layers from the PyTorch Geometric library.

Here, we create a two-layer torch_geometric.nn.conv.GCNConv on top of an pykeen.nn.representation.Embedding:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = SimpleMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers=["gcn"] * 2,
    layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
)

Initialize the representation.

Parameters:
Raises:
  • ImportError – if PyTorch Geometric is not installed

  • ValueError – if the number of activations and message passing layers do not match (after input normalization)

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x, edge_index, edge_mask=None)[source]

Perform the message passing steps.

Parameters:
  • x (FloatTensor) – shape: (n, d_in) the base entity representations

  • edge_index (LongTensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Optional[BoolTensor]) – shape: (num_edges,) an edge mask if message passing is restricted

Return type:

FloatTensor

Returns:

shape: (n, d_out) the enriched entity representations