BasesDecomposition
- class BasesDecomposition(num_bases: int | None = None, **kwargs)[source]
Bases:
Decomposition
Represent relation-weights as a linear combination of base transformation matrices.
The basis decomposition represents the relation-specific transformation matrices as a weighted combination of base matrices, \(\{\mathbf{B}_i^l\}_{i=1}^{B}\), i.e.,
\[\mathbf{W}_r^l = \sum \limits_{b=1}^B \alpha_{rb} \mathbf{B}^l_i\]The implementation uses a reshaping of the adjacency tensor into a sparse matrix to support message passing by a single sparse matrix multiplication, cf. [thanapalasingam2021].
Initialize the bases decomposition.
- Parameters:
num_bases (int | None) – the number of bases
kwargs – additional keyword-based parameters passed to
Decomposition.__init__()
Attributes Summary
Return the base weights.
Return the base representations.
Methods Summary
forward_horizontally_stacked
(x, adj)Forward pass for horizontally stacked adjacency.
forward_vertically_stacked
(x, adj)Forward pass for vertically stacked adjacency.
Iterate over components for extra_repr.
Reset the layer's parameters.
Attributes Documentation
- base_weights
Return the base weights.
- bases
Return the base representations.
Methods Documentation
- forward_horizontally_stacked(x: Tensor, adj: Tensor) Tensor [source]
Forward pass for horizontally stacked adjacency.
- Parameters:
- Returns:
shape: (num_entities, output_dim) the updated entity representations.
- Return type:
- forward_vertically_stacked(x: Tensor, adj: Tensor) Tensor [source]
Forward pass for vertically stacked adjacency.
- Parameters:
- Returns:
shape: (num_entities, output_dim) the updated entity representations.
- Return type: