from typing import Dict, Union
import lightning as pl
import torch
from lightning import Callback
from lightning.pytorch.utilities import grad_norm
from loguru import logger
from torch.nn import Module
from torch.optim import Optimizer
[docs]
class LogGradientNorm(Callback):
"""Log total and per-parameter gradient norm at every training step."""
# maybe a good idea for the future; for now, just log every iteration
# def __init__(self):
# super().__init__()
# self.total_grad_norm_history = []
# self.last_detailed_logging_step = None
# def log_everything_this_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer):
# """Whether all individual gradient norms should be logged this step. Returns true in the first step, and if
# the total gradient norm is at least 10% larger than the running maximum over the last 100 steps,
# or this was the case in the previous step, or 10 steps ago.
# """
# if len(self.total_grad_norm_history) == 1:
# return True
# if trainer.global_step - self.last_detailed_logging_step in [1, 10]:
# return True
# if self.total_grad_norm_history[-1] > 1.1 * max(self.total_grad_norm_history[-100:-1]):
# return True
# return False
[docs]
def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
"""Log the gradients before the optimizer step."""
with torch.no_grad():
norm_dict = grad_norm(pl_module, norm_type=2)
# self.total_grad_norm_history.append(norm_dict["grad_2.0_norm_total"].item())
tb_logger = pl_module.tensorboard_logger
if tb_logger is None:
logger.warning("No tensorboard logger found. Skipping logging of gradient norms.")
return
# log total gradient norm in every step
tb_logger.add_scalar(
"grad_2.0_norm_total",
norm_dict["grad_2.0_norm_total"],
global_step=trainer.global_step,
)
if True: # self.log_everything_this_step(trainer, pl_module, optimizer):
norm_dict.pop("grad_2.0_norm_total") # logged this already above
for key, value in norm_dict.items():
tb_logger.add_scalar(key, value, global_step=trainer.global_step)
# self.last_detailed_logging_step = trainer.global_step
[docs]
def parameter_norm(
module: Module,
norm_type: Union[float, int, str],
group_separator: str = "/",
learnable_only=True,
) -> Dict[str, float]:
"""Compute each parameter's norm and their overall norm.
The overall norm is computed over all parameters together, as if they
were concatenated into a single vector.
Based on :class:`lightning.pytorch.utilities.grad_norm`.
Args:
module: :class:`torch.nn.Module` to inspect.
norm_type: The type of the used p-norm, cast to float if necessary.
Can be ``'inf'`` for infinity norm.
group_separator: The separator string used by the logger to group
the parameter norms in their own subfolder instead of the logs one.
learnable_only: Whether to only consider parameters that have a gradient.
Return:
norms: The dictionary of p-norms of each parameter and
a special entry for the total p-norm of the parameters viewed
as a single vector.
"""
with torch.no_grad():
norm_type = float(norm_type)
if norm_type <= 0:
raise ValueError(
f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}"
)
norms = {
f"parameter_{norm_type}_norm{group_separator}{name}": p.data.norm(norm_type)
for name, p in module.named_parameters()
if ((p.grad is not None) if learnable_only else True)
}
if norms:
total_norm = torch.tensor(list(norms.values())).norm(norm_type)
norms[f"parameter_{norm_type}_norm_total"] = total_norm
return norms
[docs]
class LogParameterNorm(Callback):
"""Log total and per-parameter norm at every training step."""
[docs]
def on_before_optimizer_step(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer
) -> None:
"""Log the parameter norms before the optimizer step."""
tb_logger = pl_module.tensorboard_logger
if tb_logger is None:
logger.warning("No tensorboard logger found. Skipping logging of parameter norms.")
return
parameter_norm_dict = parameter_norm(pl_module, norm_type=2)
for key, value in parameter_norm_dict.items():
tb_logger.add_scalar(key, value, global_step=trainer.global_step)