from typing import Dict, Union
import lightning as pl
import torch
from lightning import Callback
from lightning.pytorch.utilities import grad_norm
from loguru import logger
from torch.nn import Module
from torch.optim import Optimizer
[docs]
class LogGradientNorm(Callback):
    """Log total and per-parameter gradient norm at every training step."""
    # maybe a good idea for the future; for now, just log every iteration
    # def __init__(self):
    #     super().__init__()
    # self.total_grad_norm_history = []
    # self.last_detailed_logging_step = None
    # def log_everything_this_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer):
    #     """Whether all individual gradient norms should be logged this step. Returns true in the first step, and if
    #     the total gradient norm is at least 10% larger than the running maximum over the last 100 steps,
    #     or this was the case in the previous step, or 10 steps ago.
    #     """
    #     if len(self.total_grad_norm_history) == 1:
    #         return True
    #     if trainer.global_step - self.last_detailed_logging_step in [1, 10]:
    #         return True
    #     if self.total_grad_norm_history[-1] > 1.1 * max(self.total_grad_norm_history[-100:-1]):
    #         return True
    #     return False
[docs]
    def on_before_optimizer_step(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
    ) -> None:
        """Log the gradients before the optimizer step."""
        with torch.no_grad():
            norm_dict = grad_norm(pl_module, norm_type=2)
        # self.total_grad_norm_history.append(norm_dict["grad_2.0_norm_total"].item())
        tb_logger = pl_module.tensorboard_logger
        if tb_logger is None:
            logger.warning("No tensorboard logger found. Skipping logging of gradient norms.")
            return
        # log total gradient norm in every step
        tb_logger.add_scalar(
            "grad_2.0_norm_total",
            norm_dict["grad_2.0_norm_total"],
            global_step=trainer.global_step,
        )
        if True:  # self.log_everything_this_step(trainer, pl_module, optimizer):
            norm_dict.pop("grad_2.0_norm_total")  # logged this already above
            for key, value in norm_dict.items():
                tb_logger.add_scalar(key, value, global_step=trainer.global_step) 
 
            # self.last_detailed_logging_step = trainer.global_step
[docs]
def parameter_norm(
    module: Module,
    norm_type: Union[float, int, str],
    group_separator: str = "/",
    learnable_only=True,
) -> Dict[str, float]:
    """Compute each parameter's norm and their overall norm.
    The overall norm is computed over all parameters together, as if they
    were concatenated into a single vector.
    Based on :class:`lightning.pytorch.utilities.grad_norm`.
    Args:
        module: :class:`torch.nn.Module` to inspect.
        norm_type: The type of the used p-norm, cast to float if necessary.
            Can be ``'inf'`` for infinity norm.
        group_separator: The separator string used by the logger to group
            the parameter norms in their own subfolder instead of the logs one.
        learnable_only: Whether to only consider parameters that have a gradient.
    Return:
        norms: The dictionary of p-norms of each parameter and
            a special entry for the total p-norm of the parameters viewed
            as a single vector.
    """
    with torch.no_grad():
        norm_type = float(norm_type)
        if norm_type <= 0:
            raise ValueError(
                f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}"
            )
        norms = {
            f"parameter_{norm_type}_norm{group_separator}{name}": p.data.norm(norm_type)
            for name, p in module.named_parameters()
            if ((p.grad is not None) if learnable_only else True)
        }
        if norms:
            total_norm = torch.tensor(list(norms.values())).norm(norm_type)
            norms[f"parameter_{norm_type}_norm_total"] = total_norm
    return norms 
[docs]
class LogParameterNorm(Callback):
    """Log total and per-parameter norm at every training step."""
[docs]
    def on_before_optimizer_step(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
    ) -> None:
        """Log the parameter norms before the optimizer step."""
        tb_logger = pl_module.tensorboard_logger
        if tb_logger is None:
            logger.warning("No tensorboard logger found. Skipping logging of parameter norms.")
            return
        parameter_norm_dict = parameter_norm(pl_module, norm_type=2)
        for key, value in parameter_norm_dict.items():
            tb_logger.add_scalar(key, value, global_step=trainer.global_step)