BlockDecomposition

class BlockDecomposition(num_blocks: int | None = None, **kwargs)[source]

Bases: Decomposition

Represent relation-specific weight matrices via block-diagonal matrices.

The block-diagonal decomposition restricts each transformation matrix to a block-diagonal-matrix, i.e.,

\[\mathbf{W}_r^l = diag(\mathbf{B}_{r,1}^l, \ldots, \mathbf{B}_{r,B}^l)\]

where \(\mathbf{B}_{r,i} \in \mathbb{R}^{(d^{(l) }/ B) \times (d^{(l)} / B)}\).

The implementation is based on the efficient version of [thanapalasingam2021], which uses a reshaping of the adjacency tensor into a sparse matrix to support message passing by a single sparse matrix multiplication.

Initialize the layer.

Parameters:

Methods Summary

forward_horizontally_stacked(x, adj)

Forward pass for horizontally stacked adjacency.

forward_vertically_stacked(x, adj)

Forward pass for vertically stacked adjacency.

iter_extra_repr()

Iterate over components for extra_repr.

reset_parameters()

Reset the layer's parameters.

Methods Documentation

forward_horizontally_stacked(x: Tensor, adj: Tensor) Tensor[source]

Forward pass for horizontally stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) The input entity representations.

  • adj (Tensor) – shape: (num_entities, num_relations * num_entities), sparse The horizontally stacked adjacency matrix.

Returns:

shape: (num_entities, output_dim) The updated entity representations.

Return type:

Tensor

forward_vertically_stacked(x: Tensor, adj: Tensor) Tensor[source]

Forward pass for vertically stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) The input entity representations

  • adj (Tensor) – shape: (num_entities * num_relations, num_entities), sparse The vertically stacked adjacency matrix.

Returns:

shape: (num_entities, output_dim) The updated entity representations.

Return type:

Tensor

iter_extra_repr() Iterable[str][source]

Iterate over components for extra_repr.

Return type:

Iterable[str]

reset_parameters()[source]

Reset the layer’s parameters.