EmbeddingBagRepresentation
- class EmbeddingBagRepresentation(assignment: Tensor, max_id: int | None = None, mode: Literal['sum', 'mean', 'max'] = 'mean', **kwargs: Any)[source]
Bases:
Representation
An embedding bag representation.
EmbeddingBag
is similar to aTokenRepresentation
followed by an aggregation along the num_tokens dimension.Its main differences are:
It fuses the token look-up and aggregation step in a single torch call.
It only allows for a limited set of non-parametric aggregations:
sum()
,mean()
, ormax()
It can handle sparse/variable number of tokens per input more naturally.
It always uses an
Embedding
layer instead of permitting an arbitraryRepresentation
If you have a boolean feature vector, for example, from a chemical fingerprint, you can construct an embedding bag with the following
features: torch.BoolTensor = ... representation = EmbeddingBagRepresentation.from_iter( list(feature.nonzero()) for feature in features )
Let’s denote \(nnz(i)\) for the non-zero indices of the feature of molecule \(i\), then we build the following representation \(\mathbf{x}_i\)
\[\mathbf{x}_i := \sum \limits_{j \in nnz(i)} \mathbf{y}_j\]where \(\mathbf{y}_j\) is the embedding for the substructure represented by dimension \(j\) in the signature. In a sense, it is very similar to using the 0/1 vectors and multiplying that with a matrix; it’s just implemented more efficiently (exploiting the sparsity).
Initialize the representation.
- Parameters:
assignment (Tensor) – shape: (nnz, 2) The assignment between indices and tokens, in edge-list format.
assignment[:, 0]
denotes the indices for the representation,assignment[:, 1]
the index of the token.max_id (int) – The maximum ID (exclusively). Valid Ids reach from
0
tomax_id-1
. Can beNone
to infer it from the assignment tensor.mode (Literal['sum', 'mean', 'max']) – The aggregation mode for
EmbeddingBag
.kwargs (Any) – Additional keyword-based parameters passed to
Representation
.
Methods Summary
from_iter
(xss, **kwargs)Instantiate from an iterable of indices.
Methods Documentation