"""Embedding weight initialization routines."""
import functools
import logging
import math
from collections.abc import Sequence
from typing import Any, Optional, cast
import numpy as np
import torch
import torch.nn
import torch.nn.init
import torch_ppr.utils
from class_resolver import FunctionResolver, Hint, HintOrType, OptionalKwargs
from more_itertools import last
from torch.nn import functional
from .text import TextEncoder, text_encoder_resolver
from .utils import iter_matrix_power, safe_diagonal
from ..triples import CoreTriplesFactory, TriplesFactory
from ..typing import FloatTensor, Initializer, LongTensor, MappedTriples, OneOrSequence
from ..utils import compose, get_edge_index, iter_weisfeiler_lehman, upgrade_to_sequence
__all__ = [
"xavier_uniform_",
"xavier_uniform_norm_",
"xavier_normal_",
"xavier_normal_norm_",
"uniform_norm_",
"uniform_norm_p1_",
"normal_norm_",
"init_phases",
# Classes
"PretrainedInitializer",
"LabelBasedInitializer",
"WeisfeilerLehmanInitializer",
"RandomWalkPositionalEncodingInitializer",
# Resolver
"initializer_resolver",
]
logger = logging.getLogger(__name__)
[docs]
def xavier_normal_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
r"""Initialize weights of the tensor similarly to Glorot/Xavier initialization.
Proceed as if it was a linear layer with `fan_in` of zero, `fan_out` of `prod(tensor.shape[1:])` and Xavier Normal
initialization is used, i.e. fill the weight of input `tensor` with values
sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_out}}}
Example usage:
>>> import torch, pykeen.nn.init
>>> w = torch.empty(3, 5)
>>> pykeen.nn.init.xavier_normal_(w, gain=torch.nn.init.calculate_gain("relu"))
.. seealso::
:func:`torch.nn.init.xavier_normal_`
:param tensor:
a tensor to initialize
:param gain:
an optional scaling factor, defaults to 1.0.
:return:
tensor with weights by this initializer.
"""
fan_out = np.prod(tensor.shape[1:])
std = gain * math.sqrt(2.0 / float(fan_out))
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
[docs]
def init_phases(x: torch.Tensor) -> torch.Tensor:
r"""
Generate random phases between 0 and :math:`2\pi`.
.. note::
This method works on the canonical torch real representation of complex tensors, cf.
https://pytorch.org/docs/stable/complex_numbers.html
:param x:
a tensor to initialize
:return:
tensor with weights set by this initializer
.. seealso ::
:func:`torch.view_as_real`
"""
# backwards compatibility
if x.shape[-1] != 2:
new_shape = (*x.shape[:-1], -1, 2)
logger.warning(
f"The input tensor shape, {tuple(x.shape)}, does not comply with the canonical complex tensor shape, "
f"(..., 2), cf. https://pytorch.org/docs/stable/complex_numbers.html. We'll try to reshape to {new_shape}",
)
x = x.view(*new_shape)
phases = 2 * np.pi * torch.rand_like(torch.view_as_complex(x).real)
return torch.view_as_real(torch.complex(real=phases.cos(), imag=phases.sin())).detach()
xavier_uniform_norm_ = compose(
torch.nn.init.xavier_uniform_,
functional.normalize,
name="xavier_uniform_norm_",
)
xavier_normal_norm_ = compose(
torch.nn.init.xavier_normal_,
functional.normalize,
name="xavier_normal_norm_",
)
uniform_norm_ = compose(
torch.nn.init.uniform_,
functional.normalize,
name="uniform_norm_",
)
normal_norm_ = compose(
torch.nn.init.normal_,
functional.normalize,
name="normal_norm_",
)
uniform_norm_p1_: Initializer = cast(
Initializer,
compose(
torch.nn.init.uniform_,
functools.partial(functional.normalize, p=1),
name="uniform_norm_p1_",
),
)
def init_quaternions(x: FloatTensor) -> FloatTensor:
"""
Initialize quaternion.
:param x: shape: (..., d, 4)
the quaternions
:raises ValueError:
if the shape's last dimension is not 4.
:return: shape: (..., d, 4)
uniform quaternions
"""
if x.ndim < 2 or x.shape[-1] != 4:
raise ValueError(f"shape must be (..., 4) but is {x.shape}.")
*shape, dim = x.shape[:-1]
device = x.device
num_elements = math.prod(shape)
# scaling factor
s = 1.0 / math.sqrt(2 * num_elements)
# modulus ~ Uniform[-s, s]
modulus = 2 * s * torch.rand(num_elements, dim, device=device) - s
# phase ~ Uniform[0, 2*pi]
phase = 2 * math.pi * torch.rand(num_elements, dim, device=device)
# real part
real = (modulus * phase.cos()).unsqueeze(dim=-1)
# purely imaginary quaternions unitary
imag = torch.rand(num_elements, dim, 3, device=device)
imag = functional.normalize(imag, p=2, dim=-1)
imag = imag * (modulus * phase.sin()).unsqueeze(dim=-1)
return torch.cat([real, imag], dim=-1)
[docs]
class PretrainedInitializer:
"""
Initialize tensor with pretrained weights.
Example usage:
.. code-block::
import torch
from pykeen.pipeline import pipeline
from pykeen.nn.init import PretrainedInitializer
# this is usually loaded from somewhere else
# the shape must match, as well as the entity-to-id mapping
pretrained_embedding_tensor = torch.rand(14, 128)
result = pipeline(
dataset="nations",
model="transe",
model_kwargs=dict(
embedding_dim=pretrained_embedding_tensor.shape[-1],
entity_initializer=PretrainedInitializer(tensor=pretrained_embedding_tensor),
),
)
"""
def __init__(self, tensor: FloatTensor) -> None:
"""
Initialize the initializer.
:param tensor:
the tensor of pretrained embeddings.
"""
self.tensor = tensor
[docs]
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""Initialize the tensor with the given tensor."""
if x.shape != self.tensor.shape:
raise ValueError(f"shape does not match: expected {self.tensor.shape} but got {x.shape}")
return self.tensor.to(device=x.device, dtype=x.dtype)
[docs]
def as_embedding(self, **kwargs: Any):
"""Get a static embedding from this pre-trained initializer.
:param kwargs: Keyword arguments to pass to :class:`pykeen.nn.representation.Embedding`
:returns: An embedding
:rtype: pykeen.nn.representation.Embedding
"""
from .representation import Embedding
max_id, *shape = self.tensor.shape
return Embedding(max_id=max_id, shape=shape, initializer=self, trainable=False, **kwargs)
[docs]
class LabelBasedInitializer(PretrainedInitializer):
"""
An initializer using pretrained models from the `transformers` library to encode labels.
Example Usage:
Initialize entity representations as Transformer encodings of their labels. Afterwards,
the parameters are detached from the labels, and trained on the KGE task without any
further connection to the Transformer model.
.. code-block :: python
from pykeen.datasets import get_dataset
from pykeen.nn.init import LabelBasedInitializer
from pykeen.models import ERMLPE
dataset = get_dataset(dataset="nations")
entity_initializer = LabelBasedInitializer.from_triples_factory(
triples_factory=dataset.training,
encoder="transformer",
)
model = ERMLPE(
triples_factory=dataset.training,
embedding_dim=entity_initializer.tensor.shape[-1], # 768 for BERT base
entity_initializer=entity_initializer,
# note: we explicitly need to provide a relation initializer here,
# since ERMLPE shares initializers between entities and relations by default
relation_initializer="uniform",
)
"""
def __init__(
self,
labels: Sequence[str],
encoder: HintOrType[TextEncoder] = None,
encoder_kwargs: OptionalKwargs = None,
batch_size: Optional[int] = None,
):
"""
Initialize the initializer.
:param labels:
the labels
:param encoder:
the text encoder to use, cf. `text_encoder_resolver`
:param encoder_kwargs:
additional keyword-based parameters passed to the encoder
:param batch_size: >0
the (maximum) batch size to use while encoding. If None, use `len(labels)`, i.e., only a single batch.
"""
super().__init__(
tensor=text_encoder_resolver.make(encoder, encoder_kwargs)
.encode_all(
labels=labels,
batch_size=batch_size,
)
# must be cloned if we want to do backprop
.clone(),
)
[docs]
@classmethod
def from_triples_factory(
cls,
triples_factory: TriplesFactory,
for_entities: bool = True,
**kwargs,
) -> "LabelBasedInitializer":
"""
Prepare a label-based initializer with labels from a triples factory.
:param triples_factory:
the triples factory
:param for_entities:
whether to create the initializer for entities (or relations)
:param kwargs:
additional keyword-based arguments passed to :func:`LabelBasedInitializer.__init__`
:returns:
A label-based initializer
:raise ImportError:
if the transformers library could not be imported
"""
id_to_label = triples_factory.entity_id_to_label if for_entities else triples_factory.relation_id_to_label
labels = [id_to_label[i] for i in sorted(id_to_label.keys())]
return cls(
labels=labels,
**kwargs,
)
[docs]
class WeisfeilerLehmanInitializer(PretrainedInitializer):
"""An initializer based on an encoding of categorical colors from the Weisfeiler-Lehman algorithm."""
def __init__(
self,
*,
# the color initializer
color_initializer: Hint[Initializer] = None,
color_initializer_kwargs: OptionalKwargs = None,
shape: OneOrSequence[int] = 32,
# variants for the edge index
edge_index: Optional[LongTensor] = None,
num_entities: Optional[int] = None,
mapped_triples: Optional[LongTensor] = None,
triples_factory: Optional[CoreTriplesFactory] = None,
# additional parameters for iter_weisfeiler_lehman
**kwargs,
) -> None:
"""
Initialize the initializer.
:param color_initializer:
the initializer for initialization color representations, or a hint thereof
:param color_initializer_kwargs:
additional keyword-based parameters for the color initializer
:param shape:
the shape to use for the color representations
:param edge_index: shape: (2, m)
the edge index
:param num_entities:
the number of entities. can be inferred
:param mapped_triples: shape: (m, 3)
the Id-based triples
:param triples_factory:
the triples factory
:param kwargs:
additional keyword-based parameters passed to :func:`pykeen.utils.iter_weisfeiler_lehman`
"""
# normalize shape
shape = upgrade_to_sequence(shape)
# get coloring
colors = last(
iter_weisfeiler_lehman(
edge_index=get_edge_index(
triples_factory=triples_factory, mapped_triples=mapped_triples, edge_index=edge_index
),
num_nodes=num_entities,
**kwargs,
)
)
# make color initializer
color_initializer = initializer_resolver.make(color_initializer, pos_kwargs=color_initializer_kwargs)
# initialize color representations
num_colors = colors.max().item() + 1
# note: this could be a representation?
color_representation = color_initializer(colors.new_empty(num_colors, *shape, dtype=torch.get_default_dtype()))
# init entity representations according to the color
super().__init__(tensor=color_representation[colors])
[docs]
class RandomWalkPositionalEncodingInitializer(PretrainedInitializer):
r"""
Initialize nodes via random-walk positional encoding.
The random walk positional encoding is given as
.. math::
\mathbf{x}_i = [\mathbf{R}_{i, i}, \mathbf{R}^{2}_{i, i}, \ldots, \mathbf{R}^{d}_{i, i}] \in \mathbb{R}^{d}
where $\mathbf{R} := \mathbf{A}\mathbf{D}^{-1}$ is the random walk matrix, with
$\mathbf{D} := \sum_i \mathbf{A}_{i, i}$.
.. seealso::
https://arxiv.org/abs/2110.07875
"""
def __init__(
self,
*,
triples_factory: Optional[CoreTriplesFactory] = None,
mapped_triples: Optional[MappedTriples] = None,
edge_index: Optional[torch.Tensor] = None,
dim: int,
num_entities: Optional[int] = None,
space_dim: int = 0,
skip_first_power: bool = True,
) -> None:
"""
Initialize the positional encoding.
One of `triples_factory`, `mapped_triples` or `edge_index` will be used.
The preference order is:
1. `triples_factory`
2. `mapped_triples`
3. `edge_index`
:param triples_factory:
the triples factory
:param mapped_triples: shape: `(m, 3)`
the mapped triples
:param edge_index: shape: `(2, m)`
the edge index
:param dim:
the dimensionality
:param num_entities:
the number of entities. If `None`, it will be inferred from `edge_index`
:param space_dim:
estimated dimensionality of the space. Used to
correct the random-walk diagonal by a factor `k^(space_dim/2)`.
In euclidean space, this correction means that the height of
the gaussian distribution stays almost constant across the number of
steps, if `space_dim` is the dimension of the euclidean space.
:param skip_first_power:
in most cases the adjacencies diagonal values will be zeros (since reflexive edges are not that common).
This flag enables skipping the first matrix power.
"""
edge_index = get_edge_index(
triples_factory=triples_factory, mapped_triples=mapped_triples, edge_index=edge_index
)
# create random walk matrix
rw = torch_ppr.utils.prepare_page_rank_adjacency(edge_index=edge_index, num_nodes=num_entities)
# TODO replace iter_matrix_power and safe_diagonal with torch_ppr functions?
# stack diagonal entries of powers of rw
tensor = torch.stack(
[
(i ** (space_dim / 2.0)) * safe_diagonal(matrix=power)
for i, power in enumerate(iter_matrix_power(matrix=rw, max_iter=dim), start=1)
if not skip_first_power or i > 1
],
dim=-1,
)
super().__init__(tensor=tensor)
# TODO: replace by automatically generated list
#: A resolver for initializers, including elements from :mod:`pykeen.nn.init`
#:
#: - :func:`pykeen.nn.init.init_phases`
#: - :func:`pykeen.nn.init.init_quaternions`
#: - :func:`pykeen.nn.init.normal_norm_`
#: - :func:`pykeen.nn.init.uniform_norm_`
#: - :func:`pykeen.nn.init.xavier_uniform_`
#: - :func:`pykeen.nn.init.xavier_uniform_norm`
#: - :func:`pykeen.nn.init.xavier_normal_`
#: - :func:`pykeen.nn.init.xavier_normal_norm_`
#:
#: as well as initializers from :mod:`torch.nn.init`.
initializer_resolver: FunctionResolver[Initializer] = FunctionResolver(
[
getattr(torch.nn.init, func)
for func in dir(torch.nn.init)
if not func.startswith("_") and func.endswith("_") and func not in {"xavier_normal_", "xavier_uniform_"}
],
default=torch.nn.init.normal_,
location="pykeen.nn.init.initializer_resolver",
)
for func in [
xavier_normal_,
xavier_uniform_,
init_phases,
init_quaternions,
xavier_normal_norm_,
xavier_uniform_norm_,
normal_norm_,
uniform_norm_,
]:
initializer_resolver.register(func)