MessagePassingRepresentation
- class MessagePassingRepresentation(triples_factory: CoreTriplesFactory, layers: str | None | type[None] | Sequence[str | None | type[None]], layers_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, base: str | Representation | type[Representation] | None = None, base_kwargs: Mapping[str, Any] | None = None, max_id: int | None = None, shape: int | Sequence[int] | None = None, activations: str | Module | type[Module] | None | Sequence[str | Module | type[Module] | None] = None, activations_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, restrict_k_hop: bool = False, **kwargs)[source]
Bases:
Representation,ABCAn abstract representation class utilizing PyTorch Geometric message passing layers.
- It comprises:
Base (entity) representations, which can also be passed as hints.
A sequence of message passing layers. They are utilized in an abstract
pass_messages()to enrich the base representations by neighborhood information.A sequence of activation layers in between the message passing layers.
An
edge_indexbuffer, which stores the edge index and is moved to the device alongside the module.
Initialize the representation.
- Parameters:
triples_factory (CoreTriplesFactory) – The factory comprising the training triples used for message passing.
layers (Sequence[None]) – The message passing layer(s) or hints thereof.
layers_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the layers upon instantiation.
base (HintOrType[Representation]) – The base representations for entities, or a hint thereof.
base_kwargs (OptionalKwargs) – Additional keyword-based parameters passed to the base representations upon instantiation.
shape (tuple[int, ...]) – The output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.
max_id (int) – The number of representations. If provided, has to match the base representation’s max_id.
activations (OneOrManyHintOrType[nn.Module]) – The activation(s), or hints thereof.
activations_kwargs (OneOrManyOptionalKwargs) – Additional keyword-based parameters passed to the activations upon instantiation
restrict_k_hop (bool) – Whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizes
torch_geometric.utils.k_hop_subgraph().kwargs – Additional keyword-based parameters passed to
Representation.
- Raises:
ImportError – If PyTorch Geometric is not installed.
ValueError – If the number of activations and message passing layers do not match (after input normalization).
Note
3 resolvers are used in this function.
The parameter pair
(layers, layers_kwargs)is used forpykeen.nn.pyg.layer_resolverThe parameter pair
(base, base_kwargs)is used forpykeen.nn.representation_resolverThe parameter pair
(activations, activations_kwargs)is used forclass_resolver.contrib.torch.activation_resolver
An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.
Methods Summary
pass_messages(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation