shrink_gate_module

Shrink gate.

class PerBasisFuncShrinkGateModule(embed_dim: int, lambda_co: float = 10.0, lambda_mul: float = 0.02)[source]

Module class for the shrink gate.

__init__(embed_dim: int, lambda_co: float = 10.0, lambda_mul: float = 0.02) None[source]

Initializes the two parameters needed for the shrink gate.

Parameters:
  • embed_dim (int) – Number of existing basis functions same as the basis_info.n_basis or sum(basis_dim_per_atom).

  • lambda_co (float, optional) – Scaling parameter of output. Defaults to 10.0.

  • lambda_mul (float, optional) – Scaling parameter inside tanh. Defaults to 0.02.

forward(coeffs: Tensor)[source]

Calculates the forward pass of the module. Where the components of the x Tensor are transformed as follows: x_ij -> lambda_co_j * tanh(lambda_mul_j * x_ij) where i is the index of the atom in the batch, j is the basis function index.

Parameters:

coeffs (Tensor) – Atom-hot embedded coefficients of shape (n_atoms, embed_dim)

Returns:

The shrunk tensor of shape (n_atoms, embed_dim)

Return type:

Tensor

class ShrinkGateModule(lambda_co: float = 10.0, lambda_mul: float = 0.02)[source]

Module class for the shrink gate.

__init__(lambda_co: float = 10.0, lambda_mul: float = 0.02) None[source]

Initializes the two parameters needed for the shrink gate.

Parameters:
  • lambda_co (float, optional) – Scaling parameter of output. Defaults to 10.0.

  • lambda_mul (float, optional) – Scaling parameter inside tanh. Defaults to 0.02.

forward(x: Tensor) Tensor[source]

Calculates the forward pass of the module. Where the components of the x Tensor are transformed as follows: x -> lambda_co * tanh(lambda_mul * x) (component-wise)

Parameters:

x (Tensor) – Input of the module.

Returns:

The transformed tensor.

Return type:

Tensor