utils

extras(cfg: DictConfig) None[source]

Applies optional utilities before the task is started.

Utilities:
  • Ignoring python warnings

  • Setting tags from command line

  • Rich config printing

Parameters:

cfg – A DictConfig object containing the config tree.

get_metric_value(metric_dict: Dict[str, Any], metric_name: str | None) float | None[source]

Safely retrieves value of the metric logged in LightningModule.

Parameters:
  • metric_dict – A dict containing metric values.

  • metric_name – If provided, the name of the metric to retrieve.

Returns:

If a metric name was provided, the value of the metric.

set_default_torch_dtype(dtype)[source]

Decorator to set default data type for torch operations.

Sets the default data type for the decorated function and unsets it afterwards to the previous default.

Parameters:

dtype – The default data type to set. Either torch.float32 or torch.float64.

task_wrapper(task_func: Callable) Callable[source]

Optional decorator that controls the failure behavior when executing the task function.

This wrapper can be used to:
  • make sure loggers are closed even if the task function raises an exception (prevents multirun failure)

  • save the exception to a .log file

  • mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later)

  • etc. (adjust depending on your needs)

Example:

@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    ...
    return metric_dict, object_dict
Parameters:

task_func – The task function to be wrapped.

Returns:

The wrapped task function.