BasesDecomposition

class BasesDecomposition(num_bases=None, **kwargs)[source]

Bases: Decomposition

Represent relation-weights as a linear combination of base transformation matrices.

The basis decomposition represents the relation-specific transformation matrices as a weighted combination of base matrices, \(\{\mathbf{B}_i^l\}_{i=1}^{B}\), i.e.,

\[\mathbf{W}_r^l = \sum \limits_{b=1}^B \alpha_{rb} \mathbf{B}^l_i\]

The implementation uses a reshaping of the adjacency tensor into a sparse matrix to support message passing by a single sparse matrix multiplication, cf. [thanapalasingam2021].

Initialize the bases decomposition.

Parameters:
  • num_bases (Optional[int]) – the number of bases

  • kwargs – additional keyword-based parameters passed to Decomposition.__init__()

Attributes Summary

base_weights

Return the base weights.

bases

Return the base representations.

Methods Summary

forward_horizontally_stacked(x, adj)

Forward pass for horizontally stacked adjacency.

forward_vertically_stacked(x, adj)

Forward pass for vertically stacked adjacency.

iter_extra_repr()

Iterate over components for extra_repr.

reset_parameters()

Reset the layer's parameters.

Attributes Documentation

base_weights

Return the base weights.

Return type:

Tensor

bases

Return the base representations.

Return type:

Tensor

Methods Documentation

forward_horizontally_stacked(x, adj)[source]

Forward pass for horizontally stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) the input entity representations

  • adj (Tensor) – shape: (num_entities, num_relations * num_entities), sparse the horizontally stacked adjacency matrix

Return type:

Tensor

Returns:

shape: (num_entities, output_dim) the updated entity representations.

forward_vertically_stacked(x, adj)[source]

Forward pass for vertically stacked adjacency.

Parameters:
  • x (Tensor) – shape: (num_entities, input_dim) the input entity representations

  • adj (Tensor) – shape: (num_entities * num_relations, num_entities), sparse the vertically stacked adjacency matrix

Return type:

Tensor

Returns:

shape: (num_entities, output_dim) the updated entity representations.

iter_extra_repr()[source]

Iterate over components for extra_repr.

Return type:

Iterable[str]

reset_parameters()[source]

Reset the layer’s parameters.