from typing import Any
import pytorch_lightning as pl
from lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
from mldft.ml.callbacks.timing import CallbackTiming, EveryIncreasingInterval
from mldft.ml.models.mldft_module import MLDFTLitModule
[docs]
def default_log_timing():
    """Returns a default log timing, which logs with exponentially increasing intervals."""
    return EveryIncreasingInterval() 
[docs]
class OnStepCallbackWithTiming(Callback):
    """Base class for timed callbacks, which execute at certain intervals.
    They execute on step, during training and validation, at independently specified timings.
    """
[docs]
    def __init__(
        self,
        train_timing: CallbackTiming = None,
        val_timing: CallbackTiming = None,
        name: str = "auto",
    ) -> None:
        """Initializes the callback. The name is used for logging.
        Args:
            train_timing: The :class:`CallbackTiming` object specifying how often the callback is to be
                called during training.
            val_timing: The :class:`CallbackTiming` object specifying how often the callback is to be
                called during validation.
            name: The name of the callback. If ``"auto"``, the class name is used, minus "Log" if the
                class name starts with that.
        """
        super().__init__()
        self.val_log_timing = val_timing if val_timing is not None else default_log_timing()
        self.train_log_timing = train_timing if train_timing is not None else default_log_timing()
        if name == "auto":
            self.name = self.__class__.__name__
            if self.name.startswith("Log"):
                self.name = self.name[3:]
        else:
            self.name = name 
[docs]
    def execute(
        self,
        trainer: pl.Trainer,
        pl_module: MLDFTLitModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        split: str,
    ) -> None:
        """Executes the callback, e.g. logs a figure to tensorboard.
        Args:
            pl_module: The lightning module.
            trainer: The lightning trainer.
            outputs: The outputs of the lightning module.
            batch: The batch data.
            split: The split, either ``"train"`` or ``"val"``.
        """
        raise NotImplementedError 
[docs]
    def on_train_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: MLDFTLitModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
    ) -> None:
        """Executes the callback via :meth:`execute`, if the timing matches."""
        if not self.train_log_timing.call_now(trainer.global_step):
            return
        self.execute(
            pl_module=pl_module,
            trainer=trainer,
            outputs=outputs,
            batch=batch,
            split="train",
        ) 
[docs]
    def on_validation_batch_end(
        self,
        trainer: pl.Trainer,
        pl_module: MLDFTLitModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
        **kwargs,
    ) -> None:
        """Executes the callback via :meth:`execute`, if the timing matches.
        Can happen only for the first validation batch.
        """
        # only log on the first validation batch, as tensorboard does not show multiple images for the same step anyway
        if batch_idx != 0:
            return
        if not self.val_log_timing.call_now(trainer.global_step):
            return
        self.execute(
            pl_module=pl_module,
            trainer=trainer,
            outputs=outputs,
            batch=batch,
            split="val",
        )