Source code for mldft.ml.callbacks.timing

"""Provides `CallbackTiming` to configure when to call a callback.

The `CallbackTiming` class is used to configure when a callback should be called.
This is useful for expensive callbacks that should not be called every step, such as image logging.
"""


[docs] class CallbackTiming: """Base class for callback timings."""
[docs] def call_now(self, global_step: int) -> bool: """Returns True if the callback should be called now. Args: global_step: The current global step. Returns: True if the callback should be called now. """ raise NotImplementedError
[docs] class EveryNSteps(CallbackTiming): """Calls the callback every `n_steps` steps."""
[docs] def __init__(self, n_steps: int) -> None: """Initializes the EveryNSteps object. Args: n_steps: The number of steps between calls. """ self.n_steps = n_steps self.next_log = 0
[docs] def call_now(self, global_step: int) -> bool: """Returns True if the callback should be called now, which is the case if the global step is larger than `self.next_log`, the next scheduled call. .. Note:: If `call_now` is not called every step, the callback might not be called at all the scheduled steps at even intervals. This happens e.g. when using this at validation time. Args: global_step: The current global step. Returns: True if the callback should be called now. """ if global_step >= self.next_log: self.next_log = global_step + self.n_steps return True else: return False
[docs] class EveryIncreasingInterval(CallbackTiming): """Callback timing with an exponentially increasing interval between calls."""
[docs] def __init__( self, initial_log_step=0, initial_interval: int = 10, slow_down_factor: float = 1.1 ) -> None: """Calls the callback every `log_interval` steps, increasing the interval by `slow_down_factor` every time. This results in exponentially increasing intervals between calls, and log files that grow logarithmically. Args: initial_log_step: The first step at which the callback should be called. initial_interval: The initial number of steps between calls. slow_down_factor: The factor by which the interval is increased after each call. """ self.log_next = initial_log_step self.log_interval = initial_interval self.slow_down_factor = slow_down_factor
[docs] def call_now(self, global_step: int) -> bool: """Returns True if the callback should be called now, which is the case if the global step is larger than `self.next_log`, the next scheduled call. .. Note:: If `call_now` is not called every step, the callback might not be called at all the scheduled steps at exponentially increasing intervals. This happens e.g. when using this at validation time. Args: global_step: The current global step. Returns: True if the callback should be called now. """ if global_step >= self.log_next: self.log_next = global_step + self.log_interval self.log_interval = self.log_interval * self.slow_down_factor return True else: return False