"""Applying a basis transformation to the dataset.
To convert a dataset to a new basis, this script can be run with:
.. code-block::
python mldft/datagen/transform_dataset.py data=qm9 data/transforms=local_frames
You can adapt ``data`` and ``data/transforms`` to your needs. The ``use_cached_data`` flag has to be set to ```false```.
"""
import multiprocessing
from collections import defaultdict
from pathlib import Path
from typing import Optional
import hydra
import numpy as np
import torch
import zarr
from loguru import logger
from omegaconf import DictConfig
from tqdm import tqdm
import mldft.utils.omegaconf_resolvers # noqa
from mldft.ml.data.components.basis_transforms import MasterTransformation
from mldft.ml.data.components.convert_transforms import ToNumpy, ToTorch
from mldft.ml.data.components.of_data import OFData
from mldft.utils.environ import get_mldft_data_path
from mldft.utils.multiprocess import configure_processes_and_threads
from mldft.utils.utils import set_default_torch_dtype
[docs]
def is_valid_label(path: Path, spatial_keys_to_check: list[str]) -> bool:
"""Check if a label file is valid. Does not guarantee that the file is not broken, but checks
that most of the necessary data is present.
Args:
path (Path): Path to the label file.
Returns:
bool: True if the label file is valid, False otherwise.
"""
try:
with zarr.open(path, mode="r") as root:
if "of_labels" in root and "ks_labels" in root and "geometry" in root:
if (
"basis" in root["of_labels"]
and "n_scf_steps" in root["of_labels"]
and "spatial" in root["of_labels"]
and "energies" in root["of_labels"]
):
if all(
key in root["of_labels/spatial"].keys() for key in spatial_keys_to_check
):
return True
except Exception as e:
logger.exception(f"Error while checking {path}: {e}")
return False
[docs]
def remove_broken_files(path_list: list[Path], old_label_dir: Path, new_label_dir: Path):
"""Scan the list of zarr files and remove those that are broken, eg. can not be opened or have
no of_labels.
Args:
path_list (list): List of paths to zarr files.
"""
for path in tqdm(path_list, desc="Searching for broken files", leave=False):
relative_path = path.relative_to(new_label_dir)
old_path = old_label_dir / relative_path
if not is_valid_label(old_path, spatial_keys_to_check=[]):
raise ValueError(f"No valid labels in old path {old_path}. Consider removing it.")
# all spatial keys present in the old file should be present in the new file
root = zarr.open(str(old_path), mode="r")
old_spatial_keys = root["of_labels/spatial"].keys()
if not is_valid_label(path, spatial_keys_to_check=old_spatial_keys):
logger.warning(
f"Invalid label in {path}, but found valid label in {old_path}. Removing {path}."
)
path.unlink()
[docs]
def convert_zarr_file(args):
"""Convert one label file with the new transforms.
Args:
args: Tuple of arguments for the convert function.
"""
(
path,
new_path,
basis_info,
transform,
) = args
logger.info(f"Processing {path.name}")
root = zarr.open(path, mode="r")
# These checks are not strictly necessary as we also demand the integrals to be present in the data loading,
# but they are good for piece of mind.
assert (
"dual_basis_integrals" in root["of_labels/spatial"]
), f"dual_basis_integrals not in {path}. Probably the file was generated with an outdated version of mldft"
scf_iterations = root["of_labels/n_scf_steps"][()]
# Copy labels to memory store
temp_store = zarr.MemoryStore()
new_root_temp = zarr.open(temp_store, mode="w")
zarr.copy_all(source=root, dest=new_root_temp)
spatial_keys = root["of_labels/spatial"].keys()
# Delete copied spatial labels, to make sure none are incorrectly left untransformed
of_spatial = new_root_temp.create_group("of_labels/spatial", overwrite=True)
new_of_spatial = defaultdict(list)
for i in range(scf_iterations):
# Load sample, apply transforms, save in list
sample = OFData.from_file_with_all_gradients(path, i, basis_info)
sample = transform(sample)
for key in spatial_keys:
new_of_spatial[key].append(sample[key])
# Stack the lists to tensors and save them
new_of_spatial = {key: np.stack(value) for key, value in new_of_spatial.items()}
for key, value in new_of_spatial.items():
of_spatial.create_dataset(key, data=value, chunks=(1, value.shape[-1]), dtype=np.float64)
# Create new zip store and move everything from the temp store to the new store
# This should be much better to prevent too many broken files
with zarr.ZipStore(new_path, mode="w") as new_store:
new_root = zarr.open(new_store, mode="w")
zarr.copy_all(source=new_root_temp, dest=new_root)
temp_store.clear()
[docs]
@hydra.main(version_base="1.3", config_path="../../configs/ml", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Hydra entry point for the script."""
transform_dataset(cfg)
if __name__ == "__main__":
main()