Source code for pykeen.nn.functional

# -*- coding: utf-8 -*-

"""Functional forms of interaction methods.

The functional forms always assume the general form of the interaction function, where head, relation and tail
representations are provided in shape (batch_size, num_heads, 1, 1, ``*``), (batch_size, 1, num_relations, 1, ``*``),
and (batch_size, 1, 1, num_tails, ``*``), and return a score tensor of shape
(batch_size, num_heads, num_relations, num_tails).
"""

from __future__ import annotations

import functools
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import numpy
import torch
import torch.fft
from torch import nn

from .compute_kernel import _complex_native_complex
from .sim import KG2E_SIMILARITIES
from ..typing import GaussianDistribution
from ..utils import (
    broadcast_cat, clamp_norm, estimate_cost_of_sequence, extended_einsum, is_cudnn_error, negative_norm,
    negative_norm_of_sum, project_entity, tensor_product, tensor_sum, view_complex,
)

__all__ = [
    'complex_interaction',
    'conve_interaction',
    'convkb_interaction',
    'distmult_interaction',
    'ermlp_interaction',
    'ermlpe_interaction',
    'hole_interaction',
    'kg2e_interaction',
    'ntn_interaction',
    'proje_interaction',
    'rescal_interaction',
    'rotate_interaction',
    'simple_interaction',
    'structured_embedding_interaction',
    'transd_interaction',
    'transe_interaction',
    'transh_interaction',
    'transr_interaction',
    'tucker_interaction',
    'unstructured_model_interaction',
]


@dataclass
class SizeInformation:
    """Size information of generic score function."""

    #: The batch size of the head representations.
    bh: int

    #: The number of head representations per batch
    nh: int

    #: The batch size of the relation representations.
    br: int

    #: The number of relation representations per batch
    nr: int

    #: The batch size of the tail representations.
    bt: int

    #: The number of tail representations per batch
    nt: int

    @property
    def same(self) -> bool:
        """Whether all representations have the same shape."""
        return (
            self.bh == self.br
            and self.bh == self.bt
            and self.nh == self.nr
            and self.nh == self.nt
        )

    @classmethod
    def extract(
        cls,
        h: torch.Tensor,
        r: torch.Tensor,
        t: torch.Tensor,
    ) -> SizeInformation:
        """Extract size information from tensors."""
        bh, nh = h.shape[:2]
        br, nr = r.shape[:2]
        bt, nt = t.shape[:2]
        return cls(bh=bh, nh=nh, br=br, nr=nr, bt=bt, nt=nt)


def _extract_sizes(
    h: torch.Tensor,
    r: torch.Tensor,
    t: torch.Tensor,
) -> Tuple[int, int, int, int, int]:
    """Extract size dimensions from head/relation/tail representations."""
    num_heads, num_relations, num_tails = [xx.shape[i] for i, xx in enumerate((h, r, t), start=1)]
    d_e = h.shape[-1]
    d_r = r.shape[-1]
    return num_heads, num_relations, num_tails, d_e, d_r


def _apply_optional_bn_to_tensor(
    x: torch.FloatTensor,
    output_dropout: nn.Dropout,
    batch_norm: Optional[nn.BatchNorm1d] = None,
) -> torch.FloatTensor:
    """Apply optional batch normalization and dropout layer. Supports multiple batch dimensions."""
    if batch_norm is not None:
        shape = x.shape
        x = x.reshape(-1, shape[-1])
        x = batch_norm(x)
        x = x.view(*shape)
    return output_dropout(x)


def _add_cuda_warning(func):
    @functools.wraps(func)
    def wrapped(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except RuntimeError as e:
            if not is_cudnn_error(e):
                raise e
            raise RuntimeError(
                '\nThis code crash might have been caused by a CUDA bug, see '
                'https://github.com/allenai/allennlp/issues/2888, '
                'which causes the code to crash during evaluation mode.\n'
                'To avoid this error, the batch size has to be reduced.',
            ) from e

    return wrapped


[docs]def complex_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: r"""Evaluate the ComplEx interaction function. .. math :: Re(\langle h, r, conj(t) \rangle) :param h: shape: (batch_size, num_heads, 1, 1, `2*dim`) The complex head representations. :param r: shape: (batch_size, 1, num_relations, 1, 2*dim) The complex relation representations. :param t: shape: (batch_size, 1, 1, num_tails, 2*dim) The complex tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return _complex_native_complex(h, r, t)
[docs]@_add_cuda_warning def conve_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, t_bias: torch.FloatTensor, input_channels: int, embedding_height: int, embedding_width: int, hr2d: nn.Module, hr1d: nn.Module, ) -> torch.FloatTensor: """Evaluate the ConvE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param t_bias: shape: (batch_size, 1, 1, num_tails, 1) The tail entity bias. :param input_channels: The number of input channels. :param embedding_height: The height of the reshaped embedding. :param embedding_width: The width of the reshaped embedding. :param hr2d: The first module, transforming the 2D stacked head-relation "image". :param hr1d: The second module, transforming the 1D flattened output of the 2D module. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # repeat if necessary, and concat head and relation, batch_size', num_input_channels, 2*height, width # with batch_size' = batch_size * num_heads * num_relations x = broadcast_cat( h.view(*h.shape[:-1], input_channels, embedding_height, embedding_width), r.view(*r.shape[:-1], input_channels, embedding_height, embedding_width), dim=-2, ).view(-1, input_channels, 2 * embedding_height, embedding_width) # batch_size', num_input_channels, 2*height, width x = hr2d(x) # batch_size', num_output_channels * (2 * height - kernel_height + 1) * (width - kernel_width + 1) x = x.view(-1, numpy.prod(x.shape[-3:])) x = hr1d(x) # reshape: (batch_size', embedding_dim) -> (b, h, r, 1, d) x = x.view(-1, h.shape[1], r.shape[2], 1, h.shape[-1]) # For efficient calculation, each of the convolved [h, r] rows has only to be multiplied with one t row # output_shape: (batch_size, num_heads, num_relations, num_tails) t = t.transpose(-1, -2) x = (x @ t).squeeze(dim=-2) # add bias term return x + t_bias.squeeze(dim=-1)
[docs]def convkb_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, conv: nn.Conv2d, activation: nn.Module, hidden_dropout: nn.Dropout, linear: nn.Linear, ) -> torch.FloatTensor: r"""Evaluate the ConvKB interaction function. .. math:: W_L drop(act(W_C \ast ([h; r; t]) + b_C)) + b_L :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param conv: The 3x1 convolution. :param activation: The activation function. :param hidden_dropout: The dropout layer applied to the hidden activations. :param linear: The final linear layer. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # decompose convolution for faster computation in 1-n case num_filters = conv.weight.shape[0] assert conv.weight.shape == (num_filters, 1, 1, 3) # compute conv(stack(h, r, t)) # prepare input shapes for broadcasting # (b, h, r, t, 1, d) h = h.unsqueeze(dim=-2) r = r.unsqueeze(dim=-2) t = t.unsqueeze(dim=-2) # conv.weight.shape = (C_out, C_in, kernel_size[0], kernel_size[1]) # here, kernel_size = (1, 3), C_in = 1, C_out = num_filters # -> conv_head, conv_rel, conv_tail shapes: (num_filters,) # reshape to (1, 1, 1, 1, f, 1) conv_head, conv_rel, conv_tail, conv_bias = [ c.view(1, 1, 1, 1, num_filters, 1) for c in list(conv.weight[:, 0, 0, :].t()) + [conv.bias] ] # convolve -> output.shape: (*, embedding_dim, num_filters) h = conv_head @ h r = conv_rel @ r t = conv_tail @ t x = tensor_sum(conv_bias, h, r, t) x = activation(x) # Apply dropout, cf. https://github.com/daiquocnguyen/ConvKB/blob/master/model.py#L54-L56 x = hidden_dropout(x) # Linear layer for final scores; use flattened representations, shape: (b, h, r, t, d * f) x = x.view(*x.shape[:-2], -1) x = linear(x) return x.squeeze(dim=-1)
[docs]def distmult_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the DistMult interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return tensor_product(h, r, t).sum(dim=-1)
[docs]def ermlp_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, hidden: nn.Linear, activation: nn.Module, final: nn.Linear, ) -> torch.FloatTensor: r"""Evaluate the ER-MLP interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param hidden: The first linear layer. :param activation: The activation function of the hidden layer. :param final: The second linear layer. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ sizes = SizeInformation.extract(h, r, t) # same shape if sizes.same: return final(activation( hidden(torch.cat([h, r, t], dim=-1).view(-1, 3 * h.shape[-1]))), ).view(sizes.bh, sizes.nh, sizes.nr, sizes.nt) hidden_dim = hidden.weight.shape[0] # split, shape: (embedding_dim, hidden_dim) head_to_hidden, rel_to_hidden, tail_to_hidden = hidden.weight.t().split(h.shape[-1]) bias = hidden.bias.view(1, 1, 1, 1, -1) h = h @ head_to_hidden.view(1, 1, 1, h.shape[-1], hidden_dim) r = r @ rel_to_hidden.view(1, 1, 1, r.shape[-1], hidden_dim) t = t @ tail_to_hidden.view(1, 1, 1, t.shape[-1], hidden_dim) return final(activation(tensor_sum(bias, h, r, t))).squeeze(dim=-1)
[docs]def ermlpe_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, mlp: nn.Module, ) -> torch.FloatTensor: r"""Evaluate the ER-MLPE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param mlp: The MLP. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # repeat if necessary, and concat head and relation, (batch_size, num_heads, num_relations, 1, 2 * embedding_dim) x = broadcast_cat(h, r, dim=-1) # Predict t embedding, shape: (b, h, r, 1, d) shape = x.shape x = mlp(x.view(-1, shape[-1])).view(*shape[:-1], -1) # transpose t, (b, 1, 1, d, t) t = t.transpose(-2, -1) # dot product, (b, h, r, 1, t) return (x @ t).squeeze(dim=-2)
[docs]def hole_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: # noqa: D102 """Evaluate the HolE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # Circular correlation of entity embeddings a_fft = torch.fft.rfft(h, dim=-1) b_fft = torch.fft.rfft(t, dim=-1) # complex conjugate a_fft = torch.conj(a_fft) # Hadamard product in frequency domain p_fft = a_fft * b_fft # inverse real FFT, shape: (b, h, 1, t, d) composite = torch.fft.irfft(p_fft, n=h.shape[-1], dim=-1) # transpose composite: (b, h, 1, d, t) composite = composite.transpose(-2, -1) # inner product with relation embedding return (r @ composite).squeeze(dim=-2)
[docs]def kg2e_interaction( h_mean: torch.FloatTensor, h_var: torch.FloatTensor, r_mean: torch.FloatTensor, r_var: torch.FloatTensor, t_mean: torch.FloatTensor, t_var: torch.FloatTensor, similarity: str = "KL", exact: bool = True, ) -> torch.FloatTensor: """Evaluate the KG2E interaction function. :param h_mean: shape: (batch_size, num_heads, 1, 1, d) The head entity distribution mean. :param h_var: shape: (batch_size, num_heads, 1, 1, d) The head entity distribution variance. :param r_mean: shape: (batch_size, 1, num_relations, 1, d) The relation distribution mean. :param r_var: shape: (batch_size, 1, num_relations, 1, d) The relation distribution variance. :param t_mean: shape: (batch_size, 1, 1, num_tails, d) The tail entity distribution mean. :param t_var: shape: (batch_size, 1, 1, num_tails, d) The tail entity distribution variance. :param similarity: The similarity measures for gaussian distributions. From {"KL", "EL"}. :param exact: Whether to leave out constants to accelerate similarity computation. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return KG2E_SIMILARITIES[similarity]( h=GaussianDistribution(mean=h_mean, diagonal_covariance=h_var), r=GaussianDistribution(mean=r_mean, diagonal_covariance=r_var), t=GaussianDistribution(mean=t_mean, diagonal_covariance=t_var), exact=exact, )
[docs]def ntn_interaction( h: torch.FloatTensor, t: torch.FloatTensor, w: torch.FloatTensor, vh: torch.FloatTensor, vt: torch.FloatTensor, b: torch.FloatTensor, u: torch.FloatTensor, activation: nn.Module, ) -> torch.FloatTensor: r"""Evaluate the NTN interaction function. .. math:: f(h,r,t) = u_r^T act(h W_r t + V_r h + V_r' t + b_r) :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param w: shape: (batch_size, 1, num_relations, 1, k, dim, dim) The relation specific transformation matrix W_r. :param vh: shape: (batch_size, 1, num_relations, 1, k, dim) The head transformation matrix V_h. :param vt: shape: (batch_size, 1, num_relations, 1, k, dim) The tail transformation matrix V_h. :param b: shape: (batch_size, 1, num_relations, 1, k) The relation specific offset b_r. :param u: shape: (batch_size, 1, num_relations, 1, k) The relation specific final linear transformation b_r. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param activation: The activation function. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ x = activation(tensor_sum( extended_einsum("bhrtd,bhrtkde,bhrte->bhrtk", h, w, t), (vh @ h.unsqueeze(dim=-1)).squeeze(dim=-1), (vt @ t.unsqueeze(dim=-1)).squeeze(dim=-1), b, )) u = u.transpose(-2, -1) return (x @ u).squeeze(dim=-1)
[docs]def proje_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, d_e: torch.FloatTensor, d_r: torch.FloatTensor, b_c: torch.FloatTensor, b_p: torch.FloatTensor, activation: nn.Module, ) -> torch.FloatTensor: r"""Evaluate the ProjE interaction function. .. math:: f(h, r, t) = g(t z(D_e h + D_r r + b_c) + b_p) :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param d_e: shape: (dim,) Global entity projection. :param d_r: shape: (dim,) Global relation projection. :param b_c: shape: (dim,) Global combination bias. :param b_p: shape: (1,) Final score bias :param activation: The activation function. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ num_heads, num_relations, num_tails, dim, _ = _extract_sizes(h, r, t) # global projections h = h * d_e.view(1, 1, 1, 1, dim) r = r * d_r.view(1, 1, 1, 1, dim) # combination, shape: (b, h, r, 1, d) x = tensor_sum(h, r, b_c) x = activation(x) # shape: (b, h, r, 1, d) # dot product with t, shape: (b, h, r, t) t = t.transpose(-2, -1) # shape: (b, 1, 1, d, t) return (x @ t).squeeze(dim=-2) + b_p
[docs]def rescal_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the RESCAL interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return extended_einsum("bhrtd,bhrtde,bhrte->bhrt", h, r, t)
[docs]def rotate_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, ) -> torch.FloatTensor: """Evaluate the RotatE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, 2*dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, 2*dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, 2*dim) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # r expresses a rotation in complex plane. h, r, t = [view_complex(x) for x in (h, r, t)] if estimate_cost_of_sequence(h.shape, r.shape) < estimate_cost_of_sequence(r.shape, t.shape): # rotate head by relation (=Hadamard product in complex space) h = h * r else: # rotate tail by inverse of relation # The inverse rotation is expressed by the complex conjugate of r. # The score is computed as the distance of the relation-rotated head to the tail. # Equivalently, we can rotate the tail by the inverse relation, and measure the distance to the head, i.e. # |h * r - t| = |h - conj(r) * t| t = t * torch.conj(r) # Workaround until https://github.com/pytorch/pytorch/issues/30704 is fixed return negative_norm(h - t, p=2, power_norm=False)
[docs]def simple_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, h_inv: torch.FloatTensor, r_inv: torch.FloatTensor, t_inv: torch.FloatTensor, clamp: Optional[Tuple[float, float]] = None, ) -> torch.FloatTensor: """Evaluate the SimplE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param h_inv: shape: (batch_size, num_heads, 1, 1, dim) The inverse head representations. :param r_inv: shape: (batch_size, 1, num_relations, 1, dim, dim) The relation representations. :param t_inv: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param clamp: Clamp the scores to the given range. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ scores = 0.5 * (distmult_interaction(h=h, r=r, t=t) + distmult_interaction(h=h_inv, r=r_inv, t=t_inv)) # Note: In the code in their repository, the score is clamped to [-20, 20]. # That is not mentioned in the paper, so it is made optional here. if clamp: min_, max_ = clamp scores = scores.clamp(min=min_, max=max_) return scores
[docs]def structured_embedding_interaction( h: torch.FloatTensor, r_h: torch.FloatTensor, r_t: torch.FloatTensor, t: torch.FloatTensor, p: int, power_norm: bool = False, ) -> torch.FloatTensor: r"""Evaluate the Structured Embedding interaction function. .. math :: f(h, r, t) = -\|R_h h - R_t t\| :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r_h: shape: (batch_size, 1, num_relations, 1, rel_dim, dim) The relation-specific head projection. :param r_t: shape: (batch_size, 1, num_relations, 1, rel_dim, dim) The relation-specific tail projection. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param p: The p for the norm. cf. torch.norm. :param power_norm: Whether to return the powered norm. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return negative_norm( (r_h @ h.unsqueeze(dim=-1) - r_t @ t.unsqueeze(dim=-1)).squeeze(dim=-1), p=p, power_norm=power_norm, )
[docs]def transd_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, h_p: torch.FloatTensor, r_p: torch.FloatTensor, t_p: torch.FloatTensor, p: int, power_norm: bool = False, ) -> torch.FloatTensor: """Evaluate the TransD interaction function. :param h: shape: (batch_size, num_heads, 1, 1, d_e) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, d_r) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, d_e) The tail representations. :param h_p: shape: (batch_size, num_heads, 1, 1, d_e) The head projections. :param r_p: shape: (batch_size, 1, num_relations, 1, d_r) The relation projections. :param t_p: shape: (batch_size, 1, 1, num_tails, d_e) The tail projections. :param p: The parameter p for selecting the norm. :param power_norm: Whether to return the powered norm instead. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # Project entities h_bot = project_entity( e=h, e_p=h_p, r_p=r_p, ) t_bot = project_entity( e=t, e_p=t_p, r_p=r_p, ) return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm)
[docs]def transe_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, p: Union[int, str] = 2, power_norm: bool = False, ) -> torch.FloatTensor: """Evaluate the TransE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, dim) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param p: The p for the norm. :param power_norm: Whether to return the powered norm. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return negative_norm_of_sum(h, r, -t, p=p, power_norm=power_norm)
[docs]def transh_interaction( h: torch.FloatTensor, w_r: torch.FloatTensor, d_r: torch.FloatTensor, t: torch.FloatTensor, p: int, power_norm: bool = False, ) -> torch.FloatTensor: """Evaluate the DistMult interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param w_r: shape: (batch_size, 1, num_relations, 1, dim) The relation normal vector representations. :param d_r: shape: (batch_size, 1, num_relations, 1, dim) The relation difference vector representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param p: The p for the norm. cf. torch.norm. :param power_norm: Whether to return $|x-y|_p^p$. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return negative_norm_of_sum( # h projection to hyperplane h, -(h * w_r).sum(dim=-1, keepdims=True) * w_r, # r d_r, # -t projection to hyperplane -t, (t * w_r).sum(dim=-1, keepdims=True) * w_r, p=p, power_norm=power_norm, )
[docs]def transr_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, m_r: torch.FloatTensor, p: int, power_norm: bool = True, ) -> torch.FloatTensor: """Evaluate the TransR interaction function. :param h: shape: (batch_size, num_heads, 1, 1, d_e) Head embeddings. :param r: shape: (batch_size, 1, num_relations, 1, d_r) Relation embeddings. :param m_r: shape: (batch_size, 1, num_relations, 1, d_e, d_r) The relation specific linear transformations. :param t: shape: (batch_size, 1, 1, num_tails, d_e) Tail embeddings. :param p: The parameter p for selecting the norm. :param power_norm: Whether to return the powered norm instead. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ # project to relation specific subspace and ensure constraints h_bot = clamp_norm((h.unsqueeze(dim=-2) @ m_r), p=2, dim=-1, maxnorm=1.).squeeze(dim=-2) t_bot = clamp_norm((t.unsqueeze(dim=-2) @ m_r), p=2, dim=-1, maxnorm=1.).squeeze(dim=-2) return negative_norm_of_sum(h_bot, r, -t_bot, p=p, power_norm=power_norm)
[docs]def tucker_interaction( h: torch.FloatTensor, r: torch.FloatTensor, t: torch.FloatTensor, core_tensor: torch.FloatTensor, do_h: nn.Dropout, do_r: nn.Dropout, do_hr: nn.Dropout, bn_h: Optional[nn.BatchNorm1d], bn_hr: Optional[nn.BatchNorm1d], ) -> torch.FloatTensor: r"""Evaluate the TuckEr interaction function. Compute scoring function W x_1 h x_2 r x_3 t as in the official implementation, i.e. as .. math :: DO_{hr}(BN_{hr}(DO_h(BN_h(h)) x_1 DO_r(W x_2 r))) x_3 t where BN denotes BatchNorm and DO denotes Dropout :param h: shape: (batch_size, num_heads, 1, 1, d_e) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, d_r) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, d_e) The tail representations. :param core_tensor: shape: (d_e, d_r, d_e) The core tensor. :param do_h: The dropout layer for the head representations. :param do_r: The first hidden dropout. :param do_hr: The second hidden dropout. :param bn_h: The first batch normalization layer. :param bn_hr: The second batch normalization layer. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return extended_einsum( # x_3 contraction "bhrtk,bhrtk->bhrt", _apply_optional_bn_to_tensor( x=extended_einsum( # x_1 contraction "bhrtik,bhrti->bhrtk", _apply_optional_bn_to_tensor( x=extended_einsum( # x_2 contraction "ijk,bhrtj->bhrtik", core_tensor, r, ), output_dropout=do_r, ), _apply_optional_bn_to_tensor( x=h, batch_norm=bn_h, output_dropout=do_h, )), batch_norm=bn_hr, output_dropout=do_hr, ), t, )
[docs]def unstructured_model_interaction( h: torch.FloatTensor, t: torch.FloatTensor, p: int, power_norm: bool = True, ) -> torch.FloatTensor: """Evaluate the SimplE interaction function. :param h: shape: (batch_size, num_heads, 1, 1, dim) The head representations. :param t: shape: (batch_size, 1, 1, num_tails, dim) The tail representations. :param p: The parameter p for selecting the norm. :param power_norm: Whether to return the powered norm instead. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return negative_norm(h - t, p=p, power_norm=power_norm)