mldft_module
Lightning module for the neural network.
- class MLDFTLitModule(net: ~torch.nn.modules.module.Module, basis_info: ~mldft.ml.data.components.basis_info.BasisInfo, optimizer: ~torch.optim.optimizer.Optimizer, scheduler: <module 'torch.optim.lr_scheduler' from '/home/runner/work/structures25/structures25/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py'>, target_key: str, loss_function: ~torch.nn.modules.module.Module, compile: bool, variational: bool = True, metric_interval: int = 10, logging_mixin_interval: int | None = None, show_logging_mixins_in_progress_bar: bool = False)[source]
The MLDFTLitModule class is a LightningModule that is used to wrap the neural network in the pytorch lightning framework.
- __init__(net: ~torch.nn.modules.module.Module, basis_info: ~mldft.ml.data.components.basis_info.BasisInfo, optimizer: ~torch.optim.optimizer.Optimizer, scheduler: <module 'torch.optim.lr_scheduler' from '/home/runner/work/structures25/structures25/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py'>, target_key: str, loss_function: ~torch.nn.modules.module.Module, compile: bool, variational: bool = True, metric_interval: int = 10, logging_mixin_interval: int | None = None, show_logging_mixins_in_progress_bar: bool = False) None[source]
Initializes the MLDFTLitModule object.
- Parameters:
net (
torch.nn.Module) – the neural networkbasis_info (
BasisInfo) – the basis info object used for the dataset. Added here for logging purposes.optimizer (
torch.optim.Optimizer) – the optimizer to use for trainingscheduler (
torch.optim.lr_scheduler) – the learning rate scheduler to use for trainingloss_function (
torch.nn.Module) – the loss function to use for trainingtarget_key (
str) – the name of the target key. Added here to easily determine which targets were used for training afterward.compile (
bool) – whether to compile the model withtorch.compile()variational (
bool) – whether the model is variational or not. If True, the model is assumed to predict two outputs, the energy and the coefficient difference to the ground state (for proj minao). If False, the model is assumed to predict the gradient directly, in a non-variational manner, so three outputs are expected: The energy, the gradient and the coefficient difference.metric_interval (
int) – the interval (in steps) at which the metrics are calculated and logged.logging_mixin_interval (
int | None) – the interval (in steps) at which the logging mixins are called. Defaults to None, which means that the logging mixins are not called.show_logging_mixins_in_progress_bar (
bool) – whether to show the values logged using the logging mixins in the progress bar. Defaults to False.
- configure_optimizers() Dict[str, Any][source]
Choose what optimizers and learning-rate schedulers to use in your optimization.
- Returns:
- A dict containing the configured optimizers and learning-rate schedulers
to be used for training.
- Return type:
Dict[str, Any]
- forward(batch: Batch) Tuple[Tensor, Tensor, Tensor][source]
Applies the forward pass of the model to the batch, and computes the energy gradients via backprop in the variational case (otherwise they are calculated directly).
- Parameters:
batch (
Batch) – Batch object containing the data- Returns:
A tuple containing (in order) the loss, the predicted energy, the predicted gradients and the predicted differences
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
- on_validation_epoch_start() None[source]
Lightning hook that is called when a validation epoch starts.
- sample_forward(sample: OFData) OFData[source]
Applies the forward pass of the model to the batch.
- Parameters:
sample – OFData object.
- Returns:
The batch with the model predictions added
- Return type:
- setup(stage: str) None[source]
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
- Parameters:
stage (
str) – Either “fit”, “val”, “test”, or “predict”.
- property tensorboard_logger
Get the tensorboard logger from the trainer.
- test_step(batch: Batch) None[source]
Performs a single test step on a batch of data from the test set.
- Parameters:
batch (
Batch) – A batch of data
- training_step(batch: Batch) dict[source]
Performs a single training step on a batch of data from the training set.
- Parameters:
batch (
Batch) – A batch of data- Returns:
A dict containing the loss, the model predictions, and the projected gradient differences.
- Return type:
dict
- validation_step(batch: Batch, batch_idx: int) dict[source]
Performs a single validation step on a batch of data from the validation set.
- Parameters:
batch (
Batch) – A batch of databatch_idx (
int) – The index of the batch
- Returns:
A dict containing the loss, the model predictions, and the projected gradient differences.
- Return type:
dict