Source code for mldft.utils.utils

import warnings
from functools import wraps
from importlib.util import find_spec
from socket import gethostname
from typing import Any, Callable, Dict, Optional, Tuple

import rich
import torch
from omegaconf import DictConfig

from mldft.utils import rich_utils
from mldft.utils.log_utils import pylogger
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree

log = pylogger.RankedLogger(__name__, rank_zero_only=True)


[docs] def extras(cfg: DictConfig) -> None: """Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing :param cfg: A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): log.warning("Extras config not found! <cfg.extras=null>") return # disable python warnings if cfg.extras.get("ignore_warnings"): log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config if cfg.extras.get("enforce_tags"): log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") rich_utils.enforce_tags(cfg, save_to_file=True) if cfg.extras.get("hostname"): log.info("Adding hostname to config! <cfg.extras.hostname=True>") cfg["extras"]["hostname"] = gethostname() # pretty print config tree using Rich library if cfg.extras.get("print_config"): log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") # rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) rich.print(dict_to_tree(cfg, guide_style="dim"))
[docs] def task_wrapper(task_func: Callable) -> Callable: """Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) - save the exception to a `.log` file - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) - etc. (adjust depending on your needs) Example: .. code-block:: python @utils.task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict :param task_func: The task function to be wrapped. :return: The wrapped task function. """ def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: # execute the task try: metric_dict, object_dict = task_func(cfg=cfg) # things to do if exception occurs except Exception as ex: # save exception to `.log` file log.exception("") # some hyperparameter combinations might be invalid or cause out-of-memory errors # so when using hparam search plugins like Optuna, you might want to disable # raising the below exception to avoid multirun failure raise ex # things to always do after either success or exception finally: # display output dir path in terminal log.info(f"Output dir: {cfg.paths.output_dir}") # always close wandb run (even if exception occurs so multirun won't fail) if find_spec("wandb"): # check if wandb is installed import wandb if wandb.run: log.info("Closing wandb!") wandb.finish() return metric_dict, object_dict return wrap
[docs] def set_default_torch_dtype(dtype): """Decorator to set default data type for torch operations. Sets the default data type for the decorated function and unsets it afterwards to the previous default. Args: dtype: The default data type to set. Either torch.float32 or torch.float64. """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): # Set default data type at the beginning previous_dtype = torch.get_default_dtype() torch.set_default_dtype(dtype) try: result = func(*args, **kwargs) return result finally: # Set default data type back to the original at the end torch.set_default_dtype(previous_dtype) return wrapper return decorator
[docs] def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. :param metric_name: If provided, the name of the metric to retrieve. :return: If a metric name was provided, the value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") return None if metric_name not in metric_dict: raise KeyError( f"Metric value not found! {metric_name=} is not in available metric names: \n " f"{list(metric_dict.keys())}\n" "Make sure metric name logged in LightningModule is correct!\n" "Make sure `optimized_metric` name in `hparams_search` config is correct!" ) metric_value = metric_dict[metric_name].item() log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") return metric_value