g3d_stack
Module to wrap a stack of G3DLayer.
- class G3DStack(g3d_class: G3DLayer, n_layers: int, energy_readout_every: int | None = None, **g3d_kwargs)[source]
Module to wrap a stack of G3D layers.
- __init__(g3d_class: G3DLayer, n_layers: int, energy_readout_every: int | None = None, **g3d_kwargs) None[source]
The G3DStack module is a stack of
G3DLayer.- Parameters:
n_layers (
int) – number of G3D layers.energy_readout_every (
int, optional) – The frequency of energy readout. If None, the energy is read out every n_layers which means only the last layer will be read out.**g3d_kwargs – Arguments of G3DLayer.
- forward(x: Tensor, edge_index: Tensor, batch: Tensor, lframes: LFrames = None, edge_attr: Tensor = None, length: Tensor = None) Tensor[source]
Calculates the forward pass of the module by going through all the G3D layers.
- Parameters:
x (
Tensor) – The input node features.edge_index (
Tensor) – The edge indices.batch (
Tensor) – Batch assigning each node to a specific graph.lframes (
LFrames, optional) – The LFrames object containing the local frames. (default: None)edge_attr (
Tensor, optional) – The edge features. (default: None)length (
Tensor, optional) – The length of the edges sorted as in the edge index.
- Returns:
The output node features.
- Return type:
Tensor