Decomposition

class Decomposition(num_relations: int, input_dim: int = 32, output_dim: int | None = None)[source]

Bases: Module, ExtraReprMixin, ABC

Base module for relation-specific message passing.

A decomposition module implementation offers a way to reduce the number of parameters needed by learning independent \(d^2\) matrices for each relation. In R-GCN, the two proposed variants are treated as hyper-parameters, and for different datasets different decompositions are superior in performance.

The decomposition module itself does not compute the full matrix from the factors, but rather provides efficient means to compute the product of the factorized matrix with the source nodes’ latent features to construct the messages. This is usually more efficient than constructing the full matrices.

For an intuition, you can think about a simple low-rank matrix factorization of rank 1, where \(W = w w^T\) for a \(d\)-dimensional vector w. Then, computing \(Wv\) as \((w w^T) v\) gives you an intermediate result of size \(d \times d\), while you can also compute \(w(w^Tv)\), where the intermediate result is just a scalar.

The implementations use the efficient version based on adjacency tensor stacking from [thanapalasingam2021]. The adjacency tensor is reshaped into a sparse matrix to support message passing by a single sparse matrix multiplication, cf. pykeen.nn.utils.adjacency_tensor_to_stacked_matrix().

Note

this module does neither take care of the self-loop, nor of applying an activation function.

Initialize the layer.

Parameters:
  • num_relations (int) – >0 The number of relations.

  • input_dim (int) – >0 The input dimension.

  • output_dim (int | None) – >0 The output dimension. If None is given, defaults to input_dim.

Methods Summary

forward(x, source, target, edge_type[, ...])

Relation-specific message passing from source to target.

forward_horizontally_stacked(x, adj)

Forward pass for horizontally stacked adjacency.

forward_vertically_stacked(x, adj)

Forward pass for vertically stacked adjacency.

iter_extra_repr()

Iterate over components for extra_repr.

reset_parameters()

Reset the layer's parameters.

Methods Documentation

forward(x: Tensor, source: Tensor, target: Tensor, edge_type: Tensor, edge_weights: Tensor | None = None, accumulator: Tensor | None = None) Tensor[source]

Relation-specific message passing from source to target.

Parameters:
  • x (Tensor) – shape: (num_nodes, input_dim) The node representations.

  • source (Tensor) – shape: (num_edges,) The source indices.

  • target (Tensor) – shape: (num_edges,) The target indices.

  • edge_type (Tensor) – shape: (num_edges,) The edge types.

  • edge_weights (Tensor | None) – shape: (num_edges,) Precomputed edge weights.

  • accumulator (Tensor | None) – shape: (num_nodes, output_dim) A pre-allocated output accumulator. May be used if multiple different message passing steps are performed and accumulated by sum. If none is given, create an accumulator filled with zeroes.

Returns:

shape: (num_nodes, output_dim) The enriched node embeddings.

Return type:

Tensor

abstract forward_horizontally_stacked(x: Tensor, adj: Tensor) Tensor[source]

Forward pass for horizontally stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) The input entity representations.

  • adj (Tensor) – shape: (num_entities, num_relations * num_entities), sparse The horizontally stacked adjacency matrix.

Returns:

shape: (num_entities, output_dim) The updated entity representations.

Return type:

Tensor

abstract forward_vertically_stacked(x: Tensor, adj: Tensor) Tensor[source]

Forward pass for vertically stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) The input entity representations

  • adj (Tensor) – shape: (num_entities * num_relations, num_entities), sparse The vertically stacked adjacency matrix.

Returns:

shape: (num_entities, output_dim) The updated entity representations.

Return type:

Tensor

iter_extra_repr() Iterable[str][source]

Iterate over components for extra_repr.

Return type:

Iterable[str]

reset_parameters()[source]

Reset the layer’s parameters.