PyG Message Passing

PyTorch Geometric based representation modules.

The modules enable entity representations which are linked to their graph neighbors’ representations. Similar representations are those by CompGCN or R-GCN. However, this module offers generic modules to combine many of the numerous message passing layers from PyTorch Geometric with base representations. A summary of available message passing layers can be found at torch_geometric.nn.conv.

The three classes differ in how the make use of the relation type information:

We can also easily utilize these representations with pykeen.models.ERModel. Here, we showcase how to combine static label-based entity features with a trainable GCN encoder for entity representations, with learned embeddings for relation representations and a DistMult interaction function.

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.init import LabelBasedInitializer
from pykeen.pipeline import pipeline

dataset = get_dataset(dataset="nations", dataset_kwargs=dict(create_inverse_triples=True))
entity_initializer = LabelBasedInitializer.from_triples_factory(
    triples_factory=dataset.training,
    for_entities=True,
)
(embedding_dim,) = entity_initializer.tensor.shape[1:]
r = pipeline(
    dataset=dataset,
    model=ERModel,
    model_kwargs=dict(
        interaction="distmult",
        entity_representations="SimpleMessagePassing",
        entity_representations_kwargs=dict(
            triples_factory=dataset.training,
            base_kwargs=dict(
                shape=embedding_dim,
                initializer=entity_initializer,
                trainable=False,
            ),
            layers=["GCN"] * 2,
            layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
        ),
        relation_representations_kwargs=dict(
            shape=embedding_dim,
        ),
    ),
)

Classes

MessagePassingRepresentation(...[, ...])

An abstract representation class utilizing PyTorch Geometric message passing layers.

SimpleMessagePassingRepresentation(...[, ...])

A representation with message passing not making use of the relation type.

FeaturizedMessagePassingRepresentation(...)

A representation with message passing with uses edge features obtained from relation representations.

TypedMessagePassingRepresentation(...)

A representation with message passing with uses categorical relation type information.

Class Inheritance Diagram

digraph inheritance4110c02433 { bgcolor=transparent; rankdir=LR; size="8.0, 12.0"; "ABC" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Helper class that provides a standard way to create an ABC using"]; "ExtraReprMixin" [URL="../utils.html#pykeen.utils.ExtraReprMixin",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A mixin for modules with hierarchical `extra_repr`."]; "FeaturizedMessagePassingRepresentation" [URL="../../api/pykeen.nn.pyg.FeaturizedMessagePassingRepresentation.html#pykeen.nn.pyg.FeaturizedMessagePassingRepresentation",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A representation with message passing with uses edge features obtained from relation representations."]; "TypedMessagePassingRepresentation" -> "FeaturizedMessagePassingRepresentation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "MessagePassingRepresentation" [URL="../../api/pykeen.nn.pyg.MessagePassingRepresentation.html#pykeen.nn.pyg.MessagePassingRepresentation",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="An abstract representation class utilizing PyTorch Geometric message passing layers."]; "Representation" -> "MessagePassingRepresentation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "ABC" -> "MessagePassingRepresentation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "Module" [fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",tooltip="Base class for all neural network modules."]; "Representation" [URL="../../api/pykeen.nn.representation.Representation.html#pykeen.nn.representation.Representation",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A base class for obtaining representations for entities/relations."]; "Module" -> "Representation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "ExtraReprMixin" -> "Representation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "ABC" -> "Representation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "SimpleMessagePassingRepresentation" [URL="../../api/pykeen.nn.pyg.SimpleMessagePassingRepresentation.html#pykeen.nn.pyg.SimpleMessagePassingRepresentation",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A representation with message passing not making use of the relation type."]; "MessagePassingRepresentation" -> "SimpleMessagePassingRepresentation" [arrowsize=0.5,style="setlinewidth(0.5)"]; "TypedMessagePassingRepresentation" [URL="../../api/pykeen.nn.pyg.TypedMessagePassingRepresentation.html#pykeen.nn.pyg.TypedMessagePassingRepresentation",fillcolor=white,fontname="Vera Sans, DejaVu Sans, Liberation Sans, Arial, Helvetica, sans",fontsize=10,height=0.25,shape=box,style="setlinewidth(0.5),filled",target="_top",tooltip="A representation with message passing with uses categorical relation type information."]; "MessagePassingRepresentation" -> "TypedMessagePassingRepresentation" [arrowsize=0.5,style="setlinewidth(0.5)"]; }