MessagePassingRepresentation

class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]

Bases: Representation, ABC

An abstract representation class utilizing PyTorch Geometric message passing layers.

It comprises:
  • base (entity) representations, which can also be passed as hints

  • a sequence of message passing layers. They are utilized in an abstract MessagePassingRepresentation._message_passing() to enrich the base representations by neighborhood information.

  • a sequence of activation layers in between the message passing layers.

  • an edge_index buffer, which stores the edge index and is moved to the device alongside the module.

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing

  • layers (Sequence[None]) – the message passing layer(s) or hints thereof

  • layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – additional keyword-based parameters passed to the layers upon instantiation

  • base (str | Representation | type[Representation] | None) – the base representations for entities, or a hint thereof

  • base_kwargs (Mapping[str, Any] | None) – additional keyword-based parameters passed to the base representations upon instantiation

  • shape (tuple[int, ...]) – the output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.

  • max_id (int) – the number of representations. If provided, has to match the base representation’s max_id

  • activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – the activation(s), or hints thereof

  • activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – additional keyword-based parameters passed to the activations upon instantiation

  • restrict_k_hop (bool) – whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizes torch_geometric.utils.k_hop_subgraph().

  • kwargs – additional keyword-based parameters passed to Representation.__init__()

Raises:
  • ImportError – if PyTorch Geometric is not installed

  • ValueError – if the number of activations and message passing layers do not match (after input normalization)

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

abstract pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor