AttentionEdgeWeighting

class AttentionEdgeWeighting(message_dim, num_heads=8, dropout=0.1)[source]

Bases: EdgeWeighting

Message weighting by attention.

Initialize the module.

Parameters:
  • message_dim (int) – >0 the message dimension. has to be divisible by num_heads .. todo:: change to multiplicative instead of divisive to make this easier to use

  • num_heads (int) – >0 the number of attention heads

  • dropout (float) – the attention dropout

Raises:

ValueError – If message_dim is not divisible by num_heads

Attributes Summary

needs_message

whether the edge weighting needs access to the message

Methods Summary

forward(source, target[, message, x_e])

Compute edge weights.

Attributes Documentation

needs_message: ClassVar[bool] = True

whether the edge weighting needs access to the message

Methods Documentation

forward(source, target, message=None, x_e=None)[source]

Compute edge weights.

Parameters:
  • source (LongTensor) – shape: (num_edges,) The source indices.

  • target (LongTensor) – shape: (num_edges,) The target indices.

  • message (Optional[FloatTensor]) – shape (num_edges, dim) Actual messages to weight

  • x_e (Optional[FloatTensor]) – shape (num_nodes, dim) Node states up to the weighting point

Return type:

FloatTensor

Returns:

shape: (num_edges, dim) Messages weighted with the edge weights.