from pathlib import Path
import numpy as np
import torch
from loguru import logger
from matplotlib import pyplot as plt
from matplotlib.patches import Patch
from pyscf import dft
from mldft.ml.data.components.basis_transforms import transform_tensor_with_sample
from mldft.ml.data.components.of_data import OFData, Representation
from mldft.ofdft.callbacks import BasicCallback
from mldft.ofdft.energies import Energies
[docs]
def plot_density_optimization(
callback: BasicCallback,
energies_label: Energies,
coeffs_label: torch.Tensor,
sample: OFData,
stopping_index: int = None,
basis_l: torch.Tensor = None,
figure_path: Path | str = None,
enable_grid_operations: bool = False,
):
"""Plot the density optimization.
If a figure path is given, the figure is saved to that path.
Args:
callback: The callback object.
energies_label: The label energies.
coeffs_label: The label coefficients.
sample: The sample used for transformations.
stopping_index: The index at which the optimization stopped. Optional.
basis_l: Angular momentum of basis functions. Optional.
figure_path: The path to save the figure. If None, the figure is not saved.
enable_grid_operations: If False, plots requiring computations on the grid
(the integrated negative density plot) are not shown.
"""
plots = [
_add_all_energy_differences,
# _add_density_difference,
# _add_gradient_norm,
_add_density_and_gradient,
# _add_coefficient_differences_pixels,
_add_coefficient_differences_lines,
]
if enable_grid_operations:
plots.insert(2, _add_integrated_negative_density)
fig, axes = plt.subplots(len(plots), 1, figsize=(9, 4 * len(plots)), sharex=False)
coeffs_label = coeffs_label.detach()
for plot, ax in zip(plots, axes):
plot(
ax=ax,
callback=callback,
energies_label=energies_label,
coeffs_label=coeffs_label,
basis_l=basis_l,
stopping_index=stopping_index,
sample=sample,
)
plt.tight_layout()
if figure_path is not None:
plt.savefig(figure_path)
[docs]
def _add_density_difference(
ax: plt.Axes,
callback: BasicCallback,
energies_label: Energies,
coeffs_label: torch.Tensor,
sample: OFData,
stopping_index: int = None,
**kwargs,
):
"""Add the density difference to the given axis.
Args:
ax: The axis to add the density difference to.
callback: The callback object.
energies_label: The label energies.
coeffs_label: The label coefficients.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
l2_norm = callback.l2_norm
ax.plot(l2_norm, label="coeffs delta")
ax.set_ylabel(r"$\Vert\rho_\mathrm{pred} - \rho_\mathrm{target}\Vert_2$ [electrons]")
ax.set_yscale("log")
ax.grid()
title = "Density difference"
if stopping_index is not None:
ax.axvline(stopping_index, color="black", linestyle="--")
title += f". Error at stopping index: {l2_norm[stopping_index]:.2g} electrons"
ax.set_title(title)
ax.set_xlabel("Iteration")
[docs]
def _add_all_energy_differences(
ax: plt.Axes,
callback: BasicCallback,
energies_label: Energies,
stopping_index: int = None,
**kwargs,
):
"""Add all energy differences to the given axis.
Args:
ax: The axis to add the energy differences to.
callback: The callback object.
energies_label: The label energies.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
linthresh = 10
total_energy = torch.as_tensor([e.total_energy for e in callback.energy], dtype=torch.float64)
ax.plot(1e3 * (total_energy - energies_label.total_energy), label=r"$\Delta E_\mathrm{tot}$")
for energy_name in energies_label.energies_dict:
if energy_name == "nuclear_repulsion":
continue
try:
energy_np = torch.as_tensor([e[energy_name] for e in callback.energy])
except KeyError:
logger.warning(f"Energy '{energy_name}' not found in callback")
continue
label = r"$\Delta E_\mathrm{" + energy_name.replace("_", r"\_") + "}$"
ax.plot(1e3 * (energy_np - energies_label[energy_name]), label=label)
ax.axhline(0, color="black")
if stopping_index is not None:
ax.axvline(stopping_index, color="black", linestyle="--", label="stopping index")
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\Delta E$ [mHa]")
ax.set_yscale("symlog", linthresh=linthresh)
title = "Energy differences"
if stopping_index is not None:
title += (
f". Error at stopping index: "
f"{1e3 * (total_energy[stopping_index] - energies_label.total_energy):.2f} mHa"
)
ax.set_title(title)
ax.legend()
ax.grid(which="both")
for h in torch.linspace(-linthresh, linthresh, 21)[1:-1]:
ax.axhline(h, color="black", alpha=0.1, linestyle="--", linewidth=0.5)
[docs]
def _add_density_and_gradient(
ax: plt.Axes,
callback: BasicCallback,
energies_label: Energies,
coeffs_label: torch.Tensor,
sample: OFData,
stopping_index: int = None,
**kwargs,
):
"""Add the density and gradient to the given axis as a pixel image.
Args:
ax: The axis to add the density and gradient to.
callback: The callback object.
energies_label: The label energies used to obtain the overlap matrix via the mol.
coeffs_label: The label coefficients.
sample: The sample used for transformations.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
_add_density_difference(
ax,
callback=callback,
energies_label=energies_label,
coeffs_label=coeffs_label,
sample=sample,
stopping_index=stopping_index,
)
twin_ax = plt.twinx(ax)
_add_gradient_norm(twin_ax, callback, stopping_index)
twin_ax.grid(False) # don't show grid twice
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = twin_ax.get_legend_handles_labels()
ax.legend(h1 + h2, l1 + l2, loc="upper center")
ax.set_xlabel("Iteration")
[docs]
def _add_gradient_norm(
ax: plt.Axes,
callback: BasicCallback,
stopping_index: int = None,
**kwargs,
):
"""Add the gradient norm to the given axis.
Args:
ax: The axis to add the gradient norm to.
callback: The callback object.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
ax.plot(callback.gradient_norm, label="gradient norm", color="tab:orange")
ax.set_ylabel("gradient norm")
ax.set_xlabel("Iteration")
ax.set_yscale("log")
ax.grid()
if stopping_index is not None:
ax.axvline(stopping_index, color="black", linestyle="--")
[docs]
def _add_coefficient_differences_pixels(
ax: plt.Axes,
callback: BasicCallback,
coeffs_label: torch.Tensor,
**kwargs,
):
"""Add the coefficient differences to the given axis as a pixel image.
Args:
ax: The axis to add the coefficient difference to.
callback: The callback object.
coeffs_label: The label coefficients.
**kwargs: Additional arguments for uniform plot interface.
"""
coeffs = torch.stack(callback.coeffs)
coeffs_delta = (coeffs - coeffs_label).cpu()
coeffs_delta_max = torch.max(torch.abs(coeffs_delta)).item()
img = ax.imshow(
coeffs_delta.T.detach().numpy(),
aspect="auto",
cmap="seismic",
interpolation="none",
vmin=-coeffs_delta_max,
vmax=coeffs_delta_max,
)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"basis function")
plt.colorbar(img, ax=ax, location="top")
[docs]
def _add_coefficient_differences_lines(
ax: plt.Axes,
callback: BasicCallback,
coeffs_label: torch.Tensor,
basis_l: torch.Tensor = None,
stopping_index: int = None,
**kwargs,
):
"""Add the coefficient differences to the given axis as a plot.
Args:
ax: The axis to add the coefficient difference to.
callback: The callback object.
coeffs_label: The label coefficients.
basis_l: Angular momentum of basis functions.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
coeffs = torch.stack(callback.coeffs).to(coeffs_label.device)
coeffs_delta = (coeffs - coeffs_label).cpu()
if basis_l is None:
ax.plot(coeffs_delta.detach().numpy(), lw=0.3)
else:
max_l = np.max(basis_l)
orbital_labels = np.array(["s", "p", "d", "f", "g", "h"])
l_colors = np.array(
["tab:blue", "tab:orange", "tab:green", "tab:red", "tab:purple", "tab:brown"]
)
l_colors_handles = [
Patch(facecolor=color, label=label)
for color, label in zip(l_colors[: max_l + 1], orbital_labels)
]
ax.legend(handles=l_colors_handles, title="Angular Momentum", fontsize="small")
# shuffle z order
np.random.seed(0)
permutation = np.random.permutation(basis_l.shape[0])
for i in permutation:
ax.plot(coeffs_delta[:, i].detach().numpy(), lw=0.3, c=l_colors[basis_l[i]])
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\Delta p$")
ax.grid()
if stopping_index is not None:
ax.axvline(stopping_index, color="black", linestyle="--")
[docs]
def _add_integrated_negative_density(
ax: plt.Axes,
callback: BasicCallback,
energies_label: Energies,
sample: OFData,
stopping_index: int = None,
**kwargs,
):
"""Add the integrated negative density to the given axis.
Args:
ax: The axis to add the integrated negative density to.
energies_label: The label energies.
sample: The sample used for transformations.
stopping_index: The index at which the optimization stopped. Optional.
**kwargs: Additional arguments for uniform plot interface.
"""
mol = sample.mol
coeffs = torch.stack(callback.coeffs)
grid = dft.Grids(mol)
grid.level = 3
grid.prune = None
grid.build()
weights = torch.as_tensor(grid.weights, dtype=torch.float64)
# weights = torch.clip(weights, min=0, max=np.inf)
ao = dft.numint.eval_ao(mol, grid.coords, deriv=0)
ao = torch.as_tensor(ao, dtype=torch.float64, device=sample.coeffs.device)
ao = transform_tensor_with_sample(sample, ao, Representation.AO)
ao = ao.to("cpu")
density = coeffs @ ao.T
density[density > 0] = 0
integrated_negative_density = density @ weights
ax.plot(integrated_negative_density, label="integrated negative density")
ax.set_xlabel("Iteration")
ax.set_ylabel("int. neg. density [electrons]")
ax.grid()
if stopping_index is not None:
ax.axvline(stopping_index, color="black", linestyle="--")