Source code for mldft.utils.plotting.summary_density_optimization

from functools import partial
from pathlib import Path
from typing import Union

import numpy as np
import seaborn as sns
import torch
import yaml
from loguru import logger
from matplotlib import lines as mlines
from matplotlib import pyplot as plt
from matplotlib.patches import Patch
from pyscf import dft
from torch_scatter import scatter
from tqdm import tqdm

from mldft.ml.data.components.of_data import BasisInfo, OFData
from mldft.utils.grids import compute_density_density_basis
from mldft.utils.molecules import build_molecule_ofdata
from mldft.utils.pdf_utils import HierarchicalPlotPDF
from mldft.utils.plotting.axes import format_basis_func_xaxis
from mldft.utils.plotting.symlog_locater import MinorSymLogLocator


[docs] def get_runwise_density_optimization_data( sample_dir: Path | str, n_molecules: int = None, energy_names: str = None, plot_l1_norm: bool = True, l1_grid_level: int = 3, l1_grid_prune: str = "nwchem_prune", ) -> dict: """Get the data for a set of molecules for the density optimization process from a directory of saved samples. Args: sample_dir: Path to the directory containing the samples. These samples are expected to result from the :mod:`mldft.ofdft.run_ofdft.py` script. n_molecules: Number of molecules to plot. If None, all molecules in the directory are plotted. energy_names: Names of the energies to be plotted. If None, all available energies are plotted. Returns: run_data_dict: Dictionary containing the stopping indices, energy trajectories, gradient norms, coefficient trajectories and density differences for the set of molecules. """ if isinstance(sample_dir, str): sample_dir = Path(sample_dir) assert ( sample_dir.is_dir() ), f"Density Optimization sample directory does not exist at {sample_dir}" if isinstance(n_molecules, int): assert n_molecules <= len( list(sample_dir.iterdir()) ), "Number of molecules to plot exceeds the number of molecules in the directory" elif n_molecules is None: # Count the number of samples (molecules) in the directory n_molecules = len(list(sample_dir.iterdir())) if (sample_dir / "basis_info.pt").is_file(): n_molecules -= 1 else: raise ValueError( f"n_molecules must be an integer or None but is of type {type(n_molecules)}" ) stopping_indices = torch.zeros(n_molecules, dtype=int) num_atoms = torch.zeros(n_molecules, dtype=int) num_electrons = torch.zeros(n_molecules, dtype=int) energy_trajectories_dict, energy_ground_state_dict = initialize_energy_dicts( sample_dir=sample_dir, n_molecules=n_molecules, energy_names=energy_names, ) gradient_norms = [] coeffs_trajectories = [] coeffs_ground_state = [] coeffs_pred_ground_state = [] density_differences_l2 = [] density_differences_l1 = [] run_times = [] try: basis_info = torch.load( sample_dir / "basis_info.pt", map_location=torch.device("cpu"), weights_only=False ) basis_function_indices = torch.as_tensor( np.concatenate(basis_info.atom_ind_to_basis_function_ind) ) # to match trajectory and basis_function_inds dtype cumulative_coeff_error = torch.zeros_like(basis_function_indices, dtype=torch.float64) cumulative_counts = torch.zeros_like(basis_function_indices, dtype=torch.int64) except Exception as e: logger.warning(f"{e} Plots requiring basis_info will skipped") basis_info = None i = 0 for file_path in tqdm( sorted(sample_dir.glob("sample*.pt")), desc="Loading samples", dynamic_ncols=True, leave=False, ): sample: OFData = torch.load( file_path, map_location=torch.device("cpu"), weights_only=False ) assert sample.stopping_index is not None stopping_indices[i] = sample.stopping_index num_atoms[i] = sample.n_atom num_electrons[i] = sample.n_electron # energy names are set correctly in the dict initialization for energy_name in energy_trajectories_dict.keys(): energy_trajectories_dict[energy_name].append( 1e3 * (sample[f"trajectory_energy_{energy_name}"].detach().cpu()) ) if f"ground_state_energy_{energy_name}" in sample.keys(): if isinstance(sample[f"ground_state_energy_{energy_name}"], torch.Tensor): energy_ground_state_dict[energy_name][i] = 1e3 * ( sample[f"ground_state_energy_{energy_name}"].detach().cpu() ) elif isinstance(sample[f"ground_state_energy_{energy_name}"], float): energy_ground_state_dict[energy_name][i] = torch.as_tensor( 1e3 * sample[f"ground_state_energy_{energy_name}"], dtype=torch.float64 ) else: logger.warning(f"{energy_name} not in ground state energies but in trajectories.") gradient_norms.append(sample.trajectory_gradient_norm.detach().cpu()) coeffs_trajectories.append(sample.trajectory_coeffs.detach().cpu()) coeffs_ground_state.append(sample.ground_state_coeffs.detach().cpu()) # new samples from run_density_optimization have predicted_ground_state_coeffs if hasattr(sample, "predicted_ground_state_coeffs"): coeffs_pred_ground_state.append(sample.predicted_ground_state_coeffs.detach().cpu()) # old samples from run_density_optimization have sample.coeffs elif not hasattr(sample, "transformation_matrix"): coeffs_pred_ground_state.append(sample.coeffs.detach().cpu()) # Add to sample for l1 norm calculation sample.add_item("predicted_ground_state_coeffs", sample.coeffs, "vector") # Samples from run_ofdft have the full trajectory_coeffs else: predicted_ground_state_coeffs = sample.trajectory_coeffs[sample.stopping_index] coeffs_pred_ground_state.append(predicted_ground_state_coeffs.detach().cpu()) # Add to sample for l1 norm calculation sample.add_item( "predicted_ground_state_coeffs", predicted_ground_state_coeffs, "vector" ) # done inside this loop as we need the samples basis_function_ind cumulative_coeff_error, cumulative_counts = cumulate_coeff_error( sample=sample, pred_ground_state_coeffs=coeffs_pred_ground_state[-1], cumulative_error=cumulative_coeff_error, cumulative_counts=cumulative_counts, ) # get density difference right away as this needs the sample's overlap matrix if hasattr(sample, "trajectory_l2_norm"): density_differences_l2.append(sample.trajectory_l2_norm.detach().cpu()) else: density_differences_l2.append(get_density_difference_l2_norm(sample)) if hasattr(sample, "time") and sample.time is not None: run_times.append(sample.time) else: run_times.append(np.nan) if plot_l1_norm: density_differences_l1.append( get_density_difference_l1_norm(sample, basis_info, l1_grid_level, l1_grid_prune) ) i += 1 # sanity check assert n_molecules == i, "Number of molecules is ambiguous" if basis_info is None: logger.warning( "Basis info not found in sample directory. Plots requiring basis info will be skipped." ) run_data_dict = { "n_molecules": n_molecules, "basis_info": basis_info, "basis_function_indices": basis_function_indices, "sample_basis_function_ind": sample.basis_function_ind, # only for fixed number of aos "stopping_indices": stopping_indices, "energy_trajectories_dict": energy_trajectories_dict, "energy_ground_state_dict": energy_ground_state_dict, "gradient_norms": gradient_norms, "coeffs_trajectories": coeffs_trajectories, "coeffs_ground_state": coeffs_ground_state, "coeffs_pred_ground_state": coeffs_pred_ground_state, "density_differences_l2": density_differences_l2, "cumulative_coeff_error": cumulative_coeff_error, "cumulative_counts": cumulative_counts, "num_atoms": num_atoms, "num_electrons": num_electrons, "run_times": run_times, } if plot_l1_norm: run_data_dict["density_differences_l1"] = density_differences_l1 return run_data_dict
[docs] def save_density_optimization_metrics( output_path: Path, run_data_dict: dict, ): """Save the density optimization metrics for a set of molecules to a file. Args: output_path: Path where to save the metrics yaml file. run_data_dict: Dictionary containing the n_molecules, basis_info, basis_function_indices, stopping indices, energy trajectories, gradient norms, coefficient trajectories and density differences for the set of molecules. """ num_atoms = run_data_dict["num_atoms"] num_electrons = run_data_dict["num_electrons"] stopping_indices = run_data_dict["stopping_indices"] # These keys will be at the top of the yaml file benchmark_keys = [ "n_molecules", "mean_total_energy_error[mHa]", "mean_per_atom_total_energy_error[mHa]", "mean_per_electron_density_error_l1[%]", "mean_gradient_norm", "nonconverged_molecules_ratio[%]", ] metric_dict = {key: np.nan for key in benchmark_keys} metric_dict["n_molecules"] = run_data_dict["n_molecules"] energy_errors, absolute_energy_errors = get_energy_ground_state_errors( energy_trajectories_dict=run_data_dict["energy_trajectories_dict"], energy_ground_state_dict=run_data_dict["energy_ground_state_dict"], stopping_indices=run_data_dict["stopping_indices"], ) energy_keys = ["total"] for energy_key in energy_keys: energy_errors = absolute_energy_errors[energy_key] metric_dict[f"mean_{energy_key}_energy_error[mHa]"] = torch.mean(energy_errors).item() metric_dict[f"mean_per_atom_{energy_key}_energy_error[mHa]"] = torch.mean( energy_errors / num_atoms ).item() metric_dict[f"median_total_{energy_key}_energy_error[mHa]"] = torch.median( energy_errors ).item() metric_dict[f"median_per_atom_{energy_key}_energy_error[mHa]"] = torch.median( energy_errors / num_atoms ).item() metric_dict[f"ratio_of_molecules_below_1mHa_{energy_key}_energy_error"] = torch.sum( energy_errors < 1 ).item() / len(absolute_energy_errors["total"]) metric_dict[ f"ratio_of_molecules_below_1mHa_per_atom_{energy_key}_energy_error" ] = torch.sum(energy_errors / num_atoms < 1).item() / len(absolute_energy_errors["total"]) metric_dict[f"90th_percentile_{energy_key}_energy_error[mHa]"] = torch.quantile( energy_errors, 0.9 ).item() stopping_density_differences_l2 = torch.stack( [ tensor[index] for tensor, index in zip(run_data_dict["density_differences_l2"], stopping_indices) ] ) metric_dict["mean_density_error_l2"] = torch.mean(stopping_density_differences_l2).item() metric_dict["mean_per_electron_density_error_l2[%]"] = ( torch.mean(stopping_density_differences_l2 / num_electrons).item() * 100 ) metric_dict["median_density_error_l2"] = torch.median(stopping_density_differences_l2).item() metric_dict["median_per_electron_density_error_l2[%]"] = ( torch.median(stopping_density_differences_l2 / num_electrons).item() * 100 ) if "density_differences_l1" in run_data_dict: density_differences_l1 = run_data_dict["density_differences_l1"] assert all( d.ndim == 0 for d in density_differences_l1 ), "Density differences must be 0D to compute mean L1 norm" stopping_density_differences_l1 = torch.stack(density_differences_l1) metric_dict["mean_density_error_l1"] = torch.mean(stopping_density_differences_l1).item() metric_dict["mean_per_electron_density_error_l1[%]"] = ( torch.mean(stopping_density_differences_l1 / num_electrons).item() * 100 ) metric_dict["median_density_error_l1"] = torch.median( stopping_density_differences_l1 ).item() metric_dict["median_per_electron_density_error_l1[%]"] = ( torch.median(stopping_density_differences_l1 / num_electrons).item() * 100 ) stopping_gradient_norms = torch.stack( [tensor[index] for tensor, index in zip(run_data_dict["gradient_norms"], stopping_indices)] ) metric_dict["mean_gradient_norm"] = torch.mean(stopping_gradient_norms).item() metric_dict["median_gradient_norm"] = torch.median(stopping_gradient_norms).item() metric_dict["nonconverged_molecules_ratio[%]"] = ( torch.sum(stopping_gradient_norms > 1e-4).item() / len(stopping_gradient_norms) * 100 ) metric_dict["average_run_time[s]"] = float( np.mean(run_data_dict["run_times"]) ) # yaml can't handle np.float64 logger.info("Successfully computed metrics:") # Round values for nicer formatting for key, value in metric_dict.items(): if isinstance(value, float): metric_dict[key] = round(value, 6) print(f"{key}: {metric_dict[key]}") with open(output_path, "w") as f: yaml.safe_dump(metric_dict, f, sort_keys=False) return
[docs] def density_optimization_summary_pdf_plot( out_pdf_path: Path | str, run_data_dict: dict, matplotlib_backend: str = "pdf", subsample: float = 1.0, ): """Create a summary page for the density optimization for a set of molecules. Args: out_pdf_path: Path to the output PDF file. run_data_dict: Dictionary containing the n_molecules, basis_info, basis_function_indices, stopping indices, energy trajectories, gradient norms, coefficient trajectories and density differences for the set of molecules. matplotlib_backend: Matplotlib backend to use for plotting. By default the pdf backend is used to create vectorized PDF plots, while for instance 'agg' could be used to create rasterized PNG plots. subsample: Fraction of molecules to plot in the swarm plots (individual molecule trajectories). By default (1.0), all molecules are plotted. Only takes effect, if n_molecules is larger than the number of molecules to plot and n_molecules > 2. """ if matplotlib_backend is not None: plt.switch_backend(matplotlib_backend) with HierarchicalPlotPDF( out_pdf_path=out_pdf_path, ) as summary_pdf: plot_ofdft_run_summary(**run_data_dict) plt.suptitle("Density Optimization Run Summary") # Using constrained layout, we have to set the bbox_inches and pad_inches to replace rect summary_pdf.savefig( "Density Optimization Run Summary", bbox_inches="tight", pad_inches=0.8 ) plt.close() plot_ofdft_energy_distribution(**run_data_dict) plt.suptitle("Energy Error Distribution") summary_pdf.savefig("Energy Error Distribution") plt.close() plot_density_optimization_trajectory_means(**run_data_dict) plt.tight_layout(rect=[0, 0.03, 1, 0.97]) plt.suptitle("Density Optimization Mean Trajectories") summary_pdf.savefig("Density Optimization Mean Trajectories") plt.close() density_optimization_swarm_plot(**run_data_dict, subsample=subsample) plt.tight_layout(rect=[0, 0.03, 1, 0.97]) plt.suptitle("Individual Molecule Density Optimization Trajectories") summary_pdf.savefig("Individual Molecule Density Optimization Trajectories") plt.close() plot_energy_summary_scatter(**run_data_dict) plt.tight_layout(rect=[0, 0.03, 1, 0.97]) plt.suptitle("Stopping vs Initial Energy Errors") summary_pdf.savefig("Stopping vs Initial Energy Errors") plt.close()
[docs] def density_optimization_swarm_plot( energy_trajectories_dict: dict[str, torch.Tensor], energy_ground_state_dict: dict[str, torch.Tensor | float], gradient_norms: torch.Tensor | list[torch.Tensor], density_differences_l2: torch.Tensor | list[torch.Tensor], stopping_indices: torch.Tensor, n_molecules: int = None, energy_names: tuple[str] | str = None, subsample: float = 1.0, **_, ): """Summarize the density optimization process for a set of molecules, by plotting a line for every molecule, showing the energy error, gradient norm and L2 norm of the density error.""" if n_molecules is not None: assert n_molecules <= len(stopping_indices), ( f"n_molecules to plot, {n_molecules}, exceeds the number of available molecules," f" {len(stopping_indices)}" ) else: n_molecules = len(stopping_indices) # all molecules if energy_names is None: energy_names = list(energy_ground_state_dict.keys()) energy_errors_dict, energy_absolute_error_dict = get_energy_errors_dict( energy_trajectories_dict=energy_trajectories_dict, energy_ground_state_dict=energy_ground_state_dict, energy_names=energy_names, ) plots = [ partial( energy_error_swarm_line_plot, energy_errors_dict=energy_errors_dict, energy_name=energy_name, subsample=subsample, ) for energy_name in energy_names ] + [ partial(gradient_norm_swarm_line_plot, gradient_norms=gradient_norms, subsample=subsample), partial( density_differences_swarm_line_plot, density_differences_l2=density_differences_l2, subsample=subsample, ), ] fig, axs = plt.subplots(len(plots), 1, figsize=(9, 5 * len(plots)), sharex=True) for ax, plot in zip(axs, plots): plot(ax=ax, stopping_indices=stopping_indices, n_molecules=n_molecules) ax.grid(True) # share x-axis with energy plot axs[-1].set_xlabel("Iteration") # Create a legend for the line styles solid_line = mlines.Line2D([], [], color="black", linestyle="-", label="Before stopping index") dashed_line = mlines.Line2D( [], [], color="black", linestyle="--", label="After stopping index" ) axs[-1].legend(handles=[solid_line, dashed_line]) fig.tight_layout()
[docs] def initialize_energy_dicts( sample_dir: Path | str, n_molecules: int = None, energy_names: tuple[str] = None, ): """Initialize the energy trajectories dictionary for a set of molecules. We load the first sample in the given directory. As of yet, this seems necessary in order to check for available energy names (ground state and trajectories) and not do it within the loop over the sample directory itself. If the energy names are not provided, all available energies are extracted from the first sample. """ energy_trajectories = {} energy_ground_state_dict = {} # create a dummy iterator to get the first sample and not exhaust the later used iterator dummy_iterator = iter(sample_dir.iterdir()) sample = torch.load(next(dummy_iterator), map_location="cpu", weights_only=False) while not isinstance(sample, OFData): sample = torch.load(next(dummy_iterator), map_location="cpu", weights_only=False) if energy_names is None: energy_names = [] for key in sample.keys(): if ( key.startswith("trajectory_energy") and key != "trajectory_energy_nuclear_repulsion" ): # extract what comes after 'trajectory_energy' energy_name = key[len("trajectory_energy_") :] if f"ground_state_energy_{energy_name}" not in sample.keys(): logger.warning( f" {energy_name} Ground State Energy not found in sample " f"but corresponding energy trajectories are present" ) else: energy_ground_state_dict[energy_name] = torch.zeros(n_molecules) energy_names.append(energy_name) energy_trajectories[energy_name] = [] else: for energy_name in energy_names: if f"trajectory_energy_{energy_name}" not in sample.keys(): logger.warning( f"Trajectory Energy '{energy_name}' not found in sample but provided " "in energy names. Energy trajectory will not be plotted" ) continue else: energy_trajectories[energy_name] = [] if f"ground_state_energy_{energy_name}" not in sample.keys(): logger.warning( f" {energy_name} Ground State Energy not found in sample " f"but corresponding energy trajectories are present" ) else: energy_ground_state_dict[energy_name] = torch.zeros(n_molecules) # move the total energy to the start of the dict keys if "total" in energy_trajectories: total = energy_trajectories.pop("total") energy_trajectories = {"total": total, **dict(sorted(energy_trajectories.items()))} if "total" in energy_ground_state_dict: total = energy_ground_state_dict.pop("total") energy_ground_state_dict = { "total": total, **dict(sorted(energy_ground_state_dict.items())), } return energy_trajectories, energy_ground_state_dict
[docs] def plot_density_optimization_trajectory_means( energy_trajectories_dict: dict[str, torch.Tensor], energy_ground_state_dict: dict[str, torch.Tensor | float], coeffs_trajectories: torch.Tensor | list[torch.Tensor], coeffs_ground_state: torch.Tensor | list[torch.Tensor], gradient_norms: torch.Tensor | list[torch.Tensor], density_differences_l2: torch.Tensor | list[torch.Tensor], stopping_indices: torch.Tensor, num_electrons: torch.Tensor, n_molecules: int, energy_names: tuple[str] | str = None, **kwargs, ): """Summarize the density optimization process for a set of molecules, by plotting the mean energy difference, gradient norm and density difference.""" molecule_indices = torch.arange(n_molecules) # simply a range of 0 to n_molecules if isinstance(density_differences_l2, torch.Tensor): stopping_density_differences_l2 = density_differences_l2[ molecule_indices, stopping_indices ] else: try: density_differences_l2 = torch.stack(density_differences_l2) stopping_density_differences_l2 = density_differences_l2[ molecule_indices, stopping_indices ] except Exception as e: stopping_density_differences_l2 = torch.stack( [ tensor[stopping_index] for tensor, stopping_index in zip(density_differences_l2, stopping_indices) ] ) if energy_names is None: energy_names = list(energy_ground_state_dict.keys()) energy_errors_dict, energy_absolute_errors_dict = get_energy_errors_dict( energy_trajectories_dict=energy_trajectories_dict, energy_ground_state_dict=energy_ground_state_dict, energy_names=energy_names, ) stopping_energy_absolute_errors_dict = {} for energy_name, energy_absolute_errors in energy_absolute_errors_dict.items(): stopping_energy_absolute_errors_dict[energy_name] = torch.as_tensor( [ molecule_energy_absolute_errors[stopping_index] for molecule_energy_absolute_errors, stopping_index in zip( energy_absolute_errors, stopping_indices ) ] ) plots = [ partial( _add_energy_mae, energy_maes=energy_absolute_errors_dict[energy_name], energy_stopping_maes=stopping_energy_absolute_errors_dict[energy_name], energy_name=energy_name, ) for energy_name in energy_names ] + [ partial( _add_mean_gradient_norm, gradient_norms=gradient_norms, stopping_indices=stopping_indices, ), partial( _add_mean_density_differences_l2, density_differences_l2=density_differences_l2, stopping_density_differences_l2=stopping_density_differences_l2, ), partial(_add_stopping_index_histogram, stopping_indices=stopping_indices), ] fig, axs = plt.subplots(len(plots), 1, figsize=(9, 4 * len(plots)), sharex=True) for plot, ax in zip(plots, axs): plot(ax) ax.axvline( torch.median(stopping_indices), color="black", linestyle="--", label=f"Median stopping index {torch.median(stopping_indices)}", ) ax.grid(True) axs[-1].legend() plt.tight_layout()
[docs] def plot_ofdft_run_summary( energy_trajectories_dict: dict[str, list[torch.Tensor]], energy_ground_state_dict: dict[str, torch.Tensor | float], coeffs_trajectories: torch.Tensor | list[torch.Tensor], coeffs_ground_state: torch.Tensor | list[torch.Tensor], coeffs_pred_ground_state: torch.Tensor | list[torch.Tensor], gradient_norms: torch.Tensor | list[torch.Tensor], density_differences_l2: torch.Tensor | list[torch.Tensor], stopping_indices: torch.Tensor, n_molecules: int, energy_names: tuple[str] | str = None, basis_info: BasisInfo = None, basis_function_indices: torch.Tensor = None, sample_basis_function_ind: torch.Tensor = None, cumulative_coeff_error: torch.Tensor = None, cumulative_counts: torch.Tensor = None, **_, ): """Create a summary page for the density optimization process for a set of molecules.""" initial_density_differences_l2 = torch.stack([tensor[0] for tensor in density_differences_l2]) stopping_density_differences_l2 = torch.stack( [tensor[index] for tensor, index in zip(density_differences_l2, stopping_indices)] ) ( stopping_energy_errors_dict, stopping_energy_absolute_errors_dict, ) = get_energy_ground_state_errors( energy_trajectories_dict, energy_ground_state_dict, stopping_indices, ) # Get the last dimension of the first tensor in the list first_n_basis = coeffs_trajectories[0].shape[-1] if all( coeff_trajectory.shape[-1] == first_n_basis for coeff_trajectory in coeffs_trajectories ): # coeffs_trajectories are of shape (n_molecules, n_cycles, n_basis_functions) stopping_coeffs_errors = np.asarray(coeffs_pred_ground_state) - np.asarray( coeffs_ground_state ) coeff_mae_scatter = partial( _single_shape_coeff_mae_scatter, coeff_errors=stopping_coeffs_errors, basis_info=basis_info, basis_function_ind=sample_basis_function_ind, ) else: coeff_mae_scatter = partial( _multi_shape_coeff_mae_scatter, cumulative_coeff_errors=cumulative_coeff_error, cumulative_counts=cumulative_counts, basis_function_indices=basis_function_indices, basis_info=basis_info, ) fig = plt.figure(figsize=(15, 20), layout="constrained") subfigs = fig.subfigures(3, 1, wspace=0.10, hspace=0.07, height_ratios=[0.80, 0.80, 0.55]) (energy_ax1, energy_ax2) = subfigs[0].subplots(1, 2) _add_energy_boxplot( energy_ax1, energy_errors_dict=stopping_energy_errors_dict, energy_names=energy_names, title="Signed Energy Contribution Errors", ) _add_energy_boxplot( energy_ax2, energy_errors_dict=stopping_energy_absolute_errors_dict, energy_names=energy_names, title="Absolute Energy Contribution Errors", ) (density_ax, gradient_ax) = subfigs[1].subplots(1, 2) _compare_initial_to_stopping_l2_norms( density_ax, initial_l2_norms=initial_density_differences_l2, stopping_l2_norms=stopping_density_differences_l2, ground_state_energy_dict=energy_ground_state_dict, ) _compare_initial_to_stopping_gradient_norms( gradient_ax, gradient_norms=gradient_norms, stopping_indices=stopping_indices, ground_state_energy_dict=energy_ground_state_dict, ) coeff_ax = subfigs[2].subplots(1, 1) coeff_mae_scatter(coeff_ax) plt.show()
[docs] def plot_ofdft_energy_distribution( energy_trajectories_dict: dict[str, torch.Tensor], energy_ground_state_dict: dict[str, torch.Tensor | float], stopping_indices: torch.Tensor, energy_names: tuple[str] | str = None, num_atoms: torch.Tensor = None, **_, ): """Create a energy distribution summary page for the density optimization process for a set of molecules.""" if energy_names is None: energy_names = list(energy_ground_state_dict.keys()) ( stopping_energy_errors_dict, stopping_energy_absolute_errors_dict, ) = get_energy_ground_state_errors( energy_trajectories_dict, energy_ground_state_dict, stopping_indices, ) fig = plt.figure(figsize=(15, 6 + len(energy_names) * 10), layout="constrained") # height ratios accostumed to variable number of energy names height_ratios = [0.55, 0.80] * len(energy_names) subfigs = fig.subfigures( 2 * len(energy_names), 1, wspace=0.10, hspace=0.07, height_ratios=height_ratios ) for i, energy_name in zip(np.arange(0, 2 * len(energy_names), 2), energy_names): eom_ax = subfigs[i].subplots(1, 1) _add_energy_error_over_n_atoms( eom_ax, num_atoms=num_atoms, energy_errors_dict=stopping_energy_absolute_errors_dict, energy_name=energy_name, ) (energy_ax1, energy_ax2) = subfigs[i + 1].subplots(1, 2) _add_energy_histogram( energy_ax1, energy_errors_dict=stopping_energy_errors_dict, energy_name=energy_name, title="Signed Energy Contribution Errors", ) _add_energy_histogram( energy_ax2, energy_errors_dict=stopping_energy_absolute_errors_dict, energy_name=energy_name, title="Absolute Energy Contribution Errors", ) plt.show()
[docs] def plot_energy_summary_scatter( energy_trajectories_dict: dict[str, torch.Tensor], energy_ground_state_dict: dict[str, torch.Tensor | float], stopping_indices: torch.Tensor, energy_names: tuple[str] | str = None, n_molecules: int = None, **_, ): """Scatter plots of the stopping over initial energy errors for a set of molecules. Args: energy_trajectories_dict: Dictionary containing the energy trajectories for each molecule. Keys are energy names and values are tensors of shape (n_molecules, n_cycles). energy_ground_state_dict: Dictionary containing the ground state energies for each molecule. Keys are energy names and values are tensors of shape (n_molecules,). stopping_indices: Tensor containing the stopping indices for each molecule determined by the convergence criteria of the density optimization run. energy_names: Tuple of energy names to be plotted. n_molecules: Number of molecules to be plotted. """ if n_molecules is None: n_molecules = len(stopping_indices) energy_errors_dict, energy_absolute_error_dict = get_energy_errors_dict( energy_trajectories_dict=energy_trajectories_dict, energy_ground_state_dict=energy_ground_state_dict, energy_names=energy_names, ) # energy_names are set correctly in the dict initialization energy_names = list(energy_errors_dict.keys()) initial_energy_absolute_errors_dict = {} stopping_energy_absolute_errors_dict = {} for energy_name, absolute_energy_errors in energy_absolute_error_dict.items(): initial_energy_absolute_errors_dict[energy_name] = torch.as_tensor( [molecule_energy_errors[0] for molecule_energy_errors in absolute_energy_errors] ) stopping_energy_absolute_errors_dict[energy_name] = torch.as_tensor( [ absolute_molecule_energy_errors[stopping_index] for absolute_molecule_energy_errors, stopping_index in zip( absolute_energy_errors, stopping_indices ) ] ) fig, axs = plt.subplots(len(energy_names), 1, figsize=(10, 8 * len(energy_names))) if len(energy_names) == 1: axs = [axs] for ax, energy_name in zip(axs, energy_names): initial_energy_error = initial_energy_absolute_errors_dict[energy_name] stopping_energy_error = stopping_energy_absolute_errors_dict[energy_name] ground_state_energy = energy_ground_state_dict[energy_name] scatter = ax.scatter( initial_energy_error, stopping_energy_error, c=ground_state_energy, # set color by ground state energy cmap="seismic", s=40 if n_molecules < 50 else 10, ) # add a diagonal line min_energy_error = min(min(initial_energy_error), min(stopping_energy_error)) max_energy_error = max(max(initial_energy_error), max(stopping_energy_error)) ax.plot( [min_energy_error, max_energy_error], [min_energy_error, max_energy_error], color="grey", alpha=0.6, ) ax.set_xlabel( rf"$ \vert \Delta E_\mathrm{{{energy_name}, initial}} - \Delta E_\mathrm{{{energy_name}, target}} \vert$ [mHa] (initial)", fontsize=14, ) ax.set_ylabel( rf"$ \vert \Delta E_\mathrm{{{energy_name}, pred}} - \Delta E_\mathrm{{{energy_name}, target}} \vert$ [mHa] (stopping)", fontsize=14, ) cbar = plt.colorbar(scatter, ax=ax, pad=0.1) cbar.set_label(rf"$E_\mathrm{{{energy_name}, target}}$ [mHa]", fontsize=12) title = f"Absolute {energy_name} energy error of initial and stopping state" ax.set_title(title, fontsize=15)
[docs] def _add_energy_boxplot( ax: plt.Axes, energy_errors_dict: dict[str, torch.Tensor], energy_names: tuple[str] = None, title="Energy Contribution Errors", plot_electronic: bool = False, ): """Add a boxplot of the energy errors of different contributions for a set of molecules.""" if not plot_electronic and "electronic" in energy_errors_dict.keys(): energy_errors_dict = energy_errors_dict.copy() del energy_errors_dict["electronic"] if energy_names is None: energy_names = list(energy_errors_dict.keys()) if not plot_electronic and "electronic" in energy_names: energy_names = list(energy_names) energy_names.remove("electronic") if "total" in energy_errors_dict.keys() and "tot" in energy_errors_dict.keys(): del energy_errors_dict["tot"] del energy_names[energy_names.index("tot")] # seaborn works better with numpy arrays if isinstance(list(energy_errors_dict.values())[0], torch.Tensor): energy_errors_dict = {key: value.numpy() for key, value in energy_errors_dict.items()} sns.boxplot( data=[energy_errors_dict[energy_name] for energy_name in energy_names], ax=ax, flierprops=dict(marker=".", markersize=4), ) ax.set_xticks([i for i in range(len(energy_names))], energy_names, fontsize=13) ax.set_ylabel(r" $\Delta E$ [mHa]", fontsize=14) ax.set_title(title, fontsize=15) ax.axhline(0, color="gray", alpha=0.5) # Add a line for 0 difference total_energy_error_mean = np.mean(np.abs(energy_errors_dict["total"])) linthresh = custom_round(total_energy_error_mean) # Calculate the minimum and maximum values in the data with some padding min_value = min(min(values) for values in energy_errors_dict.values()) - (linthresh * 10e-2) max_value = max(max(values) for values in energy_errors_dict.values()) + (linthresh * 10e-2) # restrict y_axis to data range ax.set_ylim(min_value, max_value) # when all data is positive, set y-axis to log scale if min_value > 0: ax.set_yscale("log") else: ax.set_yscale("symlog", linthresh=linthresh) ax.yaxis.set_minor_locator(MinorSymLogLocator(linthresh)) # Highlight linear region with solid lines ax.axhline(-linthresh, color="gray", linestyle="-", linewidth=0.7) ax.axhline(linthresh, color="gray", linestyle="-", linewidth=0.7) # custom gridlines ax.grid(True, which="major", linestyle="--", linewidth=0.5, alpha=0.5) # Add mean as text labels next to mean y-value for i, (key, values) in enumerate(energy_errors_dict.items()): mean_value = np.mean(values) ax.text( x=i, y=mean_value + 0.001, s=f"{mean_value:.2f}", ha="center", va="bottom", weight="bold", fontfamily="monospace", bbox=dict( facecolor="white", alpha=0.15, edgecolor="none", boxstyle="round,pad=0.1" ), # Text background ) # Draw a small horizontal line at the mean value ax.hlines(y=mean_value, xmin=i - 0.015, xmax=i + 0.015, color="darkgray", linewidth=2) # rotate x labels plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment="center")
[docs] def _compare_initial_to_stopping_l2_norms( ax: plt.Axes, initial_l2_norms: torch.Tensor | list[torch.Tensor], stopping_l2_norms: torch.Tensor | list[torch.Tensor], ground_state_energy_dict: dict[str, torch.Tensor | float], energy_name: str = "total", ): """Plot the initial density difference against the stopping density difference measured by the L2 norm.""" ground_state_energy = ground_state_energy_dict[energy_name] scatter = ax.scatter( initial_l2_norms, stopping_l2_norms, c=ground_state_energy, cmap="seismic", s=40 if len(initial_l2_norms) < 50 else 10, ) # add a diagonal line min_l2_norm = min(min(initial_l2_norms), min(stopping_l2_norms)) max_l2_norm = max(max(initial_l2_norms), max(stopping_l2_norms)) ax.plot([min_l2_norm, max_l2_norm], [min_l2_norm, max_l2_norm], color="grey", alpha=0.6) ax.set_xlabel( r"$\Vert\rho_\mathrm{pred} - \rho_\mathrm{target}\Vert_2$ [electrons] (initial)", fontsize=14, ) ax.set_ylabel( r"$\Vert\rho_\mathrm{init} - \rho_\mathrm{target}\Vert_2$ [electrons] (final)", fontsize=14 ) cbar = plt.colorbar(scatter, ax=ax, pad=0.008) cbar.set_label(rf"$E_\mathrm{{{energy_name}, target}}$ [mHa]", fontsize=13) title = "Initial vs. stopping L2 density difference" ax.set_title(title, fontsize=15)
[docs] def _compare_initial_to_stopping_gradient_norms( ax: plt.Axes, gradient_norms: torch.tensor, stopping_indices: torch.tensor, ground_state_energy_dict: dict[str, torch.Tensor | float], energy_name: str = "total", ): """Scatter plot the initial gradient norm against the stopping gradient norm.""" n_molecules = len(stopping_indices) if not isinstance(gradient_norms, torch.Tensor): try: gradient_norms = torch.stack(gradient_norms) initial_gradient_norms = gradient_norms[:, 0] stopping_gradient_norms = gradient_norms[torch.arange(n_molecules), stopping_indices] except Exception as e: initial_gradient_norms = [gradient_norm[0] for gradient_norm in gradient_norms] stopping_gradient_norms = [ gradient_norm[stopping_index] for gradient_norm, stopping_index in zip(gradient_norms, stopping_indices) ] else: initial_gradient_norms = gradient_norms[:, 0] stopping_gradient_norms = gradient_norms[torch.arange(n_molecules), stopping_indices] ground_state_energy = ground_state_energy_dict[energy_name] scatter = ax.scatter( initial_gradient_norms, stopping_gradient_norms, c=ground_state_energy, cmap="seismic", s=40 if n_molecules < 50 else 10, ) # add a diagonal line min_gradient_norm = min(min(initial_gradient_norms), min(stopping_gradient_norms)) max_gradient_norm = max(max(initial_gradient_norms), max(stopping_gradient_norms)) ax.plot( [min_gradient_norm, max_gradient_norm], [min_gradient_norm, max_gradient_norm], color="gray", alpha=0.6, ) ax.set_xlabel(r"Initial gradient norm", fontsize=14) ax.set_ylabel(r"Stopping gradient norm", fontsize=14) cbar = plt.colorbar(scatter, ax=ax, pad=0.008) cbar.set_label(rf"$E_\mathrm{{{energy_name}, target}}$ [mHa]", fontsize=13) title = "Initial vs. stopping gradient norm" ax.set_title(title, fontsize=15)
[docs] def _add_stopping_index_histogram(ax, stopping_indices): """Plot a histogram of the stopping indices.""" sns.histplot(stopping_indices, ax=ax) title = f"Histogram of stopping indices with median {torch.median(stopping_indices.detach().cpu())}" ax.set_title(title, fontsize=15)
[docs] def _add_energy_mae( ax: plt.Axes, energy_maes: torch.Tensor, energy_stopping_maes: torch.Tensor, energy_name: str = "total", color: str = "blue", ): """Add a plot of the mean energy absolute errors.""" plot_quantiles_data(energy_maes, ax, color=color) ax.axhline(0, color="black") ax.set_yticks([0, 1, 3, 6, 10, 100]) ax.set_ylim(-1e-1, max([energy_mae.max() for energy_mae in energy_maes]) * 1.05) ax.set_ylabel(r"$\Delta E$ [mHa]", fontsize=14) linthresh = 10 ax.set_yscale("symlog", linthresh=linthresh) ax.yaxis.set_minor_locator(MinorSymLogLocator(linthresh)) title = ( r"$\Delta E_\mathrm{" + energy_name.replace("_", r"\_") + "}$ with mean stopping index MAE " + f"{torch.mean(energy_stopping_maes).item():.4g} mHa" ) ax.set_xlabel("Iteration", fontsize=14) ax.set_title(title, fontsize=15)
[docs] def _add_mean_gradient_norm( ax: plt.Axes, gradient_norms: torch.Tensor, stopping_indices: torch.Tensor, color="blue" ): """Add a plot of the mean gradient norms.""" n_molecules = len(stopping_indices) if not isinstance(gradient_norms, torch.Tensor): try: gradient_norms = torch.stack(gradient_norms) stopping_gradient_norms = gradient_norms[torch.arange(n_molecules), stopping_indices] except Exception as e: stopping_gradient_norms = torch.stack( [ gradient_norm[stopping_index] for gradient_norm, stopping_index in zip(gradient_norms, stopping_indices) ] ) plot_quantiles_data(gradient_norms, ax, color=color) ax.set_ylabel("Gradient norm [Ha]", fontsize=14) ax.set_yscale("log") title = ( "Average stopping Gradient norm " + f"{torch.mean(stopping_gradient_norms).item():.4g}" ) ax.set_xlabel("Iteration", fontsize=14) ax.set_title(title, fontsize=15)
[docs] def _add_mean_density_differences_l2( ax: plt.Axes, density_differences_l2: torch.Tensor, stopping_density_differences_l2: torch.Tensor, color: str = "blue", ): """Add a plot of the mean density differences.""" plot_quantiles_data(density_differences_l2, ax, color=color) ax.axhline(0, color="black") ax.set_ylabel( r"$\Vert\rho_\mathrm{pred} - \rho_\mathrm{target}\Vert_2$ [electrons]", fontsize=14 ) ax.set_yscale("log") ax.grid() title = ( "Mean Density differences with mean stopping L2 difference " + f"{torch.mean(stopping_density_differences_l2).item():.4g} electrons" ) ax.set_xlabel("Iteration", fontsize=14) ax.set_title(title, fontsize=15)
[docs] def _add_mean_relative_density_differences_l1( ax: plt.Axes, num_electrons: torch.Tensor, density_differences_l1: torch.Tensor, stopping_density_differences_l1: torch.Tensor, color: str = "blue", ): """Add a plot of the mean density difference l1 norm divided by the number of electrons.""" if isinstance(density_differences_l1, torch.Tensor): relative_density_difference_l1 = density_differences_l1 / num_electrons.unsqueeze(-1) else: relative_density_difference_l1 = [ l1_diff / e for l1_diff, e in zip(density_differences_l1, num_electrons) ] stopping_relative_density_difference_l1 = stopping_density_differences_l1 / num_electrons plot_quantiles_data(relative_density_difference_l1, ax, color=color) ax.axhline(0, color="black") ax.set_ylabel( r"$\Vert\rho_\mathrm{pred} - \rho_\mathrm{target}\Vert_1$ [electrons]", fontsize=14 ) ax.set_yscale("log") ax.grid() title = ( "Mean Density differences l1 norm over number of electrons, stop. ind: " + f"{torch.mean(stopping_relative_density_difference_l1).item():.4g} electrons" ) ax.set_xlabel("Iteration", fontsize=14) ax.set_title(title, fontsize=15)
[docs] def stopping_index_line_plot( ax: plt.Axes, data_array: torch.Tensor | np.ndarray, stopping_index: int, color="blue", **kwargs, ): """Plot the data array as solid line up to the stopping index and dashed line afterwards.""" ax.plot( range(stopping_index), data_array[:stopping_index], linestyle="-", color=color, **kwargs, ) ax.plot( range(stopping_index - 1 if stopping_index != 0 else 0, len(data_array)), data_array[stopping_index - 1 if stopping_index != 0 else 0 :], linestyle="--", color=color, )
[docs] def get_density_difference_l2_norm(sample: OFData): """Get the L2 norm of the difference between the predicted and target densities for a single sample.""" coeffs_delta = ( sample.trajectory_coeffs.detach().cpu() - sample.ground_state_coeffs.detach().cpu() ) # assume that the overlap matrix is correctly transformed if hasattr(sample, "overlap_matrix"): overlap = sample.overlap_matrix.detach().cpu() l2_norm = torch.sqrt(torch.einsum("ni,ij,nj->n", coeffs_delta, overlap, coeffs_delta)) return l2_norm else: logger.warning( "Sample does not have an overlap matrix. Density difference plotting is skipped" ) return
[docs] def get_density_difference_l1_norm( sample: OFData, basis_info: BasisInfo, l1_grid_level: int = 3, l1_grid_prune: str = "nwchem_prune", ) -> torch.Tensor: """Get the L1 norm of the difference between the predicted and target densities for a single sample.""" coeffs_delta = ( sample.predicted_ground_state_coeffs.detach().cpu() - sample.ground_state_coeffs.detach().cpu() ) mol = build_molecule_ofdata(sample, basis_info.basis_dict) # build grid grid = dft.Grids(mol) grid.level = l1_grid_level grid.prune = l1_grid_prune grid.build() density_difference = compute_density_density_basis(mol, grid, coeffs_delta) grid_weights = torch.as_tensor(grid.weights) l1_norm_difference = torch.dot(torch.abs(density_difference), grid_weights) return l1_norm_difference
[docs] def _single_shape_coeff_mae_scatter( ax: plt.Axes, coeff_errors: torch.Tensor | list[torch.Tensor], basis_info: BasisInfo, basis_function_ind: torch.Tensor, ): """Plots the mean absolute density coefficient error over basis dimensions. This simple version can be called, if all molecules display the same coefficient shape. """ if basis_info is None: logger.info( "Basis info not found in sample directory. Skipping ao scatter plot of coefficient " "mean absolute error." ) return if basis_function_ind is None: logger.info( "Basis function indices not found in sample directory. Skipping ao scatter plot of " "coefficient mean absolute error." ) return if not isinstance(coeff_errors, torch.Tensor): try: coeff_errors = torch.stack(coeff_errors) except Exception as e: logger.info( f"""Could not convert coeff errors to tensor: {e}. Coeff mean absolute error scatter plot will be skipped.""" ) return coeffs_absolute_error = torch.abs(coeff_errors) # median returns a tuple of val and indices coeffs_mae = torch.median(coeffs_absolute_error, dim=0)[0] coeffs_mae_lower_bound = [] coeffs_mae_upper_bound = [] # Did not find a way to make this work without a loop for i in range(coeffs_absolute_error.shape[1]): sorted_errors = coeffs_absolute_error[:, i].sort().values lower_bound = sorted_errors[int(0.1 * sorted_errors.shape[0])] upper_bound = sorted_errors[int(0.9 * sorted_errors.shape[0])] coeffs_mae_lower_bound.append(lower_bound) coeffs_mae_upper_bound.append(upper_bound) coeffs_mae_lower_bound = torch.stack(coeffs_mae_lower_bound) coeffs_mae_upper_bound = torch.stack(coeffs_mae_upper_bound) x = basis_function_ind.detach().cpu() + 0.5 # Use fill_between with sorted data ax.errorbar( x, coeffs_mae, yerr=[coeffs_mae - coeffs_mae_lower_bound, coeffs_mae_upper_bound - coeffs_mae], fmt=".", markersize=4, elinewidth=0.5, ecolor="C1", label="Mean Absolute Error", ) handles = [ Patch(color="C0", label="Median Absolute Error"), Patch(color="C1", label="Interquantile Range (10-90)"), ] ax.legend(handles=handles, loc="upper right") format_basis_func_xaxis(ax, basis_info) title = "Mean Absolute Error of Coefficients" ax.set_ylabel( r"$\left \langle \vert \rho_{\mathrm{pred}} - \rho_{\mathrm{target}} \vert \right \rangle$ [electrons]", fontsize=14, ) ax.set_yscale("log") # horizontal line for 0 error ax.axhline(0, color="grey", alpha=0.5) ax.set_title(title, fontsize=15)
[docs] def _multi_shape_coeff_mae_scatter( ax: plt.Axes, cumulative_coeff_errors: torch.Tensor, cumulative_counts: torch.Tensor, basis_function_indices: torch.Tensor | np.ndarray, basis_info: BasisInfo, ): """Plots the mean absolute coefficient error over basis dimensions for possibly varying number of basis dimensions. Args: cumulative_coeff_errors (torch.Tensor): Cumulative sum of coefficient absolute errors retrieved by :func:`cumulate_coeff_error`. cumulative_counts (torch.Tensor): Counts of basis dimension appearances retrieved by :func:`cumulate_coeff_error`. basis_function_indices (torch.Tensor | np.ndarray): A tensor of all basis dimensions present in the basis_info. Differs from a sample.basis_function_ind. basis_info (BasisInfo): The dataset's basis_info. """ coeff_maes = cumulative_coeff_errors / cumulative_counts x = basis_function_indices.detach().cpu() + 0.5 scatter_kwargs = dict(s=40, alpha=0.6, marker=".", color="C0") ax.scatter(x, coeff_maes, **scatter_kwargs) format_basis_func_xaxis(ax, basis_info) title = "Mean Absolute Error of Coefficients" ax.set_ylabel( r"$\left \langle \vert \rho_{\mathrm{pred}} - \rho_{\mathrm{target}} \vert \right \rangle$ [electrons]", fontsize=14, ) ax.set_yscale("log") # horizontal line for 0 error ax.axhline(0, color="grey", alpha=0.5) ax.set_title(title, fontsize=15)
[docs] def plot_mean_and_fill_between(data: torch.Tensor, ax: plt.Axes, color: str = "blue"): """Plot the mean of the data and fill the area between mean - std and mean + std.""" # Calculate mean and standard deviation if isinstance(data, torch.Tensor): data_means = torch.mean(data, axis=0) data_stds = torch.std(data, axis=0) elif isinstance(data, np.ndarray): data_means = np.mean(data, axis=0) data_stds = np.std(data, axis=0) else: try: data = torch.stack(data) data_means = torch.mean(data, axis=0) data_stds = torch.std(data, axis=0) except Exception as _: data_means, data_stds = get_overlapping_mean(data=data) ax.plot(data_means, color=color) # Plot confidence interval ax.fill_between( range(len(data_means)), data_means - data_stds, data_means + data_stds, color=color, alpha=0.2, )
[docs] def plot_quantiles_data( data: torch.Tensor | np.ndarray, ax: plt.Axes, color: str = "blue", quantiles=[0.0, 0.1, 0.9, 1.0], quantile_colors=["lightgreen", "blue", "lightcoral"], ): """Plot ensemble data on a given Axes object. Parameters: - data: 2D array of shape (num_datasets, num_points) - ax: matplotlib Axes object to plot on - color: color for the mean line (default: 'blue') - quantiles: list of quantiles to plot (default: [0.,0.1,0.9,1.]) - quantile_colors: list of colors for quantile areas (default: ["lightgreen","blue", "lightcoral"]) """ # Calculate mean and quantiles if isinstance(data, torch.Tensor): if data.is_cuda: data = data.cpu() data_y = data.detach().numpy() mean_y = np.mean(data_y, axis=0) quantile_y = np.quantile(data_y, quantiles, axis=0) elif isinstance(data, np.ndarray): data_y = data mean_y = np.mean(data_y, axis=0) quantile_y = np.quantile(data_y, quantiles, axis=0) else: try: data_y = np.asarray(data) mean_y = np.mean(data_y, axis=0) quantile_y = np.quantile(data_y, quantiles, axis=0) except Exception as e: mean_y, quantile_y = get_overlapping_quantiles(data=data, quantiles=quantiles) mean_y = mean_y.detach().numpy() quantile_y = quantile_y.detach().numpy() mean_color = color # Create the plot x = np.arange(1, len(mean_y) + 1) # Plot mean line ax.plot(x, mean_y, label="Mean", color=mean_color, linewidth=2) # Plot quantile areas for i in range(len(quantiles) - 1): ax.fill_between( x, quantile_y[i], quantile_y[i + 1], alpha=0.3, color=quantile_colors[i], label=f"{quantiles[i] * 100}% - {quantiles[(i + 1)] * 100}% Quantile", ) # set y-limit to the data and quantile range min_y = np.min(quantile_y[0] * 0.9) max_y = np.max(quantile_y[-1] * 1.1) ax.set_ylim(min_y, max_y) ax.legend() ax.grid(True, alpha=0.3)
[docs] def get_energy_error(sample: OFData, energy_name: str): """Get the (signed) energy error of 'energy_name' for a single sample in mHa.""" predicted_energy = sample[f"trajectory_energy_{energy_name}"].detach().cpu() energy_label = sample[f"ground_state_energy_{energy_name}"] energy_error = 1e3 * (predicted_energy - energy_label) return energy_error
[docs] def get_energy_errors_dict( energy_trajectories_dict: dict[str, list[torch.Tensor]], energy_ground_state_dict: dict[str, torch.Tensor], energy_names: tuple[str] = None, ) -> tuple[dict[str, list], dict[str, list]]: """Calculate the (signed) energy errors for a set of molecules from the energy trajectories and their ground state values.""" if energy_names is None: energy_names = list(energy_ground_state_dict.keys()) energy_errors_dict = {} energy_absolute_errors_dict = {} for energy_name in energy_names: predicted_energies = energy_trajectories_dict[energy_name] try: energy_labels = energy_ground_state_dict[energy_name] except KeyError: logger.warning( f"ground state energy {energy_name} not found in sample." f"Energy errors for {energy_name} will not be plotted" ) continue # predicted_energies is a list of length n_molecules with energies of possibly different # shape, allowing the number of cycles to vary energy_errors, energy_absolute_errors = [], [] for predicted_energy, energy_label in zip(predicted_energies, energy_labels): error = predicted_energy - energy_label energy_errors.append(error) energy_absolute_errors.append(torch.abs(error)) energy_errors_dict[energy_name] = energy_errors energy_absolute_errors_dict[energy_name] = energy_absolute_errors return energy_errors_dict, energy_absolute_errors_dict
def get_energy_ground_state_errors( energy_trajectories_dict: dict[str, list[torch.Tensor]], energy_ground_state_dict: dict[str, torch.Tensor], stopping_indices: torch.Tensor, energy_names: tuple[str] = None, ): energy_ground_state_errors, energy_ground_state_absolute_errors = {}, {} if energy_names is None: energy_names = list(energy_ground_state_dict.keys()) for energy_name in energy_names: energy_ground_state_errors[energy_name] = torch.empty( len(energy_ground_state_dict[energy_name]), dtype=torch.float64 ) energy_ground_state_absolute_errors[energy_name] = torch.empty( len(energy_ground_state_dict[energy_name]), dtype=torch.float64 ) for i, (predicted_energy, energy_label) in enumerate( zip(energy_trajectories_dict[energy_name], energy_ground_state_dict[energy_name]) ): stopping_index = stopping_indices[i] ground_state_error = predicted_energy[stopping_index] - energy_label energy_ground_state_errors[energy_name][i] = ground_state_error.item() energy_ground_state_absolute_errors[energy_name][i] = torch.abs( ground_state_error ).item() return energy_ground_state_errors, energy_ground_state_absolute_errors
[docs] def energy_error_swarm_line_plot( ax: plt.Axes, stopping_indices: torch.Tensor, energy_name: str, energy_errors_dict: dict[str, torch.Tensor] = None, energy_errors: torch.Tensor = None, n_molecules: int = None, subsample: float = 1.0, ): """Plot the energy error for a set of molecules as a swarm plot with a line for each molecule.""" if n_molecules is None: n_molecules = len(stopping_indices) if energy_errors_dict is None and energy_errors is None: raise ValueError("Either energy_errors_dict or energy_errors must be provided.") if energy_errors_dict is not None and energy_errors is not None: raise ValueError("Only one of energy_errors_dict or energy_errors must be provided.") elif energy_errors_dict is not None: energy_errors = energy_errors_dict[energy_name] # else energy_errors is already provided ax.axhline(0, color="black") # 0 line for reference ax.set_ylabel(r"$\Delta E$ [mHa]", fontsize=14) ax.set_yscale("symlog", linthresh=10) title = r"$\Delta E_\mathrm{" + energy_name.replace("_", r"\_") + "}$" ax.set_title(title, fontsize=15) ax.grid(True) # NOTE: We do not use the subsample function as we sample in dependence of max and min energy error if subsample < 1.0 and n_molecules > 2: n_samples = len(energy_errors) # always get the trajectory with the max and min energy error # energy_errors is a list of length n_molecules, with tensors of varying number of denop iterations mean_energy_errors = torch.zeros(len(energy_errors)) for i, mol_error_trajectory in enumerate(energy_errors): mean_energy_errors[i] = torch.mean(mol_error_trajectory) max_mean_energy_error_idx = torch.argmax(mean_energy_errors) min_mean_energy_error_idx = torch.argmin(mean_energy_errors) # plot min and max first stopping_index_line_plot( ax=ax, data_array=energy_errors[max_mean_energy_error_idx], stopping_index=stopping_indices[max_mean_energy_error_idx], color="darkred", label="Max mean energy error", ) stopping_index_line_plot( ax=ax, data_array=energy_errors[min_mean_energy_error_idx], stopping_index=stopping_indices[min_mean_energy_error_idx], color="darkgreen", label="Min mean energy error", ) ax.legend() # Then subsample the rest omitting the max and min subsample_indices = torch.randperm(n_samples) # remove the max and min indices subsample_indices = subsample_indices[ (subsample_indices != max_mean_energy_error_idx) & (subsample_indices != min_mean_energy_error_idx) ] subsample_indices = subsample_indices[: int(subsample * n_samples - 2)] energy_errors = [energy_errors[subsample_idx] for subsample_idx in subsample_indices] stopping_indices = stopping_indices[subsample_indices] # overwrite n_molecules if subsampled n_molecules = len(subsample_indices) for i, energy_error_trajectory in enumerate(energy_errors[: n_molecules + 1]): stopping_index = stopping_indices[i] color = plt.get_cmap("rainbow")(i / (n_molecules - 1 + 1e-6)) stopping_index_line_plot( ax=ax, data_array=energy_error_trajectory, stopping_index=stopping_index, color=color, )
[docs] def gradient_norm_swarm_line_plot( ax: plt.Axes, stopping_indices: torch.Tensor, gradient_norms: torch.Tensor, n_molecules: int = None, subsample: float = 1.0, ): """Plot the gradient norms for a set of molecules as a swarm plot with a line for each molecule.""" if n_molecules is None: n_molecules = len(stopping_indices) ax.axhline(0, color="black") ax.set_yscale("log") ax.set_title("Gradient norm", fontsize=15) ax.grid(True) if subsample < 1.0 and n_molecules > 2: gradient_norms, stopping_indices, n_molecules = subsample_swarm( ax=ax, n_molecules=n_molecules, subsample=subsample, data_array=gradient_norms, stopping_indices=stopping_indices, label="gradient norm", ) for i, gradient_norm_trajectory in enumerate(gradient_norms[: n_molecules + 1]): stopping_index = stopping_indices[i] color = plt.get_cmap("rainbow")(i / (n_molecules - 1 + 1e-6)) stopping_index_line_plot( ax=ax, data_array=gradient_norm_trajectory, stopping_index=stopping_index, color=color, )
[docs] def density_differences_swarm_line_plot( ax: plt.Axes, stopping_indices: torch.Tensor, density_differences_l2: torch.Tensor, n_molecules: int = None, subsample: float = 1.0, ): """Plot the energy error for a set of molecules as a swarm plot with a line for each molecule.""" if n_molecules is None: n_molecules = len(stopping_indices) ax.axhline(0, color="black") ax.set_yscale("symlog", linthresh=1) ax.set_title("Density difference", fontsize=15) ax.set_ylabel( r"$\Vert\rho_\mathrm{pred} - \rho_\mathrm{target}\Vert_2$ [electrons]", fontsize=14 ) ax.grid(True) if subsample < 1.0 and n_molecules > 2: density_differences_l2, stopping_indices, n_molecules = subsample_swarm( ax=ax, n_molecules=n_molecules, subsample=subsample, data_array=density_differences_l2, stopping_indices=stopping_indices, label="density difference", ) for i, density_difference_trajectory in enumerate(density_differences_l2[: n_molecules + 1]): stopping_index = stopping_indices[i] color = plt.get_cmap("rainbow")(i / (n_molecules - 1 + 1e-6)) stopping_index_line_plot( ax=ax, data_array=density_difference_trajectory, stopping_index=stopping_index, color=color, )
[docs] def subsample_swarm( ax: plt.Axes, n_molecules: int, subsample: float, data_array: torch.Tensor, stopping_indices: torch.Tensor, label: str = "error", ) -> tuple[torch.Tensor, torch.Tensor, int]: """Subsample the data array and stopping indices and plot the min and max on given axis.""" if not isinstance(data_array, torch.Tensor): try: data_array = torch.stack(data_array) except Exception as e: raise NotImplementedError( f"Could not convert data to tensor: {e}. List handling not yet implemented for swarm line plot." ) n_samples = len(data_array) if subsample < 1.0 and n_molecules > 2: # get the trajectory with the max and min energy error max_mean_data_idx = torch.argmax(data_array.mean(dim=1)) min_mean_data_idx = torch.argmin(data_array.mean(dim=1)) # plot min and max first stopping_index_line_plot( ax=ax, data_array=data_array[max_mean_data_idx], stopping_index=stopping_indices[max_mean_data_idx], color="darkred", label=f"Max mean {label}", ) stopping_index_line_plot( ax=ax, data_array=data_array[min_mean_data_idx], stopping_index=stopping_indices[min_mean_data_idx], color="darkgreen", label=f"Min mean {label}", ) ax.legend() # Then subsample the rest omitting the max and min subsample_indices = torch.randperm(n_samples) # remove the max and min indices subsample_indices = subsample_indices[ (subsample_indices != max_mean_data_idx) & (subsample_indices != min_mean_data_idx) ] subsample_indices = subsample_indices[: int(subsample * n_samples - 2)] data_array = data_array[subsample_indices] stopping_indices = stopping_indices[subsample_indices] # overwrite n_molecules if subsampled n_molecules = len(subsample_indices) return data_array, stopping_indices, n_molecules
[docs] def get_overlapping_mean(data: list[Union[torch.Tensor, np.ndarray]]): """Compute the mean of a list of 1D tensors/arrays with different length for overlapping regions.""" sorted_data = sorted(data, key=len) # ascending in length overlapping_mean = torch.zeros(size=(len(sorted_data[-1]),)) overlapping_std = torch.zeros(size=(len(sorted_data[-1]),)) tensor_lengths = torch.tensor([len(tensor) for tensor in sorted_data]) previous_ending_index = 0 for i, length in enumerate(tensor_lengths): query_tensors = torch.stack( [tensor[previous_ending_index:length] for tensor in sorted_data[i:]] ) overlapping_mean[previous_ending_index:length] = torch.mean(query_tensors, axis=0) overlapping_std[previous_ending_index:length] = torch.std(query_tensors, axis=0) previous_ending_index = length return overlapping_mean, overlapping_std
[docs] def get_overlapping_quantiles(data: list[Union[torch.Tensor, np.ndarray]], quantiles: list[float]): """Compute the quantiles of a list of 1D tensors/arrays with different length for overlapping regions.""" sorted_data = sorted(data, key=len) overlapping_quantiles = torch.zeros(size=(len(quantiles), len(sorted_data[-1]))) overlapping_mean = torch.zeros(size=(len(sorted_data[-1]),)) tensor_lengths = torch.tensor([len(tensor) for tensor in sorted_data]) previous_ending_index = 0 for i, length in enumerate(tensor_lengths): if length - previous_ending_index == 0: continue query_tensors = torch.stack( [tensor[previous_ending_index:length] for tensor in sorted_data[i:]] ) overlapping_mean[previous_ending_index:length] = torch.mean(query_tensors, axis=0) overlapping_quantiles[:, previous_ending_index:length] = torch.quantile( query_tensors, torch.tensor(quantiles, dtype=query_tensors.dtype), dim=0 ) previous_ending_index = length return overlapping_mean, overlapping_quantiles
[docs] def cumulate_coeff_error( sample: OFData, pred_ground_state_coeffs: torch.Tensor, cumulative_error: torch.Tensor, cumulative_counts: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Cumulates the absolute error of coefficients by adding the (stopping) basis_function wise error onto the cumulative error. This is necessary if density coefficient dimensions vary across molecules. The cumulative counts are returned as well in order to build the average afterwards. """ coeffs_ground_state = sample.ground_state_coeffs.detach().cpu() absolute_error = torch.abs(pred_ground_state_coeffs - coeffs_ground_state) # For compatibility with old samples saved from the M-OFDFT code if isinstance(sample.basis_function_ind, np.ndarray): sample.basis_function_ind = torch.tensor(sample.basis_function_ind, dtype=torch.long) counts = torch.ones_like(sample.basis_function_ind) scatter( src=absolute_error, index=sample.basis_function_ind, out=cumulative_error, reduce="sum" ) scatter(src=counts, index=sample.basis_function_ind, out=cumulative_counts, reduce="sum") return cumulative_error, cumulative_counts
[docs] def custom_round(value: float): """Determine value's order of magnitude, find the next highest integer in that magnitude and add one integer.""" order_of_magnitude = 10 ** np.floor(np.log10(value)) next_highest = np.ceil(value / order_of_magnitude) * order_of_magnitude rounded_value = next_highest + order_of_magnitude return rounded_value
[docs] def _add_energy_histogram( ax: plt.Axes, energy_errors_dict: dict, energy_name="total", title="Stopping Energy Error Distribution", cut_off=10**3, ): """Add a histogram of the total energy error for a set of molecules.""" energy_errors_dict = {key: value.numpy() for key, value in energy_errors_dict.items()} filtered_data = energy_errors_dict[energy_name] # if all values are positive, we can use a log scale if np.min(filtered_data) > 0: if cut_off: lower_cutoff = max(1 / cut_off, np.min(filtered_data)) upper_cutoff = min(cut_off, np.max(filtered_data)) # map all outliers to the cutoffs filtered_data = [ lower_cutoff if x < lower_cutoff else upper_cutoff if x > upper_cutoff else x for x in filtered_data ] # count the data points outside the cutoffs outside_count = len( [ x for x in energy_errors_dict[energy_name] if x < lower_cutoff or x > upper_cutoff ] ) total_energy_error_mean = np.mean(np.abs(filtered_data)) ax.set_xlim(left=lower_cutoff, right=upper_cutoff) linthresh = custom_round(total_energy_error_mean) ax.set_xscale("symlog", linthresh=linthresh) ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) ax.axvline(linthresh, color="gray", linestyle="-", linewidth=0.7) sns.histplot( filtered_data, ax=ax, label=f"{energy_name} (outside range: {outside_count})", stat="density", bins="auto", ) else: if cut_off: lower_cutoff = max(-cut_off, np.min(filtered_data)) upper_cutoff = min(cut_off, np.max(filtered_data)) ax.set_xlim(lower_cutoff, upper_cutoff) filtered_data = [ lower_cutoff if x < lower_cutoff else upper_cutoff if x > upper_cutoff else x for x in filtered_data ] outside_count = len( [ x for x in energy_errors_dict[energy_name] if x < lower_cutoff or x > upper_cutoff ] ) total_energy_error_mean = np.mean(np.abs(filtered_data)) linthresh = custom_round(total_energy_error_mean) ax.set_xscale("symlog", linthresh=linthresh) ax.xaxis.set_minor_locator(MinorSymLogLocator(linthresh)) # Highlight linear regions ax.axvline(-linthresh, color="gray", linestyle="-", linewidth=0.7) ax.axvline(linthresh, color="gray", linestyle="-", linewidth=0.7) sns.histplot( filtered_data, ax=ax, label=f"{energy_name} (outside range: {outside_count})", stat="density", bins="auto", ) ax.set_xlabel(r"$\Delta E$ [mHa]", fontsize=14) ax.set_ylabel("Density", fontsize=14) ax.set_title(f"{energy_name} " + title, fontsize=15) ax.legend()
[docs] def _add_energy_error_over_n_atoms( ax: plt.Axes, energy_errors_dict: dict, num_atoms: torch.Tensor, energy_name="total", title="Absolute Stopping Energy Error vs Number of Atoms", ): """Add a scatter plot of the total energy error vs the number of atoms in the molecule.""" energy_errors = energy_errors_dict[energy_name].numpy() num_atoms = num_atoms.cpu().numpy() # Add a small amount of horizontal jitter to the num_atoms values jitter_strength = 0.05 # Adjust this value as needed to get the desired amount of jitter jitter = np.random.uniform(-jitter_strength, jitter_strength, num_atoms.size) num_atoms_jittered = num_atoms + jitter ax.scatter(num_atoms_jittered, energy_errors, alpha=0.5, s=40 if len(num_atoms) < 50 else 10) ax.set_ylabel(r"$\Delta E$ [mHa]", fontsize=14) ax.set_xlabel(r"Number of Atoms", fontsize=14) ax.set_title(energy_name.capitalize() + " " + title, fontsize=15) # get x-tick labels for all n_atoms in the dataset x_tick_labels = np.unique(num_atoms) ax.set_xticks(x_tick_labels) ax.set_yscale("log")