MessagePassingRepresentation
- class MessagePassingRepresentation(triples_factory, layers, layers_kwargs=None, base=None, base_kwargs=None, max_id=None, shape=None, activations=None, activations_kwargs=None, restrict_k_hop=False, **kwargs)[source]
Bases:
Representation
,ABC
An abstract representation class utilizing PyTorch Geometric message passing layers.
- It comprises:
base (entity) representations, which can also be passed as hints
a sequence of message passing layers. They are utilized in an abstract
MessagePassingRepresentation._message_passing()
to enrich the base representations by neighborhood information.a sequence of activation layers in between the message passing layers.
an edge_index buffer, which stores the edge index and is moved to the device alongside the module.
Initialize the representation.
- Parameters
triples_factory (
CoreTriplesFactory
) – the factory comprising the training triples used for message passinglayers (
Union
[str
,None
,Type
[None
],Sequence
[Union
[str
,None
,Type
[None
]]]]) – the message passing layer(s) or hints thereoflayers_kwargs (
Union
[Mapping
[str
,Any
],None
,Sequence
[Optional
[Mapping
[str
,Any
]]]]) – additional keyword-based parameters passed to the layers upon instantiationbase (
Union
[str
,Representation
,Type
[Representation
],None
]) – the base representations for entities, or a hint thereofbase_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the base representations upon instantiationshape (
Union
[int
,Sequence
[int
],None
]) – the output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.max_id (
Optional
[int
]) – the number of representations. If provided, has to match the base representation’s max_idactivations (
Union
[str
,Module
,Type
[Module
],None
,Sequence
[Union
[str
,Module
,Type
[Module
],None
]]]) – the activation(s), or hints thereofactivations_kwargs (
Union
[Mapping
[str
,Any
],None
,Sequence
[Optional
[Mapping
[str
,Any
]]]]) – additional keyword-based parameters passed to the activations upon instantiationrestrict_k_hop (
bool
) – whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizestorch_geometric.utils.k_hop_subgraph()
.kwargs – additional keyword-based parameters passed to
Representation.__init__()
- Raises
ImportError – if PyTorch Geometric is not installed
ValueError – if the number of activations and message passing layers do not match (after input normalization)
Methods Summary
pass_messages
(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation
- abstract pass_messages(x, edge_index, edge_mask=None)[source]
Perform the message passing steps.
- Parameters
x (
FloatTensor
) – shape: (n, d_in) the base entity representationsedge_index (
LongTensor
) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)edge_mask (
Optional
[BoolTensor
]) – shape: (num_edges,) an edge mask if message passing is restricted
- Return type
FloatTensor
- Returns
shape: (n, d_out) the enriched entity representations