BlockDecomposition

class BlockDecomposition(num_blocks=None, **kwargs)[source]

Bases: Decomposition

Represent relation-specific weight matrices via block-diagonal matrices.

The block-diagonal decomposition restricts each transformation matrix to a block-diagonal-matrix, i.e.,

\[\mathbf{W}_r^l = diag(\mathbf{B}_{r,1}^l, \ldots, \mathbf{B}_{r,B}^l)\]

where \(\mathbf{B}_{r,i} \in \mathbb{R}^{(d^{(l) }/ B) \times (d^{(l)} / B)}\).

The implementation is based on the efficient version of [thanapalasingam2021], which uses a reshaping of the adjacency tensor into a sparse matrix to support message passing by a single sparse matrix multiplication.

Initialize the layer.

Parameters:
  • num_blocks (Optional[int]) – the number of blocks.

  • kwargs – keyword-based parameters passed to Decomposition.__init__().

Methods Summary

forward_horizontally_stacked(x, adj)

Forward pass for horizontally stacked adjacency.

forward_vertically_stacked(x, adj)

Forward pass for vertically stacked adjacency.

iter_extra_repr()

Iterate over components for extra_repr.

reset_parameters()

Reset the layer's parameters.

Methods Documentation

forward_horizontally_stacked(x, adj)[source]

Forward pass for horizontally stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) the input entity representations

  • adj (Tensor) – shape: (num_entities, num_relations * num_entities), sparse the horizontally stacked adjacency matrix

Return type:

Tensor

Returns:

shape: (num_entities, output_dim) the updated entity representations.

forward_vertically_stacked(x, adj)[source]

Forward pass for vertically stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) the input entity representations

  • adj (Tensor) – shape: (num_entities * num_relations, num_entities), sparse the vertically stacked adjacency matrix

Return type:

Tensor

Returns:

shape: (num_entities, output_dim) the updated entity representations.

iter_extra_repr()[source]

Iterate over components for extra_repr.

Return type:

Iterable[str]

reset_parameters()[source]

Reset the layer’s parameters.