BasesDecomposition

class BasesDecomposition(input_dim, num_relations, num_bases=None, output_dim=None, memory_intense=False)[source]

Bases: pykeen.nn.message_passing.Decomposition

Represent relation-weights as a linear combination of base transformation matrices.

The basis decomposition represents the relation-specific transformation matrices as a weighted combination of base matrices, \(\{\mathbf{B}_i^l\}_{i=1}^{B}\), i.e.,

\[\mathbf{W}_r^l = \sum \limits_{b=1}^B \alpha_{rb} \mathbf{B}^l_i\]

Initialize the layer.

Parameters
  • input_dim (int) – >0 The input dimension.

  • num_relations (int) – >0 The number of relations.

  • num_bases (Optional[int]) – >0 The number of bases to use.

  • output_dim (Optional[int]) – >0 The output dimension. If None is given, defaults to input_dim.

  • memory_intense (bool) – Enable memory-intense forward pass which may be faster, in particular if the number of different relations is small.

Raises

ValueError – If the num_bases is greater than num_relations

Methods Summary

forward(x, source, target, edge_type[, ...])

Relation-specific message passing from source to target.

reset_parameters()

Reset the parameters of this layer.

Methods Documentation

forward(x, source, target, edge_type, edge_weights=None, accumulator=None)[source]

Relation-specific message passing from source to target.

Parameters
  • x (FloatTensor) – shape: (num_nodes, input_dim) The node representations.

  • source (LongTensor) – shape: (num_edges,) The source indices.

  • target (LongTensor) – shape: (num_edges,) The target indices.

  • edge_type (LongTensor) – shape: (num_edges,) The edge types.

  • edge_weights (Optional[FloatTensor]) – shape: (num_edges,) Precomputed edge weights.

  • accumulator (Optional[FloatTensor]) – shape: (num_nodes, output_dim) a pre-allocated output accumulator. may be used if multiple different message passing steps are performed and accumulated by sum. If none is given, create an accumulator filled with zeroes.

Return type

FloatTensor

Returns

shape: (num_nodes, output_dim) The enriched node embeddings.

reset_parameters()[source]

Reset the parameters of this layer.