Source code for mldft.api.instantiate_from_args

"""Factories that translate CLI-style arguments into MLDFT runtime components.

The helpers in this module encapsulate how user-supplied configuration is converted into concrete
model checkpoints, optimizers, and input molecule lists.
"""

import os
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence, cast

import torch  # type: ignore[import]

from mldft.ofdft.optimizer import GradientDescent, TorchOptimizer, VectorAdam
from mldft.ofdft.run_density_optimization import SampleGenerator

DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


[docs] @dataclass class BaseConfig: """Configuration required to construct the base OFDFT inputs.""" xyzfile: tuple[str | Path, ...] charge: int = 0 initialization: str = "minao" normalize_initial_guess: bool = True save_result: bool = True ks_basis: str | None = None proj_minao_module: str | None = None sad_guess_kwargs: dict | None = None disable_printing: bool | None = None def __post_init__(self) -> None: self.xyzfile = tuple(Path(path) for path in self.xyzfile)
[docs] @dataclass class ModelConfig: """Configuration describing the trained MLDFT model to use.""" model: str use_last_ckpt: bool = True device: str = DEFAULT_DEVICE transform_device: str = "cpu" negative_integrated_density_penalty_weight: float = 0.0
[docs] @dataclass class OptimizerConfig: """Configuration controlling the density optimization routine.""" optimizer: str = "gradient-descent-torch" max_cycle: int = 10000 convergence_tolerance: float = 1e-4 lr: float = 1e-3 momentum: float = 0.9 betas: Sequence[float] = (0.9, 0.999) def __post_init__(self) -> None: self.betas = tuple(self.betas)
[docs] def get_runpath(name: str) -> Path: """Get the path to a named model.""" dft_models = os.environ.get("DFT_MODELS") if dft_models is None: raise ValueError( "Environment variable DFT_MODELS not set. Please set it to the directory containing the models." ) dft_models = Path(dft_models) run_path = dft_models / "train/runs" / f"{name}" return run_path
NAMED_MODELS = { "str25_qm9": get_runpath("trained-on-qm9"), "str25_qmugs": get_runpath("trained-on-qmugs"), }
[docs] def get_gradient_descent_optimizer(optimizer_args: OptimizerConfig) -> GradientDescent: """Instantiate a simple gradient descent optimizer.""" return GradientDescent( learning_rate=optimizer_args.lr, max_cycle=optimizer_args.max_cycle, convergence_tolerance=optimizer_args.convergence_tolerance, )
[docs] def get_gradient_descent_torch_optimizer( optimizer_args: OptimizerConfig, ) -> TorchOptimizer: """Instantiate a gradient descent optimizer using PyTorch's SGD.""" return TorchOptimizer( torch.optim.SGD, lr=optimizer_args.lr, momentum=optimizer_args.momentum, max_cycle=optimizer_args.max_cycle, convergence_tolerance=optimizer_args.convergence_tolerance, )
[docs] def get_vector_adam_optimizer(optimizer_args: OptimizerConfig) -> VectorAdam: """Instantiate a Vector Adam optimizer.""" return VectorAdam( learning_rate=optimizer_args.lr, betas=cast("tuple[float, float]", tuple(optimizer_args.betas)), max_cycle=optimizer_args.max_cycle, convergence_tolerance=optimizer_args.convergence_tolerance, )
OPTIMIZER_CHOICES = { "gradient-descent": get_gradient_descent_optimizer, "gradient-descent-torch": get_gradient_descent_torch_optimizer, "vector-adam": get_vector_adam_optimizer, }
[docs] def get_optimizer_from_optimizer_args( optimizer_args: OptimizerConfig, ) -> GradientDescent | TorchOptimizer | VectorAdam: """Instantiate an optimizer from optimizer arguments.""" optimizer_name = optimizer_args.optimizer if optimizer_name not in OPTIMIZER_CHOICES: raise ValueError( f"Unknown optimizer {optimizer_name}. Choose from {list(OPTIMIZER_CHOICES.keys())}." ) return OPTIMIZER_CHOICES[optimizer_name](optimizer_args)
[docs] def get_sample_generator_from_model_args( model_args: ModelConfig, ) -> SampleGenerator: """Instantiate a SampleGenerator from model arguments.""" model_name = model_args.model if model_name in NAMED_MODELS: run_path = NAMED_MODELS[model_name] else: run_path = Path(model_name) sample_generator = SampleGenerator.from_run_path( run_path, device=model_args.device, transform_device=model_args.transform_device, negative_integrated_density_penalty_weight=( model_args.negative_integrated_density_penalty_weight ), use_last_ckpt=model_args.use_last_ckpt, ) return sample_generator
[docs] def get_xyzfiles_from_base_args(base_args: BaseConfig) -> list[Path]: """Get a list of XYZ files from base arguments.""" xyzfiles = [] for path in base_args.xyzfile: if not path.exists(): raise FileNotFoundError(f"Path {path} does not exist.") if path.is_file() and path.suffix == ".xyz": xyzfiles.append(path) elif path.is_dir(): xyzfiles.extend(sorted(path.glob("*.xyz"))) return xyzfiles