import functools
import os
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Type
import numpy as np
import torch
from scipy.optimize import LinearConstraint, minimize
from torch_geometric.nn import global_add_pool
from tqdm import tqdm
from mldft.ml.data.components.of_data import OFData
from mldft.ml.models.components.loss_function import project_gradient
from mldft.ofdft.energies import Energies
[docs]
class Optimizer(ABC):
"""Base class for optimization algorithms for density optimization."""
[docs]
@abstractmethod
def optimize(
self,
sample: OFData,
energy_functional: Callable[[OFData], tuple[Energies, torch.Tensor]],
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Perform density optimization.
Args:
sample: The OFData containing the initial coefficients.
energy_functional: Callable which returns the energy and gradient vector.
callback: Optional callback function.
disable_pbar: Whether to disable the progress bar.
Returns:
Final energy.
"""
raise NotImplementedError
[docs]
def __str__(self) -> str:
"""Return a string representation of the optimizer."""
name = self.__class__.__name__
settings = ", ".join(f"{k}={v}" for k, v in vars(self).items())
return f"{name}({settings})"
[docs]
def get_pbar_str(sample: OFData, energy: Energies, gradient_norm: float) -> str:
"""Return a string for the tqdm progress bar.
Args:
sample: The OFData containing the current coefficients. If the ground state energy is
available, the energy difference to the ground state is calculated.
energy: The current energy.
gradient_norm: The norm of the gradient vector.
Returns:
A string for the tqdm progress bar.
"""
if "of_labels/energies/ground_state_e_electron" in sample:
ground_state_electronic_energy = sample["of_labels/energies/ground_state_e_electron"]
if isinstance(ground_state_electronic_energy, torch.Tensor):
ground_state_electronic_energy = ground_state_electronic_energy.item()
delta_e = energy.electronic_energy - ground_state_electronic_energy
delta_e *= 1e3 # convert to mHa
return f"ΔE={delta_e:.3e} mHa, grad_norm={gradient_norm:.3e}"
else:
return f"E_elec={energy.electronic_energy:.6f} Ha, grad_norm={gradient_norm:.3e}"
[docs]
class GradientDescent(Optimizer):
"""Simple gradient descent optimizer."""
[docs]
def __init__(self, learning_rate: float, convergence_tolerance: float, max_cycle: int):
"""Initialize the gradient descent optimizer.
Args:
max_cycle: Maximum number of optimization cycles.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
learning_rate: The learning rate.
"""
self.learning_rate = learning_rate
self.convergence_tolerance = convergence_tolerance
self.max_cycle = max_cycle
[docs]
def optimize(
self,
sample: OFData,
energy_functional: Callable,
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Perform gradient descent optimization."""
converged = False
for cycle in (
pbar := tqdm(
range(self.max_cycle),
leave=False,
dynamic_ncols=True,
position=int(os.getenv("DENOP_PID", 0)),
disable=disable_pbar,
)
):
energy, gradient_vector = energy_functional(sample)
projected_gradient = project_gradient(gradient_vector, sample)
gradient_norm = torch.norm(projected_gradient).item()
pbar.set_description(get_pbar_str(sample, energy, gradient_norm))
if callable(callback):
coeffs = sample.coeffs # for callback
learning_rate = self.learning_rate
callback(locals())
if gradient_norm < self.convergence_tolerance:
converged = True
break
sample.coeffs -= self.learning_rate * projected_gradient
pbar.close()
return energy, converged
[docs]
class TorchOptimizer(Optimizer):
"""Wrapper for torch optimizers to be used in the optimization loop."""
[docs]
def __init__(
self,
torch_optimizer: Type[torch.optim.Optimizer],
convergence_tolerance: float,
max_cycle: int,
**optimizer_kwargs,
):
"""Initialize the torch optimizer.
Args:
torch_optimizer: The torch optimizer to use. To be able to apply the optimizer with
hydra, the class is partially applied without any arguments.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
max_cycle: Maximum number of optimization cycles.
optimizer_kwargs: Additional keyword arguments for the optimizer.
"""
self.torch_optimizer = torch_optimizer
self.convergence_tolerance = convergence_tolerance
self.max_cycle = max_cycle
self.optimizer_kwargs = optimizer_kwargs
[docs]
def optimize(
self,
sample: OFData,
energy_functional: Callable,
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Optimization loop for a torch optimizer."""
parameters = sample.coeffs.clone()
optimizer = self.torch_optimizer([parameters], **self.optimizer_kwargs)
converged = False
for cycle in (
pbar := tqdm(
range(self.max_cycle),
leave=False,
dynamic_ncols=True,
position=int(os.getenv("DENOP_PID", 0)),
disable=disable_pbar,
)
):
energy, gradient_vector = energy_functional(sample)
projected_gradient = project_gradient(gradient_vector, sample)
gradient_norm = torch.norm(projected_gradient).item()
pbar.set_description(get_pbar_str(sample, energy, gradient_norm))
if callable(callback):
coeffs = sample.coeffs # for callback
if "lr" in self.optimizer_kwargs:
learning_rate = self.optimizer_kwargs["lr"]
else:
learning_rate = 0
callback(locals())
if gradient_norm < self.convergence_tolerance:
converged = True
break
optimizer.zero_grad()
parameters.grad = projected_gradient # FIXME: could also use normal gradient
optimizer.step()
update = parameters - sample.coeffs
# projecting the update is important for, e.g., the Adam optimizer
sample.coeffs += project_gradient(update, sample)
# update parameters but clone to not update sample.coeffs in optimizer step
parameters.data = sample.coeffs.detach().clone()
pbar.close()
return energy, converged
[docs]
def __str__(self) -> str:
"""Return a string representation of the optimizer.
This method is overwritten since the torch optimizer can either be a class via direct
instantiation or via hydra with a partial.
"""
if isinstance(self.torch_optimizer, functools.partial): # if called with config
name = self.torch_optimizer.func.__name__
else:
name = self.torch_optimizer.__class__.__name__
settings = ", ".join(f"{k}={v}" for k, v in self.optimizer_kwargs.items())
return f"{name}({settings})"
[docs]
class VectorAdam(Optimizer):
"""Equivariant version of the Adam optimizer."""
[docs]
def __init__(
self,
max_cycle: int,
learning_rate: float,
convergence_tolerance: float,
betas: tuple[float, float] = (0.9, 0.999),
epsilon: float = 1e-8,
):
"""Initialize the equivariant version of the Adam optimizer.
Args:
max_cycle: Maximum number of optimization cycles.
learning_rate: Learning rate function.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
betas: Exponential decay rates for the moment estimates.
epsilon: Small value to avoid division by zero.
"""
self.max_cycle = max_cycle
self.learning_rate = learning_rate
self.convergence_tolerance = convergence_tolerance
self.betas = betas
self.epsilon = epsilon
[docs]
def optimize(
self,
sample: OFData,
energy_functional: Callable,
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Perform equivariant VectorAdam optimization."""
converged = False
m = torch.zeros_like(sample.coeffs)
v = torch.zeros_like(sample.coeffs)
beta1, beta2 = self.betas
shell_beginning_mask = sample.basis_info.shell_beginning_mask
shell_beginning_mask = torch.as_tensor(shell_beginning_mask, device=sample.coeffs.device)
shell_indices = torch.cumsum(shell_beginning_mask[sample.basis_function_ind], dim=0) - 1
for cycle in (
pbar := tqdm(
range(self.max_cycle),
dynamic_ncols=True,
position=int(os.getenv("DENOP_PID", 0)),
disable=disable_pbar,
)
):
energy, gradient_vector = energy_functional(sample)
projected_gradient = project_gradient(gradient_vector, sample)
gradient_norm = torch.norm(projected_gradient).item()
m = beta1 * m + (1 - beta1) * projected_gradient
# squared gradient per shell
v_per_irrep = global_add_pool(projected_gradient**2, shell_indices)
v = beta2 * v + (1 - beta2) * v_per_irrep[shell_indices]
m_hat = m / (1 - beta1 ** (cycle + 1))
v_hat = v / (1 - beta2 ** (cycle + 1))
update = self.learning_rate * m_hat / (torch.sqrt(v_hat) + self.epsilon)
update = project_gradient(update, sample)
pbar.set_description(get_pbar_str(sample, energy, gradient_norm))
if callable(callback):
coeffs = sample.coeffs # for callback
learning_rate = self.learning_rate
callback(locals())
if gradient_norm < self.convergence_tolerance:
converged = True
break
sample.coeffs -= update
pbar.close()
return energy, converged
[docs]
def scipy_functional(
coeffs: np.ndarray,
sample: OFData,
energy_functional: Callable,
convergence_tolerance: float,
use_projected_gradient: bool,
pbar: tqdm,
callback: Callable | None = None,
) -> tuple[float, np.ndarray]:
"""Functional for scipy optimizers.
Make the energy functional compatible with scipy optimizers.
Args:
coeffs: Input coefficients.
sample: OFData containing the basis functions and integrals.
energy_functional: Callable which returns the energy and gradient vector.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
use_projected_gradient: Whether to use the projected gradient for the optimization step.
pbar: tqdm progress bar.
callback: Optional callback function.
Returns:
Energy and gradient vector.
"""
sample.coeffs = torch.tensor(coeffs, dtype=torch.float64, device=sample.coeffs.device)
energy, gradient = energy_functional(sample)
projected_gradient = project_gradient(gradient, sample)
gradient_norm = torch.norm(projected_gradient).item()
if callable(callback):
coeffs = torch.tensor(coeffs)
learning_rate = 0 # needed for callback
callback(locals())
pbar.set_description(get_pbar_str(sample, energy, gradient_norm))
if use_projected_gradient:
gradient_np = projected_gradient.cpu().numpy()
else:
gradient_np = gradient.cpu().numpy()
if gradient_norm < convergence_tolerance:
raise StopIteration
pbar.close()
return energy.electronic_energy, gradient_np
[docs]
class SLSQP(Optimizer):
"""Wrapper for the SLSQP (Sequential Least Squares Programming) optimizer from scipy."""
[docs]
def __init__(
self,
max_cycle: int,
convergence_tolerance: float,
grad_scale: float,
use_projected_gradient: bool = True,
):
"""Initialize the SLSQP optimizer.
Args:
max_cycle: Maximum number of optimization cycles. Note that this is not the same as the
number of functional evaluations.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
grad_scale: Scaling factor for the gradient vector.
use_projected_gradient: Whether to use the projected gradient for the optimization
step.
"""
self.max_cycle = max_cycle
self.convergence_tolerance = convergence_tolerance
self.grad_scale = grad_scale
self.use_projected_gradient = use_projected_gradient
[docs]
def optimize(
self,
sample: OFData,
energy_functional: Callable,
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Perform density optimization using the SLSQP optimizer."""
converged = False
normalization_vector = sample.dual_basis_integrals.cpu().numpy()
n_electron = sample.n_electron
pbar = tqdm(
range(self.max_cycle),
leave=False,
dynamic_ncols=True,
position=int(os.getenv("DENOP_PID", 0)),
disable=disable_pbar,
)
iterator = iter(pbar)
functional = functools.partial(
scipy_functional,
sample=sample,
energy_functional=energy_functional,
convergence_tolerance=self.convergence_tolerance,
use_projected_gradient=self.use_projected_gradient,
pbar=pbar,
callback=callback,
)
# decrease the gradient to make the optimization more stable
def functional_scaled(coeffs: np.ndarray) -> tuple[float, np.ndarray]:
energy, gradient = functional(coeffs)
return energy, self.grad_scale * gradient
def electron_number_constraint(coeffs: np.ndarray) -> float:
return np.dot(coeffs, normalization_vector) - n_electron
def derivative_electron_number_constraint(_) -> np.ndarray:
return normalization_vector
constraint = {
"type": "eq",
"fun": electron_number_constraint,
"jac": derivative_electron_number_constraint,
}
try:
minimize(
functional_scaled,
x0=sample.coeffs.cpu().numpy(),
method="SLSQP",
jac=True,
constraints=constraint,
options={"maxiter": self.max_cycle, "ftol": 0}, # disable tolerance
callback=lambda _: next(iterator),
)
except StopIteration:
converged = True
pass # StopIteration is raised when the gradient norm is below the tolerance
pbar.close()
return energy_functional(sample)[0], converged
[docs]
class TrustRegionConstrained(Optimizer):
"""Wrapper for the trust region constrained optimizer from scipy."""
[docs]
def __init__(
self,
max_cycle: int,
convergence_tolerance: float,
initial_tr_radius: float,
initial_constr_penalty: float,
use_projected_gradient: bool = True,
):
"""Initialize the trust region constrained optimizer.
Args:
convergence_tolerance:
max_cycle: Maximum number of optimization cycles.
convergence_tolerance: Optimization stops if the gradient norm is below this value.
initial_tr_radius: Initial trust radius. Affects the size of the first steps.
initial_constr_penalty: Initial constraint penalty.
use_projected_gradient: Whether to use the projected gradient for the optimization step.
"""
self.max_cycle = max_cycle
self.convergence_tolerance = convergence_tolerance
self.initial_tr_radius = initial_tr_radius
self.initial_constr_penalty = initial_constr_penalty
self.use_projected_gradient = use_projected_gradient
[docs]
def optimize(
self,
sample: OFData,
energy_functional: Callable,
callback: Callable | None = None,
disable_pbar: bool = False,
) -> tuple[Energies, bool]:
"""Perform trust region constrained optimization."""
converged = False
normalization_vector = sample.dual_basis_integrals.cpu().numpy()
n_electron = sample.n_electron
pbar = tqdm(
range(self.max_cycle),
leave=False,
dynamic_ncols=True,
position=int(os.getenv("DENOP_PID", 0)),
disable=disable_pbar,
)
iterator = iter(pbar)
functional = functools.partial(
scipy_functional,
sample=sample,
energy_functional=energy_functional,
convergence_tolerance=self.convergence_tolerance,
use_projected_gradient=self.use_projected_gradient,
pbar=pbar,
callback=callback,
)
constraint = LinearConstraint(normalization_vector, n_electron, n_electron)
def scipy_callback(*_):
next(iterator)
try:
minimize(
functional,
x0=sample.coeffs.cpu().numpy(),
method="trust-constr",
jac=True,
tol=0, # disable tolerance
constraints=constraint,
options={
"maxiter": self.max_cycle,
"initial_tr_radius": self.initial_tr_radius,
"initial_constr_penalty": self.initial_constr_penalty,
},
callback=scipy_callback,
)
except StopIteration:
converged = True
pass # StopIteration is raised when the gradient norm is below the tolerance
pbar.close()
return energy_functional(sample)[0], converged
DEFAULT_DENSITY_OPTIMIZER = TorchOptimizer(
torch.optim.SGD,
convergence_tolerance=1e-4,
max_cycle=1000,
lr=1e-3,
momentum=0.9,
)