TypedMessagePassingRepresentation
- class TypedMessagePassingRepresentation(triples_factory: CoreTriplesFactory, **kwargs)[source]
Bases:
MessagePassingRepresentationA representation with message passing with uses categorical relation type information.
The message passing layers of this module internally handle the categorical relation type information via an edge_type input, e.g.,
torch_geometric.nn.conv.RGCNConv, ortorch_geometric.nn.conv.RGATConv.The following example creates a one-layer RGCN using the basis decomposition:
from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = TypedMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), layers="rgcn", layers_kwargs=dict( in_channels=embedding_dim, out_channels=embedding_dim, num_bases=2, num_relations=dataset.num_relations, ), )
Initialize the representation.
- Parameters:
triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing
kwargs – additional keyword-based parameters passed to
MessagePassingRepresentation.__init__()
Methods Summary
pass_messages(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation