shrink_gate_module
Shrink gate.
- class PerBasisFuncShrinkGateModule(embed_dim: int, lambda_co: float = 10.0, lambda_mul: float = 0.02)[source]
Module class for the shrink gate.
- __init__(embed_dim: int, lambda_co: float = 10.0, lambda_mul: float = 0.02) None[source]
Initializes the two parameters needed for the shrink gate.
- Parameters:
embed_dim (
int) – Number of existing basis functions same as the basis_info.n_basis or sum(basis_dim_per_atom).lambda_co (
float, optional) – Scaling parameter of output. Defaults to 10.0.lambda_mul (
float, optional) – Scaling parameter inside tanh. Defaults to 0.02.
- forward(coeffs: Tensor)[source]
Calculates the forward pass of the module. Where the components of the x Tensor are transformed as follows: x_ij -> lambda_co_j * tanh(lambda_mul_j * x_ij) where i is the index of the atom in the batch, j is the basis function index.
- Parameters:
coeffs (
Tensor) – Atom-hot embedded coefficients of shape (n_atoms, embed_dim)- Returns:
The shrunk tensor of shape (n_atoms, embed_dim)
- Return type:
Tensor
- class ShrinkGateModule(lambda_co: float = 10.0, lambda_mul: float = 0.02)[source]
Module class for the shrink gate.
- __init__(lambda_co: float = 10.0, lambda_mul: float = 0.02) None[source]
Initializes the two parameters needed for the shrink gate.
- Parameters:
lambda_co (
float, optional) – Scaling parameter of output. Defaults to 10.0.lambda_mul (
float, optional) – Scaling parameter inside tanh. Defaults to 0.02.