"""Factories that translate CLI-style arguments into MLDFT runtime components.
The helpers in this module encapsulate how user-supplied configuration is converted into concrete
model checkpoints, optimizers, and input molecule lists.
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence, cast
import torch # type: ignore[import]
from mldft.ofdft.optimizer import GradientDescent, TorchOptimizer, VectorAdam
from mldft.ofdft.run_density_optimization import SampleGenerator
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
[docs]
@dataclass
class BaseConfig:
"""Configuration required to construct the base OFDFT inputs."""
xyzfile: tuple[str | Path, ...]
charge: int = 0
initialization: str = "minao"
normalize_initial_guess: bool = True
save_result: bool = True
ks_basis: str | None = None
proj_minao_module: str | None = None
sad_guess_kwargs: dict | None = None
disable_printing: bool | None = None
def __post_init__(self) -> None:
self.xyzfile = tuple(Path(path) for path in self.xyzfile)
[docs]
@dataclass
class ModelConfig:
"""Configuration describing the trained MLDFT model to use."""
model: str
use_last_ckpt: bool = True
device: str = DEFAULT_DEVICE
transform_device: str = "cpu"
negative_integrated_density_penalty_weight: float = 0.0
[docs]
@dataclass
class OptimizerConfig:
"""Configuration controlling the density optimization routine."""
optimizer: str = "gradient-descent-torch"
max_cycle: int = 10000
convergence_tolerance: float = 1e-4
lr: float = 1e-3
momentum: float = 0.9
betas: Sequence[float] = (0.9, 0.999)
def __post_init__(self) -> None:
self.betas = tuple(self.betas)
[docs]
def get_runpath(name: str) -> Path:
"""Get the path to a named model."""
dft_models = os.environ.get("DFT_MODELS")
if dft_models is None:
raise ValueError(
"Environment variable DFT_MODELS not set. Please set it to the directory containing the models."
)
dft_models = Path(dft_models)
run_path = dft_models / "train/runs" / f"{name}"
return run_path
NAMED_MODELS = {
"str25_qm9": get_runpath("trained-on-qm9"),
"str25_qmugs": get_runpath("trained-on-qmugs"),
}
[docs]
def get_gradient_descent_optimizer(optimizer_args: OptimizerConfig) -> GradientDescent:
"""Instantiate a simple gradient descent optimizer."""
return GradientDescent(
learning_rate=optimizer_args.lr,
max_cycle=optimizer_args.max_cycle,
convergence_tolerance=optimizer_args.convergence_tolerance,
)
[docs]
def get_gradient_descent_torch_optimizer(
optimizer_args: OptimizerConfig,
) -> TorchOptimizer:
"""Instantiate a gradient descent optimizer using PyTorch's SGD."""
return TorchOptimizer(
torch.optim.SGD,
lr=optimizer_args.lr,
momentum=optimizer_args.momentum,
max_cycle=optimizer_args.max_cycle,
convergence_tolerance=optimizer_args.convergence_tolerance,
)
[docs]
def get_vector_adam_optimizer(optimizer_args: OptimizerConfig) -> VectorAdam:
"""Instantiate a Vector Adam optimizer."""
return VectorAdam(
learning_rate=optimizer_args.lr,
betas=cast("tuple[float, float]", tuple(optimizer_args.betas)),
max_cycle=optimizer_args.max_cycle,
convergence_tolerance=optimizer_args.convergence_tolerance,
)
OPTIMIZER_CHOICES = {
"gradient-descent": get_gradient_descent_optimizer,
"gradient-descent-torch": get_gradient_descent_torch_optimizer,
"vector-adam": get_vector_adam_optimizer,
}
[docs]
def get_optimizer_from_optimizer_args(
optimizer_args: OptimizerConfig,
) -> GradientDescent | TorchOptimizer | VectorAdam:
"""Instantiate an optimizer from optimizer arguments."""
optimizer_name = optimizer_args.optimizer
if optimizer_name not in OPTIMIZER_CHOICES:
raise ValueError(
f"Unknown optimizer {optimizer_name}. Choose from {list(OPTIMIZER_CHOICES.keys())}."
)
return OPTIMIZER_CHOICES[optimizer_name](optimizer_args)
[docs]
def get_sample_generator_from_model_args(
model_args: ModelConfig,
) -> SampleGenerator:
"""Instantiate a SampleGenerator from model arguments."""
model_name = model_args.model
if model_name in NAMED_MODELS:
run_path = NAMED_MODELS[model_name]
else:
run_path = Path(model_name)
sample_generator = SampleGenerator.from_run_path(
run_path,
device=model_args.device,
transform_device=model_args.transform_device,
negative_integrated_density_penalty_weight=(
model_args.negative_integrated_density_penalty_weight
),
use_last_ckpt=model_args.use_last_ckpt,
)
return sample_generator
[docs]
def get_xyzfiles_from_base_args(base_args: BaseConfig) -> list[Path]:
"""Get a list of XYZ files from base arguments."""
xyzfiles = []
for path in base_args.xyzfile:
if not path.exists():
raise FileNotFoundError(f"Path {path} does not exist.")
if path.is_file() and path.suffix == ".xyz":
xyzfiles.append(path)
elif path.is_dir():
xyzfiles.extend(sorted(path.glob("*.xyz")))
return xyzfiles