PyG Message Passing
PyTorch Geometric based representation modules.
The modules enable entity representations which are linked to their graph neighbors’ representations. Similar
representations are those by CompGCN or R-GCN. However, this module offers generic modules to combine many of the
numerous message passing layers from PyTorch Geometric with base representations. A summary of available message passing
layers can be found at torch_geometric.nn.conv
.
The three classes differ in how the make use of the relation type information:
SimpleMessagePassingRepresentation
only uses the connectivity information from the training triples, but ignores the relation type, e.g.,torch_geometric.nn.conv.GCNConv
.TypedMessagePassingRepresentation
is for message passing layer, which internally handle the categorical relation type information via an edge_type input, e.g.,torch_geometric.nn.conv.RGCNConv
.FeaturizedMessagePassingRepresentation
is for message passing layer which can use edge attributes via the parameter edge_attr, e.g.,torch_geometric.nn.conv.GMMConv
.
We can also easily utilize these representations with pykeen.models.ERModel
. Here, we showcase how to combine
static label-based entity features with a trainable GCN encoder for entity representations, with learned embeddings for
relation representations and a DistMult interaction function.
from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.init import LabelBasedInitializer
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations", dataset_kwargs=dict(create_inverse_triples=True))
entity_initializer = LabelBasedInitializer.from_triples_factory(
triples_factory=dataset.training,
for_entities=True,
)
(embedding_dim,) = entity_initializer.tensor.shape[1:]
r = pipeline(
dataset=dataset,
model=ERModel,
model_kwargs=dict(
interaction="distmult",
entity_representations="SimpleMessagePassing",
entity_representations_kwargs=dict(
triples_factory=dataset.training,
base_kwargs=dict(
shape=embedding_dim,
initializer=entity_initializer,
trainable=False,
),
layers=["GCN"] * 2,
layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
),
relation_representations_kwargs=dict(
shape=embedding_dim,
),
),
)
Classes
|
An abstract representation class utilizing PyTorch Geometric message passing layers. |
|
A representation with message passing not making use of the relation type. |
A representation with message passing with uses edge features obtained from relation representations. |
|
A representation with message passing with uses categorical relation type information. |