Utilities

Utilities for neural network components.

exception ShapeError(shape: Sequence[int], reference: Sequence[int])[source]

An error for a mismatch in shapes.

Initialize the error.

Parameters:
  • shape (Sequence[int]) – the mismatching shape

  • reference (Sequence[int]) – the expected shape

Return type:

None

classmethod verify(shape: int | Sequence[int], reference: int | Sequence[int] | None) Sequence[int][source]

Raise an exception if the shape does not match the reference.

This method normalizes the shapes first.

Parameters:
  • shape (int | Sequence[int]) – the shape to check

  • reference (int | Sequence[int] | None) – the reference shape. If None, the shape always matches.

Raises:

ShapeError – if the two shapes do not match.

Returns:

the normalized shape

Return type:

Sequence[int]

adjacency_tensor_to_stacked_matrix(num_relations: int, num_entities: int, source: Tensor, target: Tensor, edge_type: Tensor, edge_weights: Tensor | None = None, horizontal: bool = True) Tensor[source]

Stack adjacency matrices as described in [thanapalasingam2021].

This method re-arranges the (sparse) adjacency tensor of shape (num_entities, num_relations, num_entities) to a sparse adjacency matrix of shape (num_entities, num_relations * num_entities) (horizontal stacking) or (num_entities * num_relations, num_entities) (vertical stacking). Thereby, we can perform the relation-specific message passing of R-GCN by a single sparse matrix multiplication (and some additional pre- and/or post-processing) of the inputs.

Parameters:
  • num_relations (int) – the number of relations

  • num_entities (int) – the number of entities

  • source (Tensor) – shape: (num_triples,) the source entity indices

  • target (Tensor) – shape: (num_triples,) the target entity indices

  • edge_type (Tensor) – shape: (num_triples,) the edge type, i.e., relation ID

  • edge_weights (Tensor | None) – shape: (num_triples,) scalar edge weights

  • horizontal (bool) – whether to use horizontal or vertical stacking

Returns:

shape: (num_entities * num_relations, num_entities) or (num_entities, num_entities * num_relations) the stacked adjacency matrix

Return type:

Tensor

apply_optional_bn(x: Tensor, batch_norm: BatchNorm1d | None = None) Tensor[source]

Apply optional batch normalization.

Supports multiple batch dimensions.

Parameters:
  • x (Tensor) – shape: (..., d)` The input tensor.

  • batch_norm (BatchNorm1d | None) – An optional batch normalization layer.

Returns:

shape: (..., d)` The normalized tensor.

Return type:

Tensor

safe_diagonal(matrix: Tensor) Tensor[source]

Extract diagonal from a potentially sparse matrix.

Note

this is a work-around as long as torch.diagonal() does not work for sparse tensors

Parameters:

matrix (Tensor) – shape: (n, n) the matrix

Returns:

shape: (n,) the diagonal values.

Return type:

Tensor

use_horizontal_stacking(input_dim: int, output_dim: int) bool[source]

Determine a stacking direction based on the input and output dimension.

The vertical stacking approach is suitable for low dimensional input and high dimensional output, because the projection to low dimensions is done first. While the horizontal stacking approach is good for high dimensional input and low dimensional output as the projection to high dimension is done last.

Parameters:
  • input_dim (int) – the layer’s input dimension

  • output_dim (int) – the layer’s output dimension

Returns:

whether to use horizontal (True) or vertical stacking

Return type:

bool

Utilities for quaternions.

hamiltonian_product(qa: Tensor, qb: Tensor) Tensor[source]

Compute the hamiltonian product of two quaternions (which enables rotation).

Parameters:
Return type:

Tensor

multiplication_table() Tensor[source]

Create the quaternion basis multiplication table.

Returns:

shape: (4, 4, 4) the table of products of basis elements.

Return type:

Tensor

..seealso:: https://en.wikipedia.org/wiki/Quaternion#Multiplication_of_basis_elements

normalize(x: Tensor) Tensor[source]

Normalize the length of relation vectors, if the forward constraint has not been applied yet.

Absolute value of a quaternion

\[|a + bi + cj + dk| = \sqrt{a^2 + b^2 + c^2 + d^2}\]

L2 norm of quaternion vector:

\[\|x\|^2 = \sum_{i=1}^d |x_i|^2 = \sum_{i=1}^d (x_i.re^2 + x_i.im_1^2 + x_i.im_2^2 + x_i.im_3^2)\]
Parameters:

x (Tensor) – shape: (*batch_dims, 4 \cdot d) The vector in flat form.

Returns:

shape: (*batch_dims, 4 \cdot d) The normalized vector.

Return type:

Tensor