g3d_stack

Module to wrap a stack of G3DLayer.

class G3DStack(g3d_class: G3DLayer, n_layers: int, energy_readout_every: int | None = None, **g3d_kwargs)[source]

Module to wrap a stack of G3D layers.

__init__(g3d_class: G3DLayer, n_layers: int, energy_readout_every: int | None = None, **g3d_kwargs) None[source]

The G3DStack module is a stack of G3DLayer.

Parameters:
  • n_layers (int) – number of G3D layers.

  • energy_readout_every (int, optional) – The frequency of energy readout. If None, the energy is read out every n_layers which means only the last layer will be read out.

  • **g3d_kwargs – Arguments of G3DLayer.

forward(x: Tensor, edge_index: Tensor, batch: Tensor, lframes: LFrames = None, edge_attr: Tensor = None, length: Tensor = None) Tensor[source]

Calculates the forward pass of the module by going through all the G3D layers.

Parameters:
  • x (Tensor) – The input node features.

  • edge_index (Tensor) – The edge indices.

  • batch (Tensor) – Batch assigning each node to a specific graph.

  • lframes (LFrames, optional) – The LFrames object containing the local frames. (default: None)

  • edge_attr (Tensor, optional) – The edge features. (default: None)

  • length (Tensor, optional) – The length of the edges sorted as in the edge index.

Returns:

The output node features.

Return type:

Tensor