FeaturizedMessagePassingRepresentation¶
- class FeaturizedMessagePassingRepresentation(triples_factory, relation_representation=None, relation_representation_kwargs=None, relation_transformation=None, **kwargs)[source]¶
Bases:
TypedMessagePassingRepresentation
A representation with message passing with uses edge features obtained from relation representations.
It (re-)uses a representation layer for relations to obtain edge features, which are then utilized by appropriate message passing layers, e.g.,
torch_geometric.nn.conv.GMMConv
, ortorch_geometric.nn.conv.GATConv
. We further allow a (shared) transformation of edge features between layers.The following example creates a two-layer GAT on top of the base representations:
from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = FeaturizedMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), relation_representation_kwargs=dict( shape=embedding_dim, ), layers="gat", layers_kwargs=dict( in_channels=embedding_dim, out_channels=embedding_dim, edge_dim=embedding_dim, # should match relation dim ), )
Initialize the representation.
- Parameters:
triples_factory (
CoreTriplesFactory
) – the factory comprising the training triples used for message passingrelation_representation (
Union
[str
,Representation
,Type
[Representation
],None
]) – the base representations for relations, or a hint thereofrelation_representation_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the base representations upon instantiationrelation_transformation (
Optional
[Module
]) – an optional transformation to apply to the relation representations after each message passing step. If None, do not modify the representations.kwargs – additional keyword-based parameters passed to
TypedMessagePassingRepresentation.__init__()
, except the triples_factory
Methods Summary
pass_messages
(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation
- pass_messages(x, edge_index, edge_mask=None)[source]¶
Perform the message passing steps.
- Parameters:
x (
FloatTensor
) – shape: (n, d_in) the base entity representationsedge_index (
LongTensor
) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)edge_mask (
Optional
[BoolTensor
]) – shape: (num_edges,) an edge mask if message passing is restricted
- Return type:
FloatTensor
- Returns:
shape: (n, d_out) the enriched entity representations