TypedMessagePassingRepresentation

class TypedMessagePassingRepresentation(triples_factory, **kwargs)[source]

Bases: MessagePassingRepresentation

A representation with message passing with uses categorical relation type information.

The message passing layers of this module internally handle the categorical relation type information via an edge_type input, e.g., torch_geometric.nn.conv.RGCNConv, or torch_geometric.nn.conv.RGATConv.

The following example creates a one-layer RGCN using the basis decomposition:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = TypedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers="rgcn",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        num_bases=2,
        num_relations=dataset.num_relations,
    ),
)

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing

  • kwargs – additional keyword-based parameters passed to MessagePassingRepresentation.__init__()

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x, edge_index, edge_mask=None)[source]

Perform the message passing steps.

Parameters:
  • x (FloatTensor) – shape: (n, d_in) the base entity representations

  • edge_index (LongTensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Optional[BoolTensor]) – shape: (num_edges,) an edge mask if message passing is restricted

Return type:

FloatTensor

Returns:

shape: (n, d_out) the enriched entity representations