"""Functional forms of interaction methods.
These implementations allow for an arbitrary number of batch dimensions,
as well as broadcasting and thus naturally support slicing and 1:n scoring.
"""
from __future__ import annotations
import torch
from torch import broadcast_tensors, nn
from .compute_kernel import batched_dot
from ..typing import FloatTensor
from ..utils import (
clamp_norm,
einsum,
negative_norm,
negative_norm_of_sum,
project_entity,
tensor_product,
tensor_sum,
)
__all__ = [
"multilinear_tucker_interaction",
"mure_interaction",
"ntn_interaction",
"pair_re_interaction",
"proje_interaction",
"rescal_interaction",
"simple_interaction",
"se_interaction",
"transd_interaction",
"transe_interaction",
"transf_interaction",
"transh_interaction",
"transr_interaction",
"transformer_interaction",
"triple_re_interaction",
"tucker_interaction",
"um_interaction",
"linea_re_interaction",
]
def _apply_optional_bn_to_tensor(
x: FloatTensor,
output_dropout: nn.Dropout,
batch_norm: nn.BatchNorm1d | None = None,
) -> FloatTensor:
"""Apply optional batch normalization and dropout layer. Supports multiple batch dimensions."""
if batch_norm is not None:
shape = x.shape
x = x.reshape(-1, shape[-1])
x = batch_norm(x)
x = x.view(*shape)
return output_dropout(x)
def circular_correlation(
a: FloatTensor,
b: FloatTensor,
) -> FloatTensor:
"""
Compute the circular correlation between to vectors.
.. note ::
The implementation uses FFT.
:param a: shape: s_1
The tensor with the first vectors.
:param b:
The tensor with the second vectors.
:return:
The circular correlation between the vectors.
"""
# Circular correlation of entity embeddings
a_fft = torch.fft.rfft(a, dim=-1)
b_fft = torch.fft.rfft(b, dim=-1)
# complex conjugate
a_fft = torch.conj(a_fft)
# Hadamard product in frequency domain
p_fft = a_fft * b_fft
# inverse real FFT
return torch.fft.irfft(p_fft, n=a.shape[-1], dim=-1)
[docs]
def ntn_interaction(
h: FloatTensor,
t: FloatTensor,
w: FloatTensor,
vh: FloatTensor,
vt: FloatTensor,
b: FloatTensor,
u: FloatTensor,
activation: nn.Module,
) -> FloatTensor:
r"""Evaluate the NTN interaction function.
.. math::
f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r' t + b_r)
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param w: shape: (`*batch_dims`, k, dim, dim)
The relation specific transformation matrix W_r.
:param vh: shape: (`*batch_dims`, k, dim)
The head transformation matrix V_h.
:param vt: shape: (`*batch_dims`, k, dim)
The tail transformation matrix V_h.
:param b: shape: (`*batch_dims`, k)
The relation specific offset b_r.
:param u: shape: (`*batch_dims`, k)
The relation specific final linear transformation b_r.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param activation:
The activation function.
:return: shape: batch_dims
The scores.
"""
return (
u
* activation(
tensor_sum(
einsum("...d,...kde,...e->...k", h, w, t), # shape: (*batch_dims, k)
einsum("...d, ...kd->...k", h, vh),
einsum("...d, ...kd->...k", t, vt),
b,
)
)
).sum(dim=-1)
[docs]
def proje_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
d_e: FloatTensor,
d_r: FloatTensor,
b_c: FloatTensor,
b_p: FloatTensor,
activation: nn.Module,
) -> FloatTensor:
r"""Evaluate the ProjE interaction function.
.. math::
f(h, r, t) = g(t z(D_e h + D_r r + b_c) + b_p)
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param d_e: shape: (dim,)
Global entity projection.
:param d_r: shape: (dim,)
Global relation projection.
:param b_c: shape: (dim,)
Global combination bias.
:param b_p: shape: scalar
Final score bias
:param activation:
The activation function.
:return: shape: batch_dims
The scores.
"""
# global projections
h = einsum("...d, d -> ...d", h, d_e)
r = einsum("...d, d -> ...d", r, d_r)
# combination, shape: (*batch_dims, d)
x = activation(tensor_sum(h, r, b_c))
# dot product with t
return einsum("...d, ...d -> ...", x, t) + b_p
[docs]
def rescal_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
) -> FloatTensor:
"""Evaluate the RESCAL interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:return: shape: batch_dims
The scores.
"""
return einsum("...d,...de,...e->...", h, r, t)
[docs]
def simple_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
h_inv: FloatTensor,
r_inv: FloatTensor,
t_inv: FloatTensor,
clamp: tuple[float, float] | None = None,
) -> FloatTensor:
"""Evaluate the SimplE interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param h_inv: shape: (`*batch_dims`, dim)
The inverse head representations.
:param r_inv: shape: (`*batch_dims`, dim, dim)
The relation representations.
:param t_inv: shape: (`*batch_dims`, dim)
The tail representations.
:param clamp:
Clamp the scores to the given range.
:return: shape: batch_dims
The scores.
"""
scores = 0.5 * (tensor_product(h, r, t).sum(dim=-1) + tensor_product(h_inv, r_inv, t_inv).sum(dim=-1))
# Note: In the code in their repository, the score is clamped to [-20, 20].
# That is not mentioned in the paper, so it is made optional here.
if clamp:
min_, max_ = clamp
scores = scores.clamp(min=min_, max=max_)
return scores
[docs]
def se_interaction(
h: FloatTensor,
r_h: FloatTensor,
r_t: FloatTensor,
t: FloatTensor,
p: int,
power_norm: bool = False,
) -> FloatTensor:
r"""Evaluate the Structured Embedding interaction function.
.. math ::
f(h, r, t) = -\|R_h h - R_t t\|
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r_h: shape: (`*batch_dims`, rel_dim, dim)
The relation-specific head projection.
:param r_t: shape: (`*batch_dims`, rel_dim, dim)
The relation-specific tail projection.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param p:
The p for the norm. cf. :func:`torch.linalg.vector_norm`.
:param power_norm:
Whether to return the powered norm.
:return: shape: batch_dims
The scores.
"""
return negative_norm(
einsum("...rd,...d->...r", r_h, h) - einsum("...rd,...d->...r", r_t, t),
p=p,
power_norm=power_norm,
)
def toruse_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
p: int | str = 2,
power_norm: bool = False,
) -> FloatTensor:
"""Evaluate the TorusE interaction function from [ebisu2018].
.. note ::
This only implements the two L_p norm based variants.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param p:
The p for the norm.
:param power_norm:
Whether to return the powered norm.
:return: shape: batch_dims
The scores.
"""
d = tensor_sum(h, r, -t)
d = d - torch.floor(d)
d = torch.minimum(d, 1.0 - d)
return negative_norm(d, p=p, power_norm=power_norm)
[docs]
def transd_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
h_p: FloatTensor,
r_p: FloatTensor,
t_p: FloatTensor,
p: int,
power_norm: bool = False,
) -> FloatTensor:
"""Evaluate the TransD interaction function.
:param h: shape: (`*batch_dims`, d_e)
The head representations.
:param r: shape: (`*batch_dims`, d_r)
The relation representations.
:param t: shape: (`*batch_dims`, d_e)
The tail representations.
:param h_p: shape: (`*batch_dims`, d_e)
The head projections.
:param r_p: shape: (`*batch_dims`, d_r)
The relation projections.
:param t_p: shape: (`*batch_dims`, d_e)
The tail projections.
:param p:
The parameter p for selecting the norm.
:param power_norm:
Whether to return the powered norm instead.
:return: shape: batch_dims
The scores.
"""
# Project entities
h_bot = project_entity(
e=h,
e_p=h_p,
r_p=r_p,
)
t_bot = project_entity(
e=t,
e_p=t_p,
r_p=r_p,
)
return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm)
[docs]
def transe_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
p: int | str = 2,
power_norm: bool = False,
) -> FloatTensor:
"""Evaluate the TransE interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param p:
The p for the norm.
:param power_norm:
Whether to return the powered norm.
:return: shape: batch_dims
The scores.
"""
return negative_norm_of_sum(h, r, -t, p=p, power_norm=power_norm)
[docs]
def transf_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
) -> FloatTensor:
"""Evaluate the TransF interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param r: shape: (`*batch_dims`, dim)
The relation representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:return: shape: batch_dims
The scores.
"""
return batched_dot(h + r, t) + batched_dot(h, t - r)
[docs]
def transh_interaction(
h: FloatTensor,
w_r: FloatTensor,
d_r: FloatTensor,
t: FloatTensor,
p: int,
power_norm: bool = False,
) -> FloatTensor:
"""Evaluate the DistMult interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param w_r: shape: (`*batch_dims`, dim)
The relation normal vector representations.
:param d_r: shape: (`*batch_dims`, dim)
The relation difference vector representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param p:
The p for the norm. cf. :func:`torch.linalg.vector_norm`.
:param power_norm:
Whether to return $|x-y|_p^p$.
:return: shape: batch_dims
The scores.
"""
return negative_norm_of_sum(
# h projection to hyperplane
h,
-einsum("...i, ...i, ...j -> ...j", h, w_r, w_r),
# r
d_r,
# -t projection to hyperplane
-t,
einsum("...i, ...i, ...j -> ...j", t, w_r, w_r),
p=p,
power_norm=power_norm,
)
[docs]
def transr_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
m_r: FloatTensor,
p: int,
power_norm: bool = True,
) -> FloatTensor:
"""Evaluate the TransR interaction function.
:param h: shape: (`*batch_dims`, d_e)
Head embeddings.
:param r: shape: (`*batch_dims`, d_r)
Relation embeddings.
:param m_r: shape: (`*batch_dims`, d_e, d_r)
The relation specific linear transformations.
:param t: shape: (`*batch_dims`, d_e)
Tail embeddings.
:param p:
The parameter p for selecting the norm.
:param power_norm:
Whether to return the powered norm instead.
:return: shape: batch_dims
The scores.
"""
# project to relation specific subspace
h_bot = einsum("...e, ...er -> ...r", h, m_r)
t_bot = einsum("...e, ...er -> ...r", t, m_r)
# ensure constraints
h_bot = clamp_norm(h_bot, p=2, dim=-1, maxnorm=1.0)
t_bot = clamp_norm(t_bot, p=2, dim=-1, maxnorm=1.0)
return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm)
[docs]
def tucker_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
core_tensor: FloatTensor,
do_h: nn.Dropout,
do_r: nn.Dropout,
do_hr: nn.Dropout,
bn_h: nn.BatchNorm1d | None,
bn_hr: nn.BatchNorm1d | None,
) -> FloatTensor:
r"""Evaluate the TuckEr interaction function.
Compute scoring function W x_1 h x_2 r x_3 t as in the official implementation, i.e. as
.. math ::
DO_{hr}(BN_{hr}(DO_h(BN_h(h)) x_1 DO_r(W x_2 r))) x_3 t
where BN denotes BatchNorm and DO denotes Dropout
:param h: shape: (`*batch_dims`, d_e)
The head representations.
:param r: shape: (`*batch_dims`, d_r)
The relation representations.
:param t: shape: (`*batch_dims`, d_e)
The tail representations.
:param core_tensor: shape: (d_e, d_r, d_e)
The core tensor.
:param do_h:
The dropout layer for the head representations.
:param do_r:
The first hidden dropout.
:param do_hr:
The second hidden dropout.
:param bn_h:
The first batch normalization layer.
:param bn_hr:
The second batch normalization layer.
:return: shape: batch_dims
The scores.
"""
return (
_apply_optional_bn_to_tensor(
x=einsum(
# x_1 contraction
"...ik,...i->...k",
_apply_optional_bn_to_tensor(
x=einsum(
# x_2 contraction
"ijk,...j->...ik",
core_tensor,
r,
),
output_dropout=do_r,
),
_apply_optional_bn_to_tensor(
x=h,
batch_norm=bn_h,
output_dropout=do_h,
),
),
batch_norm=bn_hr,
output_dropout=do_hr,
)
* t
).sum(dim=-1)
[docs]
def mure_interaction(
h: FloatTensor,
b_h: FloatTensor,
r_vec: FloatTensor,
r_mat: FloatTensor,
t: FloatTensor,
b_t: FloatTensor,
p: int | float | str = 2,
power_norm: bool = False,
) -> FloatTensor:
r"""Evaluate the MuRE interaction function from [balazevic2019b]_.
.. math ::
-\|Rh + r - t\| + b_h + b_t
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param b_h: shape: batch_dims
The head entity bias.
:param r_vec: shape: (`*batch_dims`, dim)
The relation vector.
:param r_mat: shape: (`*batch_dims`, dim,)
The diagonal relation matrix.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param b_t: shape: batch_dims
The tail entity bias.
:param p:
The parameter p for selecting the norm, cf. :func:`torch.linalg.vector_norm`.
:param power_norm:
Whether to return the powered norm instead.
:return: shape: batch_dims
The scores.
"""
return (
negative_norm_of_sum(
h * r_mat,
r_vec,
-t,
p=p,
power_norm=power_norm,
)
+ b_h
+ b_t
)
[docs]
def um_interaction(
h: FloatTensor,
t: FloatTensor,
p: int,
power_norm: bool = True,
) -> FloatTensor:
"""Evaluate the SimplE interaction function.
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param p:
The parameter p for selecting the norm.
:param power_norm:
Whether to return the powered norm instead.
:return: shape: batch_dims
The scores.
"""
return negative_norm(h - t, p=p, power_norm=power_norm)
[docs]
def pair_re_interaction(
h: FloatTensor,
t: FloatTensor,
r_h: FloatTensor,
r_t: FloatTensor,
p: int | str = 2,
power_norm: bool = True,
) -> FloatTensor:
r"""Evaluate the PairRE interaction function.
.. math ::
-\|h \odot r_h - t \odot r_t \|
:param h: shape: (`*batch_dims`, dim)
The head representations.
:param t: shape: (`*batch_dims`, dim)
The tail representations.
:param r_h: shape: (`*batch_dims`, dim)
The head part of the relation representations.
:param r_t: shape: (`*batch_dims`, dim)
The tail part of the relation representations.
:param p:
The parameter p for selecting the norm.
:param power_norm:
Whether to return the powered norm instead.
:return: shape: batch_dims
The scores.
"""
return negative_norm_of_sum(
h * r_h,
-t * r_t,
p=p,
power_norm=power_norm,
)
def quat_e_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
table: FloatTensor,
):
"""Evaluate the interaction function of QuatE for given embeddings.
The embeddings have to be in a broadcastable shape.
:param h: shape: (`*batch_dims`, dim, 4)
The head representations.
:param r: shape: (`*batch_dims`, dim, 4)
The head representations.
:param t: shape: (`*batch_dims`, dim, 4)
The tail representations.
:param table:
the quaternion multiplication table.
:return: shape: (...)
The scores.
"""
# TODO: this sign is in the official code, too, but why do we need it?
return -einsum("...di, ...dj, ...dk, ijk -> ...", h, r, t, table)
[docs]
def triple_re_interaction(
# head
h: FloatTensor,
# relation
r_head: FloatTensor,
r_mid: FloatTensor,
r_tail: FloatTensor,
# tail
t: FloatTensor,
# version 2: relation factor offset
u: float | None = None,
# extension: negative (power) norm
p: int = 2,
power_norm: bool = False,
) -> FloatTensor:
r"""Evaluate the TripleRE interaction function.
.. seealso ::
:class:`pykeen.nn.modules.TripleREInteraction` for the stateful interaction module
:param h: shape: (`*batch_dims`, rank, dim)
The head representations.
:param r_head: shape: (`*batch_dims`, rank, dim)
The relation-specific head multiplicator representations.
:param r_mid: shape: (`*batch_dims`, rank, dim)
The relation representations.
:param r_tail: shape: (`*batch_dims`, rank, dim)
The relation-specific tail multiplicator representations.
:param t: shape: (`*batch_dims`, rank, dim)
The tail representations.
:param u:
the relation factor offset. If u is not None or 0, this corresponds to TripleREv2.
:param p:
The p for the norm. cf. :func:`torch.linalg.vector_norm`.
:param power_norm:
Whether to return the powered norm.
:return: shape: batch_dims
The scores.
"""
# note: normalization should be done from the representations
# cf. https://github.com/LongYu-360/TripleRE-Add-NodePiece/blob/994216dcb1d718318384368dd0135477f852c6a4/TripleRE%2BNodepiece/ogb_wikikg2/model.py#L317-L328 # noqa: E501
# version 2
if u is not None:
# r_head = r_head + u * torch.ones_like(r_head)
# r_tail = r_tail + u * torch.ones_like(r_tail)
r_head = r_head + u
r_tail = r_tail + u
return negative_norm_of_sum(
h * r_head,
-t * r_tail,
r_mid,
p=p,
power_norm=power_norm,
)
[docs]
def multilinear_tucker_interaction(
h: FloatTensor,
r: FloatTensor,
t: FloatTensor,
core_tensor: FloatTensor,
) -> FloatTensor:
r"""Evaluate the (original) multi-linear TuckEr interaction function.
.. math ::
score(h, r, t) = \sum W_{ijk} h_i r_j t_k
:param h: shape: (`*batch_dims`, d_e)
The head representations.
:param r: shape: (`*batch_dims`, d_r)
The relation representations.
:param t: shape: (`*batch_dims`, d_e)
The tail representations.
:param core_tensor: shape: (d_h, d_r, d_t)
The core tensor.
:return: shape: batch_dims
The scores.
"""
return einsum("ijk,...i,...j,...k->...", core_tensor, h, r, t)
[docs]
def linea_re_interaction(
# head
h: FloatTensor,
# relation
r_head: FloatTensor,
r_mid: FloatTensor,
r_tail: FloatTensor,
# tail
t: FloatTensor,
# extension: negative (power) norm
p: int = 2,
power_norm: bool = False,
) -> FloatTensor:
"""Evaluate the LineaRE interaction function.
.. note ::
the interaction is equivalent to TripleRE interaction without the `u` term.
:param h: shape: (`*batch_dims`, rank, dim)
The head representations.
:param r_head: shape: (`*batch_dims`, rank, dim)
The relation-specific head multiplicator representations.
:param r_mid: shape: (`*batch_dims`, rank, dim)
The relation representations.
:param r_tail: shape: (`*batch_dims`, rank, dim)
The relation-specific tail multiplicator representations.
:param t: shape: (`*batch_dims`, rank, dim)
The tail representations.
:param p:
The p for the norm. cf. :func:`negative_norm_of_sum`.
:param power_norm:
Whether to return the powered norm. cf. :func:`negative_norm_of_sum`.
:return: shape: batch_dims
The scores.
"""
return triple_re_interaction(
h=h, r_head=r_head, r_mid=r_mid, r_tail=r_tail, t=t, u=None, p=p, power_norm=power_norm
)