EdgeWeighting

class EdgeWeighting(**kwargs)[source]

Bases: Module

Base class for edge weightings.

Initialize the module.

Parameters:

kwargs – ignored keyword-based parameters.

Attributes Summary

needs_message

whether the edge weighting needs access to the message

Methods Summary

forward(source, target[, message, x_e])

Compute edge weights.

Attributes Documentation

needs_message: ClassVar[bool] = False

whether the edge weighting needs access to the message

Methods Documentation

abstract forward(source: Tensor, target: Tensor, message: Tensor | None = None, x_e: Tensor | None = None) Tensor[source]

Compute edge weights.

Parameters:
  • source (Tensor) – shape: (num_edges,) The source indices.

  • target (Tensor) – shape: (num_edges,) The target indices.

  • message (Tensor | None) – shape (num_edges, dim) Actual messages to weight

  • x_e (Tensor | None) – shape (num_nodes, dim) Node states up to the weighting point

Returns:

shape: (num_edges, dim) Messages weighted with the edge weights.

Return type:

Tensor