MessagePassingRepresentation
- class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]
Bases:
Representation,ABCAn abstract representation class utilizing PyTorch Geometric message passing layers.
- It comprises:
base (entity) representations, which can also be passed as hints
a sequence of message passing layers. They are utilized in an abstract
MessagePassingRepresentation._message_passing()to enrich the base representations by neighborhood information.a sequence of activation layers in between the message passing layers.
an edge_index buffer, which stores the edge index and is moved to the device alongside the module.
Initialize the representation.
- Parameters:
triples_factory (CoreTriplesFactory) – the factory comprising the training triples used for message passing
layers (Sequence[None]) – the message passing layer(s) or hints thereof
layers_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – additional keyword-based parameters passed to the layers upon instantiation
base (str | Representation | type[Representation] | None) – the base representations for entities, or a hint thereof
base_kwargs (Mapping[str, Any] | None) – additional keyword-based parameters passed to the base representations upon instantiation
shape (tuple[int, ...]) – the output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.
max_id (int) – the number of representations. If provided, has to match the base representation’s max_id
activations (str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None]) – the activation(s), or hints thereof
activations_kwargs (Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None]) – additional keyword-based parameters passed to the activations upon instantiation
restrict_k_hop (bool) – whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizes
torch_geometric.utils.k_hop_subgraph().kwargs – additional keyword-based parameters passed to
Representation.__init__()
- Raises:
ImportError – if PyTorch Geometric is not installed
ValueError – if the number of activations and message passing layers do not match (after input normalization)
Methods Summary
pass_messages(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation