MessagePassingRepresentation

class MessagePassingRepresentation(triples_factory, layers, layers_kwargs=None, base=None, base_kwargs=None, max_id=None, shape=None, activations=None, activations_kwargs=None, restrict_k_hop=False, **kwargs)[source]

Bases: Representation, ABC

An abstract representation class utilizing PyTorch Geometric message passing layers.

It comprises:
  • base (entity) representations, which can also be passed as hints

  • a sequence of message passing layers. They are utilized in an abstract MessagePassingRepresentation._message_passing() to enrich the base representations by neighborhood information.

  • a sequence of activation layers in between the message passing layers.

  • an edge_index buffer, which stores the edge index and is moved to the device alongside the module.

Initialize the representation.

Parameters:
Raises:
  • ImportError – if PyTorch Geometric is not installed

  • ValueError – if the number of activations and message passing layers do not match (after input normalization)

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

abstract pass_messages(x, edge_index, edge_mask=None)[source]

Perform the message passing steps.

Parameters:
  • x (FloatTensor) – shape: (n, d_in) the base entity representations

  • edge_index (LongTensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Optional[BoolTensor]) – shape: (num_edges,) an edge mask if message passing is restricted

Return type:

FloatTensor

Returns:

shape: (n, d_out) the enriched entity representations