# -*- coding: utf-8 -*-
"""Implementation of combinations for the :class:`pykeen.models.LiteralModel`."""
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional
import torch
from class_resolver import HintOrType
from torch import nn
from ..utils import activation_resolver, combine_complex, split_complex
__all__ = [
'Combination',
'RealCombination',
'ParameterizedRealCombination',
'ComplexCombination',
'ParameterizedComplexCombination',
# Concrete classes
'LinearDropout',
'DistMultCombination',
'ComplExLiteralCombination',
]
[docs]class Combination(nn.Module, ABC):
"""Base class for combinations."""
[docs] def forward(self, x: torch.FloatTensor, literal: torch.FloatTensor) -> torch.FloatTensor:
"""Combine the representation and literal then score."""
raise NotImplementedError
[docs]class RealCombination(Combination, ABC):
"""A mid-level base class for combinations of real-valued vectors."""
[docs] def forward(self, x: torch.FloatTensor, literal: torch.FloatTensor) -> torch.FloatTensor:
"""Combine the entity representation and literal, then score."""
return self.score(torch.cat([x, literal], dim=-1))
[docs] @abstractmethod
def score(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined entity representation and literals."""
raise NotImplementedError
[docs]class ParameterizedRealCombination(RealCombination):
"""A real combination parametrized by a scoring module."""
def __init__(self, module: nn.Module):
"""Initialize the parameterized real combination.
:param module: The module used to score the combination of the entity representation and literals.
"""
super().__init__()
self.module = module
[docs] def score(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined entity representation and literals with the parameterized module."""
return self.module(x)
[docs]class ComplexCombination(Combination, ABC):
"""A mid-level base class for combinations of complex-valued vectors."""
[docs] def forward(self, x: torch.FloatTensor, literal: torch.FloatTensor) -> torch.FloatTensor:
"""Split the complex vector, combine the representation parts and literal, score, then recombine."""
x_re, x_im = split_complex(x)
x_re = self.score_real(torch.cat([x_re, literal], dim=-1))
x_im = self.score_imag(torch.cat([x_im, literal], dim=-1))
return combine_complex(x_re=x_re, x_im=x_im)
[docs] @abstractmethod
def score_real(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined real part of the entity representation and literals."""
raise NotImplementedError
[docs] @abstractmethod
def score_imag(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined imaginary part of the entity representation and literals."""
raise NotImplementedError
[docs]class ParameterizedComplexCombination(ComplexCombination):
"""A complex combination parametrized by the real scoring module and imaginary soring module."""
def __init__(self, real_module: nn.Module, imag_module: nn.Module):
"""Initialize the parameterized complex combination.
:param real_module: The module used to score the combination of the real part of the entity representation
and literals.
:param imag_module: The module used to score the combination of the imaginary part of the entity
representation and literals.
"""
super().__init__()
self.real_mod = real_module
self.imag_mod = imag_module
[docs] def score_real(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined real part of the entity representation and literals with the parameterized module."""
return self.real_mod(x)
[docs] def score_imag(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Score the combined imaginary part of the entity representation and literals with the parameterized module."""
return self.imag_mod(x)
[docs]class LinearDropout(nn.Sequential):
"""A sequential module that has a linear layer, dropout later, and optional activation layer."""
def __init__(
self,
entity_embedding_dim: int,
literal_embedding_dim: int,
input_dropout: float = 0.0,
activation: HintOrType[nn.Module] = None,
activation_kwargs: Optional[Mapping[str, Any]] = None,
) -> None:
"""Instantiate the :class:`torch.nn.Sequential`.
:param entity_embedding_dim: The dimension of the entity representations to which literals are concatenated
:param literal_embedding_dim: The dimension of the literals that are concatenated
:param input_dropout: The dropout probability of an element to be zeroed.
:param activation: An optional, pre-instantiated activation module, like :class:`torch.nn.Tanh`.
"""
linear = nn.Linear(entity_embedding_dim + literal_embedding_dim, entity_embedding_dim)
dropout = nn.Dropout(input_dropout)
if activation:
activation_instance = activation_resolver.make(activation, activation_kwargs)
super().__init__(linear, dropout, activation_instance)
else:
super().__init__(linear, dropout)
[docs]class DistMultCombination(ParameterizedRealCombination):
"""The linear/dropout combination used in :class:`pykeen.models.DistMultLiteral`."""
def __init__(
self,
entity_embedding_dim: int,
literal_embedding_dim: int,
input_dropout: float = 0.0,
) -> None:
"""Instantiate the :class:`ParameterizedRealCombination` with a :class:`LinearDropout`.
:param entity_embedding_dim: The dimension of the entity representations to which literals are concatenated
:param literal_embedding_dim: The dimension of the literals that are concatenated
:param input_dropout: The dropout probability of an element to be zeroed.
This class does not use an activation in the :class:`LinearDropout` as
described by [kristiadi2018]_.
"""
super().__init__(LinearDropout(
entity_embedding_dim=entity_embedding_dim,
literal_embedding_dim=literal_embedding_dim,
input_dropout=input_dropout,
))
[docs]class ComplExLiteralCombination(ParameterizedComplexCombination):
"""The linear/dropout/tanh combination used in :class:`pykeen.models.ComplExLiteral`."""
def __init__(
self,
entity_embedding_dim: int,
literal_embedding_dim: int,
input_dropout: float = 0.0,
activation: HintOrType[nn.Module] = 'tanh',
) -> None:
"""Instantiate the :class:`ParameterizedComplexCombination` with a :class:`LinearDropout` for real and complex.
:param entity_embedding_dim: The dimension of the entity representations to which literals are concatenated
:param literal_embedding_dim: The dimension of the literals that are concatenated
:param input_dropout: The dropout probability of an element to be zeroed.
:param activation: The activation function, resolved by :data:`pykeen.utils.activation_resolver`.
This class uses a :class:`torch.nn.Tanh` by default for the activation to the :class:`LinearDropout` as
described by [kristiadi2018]_.
"""
super().__init__(
real_module=LinearDropout(
entity_embedding_dim=entity_embedding_dim,
literal_embedding_dim=literal_embedding_dim,
input_dropout=input_dropout,
activation=activation,
),
imag_module=LinearDropout(
entity_embedding_dim=entity_embedding_dim,
literal_embedding_dim=literal_embedding_dim,
input_dropout=input_dropout,
activation=activation,
),
)