"""Callback mechanism for hydra jobs.
Adapted from https://github.com/paquiteau/hydra-callbacks
"""
from __future__ import annotations
import logging
import os
from hydra.experimental.callback import Callback
from omegaconf import DictConfig, open_dict
callback_logger = logging.getLogger("hydra.callbacks")
[docs]
def dummy_run(config: DictConfig, **kwargs: None) -> None:
"""Do nothing."""
[docs]
class AnyRunCallback(Callback):
"""Abstract Callback that execute on any run."""
[docs]
def __init__(self, enabled: bool = True):
callback_logger.debug("Init %s", self.__class__.__name__)
self.enabled = enabled
if not self.enabled:
# don't do anything if not enabled
self._on_anyrun_start = dummy_run # type: ignore
self._on_anyrun_end = dummy_run # type: ignore
[docs]
def on_run_start(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before a single run."""
callback_logger.debug("run start callback %s", self.__class__.__name__)
self._on_anyrun_start(config, **kwargs)
[docs]
def on_multirun_start(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before a multi run."""
callback_logger.debug("(multi)run start callback %s", self.__class__.__name__)
self._on_anyrun_start(config, **kwargs)
[docs]
def _on_anyrun_start(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before any run."""
[docs]
def on_run_end(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before a single run."""
callback_logger.debug("run end callback %s", self.__class__.__name__)
self._on_anyrun_end(config, **kwargs)
[docs]
def on_multirun_end(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before a multi run."""
callback_logger.debug("(multi)run end callback %s", self.__class__.__name__)
self._on_anyrun_end(config, **kwargs)
[docs]
def _on_anyrun_end(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before any run."""
[docs]
class GitInfo(AnyRunCallback):
"""Callback that check git infos and log them.
Parameters
----------
clean
if True, will fail if the repo is not clean
"""
[docs]
def __init__(self, clean: bool = False):
super().__init__()
self.clean = clean
[docs]
def _on_anyrun_start(self, config: DictConfig, **kwargs: None) -> None:
"""Execute before any run."""
try:
import git
except ImportError:
callback_logger.error("GitPython is not installed, aborting")
return
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
is_dirty = repo.is_dirty(untracked_files=True)
branch_name = repo.active_branch.name # Get the branch name
callback_logger.warning(f"Git sha: {sha}, branch: {branch_name}, dirty: {is_dirty}")
if is_dirty and self.clean:
callback_logger.error("Repo is dirty, aborting")
os._exit(1)
# Add git info to config
with open_dict(config):
config.git = {"sha": sha, "branch": branch_name, "is_dirty": is_dirty}