MessagePassingRepresentation

class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]

Bases: Representation, ABC

An abstract representation class utilizing PyTorch Geometric message passing layers.

It comprises:
  • Base (entity) representations, which can also be passed as hints.

  • A sequence of message passing layers. They are utilized in an abstract pass_messages() to enrich the base representations by neighborhood information.

  • A sequence of activation layers in between the message passing layers.

  • An edge_index buffer, which stores the edge index and is moved to the device alongside the module.

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – The factory comprising the training triples used for message passing.

  • layers (Sequence[None]) – The message passing layer(s) or hints thereof.

  • layers_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the layers upon instantiation.

  • base (HintOrType[Representation]) – The base representations for entities, or a hint thereof.

  • base_kwargs (OptionalKwargs) – Additional keyword-based parameters passed to the base representations upon instantiation.

  • shape (tuple[int, ...]) – The output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.

  • max_id (int) – The number of representations. If provided, has to match the base representation’s max_id.

  • activations (OneOrManyHintOrType[nn.Module]) – The activation(s), or hints thereof.

  • activations_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the activations upon instantiation

  • restrict_k_hop (bool) – Whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizes torch_geometric.utils.k_hop_subgraph().

  • kwargs – Additional keyword-based parameters passed to Representation.

Raises:
  • ImportError – If PyTorch Geometric is not installed.

  • ValueError – If the number of activations and message passing layers do not match (after input normalization).

Note

3 resolvers are used in this function.

  • The parameter pair (layers, layers_kwargs) is used for pykeen.nn.pyg.layer_resolver

  • The parameter pair (base, base_kwargs) is used for pykeen.nn.representation_resolver

  • The parameter pair (activations, activations_kwargs) is used for class_resolver.contrib.torch.activation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

abstractmethod pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor