mldft_module

Lightning module for the neural network.

class MLDFTLitModule(net: ~torch.nn.modules.module.Module, basis_info: ~mldft.ml.data.components.basis_info.BasisInfo, optimizer: ~torch.optim.optimizer.Optimizer, scheduler: <module 'torch.optim.lr_scheduler' from '/home/runner/work/structures25/structures25/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py'>, target_key: str, loss_function: ~torch.nn.modules.module.Module, compile: bool, variational: bool = True, metric_interval: int = 10, logging_mixin_interval: int | None = None, show_logging_mixins_in_progress_bar: bool = False)[source]

The MLDFTLitModule class is a LightningModule that is used to wrap the neural network in the pytorch lightning framework.

__init__(net: ~torch.nn.modules.module.Module, basis_info: ~mldft.ml.data.components.basis_info.BasisInfo, optimizer: ~torch.optim.optimizer.Optimizer, scheduler: <module 'torch.optim.lr_scheduler' from '/home/runner/work/structures25/structures25/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py'>, target_key: str, loss_function: ~torch.nn.modules.module.Module, compile: bool, variational: bool = True, metric_interval: int = 10, logging_mixin_interval: int | None = None, show_logging_mixins_in_progress_bar: bool = False) None[source]

Initializes the MLDFTLitModule object.

Parameters:
  • net (torch.nn.Module) – the neural network

  • basis_info (BasisInfo) – the basis info object used for the dataset. Added here for logging purposes.

  • optimizer (torch.optim.Optimizer) – the optimizer to use for training

  • scheduler (torch.optim.lr_scheduler) – the learning rate scheduler to use for training

  • loss_function (torch.nn.Module) – the loss function to use for training

  • target_key (str) – the name of the target key. Added here to easily determine which targets were used for training afterward.

  • compile (bool) – whether to compile the model with torch.compile()

  • variational (bool) – whether the model is variational or not. If True, the model is assumed to predict two outputs, the energy and the coefficient difference to the ground state (for proj minao). If False, the model is assumed to predict the gradient directly, in a non-variational manner, so three outputs are expected: The energy, the gradient and the coefficient difference.

  • metric_interval (int) – the interval (in steps) at which the metrics are calculated and logged.

  • logging_mixin_interval (int | None) – the interval (in steps) at which the logging mixins are called. Defaults to None, which means that the logging mixins are not called.

  • show_logging_mixins_in_progress_bar (bool) – whether to show the values logged using the logging mixins in the progress bar. Defaults to False.

activate_logging_mixins() None[source]

Activates the logging mixins.

configure_optimizers() Dict[str, Any][source]

Choose what optimizers and learning-rate schedulers to use in your optimization.

Returns:

A dict containing the configured optimizers and learning-rate schedulers

to be used for training.

Return type:

Dict[str, Any]

deactivate_logging_mixins() None[source]

Deactivates the logging mixins.

forward(batch: Batch) Tuple[Tensor, Tensor, Tensor][source]

Applies the forward pass of the model to the batch, and computes the energy gradients via backprop in the variational case (otherwise they are calculated directly).

Parameters:

batch (Batch) – Batch object containing the data

Returns:

A tuple containing (in order) the loss, the predicted energy, the predicted gradients and the predicted differences

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

on_test_epoch_end() None[source]

Lightning hook that is called when a test epoch ends.

on_train_epoch_end() None[source]

Lightning hook that is called when a training epoch ends.

on_train_start() None[source]

Lightning hook that is called when training begins.

on_validation_epoch_end() None[source]

Lightning hook that is called when a validation epoch ends.

on_validation_epoch_start() None[source]

Lightning hook that is called when a validation epoch starts.

sample_forward(sample: OFData) OFData[source]

Applies the forward pass of the model to the batch.

Parameters:

sample – OFData object.

Returns:

The batch with the model predictions added

Return type:

OFData

setup(stage: str) None[source]

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

Parameters:

stage (str) – Either “fit”, “val”, “test”, or “predict”.

property tensorboard_logger

Get the tensorboard logger from the trainer.

test_step(batch: Batch) None[source]

Performs a single test step on a batch of data from the test set.

Parameters:

batch (Batch) – A batch of data

training_step(batch: Batch) dict[source]

Performs a single training step on a batch of data from the training set.

Parameters:

batch (Batch) – A batch of data

Returns:

A dict containing the loss, the model predictions, and the projected gradient differences.

Return type:

dict

validation_step(batch: Batch, batch_idx: int) dict[source]

Performs a single validation step on a batch of data from the validation set.

Parameters:
  • batch (Batch) – A batch of data

  • batch_idx (int) – The index of the batch

Returns:

A dict containing the loss, the model predictions, and the projected gradient differences.

Return type:

dict