Source code for mldft.api.setup

import os
import sys
from pathlib import Path

from huggingface_hub import hf_hub_download, snapshot_download

# ------------------------------------------------------
# Configuration
# ------------------------------------------------------
REPO_ID = "sciai-lab/structures25"

HUGGINGFACE_MODELS = {
    "QM9": "trained-on-qm9",
    "QMUGS": "trained-on-qmugs",
}

DEFAULT_DATA_DIR = Path(os.getenv("DFT_DATA", Path.home() / "dft_data"))
DEFAULT_MODELS_DIR = Path(os.getenv("DFT_MODELS", Path.home() / "dft_models"))
DEFAULT_STATISTICS_DIR = DEFAULT_MODELS_DIR / "dataset_statistics"


[docs] def query_yes_no(question, default="yes"): """Ask a yes/no question via raw_input() and return their answer. "question" is a string that is presented to the user. "default" is the presumed answer if the user just hits <Enter>. It must be "yes" (the default), "no" or None (meaning an answer is required of the user). The "answer" return value is True for "yes" or False for "no". """ valid = {"yes": True, "y": True, "ye": True, "no": False, "n": False} if default is None: prompt = " [y/n] " elif default == "yes": prompt = " [Y/n] " elif default == "no": prompt = " [y/N] " else: raise ValueError("invalid default answer: '%s'" % default) while True: sys.stdout.write(question + prompt) choice = input().lower() if default is not None and choice == "": return valid[default] elif choice in valid: return valid[choice] else: sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n")
[docs] def ask_path(prompt: str, default: Path) -> Path: """Prompt user for a directory path with a default.""" user_input = input(f"{prompt} [{default}]: ").strip() chosen = Path(user_input) if user_input else default chosen.mkdir(parents=True, exist_ok=True) return chosen.resolve()
[docs] def download_model(model_name: str, repo_id: str, target_dir: Path): """Download a model from Hugging Face Hub into target directory.""" print(f"⬇️ Downloading {model_name} from {repo_id}...") model_dir = target_dir / "train" / "runs" model_dir.mkdir(parents=True, exist_ok=True) try: model_config_path = hf_hub_download( repo_id=repo_id, filename=f"{model_name}/hparams.yaml", local_dir=model_dir, ) ckpt_file = hf_hub_download( repo_id=repo_id, filename=f"{model_name}/{model_name}.ckpt", local_dir=model_dir, ) ckpt_dir = model_dir / f"{model_name}" / "checkpoints" ckpt_dir.mkdir(parents=True, exist_ok=True) Path(ckpt_file).rename(ckpt_dir / "last.ckpt") print(f"✅ {model_name} downloaded to {model_dir}") except Exception as e: print(f"❌ Failed to download {model_name}: {e}")
[docs] def download_dataset_statistics(repo_id: str, target_dir: Path): """Download dataset statistics from Hugging Face Hub into target directory.""" print(f"⬇️ Downloading statistics from {repo_id}...") try: print(f"Saving to {target_dir}") # Download only a subdirectory (e.g., "data/train") prefix = "sciai-test-mol/dataset_statistics/dataset_statistics_labels_no_basis_transforms_e_kin_plus_xc.zarr/" snapshot_download( repo_id=repo_id, allow_patterns=[ prefix + suffix for suffix in ["*0", "*.zarray", "*.zattrs", "*.zgroup"] ], local_dir=target_dir, ) print(f"✅ Dataset statistics downloaded to {target_dir}") except Exception as e: print(f"❌ Failed to download dataset statistics: {e}")
[docs] def ask_download_models(models_dir: Path, repo_id: str = REPO_ID): """Ask user if they want to download models.""" question = "Would you like to download the QM9 and QMUGS models from Hugging Face?" answer = query_yes_no(question, default="yes") if answer: for key, name in HUGGINGFACE_MODELS.items(): download_model(name, repo_id, models_dir) else: print("Skipping model downloads.")
[docs] def ask_download_dataset_statistics(statistics_dir: Path, repo_id: str = REPO_ID): """Ask user if they want to download models.""" question = "Would you like to download dataset statistics from Hugging Face for the SAD guess?" answer = query_yes_no(question, default="yes") if answer: download_dataset_statistics(repo_id, statistics_dir) else: print("Skipping dataset statistics downloads.")
[docs] def main(): """Main setup function.""" print("🚀 MLDFT Package Setup\n--------------------") data_dir = ask_path("Enter data directory path", DEFAULT_DATA_DIR) models_dir = ask_path("Enter models directory path", DEFAULT_MODELS_DIR) statistics_dir = ask_path("Enter models directory path", DEFAULT_STATISTICS_DIR) print("\n📋 Using the following paths for setup:") print(f" - DFT_DATA = {data_dir}") print(f" - DFT_MODELS = {models_dir}") print(f" - DFT_STATISTICS = {statistics_dir}") print("\nTo set these environment variables in your current shell, run:") print(f"export DFT_DATA='{data_dir}'") print(f"export DFT_MODELS='{models_dir}'") print(f"export DFT_STATISTICS='{statistics_dir}'") print("\nTo make them permanent, add the above lines to your ~/.zshrc file.") ask_download_models(models_dir, repo_id=REPO_ID) ask_download_dataset_statistics(statistics_dir, repo_id=REPO_ID) print("\n✅ Setup complete! 🎉")
if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\nSetup cancelled.") sys.exit(1)