SimpleMessagePassingRepresentation
- class SimpleMessagePassingRepresentation(triples_factory, layers, layers_kwargs=None, base=None, base_kwargs=None, max_id=None, shape=None, activations=None, activations_kwargs=None, restrict_k_hop=False, **kwargs)[source]
Bases:
MessagePassingRepresentation
A representation with message passing not making use of the relation type.
By only using the connectivity information, but not the relation type information, this module can utilize message passing layers defined on uni-relational graphs, which are the majority of available layers from the PyTorch Geometric library.
Here, we create a two-layer
torch_geometric.nn.conv.GCNConv
on top of anpykeen.nn.representation.Embedding
:from pykeen.datasets import get_dataset embedding_dim = 64 dataset = get_dataset(dataset="nations") r = SimpleMessagePassingRepresentation( triples_factory=dataset.training, base_kwargs=dict(shape=embedding_dim), layers=["gcn"] * 2, layers_kwargs=dict(in_channels=embedding_dim, out_channels=embedding_dim), )
Initialize the representation.
- Parameters:
triples_factory (
CoreTriplesFactory
) – the factory comprising the training triples used for message passinglayers (
Union
[str
,None
,Type
[None
],Sequence
[Union
[str
,None
,Type
[None
]]]]) – the message passing layer(s) or hints thereoflayers_kwargs (
Union
[Mapping
[str
,Any
],None
,Sequence
[Optional
[Mapping
[str
,Any
]]]]) – additional keyword-based parameters passed to the layers upon instantiationbase (
Union
[str
,Representation
,Type
[Representation
],None
]) – the base representations for entities, or a hint thereofbase_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the base representations upon instantiationshape (
Union
[int
,Sequence
[int
],None
]) – the output shape. Defaults to the base representation shape. Has to match to output shape of the last message passing layer.max_id (
Optional
[int
]) – the number of representations. If provided, has to match the base representation’s max_idactivations (
Union
[str
,Module
,Type
[Module
],None
,Sequence
[Union
[str
,Module
,Type
[Module
],None
]]]) – the activation(s), or hints thereofactivations_kwargs (
Union
[Mapping
[str
,Any
],None
,Sequence
[Optional
[Mapping
[str
,Any
]]]]) – additional keyword-based parameters passed to the activations upon instantiationrestrict_k_hop (
bool
) – whether to restrict the message passing only to the k-hop neighborhood, when only some indices are requested. This utilizestorch_geometric.utils.k_hop_subgraph()
.kwargs – additional keyword-based parameters passed to
Representation.__init__()
- Raises:
ImportError – if PyTorch Geometric is not installed
ValueError – if the number of activations and message passing layers do not match (after input normalization)
Methods Summary
pass_messages
(x, edge_index[, edge_mask])Perform the message passing steps.
Methods Documentation
- pass_messages(x, edge_index, edge_mask=None)[source]
Perform the message passing steps.
- Parameters:
x (
FloatTensor
) – shape: (n, d_in) the base entity representationsedge_index (
LongTensor
) – shape: (num_selected_edges,) the edge index (which may already be a selection of the full edge index)edge_mask (
Optional
[BoolTensor
]) – shape: (num_edges,) an edge mask if message passing is restricted
- Return type:
FloatTensor
- Returns:
shape: (n, d_out) the enriched entity representations