MessagePassingRepresentation
- class MessagePassingRepresentation(triples_factory, layers, layers_kwargs=None, base=None, base_kwargs=None, max_id=None, shape=None, activations=None, activations_kwargs=None, restrict_k_hop=False, **kwargs)[source]
Bases:
Representation,ABCAn abstract representation class utilizing PyTorch Geometric message passing layers.
- It comprises:
base (entity) representations, which can also be passed as hints
a sequence of message passing layers. They are utilized in an abstract
MessagePassingRepresentation._message_passing()to enrich the base representations by neighborhood information.a sequence of activation layers in between the message passing layers.
an edge_index buffer, which stores the edge index and is moved to the device alongside the module.
Initialize the representation.
- Parameters:
triples_factory (
CoreTriplesFactory) – the factory comprising the training triples used for message passinglayers (
Union[str,None,Type[None],Sequence[Union[str,None,Type[None]]]]) – the message passing layer(s) or hints thereoflayers_kwargs (
Union[Mapping[str,Any],None,Sequence[Optional[Mapping[str,Any]]]]) – additional keyword-based parameters passed to the layers upon instantiationbase (
Union[str,Representation,Type[Representation],None]) – the base representations for entities, or a hint thereofbase_kwargs (
Optional[Mapping[str,Any]]) – additional keyword-based parameters passed to the base representations upon instantiationshape (
Union[int,Sequence[int],None]) – the output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.max_id (
Optional[int]) – the number of representations. If provided, has to match the base representation’s max_idactivations (
Union[str,Module,Type[Module],None,Sequence[Union[str,Module,Type[Module],None]]]) – the activation(s), or hints thereofactivations_kwargs (
Union[Mapping[str,Any],None,Sequence[Optional[Mapping[str,Any]]]]) – additional keyword-based parameters passed to the activations upon instantiationrestrict_k_hop (
bool) – whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizestorch_geometric.utils.k_hop_subgraph().kwargs – additional keyword-based parameters passed to
Representation.__init__()
- Raises:
ImportError – if PyTorch Geometric is not installed
ValueError – if the number of activations and message passing layers do not match (after input normalization)
Methods Summary
pass_messages(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation
- abstract pass_messages(x, edge_index, edge_mask=None)[source]
Perform the message passing steps.
- Parameters:
x (
FloatTensor) – shape: (n, d_in) the base entity representationsedge_index (
LongTensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)edge_mask (
Optional[BoolTensor]) – shape: (num_edges,) an edge mask if message passing is restricted
- Return type:
FloatTensor- Returns:
shape: (n, d_out) the enriched entity representations