gbf_module
Compute edge features by applying Gaussian basis functions to the distances between nodes.
- class GBFModule(num_gaussians: int = 10, normalized: bool = False)[source]
Module to calculate the edge attributes based on gaussian basis functions and the distance between nodes.
- __init__(num_gaussians: int = 10, normalized: bool = False) None[source]
Initializes the class. You can specify the number of gaussians and if the gaussians should be normalized. This function initializes the shift and scale parameters of the gaussians as learnable parameters.
- Parameters:
num_gaussians (
int, optional) – Number of gaussian functions on which the edge features are calculated. Defaults to 10.normalized (
bool, optional) – Defines if the gaussians should be normalized. Defaults to False.
- forward(sample: OFData) Tensor[source]
Calculate the forward pass of the module. This function calculates the edge attributes as follows:
\[e_{ij}^{k} = \exp\left(-\frac{1}{2} \left(\frac{\|r_i - r_j\| - \mu^k} {\sigma^k} \right)^2\right).\]A normalisation factor is added if the normalized flag is set to true.
- Parameters:
sample (
OFData) – The input data containing the positions and edge_index.- Returns:
The edge attributes. Shape: (E, num_gauss)
- Return type:
Tensor
- class GaussianLayer(basis_info: BasisInfo, num_gaussians: int = 128, normalized: bool = True, directed=True, init_radius_range=(0, 3))[source]
- __init__(basis_info: BasisInfo, num_gaussians: int = 128, normalized: bool = True, directed=True, init_radius_range=(0, 3)) None[source]
Initialize the GaussianLayer.
- Parameters:
basis_info (
BasisInfo) – Used to determine the number of atom types.num_gaussians (
int, optional) – Number of gaussians to use. Defaults to 128.normalized (
bool, optional) – Whether to normalize the gaussians. Defaults to True.directed (
bool, optional) – Whether the learned means and stds are directed, i.e. whether a C-H edge gets the same embedding as the reverse H-C edge. Defaults to True.init_radius_range (
Tuple[float,float], optional) – Range for the initialization of the means. Defaults to (0, 3).