class Decomposition(input_dim, num_relations, output_dim=None)[source]

Bases: torch.nn.modules.module.Module, abc.ABC

Base module for relation-specific message passing.

A decomposition module implementation offers a way to reduce the number of parameters needed by learning independent \(d^2\) matrices for each relation. In R-GCN, the two proposed variants are treated as hyper-parameters, and for different datasets different decompositions are superior in performance.

The decomposition module itself does not compute the full matrix from the factors, but rather provides efficient means to compute the product of the factorized matrix with the source nodes’ latent features to construct the messages. This is usually more efficient than constructing the full matrices.

For an intuition, you can think about a simple low-rank matrix factorization of rank 1, where \(W = w w^T\) for a \(d\)-dimensional vector w. Then, computing \(Wv\) as \((w w^T) v\) gives you an intermediate result of size \(d \times d\), while you can also compute \(w(w^Tv)\), where the intermediate result is just a scalar.

Initialize the layer.

  • input_dim (int) – >0 The input dimension.

  • num_relations (int) – >0 The number of relations.

  • output_dim (Optional[int]) – >0 The output dimension. If None is given, defaults to input_dim.

Methods Summary

forward(x, node_keep_mask, source, target, ...)

Relation-specific message passing from source to target.


Reset the parameters of this layer.

Methods Documentation

abstract forward(x, node_keep_mask, source, target, edge_type, edge_weights=None)[source]

Relation-specific message passing from source to target.

  • x (FloatTensor) – shape: (num_nodes, input_dim) The node representations.

  • node_keep_mask (Optional[BoolTensor]) – shape: (num_nodes,) The node-keep mask for self-loop dropout.

  • source (LongTensor) – shape: (num_edges,) The source indices.

  • target (LongTensor) – shape: (num_edges,) The target indices.

  • edge_type (LongTensor) – shape: (num_edges,) The edge types.

  • edge_weights (Optional[FloatTensor]) – shape: (num_edges,) Precomputed edge weights.

Return type



shape: (num_nodes, output_dim) The enriched node embeddings.

abstract reset_parameters()[source]

Reset the parameters of this layer.