TypedMessagePassingRepresentation¶
- class TypedMessagePassingRepresentation(triples_factory, **kwargs)[source]¶
Bases:
MessagePassingRepresentation
A representation with message passing with uses categorical relation type information.
The message passing layers of this module internally handle the categorical relation type information via an edge_type input, e.g.,
torch_geometric.nn.conv.RGCNConv
, ortorch_geometric.nn.conv.RGATConv
.The following example creates a one-layer RGCN using the basis decomposition:
from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = TypedMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), layers="rgcn", layers_kwargs=dict( in_channels=embedding_dim, out_channels=embedding_dim, num_bases=2, num_relations=dataset.num_relations, ), )
Initialize the representation.
- Parameters:
triples_factory (
CoreTriplesFactory
) – the factory comprising the training triples used for message passingkwargs – additional keyword-based parameters passed to
MessagePassingRepresentation.__init__()
Methods Summary
pass_messages
(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation
- pass_messages(x, edge_index, edge_mask=None)[source]¶
Perform the message passing steps.
- Parameters:
x (
FloatTensor
) – shape: (n, d_in) the base entity representationsedge_index (
LongTensor
) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)edge_mask (
Optional
[BoolTensor
]) – shape: (num_edges,) an edge mask if message passing is restricted
- Return type:
FloatTensor
- Returns:
shape: (n, d_out) the enriched entity representations