Source code for pykeen.nn.utils
# -*- coding: utf-8 -*-
"""Utilities for neural network components."""
import logging
from typing import Iterable, Optional, Sequence, Union
import torch
from more_itertools import chunked
from torch import nn
from torch_max_mem import MemoryUtilizationMaximizer
from tqdm.auto import tqdm
from ..utils import get_preferred_device, resolve_device, upgrade_to_sequence
__all__ = [
"TransformerEncoder",
"safe_diagonal",
]
logger = logging.getLogger(__name__)
memory_utilization_maximizer = MemoryUtilizationMaximizer()
@memory_utilization_maximizer
def _encode_all_memory_utilization_optimized(
encoder: "TransformerEncoder",
labels: Sequence[str],
batch_size: int,
) -> torch.Tensor:
"""
Encode all labels with the given batch-size.
Wrapped by memory utilization maximizer to automatically reduce the batch size if needed.
:param encoder:
the encoder
:param labels:
the labels to encode
:param batch_size:
the batch size to use. Will automatically be reduced if necessary.
:return: shape: `(len(labels), dim)`
the encoded labels
"""
return torch.cat(
[encoder(batch) for batch in chunked(tqdm(map(str, labels), leave=False), batch_size)],
dim=0,
)
[docs]class TransformerEncoder(nn.Module):
"""A combination of a tokenizer and a model."""
def __init__(
self,
pretrained_model_name_or_path: str,
max_length: Optional[int] = None,
):
"""
Initialize the encoder.
:param pretrained_model_name_or_path:
the name of the pretrained model, or a path, cf. :func:`transformers.AutoModel.from_pretrained`
:param max_length: >0, default: 512
the maximum number of tokens to pad/trim the labels to
:raises ImportError:
if the :mod:`transformers` library could not be imported
"""
super().__init__()
try:
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizer
except ImportError as error:
raise ImportError(
"Please install the `transformers` library, use the _transformers_ extra"
" for PyKEEN with `pip install pykeen[transformers] when installing, or "
" see the PyKEEN installation docs at https://pykeen.readthedocs.io/en/stable/installation.html"
" for more information."
) from error
self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path
)
self.model = AutoModel.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path).to(
resolve_device()
)
self.max_length = max_length or 512
[docs] def forward(self, labels: Union[str, Sequence[str]]) -> torch.FloatTensor:
"""Encode labels via the provided model and tokenizer."""
labels = upgrade_to_sequence(labels)
labels = list(map(str, labels))
return self.model(
**self.tokenizer(
labels,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
).to(get_preferred_device(self.model))
).pooler_output
[docs] @torch.inference_mode()
def encode_all(
self,
labels: Sequence[str],
batch_size: Optional[int] = None,
) -> torch.FloatTensor:
"""Encode all labels (inference mode & batched).
:param labels:
a sequence of strings to encode
:param batch_size:
the batch size to use for encoding the labels. ``batch_size=1``
means that the labels are encoded one-by-one, while ``batch_size=len(labels)``
would correspond to encoding all at once.
Larger batch sizes increase memory requirements, but may be computationally
more efficient. `batch_size` can also be set to `None` to enable automatic batch
size maximization for the employed hardware.
:returns: shape: (len(labels), dim)
a tensor representing the encodings for all labels
"""
return _encode_all_memory_utilization_optimized(
encoder=self, labels=labels, batch_size=batch_size or len(labels)
).detach()
def iter_matrix_power(matrix: torch.Tensor, max_iter: int) -> Iterable[torch.Tensor]:
"""
Iterate over matrix powers.
:param matrix: shape: `(n, n)`
the square matrix
:param max_iter:
the maximum number of iterations.
:yields: increasing matrix powers
"""
yield matrix
a = matrix
for _ in range(max_iter - 1):
# if the sparsity becomes too low, convert to a dense matrix
# note: this heuristic is based on the memory consumption,
# for a sparse matrix, we store 3 values per nnz (row index, column index, value)
# performance-wise, it likely makes sense to switch even earlier
# `torch.sparse.mm` can also deal with dense 2nd argument
if a.is_sparse and a._nnz() >= a.numel() // 4:
a = a.to_dense()
# note: torch.sparse.mm only works for COO matrices;
# @ only works for CSR matrices
if matrix.is_sparse_csr:
a = matrix @ a
else:
a = torch.sparse.mm(matrix, a)
yield a
[docs]def safe_diagonal(matrix: torch.Tensor) -> torch.Tensor:
"""
Extract diagonal from a potentially sparse matrix.
.. note ::
this is a work-around as long as `torch.diagonal` does not work for sparse tensors
:param matrix: shape: `(n, n)`
the matrix
:return: shape: `(n,)`
the diagonal values.
"""
if not matrix.is_sparse:
return torch.diagonal(matrix)
# convert to COO, if necessary
if matrix.is_sparse_csr:
matrix = matrix.to_sparse_coo()
n = matrix.shape[0]
# we need to use indices here, since there may be zero diagonal entries
indices = matrix._indices()
mask = indices[0] == indices[1]
diagonal_values = matrix._values()[mask]
diagonal_indices = indices[0][mask]
return torch.zeros(n, device=matrix.device).scatter_add(dim=0, index=diagonal_indices, src=diagonal_values)