Source code for mldft.utils.create_subset

import argparse
import pickle
import shutil
from pathlib import Path

import numpy as np
import yaml
from loguru import logger

from mldft.utils.environ import get_mldft_data_path


[docs] def create_subset( dataset_path: Path, subset_path: Path, label_dirs: list, split_file_name: str, reduction: int | float = 100, ): """Script to create a subset of a dataset. Args: dataset_path (Path): Path to the dataset to create a subset of. subset_path (Path): Path to the subset directory. label_dirs (list): List of label directories in the dataset, from these the labels will be copied. split_file_name (Path): Name of the split file. reduction (int, optional): Factor to reduce the dataset size by. Defaults to 100. """ with open(dataset_path / f"{split_file_name}.pkl", "rb") as f: split = pickle.load(f) new_splits = {} train_val_test = ["train", "val", "test"] for key in train_val_test: samples = split[key] new_splits[key] = [] n_samples = len(samples) logger.info(f"Old {key} set has {n_samples} geometries.") generator = np.random.Generator(np.random.PCG64(1)) new_size = int(n_samples / reduction) logger.info(f"Sampling {new_size} of those {n_samples} geometries.") new_sample_ids = generator.choice(np.arange(n_samples), size=new_size, replace=False) new_samples = [samples[i] for i in new_sample_ids] for sample in new_samples: for dir in label_dirs: label_path = dataset_path / dir / sample[1] if label_path.exists(): new_label_path = subset_path / dir / sample[1] new_label_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(label_path, new_label_path) new_splits[key].append(sample) else: print(f"{label_path} does not exist") logger.info(f"New {key} set has {len(new_splits[key])} geometries.") logger.info("Saving split file.") new_split_yaml = subset_path / split_file_name with new_split_yaml.open("w") as f: yaml.dump(new_splits, f, sort_keys=False, default_flow_style=None) try: new_split_yaml.chmod(0o770) except Exception as e: logger.warning(f"Could not change permissions of {new_split_yaml}: {e}") pickle_path = subset_path / f"{split_file_name}.pkl" with pickle_path.open("wb") as f: pickle.dump(new_splits, f, pickle.HIGHEST_PROTOCOL) try: pickle_path.chmod(0o770) except Exception as e: logger.warning(f"Could not change permissions of {pickle_path}: {e}") if (dataset_path / "basis_transformations").exists(): shutil.copytree( dataset_path / "basis_transformations", subset_path / "basis_transformations" ) logger.info("Copied basis transformations.") else: logger.warning("No basis transformations found.") if (dataset_path / "dataset_statistics").exists(): shutil.copytree(dataset_path / "dataset_statistics", subset_path / "dataset_statistics") logger.info("Copied dataset statistics.") else: logger.warning("No dataset statistics found.")
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("dataset", type=str, help="Name of the dataset to create a subset of.") parser.add_argument( "--reduction", type=float, default=100, help="Factor to reduce the dataset size by." ) parser.add_argument( "--override", action="store_true", help="Whether to override the existing subset." ) args = parser.parse_args() data_path = get_mldft_data_path() dataset_path = data_path / args.dataset subset_path = data_path / f"{args.dataset}subset" if subset_path.exists() and not args.override: raise FileExistsError(f"Subset {subset_path} already exists.") logger.info(f"Creating subset of {args.dataset} dataset at {subset_path}.") logger.info(f"Reducing dataset size by factor {args.reduction}.") label_dirs = sorted(dir.name for dir in dataset_path.glob("labels*") if dir.is_dir()) split_str = "train_val_test_split" create_subset(dataset_path, subset_path, label_dirs, split_str, args.reduction)