TypedMessagePassingRepresentation

class TypedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, **kwargs)[source]

Bases: MessagePassingRepresentation

A representation with message passing with uses categorical relation type information.

The message passing layers of this module internally handle the categorical relation type information via an edge_type input, e.g., torch_geometric.nn.conv.RGCNConv, or torch_geometric.nn.conv.RGATConv.

The following example creates a one-layer RGCN using the basis decomposition:

from pykeen.datasets import get_dataset

embedding_dim = 64
dataset = get_dataset(dataset="nations")
r = TypedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers="rgcn",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        num_bases=2,
        num_relations=dataset.num_relations,
    ),
)

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing

  • kwargs – additional keyword-based parameters passed to MessagePassingRepresentation.__init__()

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor