Source code for pykeen.nn.emb

# -*- coding: utf-8 -*-

"""Embedding modules."""

import functools
from typing import Any, Mapping, Optional

import torch
import torch.nn
from torch import nn

__all__ = [

from pykeen.typing import Constrainer, Initializer, Normalizer

[docs]class RepresentationModule(nn.Module): """A base class for obtaining representations for entities/relations."""
[docs] def forward( self, indices: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: """Get representations for indices. :param indices: shape: (m,) The indices, or None. If None, return all representations. :return: shape: (m, d) The representations. """ raise NotImplementedError
[docs] def reset_parameters(self) -> None: """Reset the module's parameters."""
[docs] def post_parameter_update(self): """Apply constraints which should not be included in gradients."""
[docs]class Embedding(RepresentationModule): """Trainable embeddings. This class provides the same interface as :class:`torch.nn.Embedding` and can be used throughout PyKEEN as a more fully featured drop-in replacement. """ def __init__( self, num_embeddings: int, embedding_dim: int, initializer: Optional[Initializer] = None, initializer_kwargs: Optional[Mapping[str, Any]] = None, normalizer: Optional[Normalizer] = None, normalizer_kwargs: Optional[Mapping[str, Any]] = None, constrainer: Optional[Constrainer] = None, constrainer_kwargs: Optional[Mapping[str, Any]] = None, trainable: bool = True, ): """Instantiate an embedding with extended functionality. :param num_embeddings: >0 The number of embeddings. :param embedding_dim: >0 The embedding dimensionality. :param initializer: An optional initializer, which takes an uninitialized (num_embeddings, embedding_dim) tensor as input, and returns an initialized tensor of same shape and dtype (which may be the same, i.e. the initialization may be in-place) :param initializer_kwargs: Additional keyword arguments passed to the initializer :param normalizer: A normalization function, which is applied in every forward pass. :param normalizer_kwargs: Additional keyword arguments passed to the normalizer :param constrainer: A function which is applied to the weights after each parameter update, without tracking gradients. It may be used to enforce model constraints outside of gradient-based training. The function does not need to be in-place, but the weight tensor is modified in-place. :param constrainer_kwargs: Additional keyword arguments passed to the constrainer """ super().__init__() if initializer is None: initializer = nn.init.normal_ if initializer_kwargs: self.initializer = functools.partial(initializer, **initializer_kwargs) else: self.initializer = initializer # type: ignore if constrainer is not None and constrainer_kwargs: self.constrainer = functools.partial(constrainer, **constrainer_kwargs) else: self.constrainer = constrainer # type: ignore if normalizer is not None and normalizer_kwargs: self.normalizer = functools.partial(normalizer, **normalizer_kwargs) else: self.normalizer = normalizer # type: ignore self._embeddings = torch.nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embedding_dim, ) self._embeddings.requires_grad_(trainable)
[docs] @classmethod def init_with_device( cls, num_embeddings: int, embedding_dim: int, device: torch.device, initializer: Optional[Initializer] = None, initializer_kwargs: Optional[Mapping[str, Any]] = None, normalizer: Optional[Normalizer] = None, normalizer_kwargs: Optional[Mapping[str, Any]] = None, constrainer: Optional[Constrainer] = None, constrainer_kwargs: Optional[Mapping[str, Any]] = None, ) -> 'Embedding': # noqa:E501 """Create an embedding object on the given device by wrapping :func:`__init__`. This method is a hotfix for not being able to pass a device during initialization of :class:`torch.nn.Embedding`. Instead the weight is always initialized on CPU and has to be moved to GPU afterwards. .. seealso:: :return: The embedding. """ return cls( num_embeddings=num_embeddings, embedding_dim=embedding_dim, initializer=initializer, initializer_kwargs=initializer_kwargs, normalizer=normalizer, normalizer_kwargs=normalizer_kwargs, constrainer=constrainer, constrainer_kwargs=constrainer_kwargs, ).to(device=device)
@property def num_embeddings(self) -> int: # noqa: D401 """The total number of representations (i.e. the maximum ID).""" return self._embeddings.num_embeddings @property def embedding_dim(self) -> int: # noqa: D401 """The representation dimension.""" return self._embeddings.embedding_dim
[docs] def reset_parameters(self) -> None: # noqa: D102 # initialize weights in-place = self.initializer(
[docs] def post_parameter_update(self): # noqa: D102 # apply constraints in-place if self.constrainer is not None: = self.constrainer(
[docs] def forward( self, indices: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: # noqa: D102 if indices is None: x = self._embeddings.weight else: x = self._embeddings(indices) if self.normalizer is not None: x = self.normalizer(x) return x
[docs] def get_in_canonical_shape( self, indices: Optional[torch.LongTensor] = None, ) -> torch.FloatTensor: """Get embedding in canonical shape. :param indices: The indices. If None, return all embeddings. :return: shape: (batch_size, num_embeddings, d) """ x = self(indices=indices) if indices is None: return x.unsqueeze(dim=0) return x.unsqueeze(dim=1)