Decomposition
- class Decomposition(input_dim, num_relations, output_dim=None)[source]
Bases:
torch.nn.modules.module.Module
,abc.ABC
Base module for relation-specific message passing.
A decomposition module implementation offers a way to reduce the number of parameters needed by learning independent \(d^2\) matrices for each relation. In R-GCN, the two proposed variants are treated as hyper-parameters, and for different datasets different decompositions are superior in performance.
The decomposition module itself does not compute the full matrix from the factors, but rather provides efficient means to compute the product of the factorized matrix with the source nodes’ latent features to construct the messages. This is usually more efficient than constructing the full matrices.
For an intuition, you can think about a simple low-rank matrix factorization of rank 1, where \(W = w w^T\) for a \(d\)-dimensional vector w. Then, computing \(Wv\) as \((w w^T) v\) gives you an intermediate result of size \(d \times d\), while you can also compute \(w(w^Tv)\), where the intermediate result is just a scalar.
Initialize the layer.
- Parameters
Methods Summary
forward
(x, node_keep_mask, source, target, ...)Relation-specific message passing from source to target.
Reset the parameters of this layer.
Methods Documentation
- abstract forward(x, node_keep_mask, source, target, edge_type, edge_weights=None)[source]
Relation-specific message passing from source to target.
- Parameters
x (
FloatTensor
) – shape: (num_nodes, input_dim) The node representations.node_keep_mask (
Optional
[BoolTensor
]) – shape: (num_nodes,) The node-keep mask for self-loop dropout.source (
LongTensor
) – shape: (num_edges,) The source indices.target (
LongTensor
) – shape: (num_edges,) The target indices.edge_type (
LongTensor
) – shape: (num_edges,) The edge types.edge_weights (
Optional
[FloatTensor
]) – shape: (num_edges,) Precomputed edge weights.
- Return type
FloatTensor
- Returns
shape: (num_nodes, output_dim) The enriched node embeddings.