utils
- extras(cfg: DictConfig) None[source]
Applies optional utilities before the task is started.
- Utilities:
Ignoring python warnings
Setting tags from command line
Rich config printing
- Parameters:
cfg – A DictConfig object containing the config tree.
- get_metric_value(metric_dict: Dict[str, Any], metric_name: str | None) float | None[source]
Safely retrieves value of the metric logged in LightningModule.
- Parameters:
metric_dict – A dict containing metric values.
metric_name – If provided, the name of the metric to retrieve.
- Returns:
If a metric name was provided, the value of the metric.
- set_default_torch_dtype(dtype)[source]
Decorator to set default data type for torch operations.
Sets the default data type for the decorated function and unsets it afterwards to the previous default.
- Parameters:
dtype – The default data type to set. Either torch.float32 or torch.float64.
- task_wrapper(task_func: Callable) Callable[source]
Optional decorator that controls the failure behavior when executing the task function.
- This wrapper can be used to:
make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
save the exception to a .log file
mark the run as failed with a dedicated file in the logs/ folder (so we can find and rerun it later)
etc. (adjust depending on your needs)
Example:
@utils.task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict
- Parameters:
task_func – The task function to be wrapped.
- Returns:
The wrapped task function.