ConcatAggregationCombination

class ConcatAggregationCombination(aggregation: str | Callable[[Tensor], Tensor] | None = None, aggregation_kwargs: Mapping[str, Any] | None = None, dim: int = -1)[source]

Bases: ConcatCombination

Combine representation by concatenation followed by an aggregation along the same axis.

Initialize the combination.

Parameters:
  • aggregation (str | Callable[[Tensor], Tensor] | None) – The aggregation, or a hint thereof.

  • aggregation_kwargs (Mapping[str, Any] | None) – Additional keyword-based parameters.

  • dim (int) – the concatenation and reduction dimension.

Note

The parameter pair (aggregation, aggregation_kwargs) is used for class_resolver.contrib.torch.aggregation_resolver

An explanation of resolvers and how to use them is given in https://class-resolver.readthedocs.io/en/latest/.

Methods Summary

forward(xs)

Combine a sequence of individual representations.

iter_extra_repr()

Iterate over the components of the extra_repr().

Methods Documentation

forward(xs: Sequence[Tensor]) Tensor[source]

Combine a sequence of individual representations.

Parameters:

xs (Sequence[Tensor]) – shape: (*batch_dims, *input_dims_i) the individual representations

Returns:

shape: (*batch_dims, *output_dims) a combined representation

Return type:

Tensor

iter_extra_repr() Iterable[str][source]

Iterate over the components of the extra_repr().

This method is typically overridden. A common pattern would be

def iter_extra_repr(self) -> Iterable[str]:
    yield from super().iter_extra_repr()
    yield "<key1>=<value1>"
    yield "<key2>=<value2>"
Returns:

an iterable over individual components of the extra_repr()

Return type:

Iterable[str]