Source code for mldft.utils.log_utils.logging_mixin

"""Implements :class:`LoggingMixin`, which can be used easily log from any :class:`torch.nn.Module`

Example:

    >>> class MyLayer(torch.nn.Module, LoggingMixin):
    >>>     def forward(self, x, y):
    >>>         z = x + y
    >>>         # easily log intermediate value
    >>>         self.log(x=x, z_intermediate=z)
    >>>         z = z ** 2
    >>>         self.log(z_final=z)
    >>>         return z
"""

from numbers import Number

import numpy as np
import torch
from torch import Tensor


[docs] class LoggingMixin: """A mixin class for logging. Use in your ``nn.Module`` to log values during training and validation. """
[docs] def activate_logging(self) -> None: """Activate logging.""" self._logging_active = True
[docs] def deactivate_logging(self) -> None: """Deactivate logging.""" self._logging_active = False # clear log_dict self._log_dict = dict()
@property def log_dict(self) -> dict: """A dictionary of values to log.""" if not hasattr(self, "_log_dict"): self._log_dict = dict() return self._log_dict
[docs] def _process_log_dict(self, log_dict: dict) -> dict: """process a dict of things to be logged: log scalars directly, otherwise compute mean, std, abs_max""" processed_log_dict = {} for key, value in log_dict.items(): if not isinstance(value, (Number, np.ndarray, Tensor)): raise ValueError( "Currently, only the logging of numbers, arrays and tensors is implemented" ) if isinstance(value, Number): processed_log_dict[key] = value continue if isinstance(value, np.ndarray): value = torch.from_numpy(value) if value.numel() == 1: processed_log_dict[key] = value.item() continue # value is a tensor with multiple elements value = value.detach() processed_log_dict[f"mean_{key}"] = value.mean().item() processed_log_dict[f"std_{key}"] = value.std().item() processed_log_dict[f"abs-max_{key}"] = value.abs().max().item() return processed_log_dict
[docs] def log(self, **log_dict: dict) -> None: """Log values, if logging was activated.""" if not getattr(self, "_logging_active", False): return self.log_dict.update(self._process_log_dict(log_dict))
[docs] def _locate_logging_mixins_generator(module: torch.nn.Module, prefix: str = ""): """Recursively locate all LoggingMixin instances in a module and its children.""" if isinstance(module, LoggingMixin): yield prefix, module for name, child in module.named_children(): yield from _locate_logging_mixins_generator(child, prefix + "." + name if prefix else name)
[docs] def locate_logging_mixins(module: torch.nn.Module) -> dict: """Locate all LoggingMixin instances in a module and its (named) children. Args: module: The module to search. Returns: A dictionary mapping the dotted 'paths' within the module to the LoggingMixin instances. """ return dict(_locate_logging_mixins_generator(module))