Source code for pykeen.nn.text.encoder

"""Modules for text encoding."""

import logging
import string
from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable, Optional, Union

import torch
from class_resolver import ClassResolver, Hint, HintOrType
from class_resolver.contrib.torch import aggregation_resolver
from more_itertools import chunked
from torch import nn
from torch_max_mem import maximize_memory_utilization
from tqdm.auto import tqdm

from ...typing import FloatTensor
from ...utils import determine_maximum_batch_size, get_preferred_device, resolve_device, upgrade_to_sequence

if TYPE_CHECKING:
    from ..representation import Representation

__all__ = [
    # abstract
    "TextEncoder",
    "text_encoder_resolver",
    # concrete
    "CharacterEmbeddingTextEncoder",
    "TransformerTextEncoder",
]

logger = logging.getLogger(__name__)


@maximize_memory_utilization(keys="encoder")
def _encode_all_memory_utilization_optimized(
    encoder: "TextEncoder",
    labels: Sequence[str],
    batch_size: int,
) -> torch.Tensor:
    """
    Encode all labels with the given batch-size.

    Wrapped by memory utilization maximizer to automatically reduce the batch size if needed.

    :param encoder:
        the encoder
    :param labels:
        the labels to encode
    :param batch_size:
        the batch size to use. Will automatically be reduced if necessary.

    :return: shape: `(len(labels), dim)`
        the encoded labels
    """
    return torch.cat(
        [encoder(batch) for batch in chunked(tqdm(map(str, labels), leave=False), batch_size)],
        dim=0,
    )


[docs] class TextEncoder(nn.Module): """An encoder for text."""
[docs] def forward(self, labels: Union[str, Sequence[str]]) -> FloatTensor: """ Encode a batch of text. :param labels: length: b the texts :return: shape: `(b, dim)` an encoding of the texts """ labels = upgrade_to_sequence(labels) labels = list(map(str, labels)) return self.forward_normalized(texts=labels)
[docs] @abstractmethod def forward_normalized(self, texts: Sequence[str]) -> FloatTensor: """ Encode a batch of text. :param texts: length: b the texts :return: shape: `(b, dim)` an encoding of the texts """ raise NotImplementedError
[docs] @torch.inference_mode() def encode_all( self, labels: Sequence[str], batch_size: Optional[int] = None, ) -> FloatTensor: """Encode all labels (inference mode & batched). :param labels: a sequence of strings to encode :param batch_size: the batch size to use for encoding the labels. ``batch_size=1`` means that the labels are encoded one-by-one, while ``batch_size=len(labels)`` would correspond to encoding all at once. Larger batch sizes increase memory requirements, but may be computationally more efficient. `batch_size` can also be set to `None` to enable automatic batch size maximization for the employed hardware. :returns: shape: (len(labels), dim) a tensor representing the encodings for all labels """ batch_size = determine_maximum_batch_size( batch_size=batch_size, device=get_preferred_device(self), maximum_batch_size=len(labels) ) return _encode_all_memory_utilization_optimized(encoder=self, labels=labels, batch_size=batch_size).detach()
[docs] class CharacterEmbeddingTextEncoder(TextEncoder): """ A simple character-based text encoder. This encoder uses base representations for each character from a given alphabet, as well as two special tokens for unknown character and padding. To encoder a sentence, it converts it to a sequence of characters, obtains the invidual characters representations and aggregates these representations to a single one. With :class:`pykeen.nn.representation.Embedding` character representation and :func:`torch.mean` aggregation, this encoder is similar to a bag-of-characters model with trainable character embeddings. Therefore, it is invariant to the ordering of characters: >>> from pykeen.nn.text import CharacterEmbeddingTextEncoder >>> encoder = CharacterEmbeddingTextEncoder() >>> import torch >>> torch.allclose(encoder("seal"), encoder("sale")) True """ def __init__( self, dim: int = 32, character_representation: HintOrType["Representation"] = None, vocabulary: str = string.printable, aggregation: Hint[Callable[..., FloatTensor]] = None, ) -> None: """Initialize the encoder. :param dim: the embedding dimension :param character_representation: the character representation or a hint thereof :param vocabulary: the vocabulary, i.e., the allowed characters :param aggregation: the aggregation to use to pool the character embeddings """ super().__init__() from .. import representation_resolver self.aggregation = aggregation_resolver.make(aggregation, dim=-2) self.vocabulary = vocabulary self.token_to_id = {c: i for i, c in enumerate(vocabulary)} num_real_tokens = len(self.vocabulary) self.unknown_idx = num_real_tokens self.padding_idx = num_real_tokens + 1 self.character_embedding = representation_resolver.make( character_representation, max_id=num_real_tokens + 2, shape=dim ) # docstr-coverage: inherited
[docs] def forward_normalized(self, texts: Sequence[str]) -> FloatTensor: # noqa: D102 # tokenize token_ids = [[self.token_to_id.get(c, self.unknown_idx) for c in text] for text in texts] # pad max_length = max(map(len, token_ids)) indices = torch.full(size=(len(texts), max_length), fill_value=self.padding_idx) for i, ids in enumerate(token_ids): indices[i, : len(ids)] = torch.as_tensor(ids, dtype=torch.long) # get character embeddings x = self.character_embedding(indices=indices) # pool x = self.aggregation(x) if not torch.is_tensor(x): x = x.values return x
[docs] class TransformerTextEncoder(TextEncoder): """A combination of a tokenizer and a model.""" def __init__( self, pretrained_model_name_or_path: str = "bert-base-cased", max_length: int = 512, ): """ Initialize the encoder using :class:`transformers.AutoModel`. :param pretrained_model_name_or_path: the name of the pretrained model, or a path, cf. :meth:`transformers.AutoModel.from_pretrained` :param max_length: >0, default: 512 the maximum number of tokens to pad/trim the labels to :raises ImportError: if the :mod:`transformers` library could not be imported """ super().__init__() try: from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer except ImportError as error: raise ImportError( "Please install the `transformers` library, use the _transformers_ extra" " for PyKEEN with `pip install pykeen[transformers] when installing, or " " see the PyKEEN installation docs at https://pykeen.readthedocs.io/en/stable/installation.html" " for more information." ) from error self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path ) self.model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).to( resolve_device() ) self.max_length = max_length or 512 # docstr-coverage: inherited
[docs] def forward_normalized(self, texts: Sequence[str]) -> FloatTensor: # noqa: D102 return self.model( **self.tokenizer( texts, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length, ).to(get_preferred_device(self.model)) ).pooler_output
#: A resolver for text encoders. By default, can use 'characterembedding' #: for :class:`CharacterEmbeddingTextEncoder` or 'transformer' for #: :class:`TransformerTextEncoder`. text_encoder_resolver: ClassResolver[TextEncoder] = ClassResolver.from_subclasses( base=TextEncoder, default=CharacterEmbeddingTextEncoder, )