Source code for mldft.utils.log_utils.config_in_tensorboard

from io import StringIO

import rich
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from omegaconf import DictConfig
from rich.tree import Tree


[docs] def dict_to_tree( data: dict | DictConfig, tree: Tree = None, name: str = "Config", **tree_kwargs ) -> Tree: """Convert a nested dictionary to a Rich Tree.""" if tree is None: tree = Tree(name, **tree_kwargs) for key, value in data.items(): if isinstance(value, (dict, DictConfig)): dict_to_tree(value, tree.add(f"[cyan]{key}[/cyan]")) else: tree.add(f"[yellow]{key}[/yellow]: {value}") return tree
[docs] def tree_to_string(tree: Tree) -> str: """Convert a Rich Tree to a string.""" s = StringIO() rich.print(tree, file=s) s.seek(0) tree_str = s.read() s.close() return tree_str
[docs] def log_config_text_to_tensorboard(cfg: DictConfig, trainer: Trainer) -> None: """Log the config tree to TensorBoard. Args: cfg: A DictConfig composed by Hydra. trainer: The Lightning trainer. """ # send hparams to all loggers for logger in trainer.loggers: # for tensorboard logger, additionally log hparams as text if isinstance(logger, TensorBoardLogger): hparams_text = tree_to_string(dict_to_tree(cfg)) hparams_text = "```\n" + hparams_text + "\n```" logger.experiment.add_text("hparams", hparams_text)