Source code for mldft.utils.rich_utils

from io import StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple

import rich
import rich.syntax
import rich.tree
from hydra.core.hydra_config import HydraConfig
from lightning_utilities.core.rank_zero import rank_zero_only
from omegaconf import DictConfig, OmegaConf, open_dict
from rich.console import Console
from rich.prompt import Prompt
from rich.table import Table

from mldft.utils.log_utils import pylogger

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


[docs] def rich_to_str(rich_object: Table | Any) -> str: """Converts a Rich object to a string.""" console = Console(file=StringIO(), width=120) console.print(rich_object) return console.file.getvalue()
[docs] def add_as_string_option(func) -> Callable: """Decorator that adds an option to a function that returns a Rich object to return a string instead.""" def wrapper(*args, **kwargs): """Wrapper function.""" if kwargs.pop("as_string", False): return rich_to_str(func(*args, **kwargs)) return func(*args, **kwargs) return wrapper
[docs] @add_as_string_option def format_table_rich( *cols: Tuple[str, List[str]], col_kwargs: List[Dict[str, Any]] = None, title=None, ) -> Table: """Formats a table using the Rich library. Args: *cols: A list of tuples of the form (title, values) or (title, values_section_1, values_section_2, ..). col_kwargs: A list of dictionaries with keyword arguments for each column. title: The title of the table. as_string: Whether to return the table as a string. Default is ``False``. Returns: A Rich Table object. """ table = Table(title=title, style="dim", title_style="dim") if col_kwargs is None: for c in cols: table.add_column(c[0]) else: for c, kw in zip(cols, col_kwargs): table.add_column(c[0], **kw) for sec in range(1, len(cols[0])): table.add_section() sec_lengths = [len(col[sec]) for col in cols] assert all( length == sec_lengths[0] for length in sec_lengths ), f"All sections must have the same length! Got {sec_lengths}." for i in range(len(cols[0][sec])): table.add_row(*[c[sec][i] for c in cols]) return table
[docs] @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: """Prompts user to input tags from command line if no tags are provided in config. :param cfg: A DictConfig composed by Hydra. :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") log.warning("No tags provided in config. Prompting user to input tags...") tags = Prompt.ask("Enter a list of comma separated tags", default="dev") tags = [t.strip() for t in tags.split(",") if t != ""] with open_dict(cfg): cfg.tags = tags log.info(f"Tags: {cfg.tags}") if save_to_file: with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: rich.print(cfg.tags, file=file)