"""Implementation of combinations for the :class:`pykeen.models.LiteralModel`."""
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any
import torch
from class_resolver import (
ClassResolver,
Hint,
HintOrType,
OptionalKwargs,
ResolverKey,
update_docstring_with_resolver_keys,
)
from class_resolver.contrib.torch import activation_resolver, aggregation_resolver
from torch import nn
from ..typing import FloatTensor
from ..utils import ExtraReprMixin, combine_complex, split_complex
__all__ = [
"Combination",
"combination_resolver",
# Concrete classes
"ComplexSeparatedCombination",
"ConcatCombination",
"ConcatAggregationCombination",
"ConcatProjectionCombination",
"GatedCombination",
]
logger = logging.getLogger(__name__)
[docs]
class Combination(nn.Module, ExtraReprMixin, ABC):
"""Base class for combinations."""
[docs]
@abstractmethod
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor:
"""
Combine a sequence of individual representations.
:param xs: shape: `(*batch_dims, *input_dims_i)`
the individual representations
:return: shape: `(*batch_dims, *output_dims)`
a combined representation
"""
raise NotImplementedError
[docs]
def output_shape(self, input_shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...]:
"""
Calculate the output shape for the given input shapes.
.. note ::
this method runs a single forward pass if no symbolic computation is available.
:param input_shapes:
the input shapes without the batch dimensions
:return:
the output shape
"""
logger.warning("No symbolic computation of output shape.")
return self(xs=[torch.empty(size=shape) for shape in input_shapes]).shape
[docs]
class ConcatCombination(Combination):
"""Combine representation by concatenation."""
def __init__(self, dim: int = -1) -> None:
"""
Initialize the combination.
:param dim:
the concatenation dimension
"""
super().__init__()
self.dim = dim
# docstr-coverage: inherited
[docs]
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor: # noqa: D102
return torch.cat(xs, dim=self.dim)
# docstr-coverage: inherited
[docs]
class ConcatProjectionCombination(ConcatCombination):
"""Combine representations by concatenation follow by a linear projection and activation."""
def __init__(
self,
input_dims: Sequence[int],
output_dim: int | None = None,
bias: bool = True,
dropout: float = 0.0,
activation: HintOrType[nn.Module] = nn.Identity,
activation_kwargs: OptionalKwargs = None,
) -> None:
"""
Initialize the combination.
:param input_dims:
the input dimensions
:param output_dim:
the output dimension. Defaults to the first input dimension
:param bias:
whether to add a bias term in between the linear projection and the activation
:param dropout:
dropout to use before the activation
:param activation:
the activation, or a hint thereof
:param activation_kwargs:
additional keyword-based parameters used to instantiate the activation
:raises ValueError:
if `input_dims` is empty
"""
super().__init__()
if not input_dims:
raise ValueError("Cannot provide empty input dimensions")
output_dim = output_dim or input_dims[0]
self.projection = nn.Sequential(
nn.Linear(sum(input_dims), output_dim, bias=bias),
nn.Dropout(dropout),
activation_resolver.make(activation, activation_kwargs),
)
# docstr-coverage: inherited
[docs]
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor: # noqa: D102
return self.projection(super().forward(xs))
[docs]
class ConcatAggregationCombination(ConcatCombination):
"""Combine representation by concatenation followed by an aggregation along the same axis."""
@update_docstring_with_resolver_keys(
ResolverKey("aggregation", resolver="class_resolver.contrib.torch.aggregation_resolver"),
)
def __init__(
self,
aggregation: Hint[Callable[[FloatTensor], FloatTensor]] = None,
aggregation_kwargs: OptionalKwargs = None,
dim: int = -1,
) -> None:
"""
Initialize the combination.
:param aggregation:
The aggregation, or a hint thereof.
:param aggregation_kwargs:
Additional keyword-based parameters.
:param dim:
the concatenation and reduction dimension.
"""
super().__init__(dim=dim)
self.dim = dim
self.aggregation = aggregation_resolver.make(aggregation, aggregation_kwargs)
# docstr-coverage: inherited
[docs]
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor: # noqa: D102
return self.aggregation(super().forward(xs=xs), dim=self.dim)
# docstr-coverage: inherited
[docs]
class ComplexSeparatedCombination(Combination):
"""A combination for mixed complex & real representations."""
def __init__(
self,
combination: HintOrType[Combination] = None,
combination_kwargs: OptionalKwargs = None,
imag_combination: HintOrType[Combination] = None,
imag_combination_kwargs: OptionalKwargs = None,
):
"""
Initialize the combination.
.. note ::
if non-instantiated combinations are passed, separate instances will be created for real and imaginary parts
:param combination:
the real combination, or a hint thereof
:param combination_kwargs:
keyword-based parameters for the real combination
:param imag_combination:
the imaginary combination, or a hint thereof. If None, use combination for both.
:param imag_combination_kwargs:
keyword-based parameters for the imaginary combination; only used if imag_combination is not None
"""
super().__init__()
# input normalization
if imag_combination is None:
imag_combination, imag_combination_kwargs = combination, combination_kwargs
# instantiate separate combinations
self.real_combination = combination_resolver.make(combination, combination_kwargs)
self.imag_combination = combination_resolver.make(imag_combination, imag_combination_kwargs)
# docstr-coverage: inherited
[docs]
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor: # noqa: D102
if not any(x.is_complex() for x in xs):
raise ValueError(
f"{self.__class__} is a combination for complex representations, but none of the inputs was of "
f"complex data type."
)
# split complex; repeat real
xs_real, xs_imag = list(zip(*(split_complex(x) if x.is_complex() else (x, x) for x in xs), strict=False))
# separately combine real and imaginary parts
x_re = self.real_combination(xs_real)
x_im = self.imag_combination(xs_imag)
# combine
return combine_complex(x_re=x_re, x_im=x_im)
# docstr-coverage: inherited
[docs]
def output_shape(self, input_shapes: Sequence[tuple[int, ...]]) -> tuple[int, ...]: # noqa: D102
# symbolic output to avoid dtype issue
# we only need to consider real part here
return self.real_combination.output_shape(input_shapes=input_shapes)
[docs]
class GatedCombination(Combination):
r"""A module that implements a gated linear transformation for the combination of entities and literals.
Compared to the other Combinations, this combination makes use of a gating mechanism commonly found in RNNs.
The main goal of this gating mechanism is to learn which parts of the additional literal information is
useful or not and act accordingly, by incorporating them into the new combined embedding or discarding them.
For given entity representation $\mathbf{x}_e \in \mathbb{R}^{d_e}$ and literal representation
$\mathbf{x}_l \in \mathbb{R}^{d_l}$, the module calculates
.. math ::
z = f_{gate}(\mathbf{W}_e x_e + \mathbf{W}_l x_l + \mathbf{b})
h = f_{hidden}(\mathbf{W} [x_e; x_l])
y = Dropout(z \odot h + (1 - z) \odot x)
where $\mathbf{W}_e \in \mathbb{R}^{d_e \times d_e}$,$\mathbf{W}_l \in \mathbb{R}^{d_l \times d_e}$,
$\mathbf{W} \in \mathbb{R}^{(d_e + d_l) \ times d_e}$, and $\mathbf{b} \in \mathbb{R}^{d_e}$ are trainable
parameters, $f_{gate}$ and $f_{hidden}$ are activation functions, defaulting to sigmoid and tanh, $\odot$ denotes
the element-wise multiplication, and $[x_e; x_l]$ the concatenation operation.
.. note ::
We can alternatively express the gate
.. math ::
z = f_{gate}(\mathbf{W}_e x_e + \mathbf{W}_l x_l + \mathbf{b})
as
.. math ::
z = f_{gate}(\mathbf{W}_{el} [x_e; x_l] + \mathbf{b})
with $\mathbf{W}_{el} \in \mathbb{R}^{(d_e + d_l) \times d_e}$.
Implementation based on https://github.com/SmartDataAnalytics/LiteralE/blob/master/model.py Gate class.
"""
def __init__(
self,
entity_dim: int = 32,
literal_dim: int | None = None,
input_dropout: float = 0.0,
gate_activation: HintOrType[nn.Module] = nn.Sigmoid,
gate_activation_kwargs: Mapping[str, Any] | None = None,
hidden_activation: HintOrType[nn.Module] = nn.Tanh,
hidden_activation_kwargs: Mapping[str, Any] | None = None,
) -> None:
"""Instantiate the module.
:param entity_dim:
the dimension of the entity representations.
:param literal_dim:
the dimension of the literals; defaults to entity_dim
:param input_dropout:
the dropout to use
:param gate_activation:
the activation to use on the gate, or a hint thereof
:param gate_activation_kwargs:
the keyword arguments to be used to instantiate the `gate_activation` if
a class or name is given instead of a pre-instantiated activation module
:param hidden_activation:
the activation to use in the hidden layer, or a hint thereof
:param hidden_activation_kwargs:
the keyword arguments to be used to instantiate the hidden activation if
a class or name is given instead of a pre-instantiated activation module
"""
super().__init__()
literal_dim = literal_dim or entity_dim
self.dropout = nn.Dropout(input_dropout)
# the gate
self.gate = ConcatProjectionCombination(
input_dims=[entity_dim, literal_dim],
output_dim=entity_dim,
bias=True,
activation=gate_activation,
activation_kwargs=gate_activation_kwargs,
)
# the combination
self.combination = ConcatProjectionCombination(
input_dims=[entity_dim, literal_dim],
output_dim=entity_dim,
bias=True,
activation=hidden_activation,
activation_kwargs=hidden_activation_kwargs,
)
# docstr-coverage: inherited
[docs]
def forward(self, xs: Sequence[FloatTensor]) -> FloatTensor: # noqa: D102
assert len(xs) == 2
z = self.gate(xs)
h = self.combination(xs)
return self.dropout(z * h + (1 - z) * xs[0])
#: Resolve combinations
combination_resolver: ClassResolver[Combination] = ClassResolver.from_subclasses(
base=Combination,
default=ConcatCombination,
location="pykeen.nn.combination.combination_resolver",
)