EmbeddingBagRepresentation

class EmbeddingBagRepresentation(assignment: Tensor, max_id: int | None = None, mode: Literal['sum', 'mean', 'max'] = 'mean', **kwargs: Any)[source]

Bases: Representation

An embedding bag representation.

EmbeddingBag is similar to a TokenRepresentation followed by an aggregation along the num_tokens dimension.

Its main differences are:

  • It fuses the token look-up and aggregation step in a single torch call.

  • It only allows for a limited set of non-parametric aggregations: sum(), mean(), or max()

  • It can handle sparse/variable number of tokens per input more naturally.

  • It always uses an Embedding layer instead of permitting an arbitrary Representation

If you have a boolean feature vector, for example, from a chemical fingerprint, you can construct an embedding bag with the following

features: torch.BoolTensor = ...

representation = EmbeddingBagRepresentation.from_iter(
    list(feature.nonzero())
    for feature in features
)

Let’s denote \(nnz(i)\) for the non-zero indices of the feature of molecule \(i\), then we build the following representation \(\mathbf{x}_i\)

\[\mathbf{x}_i := \sum \limits_{j \in nnz(i)} \mathbf{y}_j\]

where \(\mathbf{y}_j\) is the embedding for the substructure represented by dimension \(j\) in the signature. In a sense, it is very similar to using the 0/1 vectors and multiplying that with a matrix; it’s just implemented more efficiently (exploiting the sparsity).

Initialize the representation.

Parameters:
  • assignment (Tensor) – shape: (nnz, 2) The assignment between indices and tokens, in edge-list format. assignment[:, 0] denotes the indices for the representation, assignment[:, 1] the index of the token.

  • max_id (int) – The maximum ID (exclusively). Valid Ids reach from 0 to max_id-1. Can be None to infer it from the assignment tensor.

  • mode (Literal['sum', 'mean', 'max']) – The aggregation mode for EmbeddingBag.

  • kwargs (Any) – Additional keyword-based parameters passed to Representation.

Methods Summary

from_iter(xss, **kwargs)

Instantiate from an iterable of indices.

Methods Documentation

classmethod from_iter(xss: Iterable[Iterable[int]], **kwargs: Any) Self[source]

Instantiate from an iterable of indices.

Parameters:
  • xss (Iterable[Iterable[int]]) – An iterable over the indices, where each element is an iterable over the token indices for the given index.

  • kwargs (Any) – Additional keyword-based parameters passed to __init__()

Returns:

A corresponding representation.

Return type:

Self