import warnings
from functools import wraps
from importlib.util import find_spec
from socket import gethostname
from typing import Any, Callable, Dict, Optional, Tuple
import rich
import torch
from omegaconf import DictConfig
from mldft.utils import rich_utils
from mldft.utils.log_utils import pylogger
from mldft.utils.log_utils.config_in_tensorboard import dict_to_tree
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
[docs]
def task_wrapper(task_func: Callable) -> Callable:
    """Optional decorator that controls the failure behavior when executing the task function.
    This wrapper can be used to:
        - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
        - save the exception to a `.log` file
        - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
        - etc. (adjust depending on your needs)
    Example:
    .. code-block:: python
        @utils.task_wrapper
        def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
            ...
            return metric_dict, object_dict
    :param task_func: The task function to be wrapped.
    :return: The wrapped task function.
    """
    def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        # execute the task
        try:
            metric_dict, object_dict = task_func(cfg=cfg)
        # things to do if exception occurs
        except Exception as ex:
            # save exception to `.log` file
            log.exception("")
            # some hyperparameter combinations might be invalid or cause out-of-memory errors
            # so when using hparam search plugins like Optuna, you might want to disable
            # raising the below exception to avoid multirun failure
            raise ex
        # things to always do after either success or exception
        finally:
            # display output dir path in terminal
            log.info(f"Output dir: {cfg.paths.output_dir}")
            # always close wandb run (even if exception occurs so multirun won't fail)
            if find_spec("wandb"):  # check if wandb is installed
                import wandb
                if wandb.run:
                    log.info("Closing wandb!")
                    wandb.finish()
        return metric_dict, object_dict
    return wrap 
[docs]
def set_default_torch_dtype(dtype):
    """Decorator to set default data type for torch operations.
    Sets the default data type for the decorated function and unsets it afterwards to the previous default.
    Args:
        dtype: The default data type to set. Either torch.float32 or torch.float64.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # Set default data type at the beginning
            previous_dtype = torch.get_default_dtype()
            torch.set_default_dtype(dtype)
            try:
                result = func(*args, **kwargs)
                return result
            finally:
                # Set default data type back to the original at the end
                torch.set_default_dtype(previous_dtype)
        return wrapper
    return decorator 
[docs]
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
    """Safely retrieves value of the metric logged in LightningModule.
    :param metric_dict: A dict containing metric values.
    :param metric_name: If provided, the name of the metric to retrieve.
    :return: If a metric name was provided, the value of the metric.
    """
    if not metric_name:
        log.info("Metric name is None! Skipping metric value retrieval...")
        return None
    if metric_name not in metric_dict:
        raise KeyError(
            f"Metric value not found! {metric_name=} is not in available metric names: \n "
            f"{list(metric_dict.keys())}\n"
            "Make sure metric name logged in LightningModule is correct!\n"
            "Make sure `optimized_metric` name in `hparams_search` config is correct!"
        )
    metric_value = metric_dict[metric_name].item()
    log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
    return metric_value