Source code for pykeen.nn.modules

# -*- coding: utf-8 -*-

"""Stateful interaction functions."""

from __future__ import annotations

import logging
import math
from abc import ABC, abstractmethod
from typing import (
    Any, Callable, Generic, Mapping, MutableMapping, Optional, Sequence, Tuple, Union,

import torch
from torch import FloatTensor, nn

from . import functional as pkf
from ..typing import HeadRepresentation, RelationRepresentation, TailRepresentation
from ..utils import CANONICAL_DIMENSIONS, convert_to_canonical_shape, ensure_tuple, upgrade_to_sequence

__all__ = [
    # Base Classes
    # Concrete Classes

logger = logging.getLogger(__name__)

def _get_batches(z, slice_size):
    for batch in zip(*(hh.split(slice_size, dim=1) for hh in ensure_tuple(z)[0])):
        if len(batch) == 1:
            batch = batch[0]
        yield batch

[docs]class Interaction(nn.Module, Generic[HeadRepresentation, RelationRepresentation, TailRepresentation], ABC): """Base class for interaction functions.""" #: The symbolic shapes for entity representations entity_shape: Sequence[str] = ("d",) #: The symbolic shapes for entity representations for tail entities, if different. This is ony relevant for ConvE. tail_entity_shape: Optional[Sequence[str]] = None #: The symbolic shapes for relation representations relation_shape: Sequence[str] = ("d",)
[docs] @abstractmethod def forward( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> torch.FloatTensor: """Compute broadcasted triple scores given broadcasted representations for head, relation and tails. :param h: shape: (batch_size, num_heads, 1, 1, ``*``) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, ``*``) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, ``*``) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """
[docs] def score( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, slice_size: Optional[int] = None, slice_dim: Optional[str] = None, ) -> torch.FloatTensor: """Compute broadcasted triple scores with optional slicing. .. note :: At most one of the slice sizes may be not None. :param h: shape: (batch_size, num_heads, `1, 1, `*``) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, ``*``) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, ``*``) The tail representations. :param slice_size: The slice size. :param slice_dim: The dimension along which to slice. From {"h", "r", "t"} :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return self._forward_slicing_wrapper(h=h, r=r, t=t, slice_size=slice_size, slice_dim=slice_dim)
def _score( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, slice_size: Optional[int] = None, slice_dim: str = None, ) -> torch.FloatTensor: """Compute scores for the score_* methods outside of models. TODO: merge this with the Model utilities? :param h: shape: (b, h, *) :param r: shape: (b, r, *) :param t: shape: (b, t, *) :param slice_size: The slice size. :param slice_dim: The dimension along which to slice. From {"h", "r", "t"} :return: shape: (b, h, r, t) """ args = [] for key, x in zip("hrt", (h, r, t)): value = [] for xx in upgrade_to_sequence(x): # type: torch.FloatTensor # bring to (b, n, *) xx = xx.unsqueeze(dim=1 if key != slice_dim else 0) # bring to (b, h, r, t, *) xx = convert_to_canonical_shape( x=xx, dim=key, num=xx.shape[1], batch_size=xx.shape[0], suffix_shape=xx.shape[2:], ) value.append(xx) # unpack singleton if len(value) == 1: value = value[0] args.append(value) h, r, t = cast(Tuple[HeadRepresentation, RelationRepresentation, TailRepresentation], args) return self._forward_slicing_wrapper(h=h, r=r, t=t, slice_dim=slice_dim, slice_size=slice_size) def _forward_slicing_wrapper( self, h: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], r: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], t: Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]], slice_size: Optional[int], slice_dim: Optional[str], ) -> torch.FloatTensor: """Compute broadcasted triple scores with optional slicing for representations in canonical shape. .. note :: Depending on the interaction function, there may be more than one representation for h/r/t. In that case, a tuple of at least two tensors is passed. :param h: shape: (batch_size, num_heads, 1, 1, ``*``) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, ``*``) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, ``*``) The tail representations. :param slice_size: The slice size. :param slice_dim: The dimension along which to slice. From {"h", "r", "t"} :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. :raises ValueError: If slice_dim is invalid. """ if slice_size is None: scores = self(h=h, r=r, t=t) elif slice_dim == "h": scores =[ self(h=h_batch, r=r, t=t) for h_batch in _get_batches(h, slice_size) ], dim=CANONICAL_DIMENSIONS[slice_dim]) elif slice_dim == "r": scores =[ self(h=h, r=r_batch, t=t) for r_batch in _get_batches(r, slice_size) ], dim=CANONICAL_DIMENSIONS[slice_dim]) elif slice_dim == "t": scores =[ self(h=h, r=r, t=t_batch) for t_batch in _get_batches(t, slice_size) ], dim=CANONICAL_DIMENSIONS[slice_dim]) else: raise ValueError(f'Invalid slice_dim: {slice_dim}') return scores
[docs] def score_hrt( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> torch.FloatTensor: """Score a batch of triples. :param h: shape: (batch_size, d_e) The head representations. :param r: shape: (batch_size, d_r) The relation representations. :param t: shape: (batch_size, d_e) The tail representations. :return: shape: (batch_size, 1) The scores. """ return self._score(h=h, r=r, t=t)[:, 0, 0, 0, None]
[docs] def score_h( self, all_entities: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, slice_size: Optional[int] = None, ) -> torch.FloatTensor: """Score all head entities. :param all_entities: shape: (num_entities, d_e) The head representations. :param r: shape: (batch_size, d_r) The relation representations. :param t: shape: (batch_size, d_e) The tail representations. :param slice_size: The slice size. :return: shape: (batch_size, num_entities) The scores. """ return self._score(h=all_entities, r=r, t=t, slice_dim="h", slice_size=slice_size)[:, :, 0, 0]
[docs] def score_r( self, h: HeadRepresentation, all_relations: RelationRepresentation, t: TailRepresentation, slice_size: Optional[int] = None, ) -> torch.FloatTensor: """Score all relations. :param h: shape: (batch_size, d_e) The head representations. :param all_relations: shape: (num_relations, d_r) The relation representations. :param t: shape: (batch_size, d_e) The tail representations. :param slice_size: The slice size. :return: shape: (batch_size, num_entities) The scores. """ return self._score(h=h, r=all_relations, t=t, slice_dim="r", slice_size=slice_size)[:, 0, :, 0]
[docs] def score_t( self, h: HeadRepresentation, r: RelationRepresentation, all_entities: TailRepresentation, slice_size: Optional[int] = None, ) -> torch.FloatTensor: """Score all tail entities. :param h: shape: (batch_size, d_e) The head representations. :param r: shape: (batch_size, d_r) The relation representations. :param all_entities: shape: (num_entities, d_e) The tail representations. :param slice_size: The slice size. :return: shape: (batch_size, num_entities) The scores. """ return self._score(h=h, r=r, t=all_entities, slice_dim="t", slice_size=slice_size)[:, 0, 0, :]
[docs] def reset_parameters(self): """Reset parameters the interaction function may have.""" for mod in self.modules(): if mod is self: continue if hasattr(mod, 'reset_parameters'): mod.reset_parameters()
[docs]class FunctionalInteraction(Interaction, Generic[HeadRepresentation, RelationRepresentation, TailRepresentation]): """Base class for interaction functions.""" #: The functional interaction form func: Callable[..., torch.FloatTensor]
[docs] def forward( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> torch.FloatTensor: """Compute broadcasted triple scores given broadcasted representations for head, relation and tails. :param h: shape: (batch_size, num_heads, 1, 1, ``*``) The head representations. :param r: shape: (batch_size, 1, num_relations, 1, ``*``) The relation representations. :param t: shape: (batch_size, 1, 1, num_tails, ``*``) The tail representations. :return: shape: (batch_size, num_heads, num_relations, num_tails) The scores. """ return self.__class__.func(**self._prepare_for_functional(h=h, r=r, t=t))
def _prepare_for_functional( self, h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> Mapping[str, torch.FloatTensor]: """Conversion utility to prepare the arguments for the functional form.""" kwargs = self._prepare_hrt_for_functional(h=h, r=r, t=t) kwargs.update(self._prepare_state_for_functional()) return kwargs @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: """Conversion utility to prepare the h/r/t representations for the functional form.""" assert all(torch.is_tensor(x) for x in (h, r, t)) return dict(h=h, r=r, t=t) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: """Conversion utility to prepare the state to be passed to the functional form.""" return dict()
[docs]class TranslationalInteraction( FunctionalInteraction, Generic[HeadRepresentation, RelationRepresentation, TailRepresentation], ABC, ): """The translational interaction function shared by the TransE, TransR, TransH, and other Trans<X> models.""" def __init__(self, p: int, power_norm: bool = False): """Initialize the translational interaction function. :param p: The norm used with :func:`torch.norm`. Typically is 1 or 2. :param power_norm: Whether to use the p-th power of the L_p norm. It has the advantage of being differentiable around 0, and numerically more stable. """ super().__init__() self.p = p self.power_norm = power_norm def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict(p=self.p, power_norm=self.power_norm)
[docs]class TransEInteraction(TranslationalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the TransE interaction function. .. seealso:: :func:`pykeen.nn.functional.transe_interaction` """ func = pkf.transe_interaction
[docs]class ComplExInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A module wrapper for the stateless ComplEx interaction function. .. seealso:: :func:`pykeen.nn.functional.complex_interaction` """ func = pkf.complex_interaction
def _calculate_missing_shape_information( embedding_dim: int, input_channels: Optional[int] = None, width: Optional[int] = None, height: Optional[int] = None, ) -> Tuple[int, int, int]: """Automatically calculates missing dimensions for ConvE. :param embedding_dim: The embedding dimension. :param input_channels: The number of input channels for the convolution. :param width: The width of the embedding "image". :param height: The height of the embedding "image". :return: (input_channels, width, height), such that `embedding_dim = input_channels * width * height` :raises ValueError: If no factorization could be found. """ # Store initial input for error message original = (input_channels, width, height) # All are None -> try and make closest to square if input_channels is None and width is None and height is None: input_channels = 1 result_sqrt = math.floor(math.sqrt(embedding_dim)) height = max(factor for factor in range(1, result_sqrt + 1) if embedding_dim % factor == 0) width = embedding_dim // height # Only input channels is None elif input_channels is None and width is not None and height is not None: input_channels = embedding_dim // (width * height) # Only width is None elif input_channels is not None and width is None and height is not None: width = embedding_dim // (height * input_channels) # Only height is none elif height is None and width is not None and input_channels is not None: height = embedding_dim // (width * input_channels) # Width and input_channels are None -> set input_channels to 1 and calculage height elif input_channels is None and height is None and width is not None: input_channels = 1 height = embedding_dim // width # Width and input channels are None -> set input channels to 1 and calculate width elif input_channels is None and height is not None and width is None: input_channels = 1 width = embedding_dim // height if input_channels * width * height != embedding_dim: # type: ignore raise ValueError(f'Could not resolve {original} to a valid factorization of {embedding_dim}.') return input_channels, width, height # type: ignore
[docs]class ConvEInteraction( FunctionalInteraction[torch.FloatTensor, torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor]], ): """A stateful module for the ConvE interaction function. .. seealso:: :func:`pykeen.nn.functional.conve_interaction` """ tail_entity_shape = ("d", "k") # with k=1 #: The head-relation encoder operating on 2D "images" hr2d: nn.Module #: The head-relation encoder operating on the 1D flattened version hr1d: nn.Module #: The interaction function func = pkf.conve_interaction def __init__( self, input_channels: Optional[int] = None, output_channels: int = 32, embedding_height: Optional[int] = None, embedding_width: Optional[int] = None, kernel_height: int = 3, kernel_width: int = 3, input_dropout: float = 0.2, output_dropout: float = 0.3, feature_map_dropout: float = 0.2, embedding_dim: int = 200, apply_batch_normalization: bool = True, ): super().__init__() # Automatic calculation of remaining dimensions'Resolving {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') if embedding_dim is None: embedding_dim = input_channels * embedding_width * embedding_height # Parameter need to fulfil: # input_channels * embedding_height * embedding_width = embedding_dim input_channels, embedding_width, embedding_height = _calculate_missing_shape_information( embedding_dim=embedding_dim, input_channels=input_channels, width=embedding_width, height=embedding_height, )'Resolved to {input_channels} * {embedding_width} * {embedding_height} = {embedding_dim}.') if input_channels * embedding_height * embedding_width != embedding_dim: raise ValueError( f'Product of input channels ({input_channels}), height ({embedding_height}), and width ' f'({embedding_width}) does not equal target embedding dimension ({embedding_dim})', ) # encoders # 1: 2D encoder: BN?, DO, Conv, BN?, Act, DO hr2d_layers = [ nn.BatchNorm2d(input_channels) if apply_batch_normalization else None, nn.Dropout(input_dropout), nn.Conv2d( in_channels=input_channels, out_channels=output_channels, kernel_size=(kernel_height, kernel_width), stride=1, padding=0, bias=True, ), nn.BatchNorm2d(output_channels) if apply_batch_normalization else None, nn.ReLU(), nn.Dropout2d(feature_map_dropout), ] self.hr2d = nn.Sequential(*(layer for layer in hr2d_layers if layer is not None)) # 2: 1D encoder: FC, DO, BN?, Act num_in_features = ( output_channels * (2 * embedding_height - kernel_height + 1) * (embedding_width - kernel_width + 1) ) hr1d_layers = [ nn.Linear(num_in_features, embedding_dim), nn.Dropout(output_dropout), nn.BatchNorm1d(embedding_dim) if apply_batch_normalization else None, nn.ReLU(), ] self.hr1d = nn.Sequential(*(layer for layer in hr1d_layers if layer is not None)) # store reshaping dimensions self.embedding_height = embedding_height self.embedding_width = embedding_width self.input_channels = input_channels @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h, r=r, t=t[0], t_bias=t[1]) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict( input_channels=self.input_channels, embedding_height=self.embedding_height, embedding_width=self.embedding_width, hr2d=self.hr2d, hr1d=self.hr1d, )
[docs]class ConvKBInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the ConvKB interaction function. .. seealso:: :func:`pykeen.nn.functional.convkb_interaction`` """ func = pkf.convkb_interaction def __init__( self, hidden_dropout_rate: float = 0., embedding_dim: int = 200, num_filters: int = 400, ): super().__init__() self.embedding_dim = embedding_dim self.num_filters = num_filters # The interaction model self.conv = nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(1, 3), bias=True) self.activation = nn.ReLU() self.hidden_dropout = nn.Dropout(p=hidden_dropout_rate) self.linear = nn.Linear(embedding_dim * num_filters, 1, bias=True)
[docs] def reset_parameters(self): # noqa: D102 # Use Xavier initialization for weight; bias to zero nn.init.xavier_uniform_(self.linear.weight, gain=nn.init.calculate_gain('relu')) nn.init.zeros_(self.linear.bias) # Initialize all filters to [0.1, 0.1, -0.1], # c.f. nn.init.constant_(self.conv.weight[..., :2], 0.1) nn.init.constant_(self.conv.weight[..., 2], -0.1) nn.init.zeros_(self.conv.bias)
def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict( conv=self.conv, activation=self.activation, hidden_dropout=self.hidden_dropout, linear=self.linear, )
[docs]class DistMultInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A module wrapper for the stateless DistMult interaction function. .. seealso:: :func:`pykeen.nn.functional.distmult_interaction` """ func = pkf.distmult_interaction
[docs]class ERMLPInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the ER-MLP interaction. .. seealso:: :func:`pykeen.nn.functional.ermlp_interaction` .. math :: f(h, r, t) = W_2 ReLU(W_1 cat(h, r, t) + b_1) + b_2 """ func = pkf.ermlp_interaction def __init__( self, embedding_dim: int, hidden_dim: int, ): """Initialize the interaction function. :param embedding_dim: The embedding vector dimension. :param hidden_dim: The hidden dimension of the MLP. """ super().__init__() """The multi-layer perceptron consisting of an input layer with 3 * self.embedding_dim neurons, a hidden layer with self.embedding_dim neurons and output layer with one neuron. The input is represented by the concatenation embeddings of the heads, relations and tail embeddings. """ self.hidden = nn.Linear(in_features=3 * embedding_dim, out_features=hidden_dim, bias=True) self.activation = nn.ReLU() self.hidden_to_score = nn.Linear(in_features=hidden_dim, out_features=1, bias=True) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict( hidden=self.hidden, activation=self.activation, final=self.hidden_to_score, )
[docs] def reset_parameters(self): # noqa: D102 # Initialize biases with zero nn.init.zeros_(self.hidden.bias) nn.init.zeros_(self.hidden_to_score.bias) # In the original formulation, nn.init.xavier_uniform_(self.hidden.weight) nn.init.xavier_uniform_( self.hidden_to_score.weight, gain=nn.init.calculate_gain(self.activation.__class__.__name__.lower()), )
[docs]class ERMLPEInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the ER-MLP (E) interaction function. .. seealso:: :func:`pykeen.nn.functional.ermlpe_interaction` """ func = pkf.ermlpe_interaction def __init__( self, hidden_dim: int = 300, input_dropout: float = 0.2, hidden_dropout: float = 0.3, embedding_dim: int = 200, ): super().__init__() self.mlp = nn.Sequential( nn.Dropout(input_dropout), nn.Linear(2 * embedding_dim, hidden_dim), nn.Dropout(hidden_dropout), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, embedding_dim), nn.Dropout(hidden_dropout), nn.BatchNorm1d(embedding_dim), nn.ReLU(), ) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict(mlp=self.mlp)
[docs]class TransRInteraction( TranslationalInteraction[ torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor, ], ): """A stateful module for the TransR interaction function. .. seealso:: :func:`pykeen.nn.functional.transr_interaction` """ relation_shape = ("e", "de") func = pkf.transr_interaction def __init__(self, p: int, power_norm: bool = True): super().__init__(p=p, power_norm=power_norm) @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h, r=r[0], t=t, m_r=r[1])
[docs]class RotatEInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A module wrapper for the stateless RotatE interaction function. .. seealso:: :func:`pykeen.nn.functional.rotate_interaction` """ func = pkf.rotate_interaction
[docs]class HolEInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A module wrapper for the stateless HolE interaction function. .. seealso:: :func:`pykeen.nn.functional.hole_interaction` """ func = pkf.hole_interaction
[docs]class ProjEInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the ProjE interaction function. .. seealso:: :func:`pykeen.nn.functional.proje_interaction` """ func = pkf.proje_interaction def __init__( self, embedding_dim: int = 50, inner_non_linearity: Optional[nn.Module] = None, ): super().__init__() # Global entity projection self.d_e = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) # Global relation projection self.d_r = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) # Global combination bias self.b_c = nn.Parameter(torch.empty(embedding_dim), requires_grad=True) # Global combination bias self.b_p = nn.Parameter(torch.empty(1), requires_grad=True) if inner_non_linearity is None: inner_non_linearity = nn.Tanh() self.inner_non_linearity = inner_non_linearity
[docs] def reset_parameters(self): # noqa: D102 embedding_dim = self.d_e.shape[0] bound = math.sqrt(6) / embedding_dim for p in self.parameters(): nn.init.uniform_(p, a=-bound, b=bound)
def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: return dict(d_e=self.d_e, d_r=self.d_r, b_c=self.b_c, b_p=self.b_p, activation=self.inner_non_linearity)
[docs]class RESCALInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A module wrapper for the stateless RESCAL interaction function. .. seealso:: :func:`pykeen.nn.functional.rescal_interaction` """ relation_shape = ("dd",) func = pkf.rescal_interaction
[docs]class StructuredEmbeddingInteraction( TranslationalInteraction[ torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor], torch.FloatTensor, ], ): """A stateful module for the Structured Embedding (SE) interaction function. .. seealso:: :func:`pykeen.nn.functional.structured_embedding_interaction` """ relation_shape = ("dd", "dd") func = pkf.structured_embedding_interaction @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h, t=t, r_h=r[0], r_t=r[1])
[docs]class TuckerInteraction(FunctionalInteraction[FloatTensor, FloatTensor, FloatTensor]): """A stateful module for the stateless Tucker interaction function. .. seealso:: :func:`pykeen.nn.functional.tucker_interaction` """ func = pkf.tucker_interaction def __init__( self, embedding_dim: int = 200, relation_dim: Optional[int] = None, head_dropout: float = 0.3, relation_dropout: float = 0.4, head_relation_dropout: float = 0.5, apply_batch_normalization: bool = True, ): """Initialize the Tucker interaction function. :param embedding_dim: The entity embedding dimension. :param relation_dim: The relation embedding dimension. :param head_dropout: The dropout rate applied to the head representations. :param relation_dropout: The dropout rate applied to the relation representations. :param head_relation_dropout: The dropout rate applied to the combined head and relation representations. :param apply_batch_normalization: Whether to use batch normalization on head representations and the combination of head and relation. """ super().__init__() if relation_dim is None: relation_dim = embedding_dim # Core tensor # Note: we use a different dimension permutation as in the official implementation to match the paper. self.core_tensor = nn.Parameter( torch.empty(embedding_dim, relation_dim, embedding_dim), requires_grad=True, ) # Dropout self.head_dropout = nn.Dropout(head_dropout) self.relation_dropout = nn.Dropout(relation_dropout) self.head_relation_dropout = nn.Dropout(head_relation_dropout) if apply_batch_normalization: self.head_batch_norm = nn.BatchNorm1d(embedding_dim) self.head_relation_batch_norm = nn.BatchNorm1d(embedding_dim) else: self.head_batch_norm = self.head_relation_batch_norm = None
[docs] def reset_parameters(self): # noqa:D102 # Initialize core tensor, cf. nn.init.uniform_(self.core_tensor, -1., 1.)
# batch norm gets reset automatically, since it defines reset_parameters def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: return dict( core_tensor=self.core_tensor, do_h=self.head_dropout, do_r=self.relation_dropout, do_hr=self.head_relation_dropout, bn_h=self.head_batch_norm, bn_hr=self.head_relation_batch_norm, )
[docs]class UnstructuredModelInteraction( TranslationalInteraction[torch.FloatTensor, None, torch.FloatTensor], ): """A stateful module for the UnstructuredModel interaction function. .. seealso:: :func:`pykeen.nn.functional.unstructured_model_interaction` """ # shapes relation_shape: Sequence[str] = tuple() func = pkf.unstructured_model_interaction def __init__(self, p: int, power_norm: bool = True): super().__init__(p=p, power_norm=power_norm) @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h, t=t)
[docs]class TransDInteraction( TranslationalInteraction[ Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], ], ): """A stateful module for the TransD interaction function. .. seealso:: :func:`pykeen.nn.functional.transd_interaction` """ entity_shape = ("d", "d") relation_shape = ("e", "e") func = pkf.transd_interaction def __init__(self, p: int = 2, power_norm: bool = True): super().__init__(p=p, power_norm=power_norm) @staticmethod def _prepare_hrt_for_functional( h: Tuple[torch.FloatTensor, torch.FloatTensor], r: Tuple[torch.FloatTensor, torch.FloatTensor], t: Tuple[torch.FloatTensor, torch.FloatTensor], ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 h, h_p = h r, r_p = r t, t_p = t return dict(h=h, r=r, t=t, h_p=h_p, r_p=r_p, t_p=t_p)
[docs]class NTNInteraction( FunctionalInteraction[ torch.FloatTensor, Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor], torch.FloatTensor, ], ): """A stateful module for the NTN interaction function. .. seealso:: :func:`pykeen.nn.functional.ntn_interaction` """ relation_shape = ("kdd", "kd", "kd", "k", "k") func = pkf.ntn_interaction def __init__(self, non_linearity: Optional[nn.Module] = None): super().__init__() if non_linearity is None: non_linearity = nn.Tanh() self.non_linearity = non_linearity @staticmethod def _prepare_hrt_for_functional( h: torch.FloatTensor, r: Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor], t: torch.FloatTensor, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 w, vh, vt, b, u = r return dict(h=h, t=t, w=w, b=b, u=u, vh=vh, vt=vt) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict(activation=self.non_linearity)
[docs]class KG2EInteraction( FunctionalInteraction[ Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], ], ): """A stateful module for the KG2E interaction function. .. seealso:: :func:`pykeen.nn.functional.kg2e_interaction` """ entity_shape = ("d", "d") relation_shape = ("d", "d") similarity: str exact: bool func = pkf.kg2e_interaction def __init__(self, similarity: Optional[str] = None, exact: bool = True): super().__init__() if similarity is None: similarity = 'KL' self.similarity = similarity self.exact = exact @staticmethod def _prepare_hrt_for_functional( h: Tuple[torch.FloatTensor, torch.FloatTensor], r: Tuple[torch.FloatTensor, torch.FloatTensor], t: Tuple[torch.FloatTensor, torch.FloatTensor], ) -> MutableMapping[str, torch.FloatTensor]: h_mean, h_var = h r_mean, r_var = r t_mean, t_var = t return dict( h_mean=h_mean, h_var=h_var, r_mean=r_mean, r_var=r_var, t_mean=t_mean, t_var=t_var, ) def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: return dict( similarity=self.similarity, exact=self.exact, )
[docs]class TransHInteraction(TranslationalInteraction[FloatTensor, Tuple[FloatTensor, FloatTensor], FloatTensor]): """A stateful module for the TransH interaction function. .. seealso:: :func:`pykeen.nn.functional.transh_interaction` """ relation_shape = ("d", "d") func = pkf.transh_interaction @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h, w_r=r[0], d_r=r[1], t=t)
[docs]class SimplEInteraction( FunctionalInteraction[ Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], Tuple[torch.FloatTensor, torch.FloatTensor], ], ): """A module wrapper for the SimplE interaction function. .. seealso:: :func:`pykeen.nn.functional.simple_interaction` """ func = pkf.simple_interaction entity_shape = ("d", "d") relation_shape = ("d", "d") def __init__(self, clamp_score: Union[None, float, Tuple[float, float]] = None): super().__init__() if isinstance(clamp_score, float): clamp_score = (-clamp_score, clamp_score) self.clamp_score = clamp_score def _prepare_state_for_functional(self) -> MutableMapping[str, Any]: # noqa: D102 return dict(clamp=self.clamp_score) @staticmethod def _prepare_hrt_for_functional( h: HeadRepresentation, r: RelationRepresentation, t: TailRepresentation, ) -> MutableMapping[str, torch.FloatTensor]: # noqa: D102 return dict(h=h[0], h_inv=h[1], r=r[0], r_inv=r[1], t=t[0], t_inv=t[1])