import warnings
from functools import wraps
from importlib.util import find_spec
from socket import gethostname
from typing import Any, Callable, Dict, Optional, Tuple
import rich
import torch
from omegaconf import DictConfig
from mldft.utils import rich_utils
from mldft.utils.log_utils import pylogger
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
[docs]
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that controls the failure behavior when executing the task function.
This wrapper can be used to:
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
- save the exception to a `.log` file
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
- etc. (adjust depending on your needs)
Example:
.. code-block:: python
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
...
return metric_dict, object_dict
:param task_func: The task function to be wrapped.
:return: The wrapped task function.
"""
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
# execute the task
try:
metric_dict, object_dict = task_func(cfg=cfg)
# things to do if exception occurs
except Exception as ex:
# save exception to `.log` file
log.exception("")
# some hyperparameter combinations might be invalid or cause out-of-memory errors
# so when using hparam search plugins like Optuna, you might want to disable
# raising the below exception to avoid multirun failure
raise ex
# things to always do after either success or exception
finally:
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")
# always close wandb run (even if exception occurs so multirun won't fail)
if find_spec("wandb"): # check if wandb is installed
import wandb
if wandb.run:
log.info("Closing wandb!")
wandb.finish()
return metric_dict, object_dict
return wrap
[docs]
def set_default_torch_dtype(dtype):
"""Decorator to set default data type for torch operations.
Sets the default data type for the decorated function and unsets it afterwards to the previous default.
Args:
dtype: The default data type to set. Either torch.float32 or torch.float64.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Set default data type at the beginning
previous_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
try:
result = func(*args, **kwargs)
return result
finally:
# Set default data type back to the original at the end
torch.set_default_dtype(previous_dtype)
return wrapper
return decorator
[docs]
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
"""Safely retrieves value of the metric logged in LightningModule.
:param metric_dict: A dict containing metric values.
:param metric_name: If provided, the name of the metric to retrieve.
:return: If a metric name was provided, the value of the metric.
"""
if not metric_name:
log.info("Metric name is None! Skipping metric value retrieval...")
return None
if metric_name not in metric_dict:
raise KeyError(
f"Metric value not found! {metric_name=} is not in available metric names: \n "
f"{list(metric_dict.keys())}\n"
"Make sure metric name logged in LightningModule is correct!\n"
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
)
metric_value = metric_dict[metric_name].item()
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
return metric_value