Decomposition¶
- class Decomposition(num_relations, input_dim=32, output_dim=None)[source]¶
Bases:
Module
,ExtraReprMixin
,ABC
Base module for relation-specific message passing.
A decomposition module implementation offers a way to reduce the number of parameters needed by learning independent \(d^2\) matrices for each relation. In R-GCN, the two proposed variants are treated as hyper-parameters, and for different datasets different decompositions are superior in performance.
The decomposition module itself does not compute the full matrix from the factors, but rather provides efficient means to compute the product of the factorized matrix with the source nodes’ latent features to construct the messages. This is usually more efficient than constructing the full matrices.
For an intuition, you can think about a simple low-rank matrix factorization of rank 1, where \(W = w w^T\) for a \(d\)-dimensional vector w. Then, computing \(Wv\) as \((w w^T) v\) gives you an intermediate result of size \(d \times d\), while you can also compute \(w(w^Tv)\), where the intermediate result is just a scalar.
The implementations use the efficient version based on adjacency tensor stacking from [thanapalasingam2021]. The adjacency tensor is reshaped into a sparse matrix to support message passing by a single sparse matrix multiplication, cf.
pykeen.nn.utils.adjacency_tensor_to_stacked_matrix()
.Note
this module does neither take care of the self-loop, nor of applying an activation function.
Initialize the layer.
- Parameters:
Methods Summary
forward
(x, source, target, edge_type[, ...])Relation-specific message passing from source to target.
forward_horizontally_stacked
(x, adj)Forward pass for horizontally stacked adjacency.
forward_vertically_stacked
(x, adj)Forward pass for vertically stacked adjacency.
Iterate over components for extra_repr.
Reset the layer's parameters.
Methods Documentation
- forward(x, source, target, edge_type, edge_weights=None, accumulator=None)[source]¶
Relation-specific message passing from source to target.
- Parameters:
x (
FloatTensor
) – shape: (num_nodes, input_dim) The node representations.source (
LongTensor
) – shape: (num_edges,) The source indices.target (
LongTensor
) – shape: (num_edges,) The target indices.edge_type (
LongTensor
) – shape: (num_edges,) The edge types.edge_weights (
Optional
[FloatTensor
]) – shape: (num_edges,) Precomputed edge weights.accumulator (
Optional
[FloatTensor
]) – shape: (num_nodes, output_dim) a pre-allocated output accumulator. may be used if multiple different message passing steps are performed and accumulated by sum. If none is given, create an accumulator filled with zeroes.
- Return type:
FloatTensor
- Returns:
shape: (num_nodes, output_dim) The enriched node embeddings.
- abstract forward_horizontally_stacked(x, adj)[source]¶
Forward pass for horizontally stacked adjacency.
- Parameters:
- Return type:
- Returns:
shape: (num_entities, output_dim) the updated entity representations.