TypedMessagePassingRepresentation

class TypedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, **kwargs)[source]

Bases: MessagePassingRepresentation

A representation with message passing with uses categorical relation type information.

The message passing layers of this module internally handle the categorical relation type information via an edge_type input, e.g., torch_geometric.nn.conv.RGCNConv, or torch_geometric.nn.conv.RGATConv.

The following example creates a one-layer RGCN using the basis decomposition:

"""Example for message passing with type information.

Here, we use a one-layer RGCN using the basis decomposition.
"""

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.pyg import TypedMessagePassingRepresentation
from pykeen.pipeline import pipeline

embedding_dim = 64
dataset = get_dataset(dataset="nations")
entities = TypedMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers="rgcn",
    layers_kwargs=dict(
        in_channels=embedding_dim,
        out_channels=embedding_dim,
        num_bases=2,
        num_relations=dataset.num_relations,
    ),
)
result = pipeline(
    dataset=dataset,
    # compose a model with distmult interaction function
    model=ERModel(
        triples_factory=dataset.training,
        entity_representations=entities,
        relation_representations_kwargs=dict(embedding_dim=embedding_dim),  # use embedding with same dimension
        interaction="DistMult",
    ),
)

Initialize the representation.

Parameters:

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor