PyG Message Passing

PyTorch Geometric based representation modules.

The modules enable entity representations which are linked to their graph neighbors’ representations. Similar representations are those by CompGCN or R-GCN. However, this module offers generic modules to combine many of the numerous message passing layers from PyTorch Geometric with base representations. A summary of available message passing layers can be found at torch_geometric.nn.conv.

The three classes differ in how the make use of the relation type information:

We can also easily utilize these representations with ERModel. Here, we showcase how to combine static label-based entity features with a trainable GCN encoder for entity representations, with learned embeddings for relation representations and a DistMultInteraction function.

"""Example for using PyTorch Geometric.

Combine static label-based entity features with a trainable GCN encoder for entity representations, with learned
embeddings for relation representations and a DistMult interaction function.
"""

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.init import LabelBasedInitializer
from pykeen.pipeline import pipeline
from pykeen.triples.triples_factory import TriplesFactory

dataset = get_dataset(
    dataset="nations",
    dataset_kwargs={"create_inverse_triples": True},
)
triples_factory = dataset.training
# build initializer with encoding of entity labels
assert isinstance(triples_factory, TriplesFactory)
entity_initializer = LabelBasedInitializer.from_triples_factory(
    triples_factory=triples_factory,
    for_entities=True,
)
(embedding_dim,) = entity_initializer.tensor.shape[1:]
pipeline(
    dataset=dataset,
    model=ERModel,
    model_kwargs={
        "interaction": "distmult",
        "entity_representations": "SimpleMessagePassing",
        "entity_representations_kwargs": {
            "triples_factory": triples_factory,
            "base_kwargs": {
                "shape": embedding_dim,
                "initializer": entity_initializer,
                "trainable": False,
            },
            "layers": ["GCN"] * 2,
            "layers_kwargs": {
                "in_channels": embedding_dim,
                "out_channels": embedding_dim,
            },
        },
        "relation_representations_kwargs": {
            "shape": embedding_dim,
        },
    },
)

Classes

MessagePassingRepresentation(...[, ...])

An abstract representation class utilizing PyTorch Geometric message passing layers.

SimpleMessagePassingRepresentation(...[, ...])

A representation with message passing not making use of the relation type.

FeaturizedMessagePassingRepresentation(...)

A representation with message passing with uses edge features obtained from relation representations.

TypedMessagePassingRepresentation(...)

A representation with message passing with uses categorical relation type information.

Class Inheritance Diagram

Inheritance diagram of pykeen.nn.pyg.MessagePassingRepresentation, pykeen.nn.pyg.SimpleMessagePassingRepresentation, pykeen.nn.pyg.FeaturizedMessagePassingRepresentation, pykeen.nn.pyg.TypedMessagePassingRepresentation