Source code for pykeen.nn.quaternion

"""Utilities for quaternions."""

from functools import lru_cache

import torch

from ..typing import FloatTensor

__all__ = [
    "normalize",
    "hamiltonian_product",
    "multiplication_table",
]


[docs] def normalize(x: FloatTensor) -> FloatTensor: r""" Normalize the length of relation vectors, if the forward constraint has not been applied yet. Absolute value of a quaternion .. math:: |a + bi + cj + dk| = \sqrt{a^2 + b^2 + c^2 + d^2} L2 norm of quaternion vector: .. math:: \|x\|^2 = \sum_{i=1}^d |x_i|^2 = \sum_{i=1}^d (x_i.re^2 + x_i.im_1^2 + x_i.im_2^2 + x_i.im_3^2) :param x: shape: ``(*batch_dims, 4 \cdot d)`` The vector in flat form. :return: shape: ``(*batch_dims, 4 \cdot d)`` The normalized vector. """ # Normalize relation embeddings shape = x.shape x = x.view(*shape[:-1], -1, 4) x = torch.nn.functional.normalize(x, p=2, dim=-1) return x.view(*shape)
[docs] def hamiltonian_product(qa: FloatTensor, qb: FloatTensor) -> FloatTensor: """Compute the hamiltonian product of two quaternions (which enables rotation).""" return torch.stack( [ qa[0] * qb[0] - qa[1] * qb[1] - qa[2] * qb[2] - qa[3] * qb[3], qa[0] * qb[1] + qa[1] * qb[0] + qa[2] * qb[3] - qa[3] * qb[2], qa[0] * qb[2] - qa[1] * qb[3] + qa[2] * qb[0] + qa[3] * qb[1], qa[0] * qb[3] + qa[1] * qb[2] - qa[2] * qb[1] + qa[3] * qb[0], ], dim=-1, )
[docs] @lru_cache(1) def multiplication_table() -> FloatTensor: """ Create the quaternion basis multiplication table. :return: shape: (4, 4, 4) the table of products of basis elements. ..seealso:: https://en.wikipedia.org/wiki/Quaternion#Multiplication_of_basis_elements """ _1, _i, _j, _k = 0, 1, 2, 3 table = torch.zeros(4, 4, 4) for i, j, k, v in [ # 1 * ? = ?; ? * 1 = ? (_1, _1, _1, 1), (_1, _i, _i, 1), (_1, _j, _j, 1), (_1, _k, _k, 1), (_i, _1, _i, 1), (_j, _1, _j, 1), (_k, _1, _k, 1), # i**2 = j**2 = k**2 = -1 (_i, _i, _1, -1), (_j, _j, _1, -1), (_k, _k, _1, -1), # i * j = k; i * k = -j (_i, _j, _k, 1), (_i, _k, _j, -1), # j * i = -k, j * k = i (_j, _i, _k, -1), (_j, _k, _i, 1), # k * i = j; k * j = -i (_k, _i, _j, 1), (_k, _j, _i, -1), ]: table[i, j, k] = v return table