Source code for mldft.ml.callbacks.custom_scalars

from collections import defaultdict

import lightning as pl
from lightning import Callback
from loguru import logger

from mldft.ml.models.components.loss_function import WeightedLoss


[docs] class AddMetricAndLossCustomScalars(Callback): """Add custom scalars to the tensorboard logger, for comparison of train and val metrics and losses."""
[docs] def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: """Adds the custom scalars.""" tb_logger = pl_module.tensorboard_logger if tb_logger is None: logger.warning("No tensorboard logger found. Skipping addition of custom scalars.") return def _get_group(metric_key): """e.g. 'train_metrics_per_electron/total_loss' -> 'metrics per electron'.""" return metric_key.split("/")[0].replace("_", " ").replace("train", "").strip() layout = defaultdict(dict) # metrics for key, val in pl_module.train_metrics.items(): group = _get_group(key) layout[group][key.split("/")[-1]] = ["Multiline", [key, key.replace("train", "val")]] # losses layout["losses"]["total"] = ["Multiline", ["train_loss/total", "val_loss/total"]] if isinstance(pl_module.loss_function, WeightedLoss): for key, weight in pl_module.loss_function.weight_dict.items(): layout["losses"][key] = ["Multiline", [f"train_loss/{key}", f"val_loss/{key}"]] # layout['losses']['all'] = ['Multiline', ['train_loss/*', 'val_loss/*']] tb_logger.add_custom_scalars(layout)