SimpleMessagePassingRepresentation

class SimpleMessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]

Bases: MessagePassingRepresentation

A representation with message passing not making use of the relation type.

By only using the connectivity information, but not the relation type information, this module can utilize message passing layers defined on uni-relational graphs, which are the majority of available layers from the PyTorch Geometric library.

Here, we create a two-layer torch_geometric.nn.conv.GCNConv on top of an Embedding.

"""An example for using simple message passing, ignoring edge types.

We create a two-layer GCN on top of an Embedding.
"""

from pykeen.datasets import get_dataset
from pykeen.models import ERModel
from pykeen.nn.pyg import SimpleMessagePassingRepresentation
from pykeen.pipeline import pipeline

embedding_dim = 64
dataset = get_dataset(dataset="nations")
entities = SimpleMessagePassingRepresentation(
    triples_factory=dataset.training,
    base_kwargs=dict(shape=embedding_dim),
    layers=["gcn"] * 2,
    layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim),
)
result = pipeline(
    dataset=dataset,
    # compose a model with distmult interaction function
    model=ERModel(
        triples_factory=dataset.training,
        entity_representations=entities,
        relation_representations_kwargs=dict(embedding_dim=embedding_dim),  # use embedding with same dimension
        interaction="DistMult",
    ),
)

Initialize the representation.

Parameters:
  • triples_factory (CoreTriplesFactory) – The factory comprising the training triples used for message passing.

  • layers (Sequence[None]) – The message passing layer(s) or hints thereof.

  • layers_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the layers upon instantiation.

  • base (HintOrType[Representation]) – The base representations for entities, or a hint thereof.

  • base_kwargs (OptionalKwargs) – Additional keyword-based parameters passed to the base representations upon instantiation.

  • shape (tuple[int, ...]) – The output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.

  • max_id (int) – The number of representations. If provided, has to match the base representation’s max_id.

  • activations (OneOrManyHintOrType[nn.Module]) – The activation(s), or hints thereof.

  • activations_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the activations upon instantiation

  • restrict_k_hop (bool) – Whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizes torch_geometric.utils.k_hop_subgraph().

  • kwargs – Additional keyword-based parameters passed to Representation.

Raises:
  • ImportError – If PyTorch Geometric is not installed.

  • ValueError – If the number of activations and message passing layers do not match (after input normalization).

Note

3 resolvers are used in this function.

  • The parameter pair (layers, layers_kwargs) is used for pykeen.nn.pyg.layer_resolver

  • The parameter pair (base, base_kwargs) is used for pykeen.nn.representation_resolver

  • The parameter pair (activations, activations_kwargs) is used for class_resolver.contrib.torch.activation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

pass_messages(x, edge_index[, edge_mask])

Perform the message passing steps.

Methods Documentation

pass_messages(x: Tensor, edge_index: Tensor, edge_mask: Tensor | None = None) Tensor[source]

Perform the message passing steps.

Parameters:
  • x (Tensor) – shape: (n, d_in) the base entity representations

  • edge_index (Tensor) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)

  • edge_mask (Tensor | None) – shape: (num_edges,) an edge mask if message passing is restricted

Returns:

shape: (n, d_out) the enriched entity representations

Return type:

Tensor