CombinedRepresentation

class CombinedRepresentation(max_id: int, shape: int | Sequence[int] | None = None, base: str | Representation | type[Representation] | None | Sequence[str | Representation | type[Representation] | None] = None, base_kwargs: Mapping[str, Any] | None | Sequence[Mapping[str, Any] | None] = None, combination: str | Combination | type[Combination] | None = None, combination_kwargs: Mapping[str, Any] | None = None, **kwargs)[source]

Bases: Representation

A combined representation.

Initialize the representation.

Parameters:
  • max_id (int) – the number of representations.

  • shape (tuple[int, ...]) – The shape of an individual representation.

  • base (Sequence[Representation]) – the base representations, or hints thereof

  • base_kwargs (OneOrManyOptionalKwargs) – keyword-based parameters for the instantiation of base representations

  • combination (Combination) – the combination, or a hint thereof

  • combination_kwargs (OptionalKwargs) – additional keyword-based parameters used to instantiate the combination

  • kwargs – additional keyword-based parameters passed to Representation.__init__. May not contain any of {max_id, shape, unique}.

Raises:

ValueError – if the max_id of the base representations does not match

Methods Summary

combine(combination, base[, indices])

Combine base representations for the given indices.

Methods Documentation

static combine(combination: Module, base: Sequence[Representation], indices: Tensor | None = None) Tensor[source]

Combine base representations for the given indices.

Parameters:
  • combination (Module) – the combination

  • base (Sequence[Representation]) – the base representations

  • indices (Tensor | None) – the indices, as given to Representation._plain_forward()

Returns:

the combined representations for the given indices

Return type:

Tensor