from typing import Any
import pytorch_lightning as pl
from lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
from mldft.ml.callbacks.timing import CallbackTiming, EveryIncreasingInterval
from mldft.ml.models.mldft_module import MLDFTLitModule
[docs]
def default_log_timing():
"""Returns a default log timing, which logs with exponentially increasing intervals."""
return EveryIncreasingInterval()
[docs]
class OnStepCallbackWithTiming(Callback):
"""Base class for timed callbacks, which execute at certain intervals.
They execute on step, during training and validation, at independently specified timings.
"""
[docs]
def __init__(
self,
train_timing: CallbackTiming = None,
val_timing: CallbackTiming = None,
name: str = "auto",
) -> None:
"""Initializes the callback. The name is used for logging.
Args:
train_timing: The :class:`CallbackTiming` object specifying how often the callback is to be
called during training.
val_timing: The :class:`CallbackTiming` object specifying how often the callback is to be
called during validation.
name: The name of the callback. If ``"auto"``, the class name is used, minus "Log" if the
class name starts with that.
"""
super().__init__()
self.val_log_timing = val_timing if val_timing is not None else default_log_timing()
self.train_log_timing = train_timing if train_timing is not None else default_log_timing()
if name == "auto":
self.name = self.__class__.__name__
if self.name.startswith("Log"):
self.name = self.name[3:]
else:
self.name = name
[docs]
def execute(
self,
trainer: pl.Trainer,
pl_module: MLDFTLitModule,
outputs: STEP_OUTPUT,
batch: Any,
split: str,
) -> None:
"""Executes the callback, e.g. logs a figure to tensorboard.
Args:
pl_module: The lightning module.
trainer: The lightning trainer.
outputs: The outputs of the lightning module.
batch: The batch data.
split: The split, either ``"train"`` or ``"val"``.
"""
raise NotImplementedError
[docs]
def on_train_batch_end(
self,
trainer: pl.Trainer,
pl_module: MLDFTLitModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
"""Executes the callback via :meth:`execute`, if the timing matches."""
if not self.train_log_timing.call_now(trainer.global_step):
return
self.execute(
pl_module=pl_module,
trainer=trainer,
outputs=outputs,
batch=batch,
split="train",
)
[docs]
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: MLDFTLitModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
**kwargs,
) -> None:
"""Executes the callback via :meth:`execute`, if the timing matches.
Can happen only for the first validation batch.
"""
# only log on the first validation batch, as tensorboard does not show multiple images for the same step anyway
if batch_idx != 0:
return
if not self.val_log_timing.call_now(trainer.global_step):
return
self.execute(
pl_module=pl_module,
trainer=trainer,
outputs=outputs,
batch=batch,
split="val",
)