Source code for pykeen.nn.utils
# -*- coding: utf-8 -*-
"""Utilities for neural network components."""
from typing import Optional, Sequence, Union
import torch
from more_itertools import chunked
from torch import nn
from tqdm.auto import tqdm
__all__ = [
"TransformerEncoder",
]
[docs]class TransformerEncoder(nn.Module):
"""A combination of a tokenizer and a model."""
def __init__(
self,
pretrained_model_name_or_path: str,
max_length: Optional[int] = None,
):
"""
Initialize the encoder.
:param pretrained_model_name_or_path:
the name of the pretrained model, or a path, cf. :func:`transformers.AutoModel.from_pretrained`
:param max_length: >0, default: 512
the maximum number of tokens to pad/trim the labels to
:raises ImportError:
if the :mod:`transformers` library could not be imported
"""
super().__init__()
try:
from transformers import AutoModel, AutoTokenizer
except ImportError as error:
raise ImportError(
"Please install the `transformers` library, use the _transformers_ extra"
" for PyKEEN iwth `pip install pykeen[transformers] when installing, or "
" see the PyKEEN installation docs at https://pykeen.readthedocs.io/en/stable/installation.html"
" for more information."
) from error
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
self.model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
self.max_length = max_length or 512
[docs] def forward(self, labels: Union[str, Sequence[str]]) -> torch.FloatTensor:
"""Encode labels via the provided model and tokenizer."""
if isinstance(labels, str):
labels = [labels]
return self.model(
**self.tokenizer(
labels,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
)
).pooler_output
[docs] @torch.inference_mode()
def encode_all(
self,
labels: Sequence[str],
batch_size: int = 1,
) -> torch.FloatTensor:
"""Encode all labels (inference mode & batched).
:param labels:
a sequence of strings to encode
:param batch_size:
the batch size to use for encoding the labels. ``batch_size=1``
means that the labels are encoded one-by-one, while ``batch_size=len(labels)``
would correspond to encoding all at once.
Larger batch sizes increase memory requirements, but may be computationally
more efficient.
:returns: shape: (len(labels), dim)
a tensor representing the encodings for all labels
"""
return torch.cat(
[self(batch) for batch in chunked(tqdm(labels), batch_size)],
dim=0,
)