Source code for mldft.ml.callbacks.image_logging
import logging
from typing import Any
import lightning as pl
import matplotlib.pyplot as plt
import numpy as np
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from matplotlib.figure import Figure
from matplotlib.patches import Patch
from pyscf.data.elements import ELEMENTS
from mldft.ml.callbacks.base import OnStepCallbackWithTiming
from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.models.mldft_module import MLDFTLitModule
from mldft.utils.plotting.axes import format_basis_func_xaxis
from mldft.utils.plotting.scatter import comparison_scatter
logger = logging.getLogger(__name__)
[docs]
class LogMatplotlibToTensorboard(OnStepCallbackWithTiming):
"""Base Class to log matplotlib figures to tensorboard."""
[docs]
def get_figure(
self,
pl_module: MLDFTLitModule,
batch: Any,
outputs: STEP_OUTPUT,
basis_info: BasisInfo,
) -> Figure:
"""Create the figure to be plotted.
Args:
pl_module: The lightning module.
batch: The batch.
outputs: The outputs of the lightning module.
basis_info: The :class:`BasisInfo` object.
Returns:
Figure: The figure to be plotted.
"""
raise NotImplementedError
[docs]
def execute(
self,
trainer: pl.Trainer,
pl_module: MLDFTLitModule,
outputs: STEP_OUTPUT,
batch: Any,
split: str,
) -> None:
"""Generates a figure using :meth:`get_figure` and logs it to tensorboard."""
logger.debug(f"Logging training figure to tensorboard from {self.__class__.__name__}")
basis_info = pl_module.basis_info
tb_logger = pl_module.tensorboard_logger
tb_logger.add_figure(
f"{split}_{self.name}",
self.get_figure(pl_module, batch, outputs, basis_info),
global_step=trainer.global_step,
)
[docs]
class LogTargetPredScatters(LogMatplotlibToTensorboard):
"""Logs scatter plots of the target and predicted energies, gradients and initial guess
deltas."""
[docs]
def __init__(self, with_atom_ref: bool | str = "auto", **super_kwargs):
"""
Args:
with_atom_ref (bool | str): Whether to add two additional plots of energy / gradient minus their AtomRef
values. If 'auto', it is checked whether the model has a property ``atom_ref_module``.
Defaults to 'auto'.
**super_kwargs: Additional kwargs for the superclass.
"""
super().__init__(**super_kwargs)
assert with_atom_ref in [
True,
False,
"auto",
], f'with_atom_ref must be one of [True, False, "auto"], got {with_atom_ref}'
self.with_atom_ref = with_atom_ref
[docs]
def get_figure(
self,
pl_module: MLDFTLitModule,
batch: Any,
outputs: STEP_OUTPUT,
basis_info: BasisInfo,
) -> Figure:
"""Create two scatter plots: One for the energies, and one for the gradients.
Args:
pl_module: The lightning module.
batch: The batch.
outputs: The outputs of the lightning module.
basis_info: The :class:`BasisInfo` object.
Returns:
Figure: The figure to be plotted.
"""
# mask out samples without energy labels for energy and gradient plots
mask = batch.has_energy_label.detach().cpu().numpy()
# gradients
gradient_label = batch.gradient_label.detach().cpu().numpy()
# pred_gradients = outputs["model_outputs"]["pred_gradients"].detach().cpu().numpy()
projected_gradient_difference = (
outputs["projected_gradient_difference"].detach().cpu().numpy()
)
# energies
energy_labels = batch.energy_label.detach().cpu().numpy()
pred_energies = outputs["model_outputs"]["pred_energy"].detach().cpu().float().numpy()
# initial guess delta
pred_diff = outputs["model_outputs"]["pred_diff"].detach().cpu().float().numpy()
gt_diff = (batch.coeffs - batch.ground_state_coeffs).detach().cpu().float().numpy()
# prepare to color by scf iteration
scf_iteration = batch.scf_iteration.cpu().numpy()
scf_iteration_per_basis_func = scf_iteration[batch.coeffs_batch.cpu().numpy()]
n_scf_iteration_colors = 7 # iterations 0-5, 6+ is the same color
scf_iteration_colors = np.array([f"C{i}" for i in range(n_scf_iteration_colors)])
scf_iteration_colors_per_coeff = scf_iteration_colors[
np.clip(scf_iteration_per_basis_func, 0, len(scf_iteration_colors) - 1)
]
scf_iteration_handles = [
Patch(
facecolor=color,
label=i if i < len(scf_iteration_colors) - 1 else f">{i-1}",
)
for i, color in enumerate(scf_iteration_colors)
]
scf_iteration_legend_kwargs = dict(
handles=scf_iteration_handles,
title="SCF Iteration",
loc="lower right",
fontsize="small",
)
# prepare to color by atom type
atom_ind = batch.atom_ind.cpu().numpy()
atom_ind_per_coeff = atom_ind[batch.coeff_ind_to_node_ind.cpu().numpy().flatten()]
atom_ind_colors = np.array([f"C{str(i).zfill(2)}" for i in range(basis_info.n_types)])
atom_ind_colors_per_coeff = atom_ind_colors[atom_ind_per_coeff]
atom_ind_handles = [
Patch(facecolor=color, label=ELEMENTS[atomic_number])
for color, atomic_number in zip(atom_ind_colors, basis_info.atomic_numbers)
]
atom_ind_legend_kwargs = dict(
handles=atom_ind_handles,
title="Element",
loc="lower right",
fontsize="small",
)
# prepare to color by angular momentum
basis_func_to_l = basis_info.l_per_basis_func[batch.basis_function_ind.cpu().numpy()]
max_l = np.max(basis_func_to_l)
orbital_labels = np.array(["s", "p", "d", "f", "g", "h"])
l_colors = np.array(
[
"tab:blue",
"tab:orange",
"tab:green",
"tab:red",
"tab:purple",
"tab:brown",
]
)
l_colors_per_coeff = l_colors[basis_func_to_l]
l_colors_handles = [
Patch(facecolor=color, label=label)
for color, label in zip(l_colors[: max_l + 1], orbital_labels)
]
l_color_legend_kwargs = dict(
handles=l_colors_handles,
title="Angular Momentum",
loc="lower right",
fontsize="small",
)
if self.with_atom_ref == "auto":
with_atom_ref = hasattr(pl_module.net, "atom_ref_module")
else:
with_atom_ref = self.with_atom_ref
if with_atom_ref:
try:
atom_ref = pl_module.net.atom_ref_module
except Exception:
raise ValueError(
"Could not get the atom_ref_module, but with_atom_ref is set to True."
)
with torch.enable_grad():
# batch.coeffs.requires_grad_()
pred_energies_atom_ref = atom_ref.sample_forward(batch)
pred_gradient_atom_ref = (
torch.autograd.grad(pred_energies_atom_ref.sum(), batch.coeffs)[0]
.detach()
.cpu()
.numpy()
)
pred_energies_atom_ref = pred_energies_atom_ref.detach().cpu().numpy()
else:
pred_energies_atom_ref = None
pred_gradient_atom_ref = None
n_rows = 3
n_cols = 4 if with_atom_ref else 3
plot_size = 6
fig, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * plot_size, n_rows * plot_size))
ax = axs[0, 0]
comparison_scatter(
ax,
energy_labels[mask],
pred_energies[mask],
c=scf_iteration_colors[np.clip(scf_iteration, 0, len(scf_iteration_colors) - 1)][mask],
s=12,
)
ax.set_title("Target vs. Predicted Energies")
ax.set_xlabel("Target Energy")
ax.set_ylabel("Predicted Energy")
ax.legend(**scf_iteration_legend_kwargs)
if with_atom_ref:
ax = axs[1, 0]
comparison_scatter(
ax,
(energy_labels - pred_energies_atom_ref)[mask],
(pred_energies - pred_energies_atom_ref)[mask],
c=scf_iteration_colors[np.clip(scf_iteration, 0, len(scf_iteration_colors) - 1)][
mask
],
s=12,
)
ax.set_title("Target vs. Predicted Energies minus AtomRef")
ax.set_xlabel("Target Energy minus AtomRef")
ax.set_ylabel("Predicted Energy minus AtomRef")
ax.legend(**scf_iteration_legend_kwargs)
mask_per_coeff = mask[batch.coeffs_batch.cpu().numpy()]
for row, c, legend_kwargs in zip(
[0, 1, 2],
[
scf_iteration_colors_per_coeff,
l_colors_per_coeff,
atom_ind_colors_per_coeff,
],
[
scf_iteration_legend_kwargs,
l_color_legend_kwargs,
atom_ind_legend_kwargs,
],
):
ax = axs[row, 1]
comparison_scatter(
ax,
gradient_label[mask_per_coeff],
(gradient_label + projected_gradient_difference)[mask_per_coeff],
c=c[mask_per_coeff],
)
ax.set_title("Target Gradients vs.\nTarget Gradients + Projected gradient difference")
ax.set_xlabel("Target Gradient")
ax.set_ylabel("Predicted Gradient")
ax.legend(**legend_kwargs)
if with_atom_ref:
ax = axs[row, 2]
comparison_scatter(
ax,
(gradient_label - pred_gradient_atom_ref)[mask_per_coeff],
(gradient_label + projected_gradient_difference - pred_gradient_atom_ref)[
mask_per_coeff
],
c=c[mask_per_coeff],
)
ax.set_title(
"Target Gradients minus AtomRef vs.\n"
"Target Gradients + Projected gradient difference - AtomRef"
)
ax.set_xlabel("Target Gradient minus AtomRef")
ax.set_ylabel("Predicted Gradient minus AtomRef")
ax.legend(**legend_kwargs)
ax = axs[row, 3 if with_atom_ref else 2]
comparison_scatter(
ax,
gt_diff,
pred_diff,
c=c,
)
ax.set_title("Target vs. Predicted Initial Guess Delta")
ax.set_xlabel("Target Initial Guess Delta")
ax.set_ylabel("Predicted Initial Guess Delta")
ax.legend(**legend_kwargs)
fig.tight_layout()
# hide for now
axs[2, 0].set_visible(False)
if not with_atom_ref:
axs[1, 0].set_visible(False)
return fig
[docs]
class LogGradientScatter(LogMatplotlibToTensorboard):
"""Logs a scatter plot of the target and predicted gradients per basis function."""
[docs]
def get_figure(
self,
pl_module: MLDFTLitModule,
batch: Any,
outputs: STEP_OUTPUT,
basis_info: BasisInfo,
) -> Figure:
"""Create the figure to be plotted: A scatter plot of the target and predicted gradients per
basis function, as well as their projected difference.
Args:
pl_module: The lightning module.
batch: The batch.
outputs: The outputs of the lightning module.
basis_info: The :class:`BasisInfo` object.
Returns:
Figure: The figure to be plotted.
"""
gradient_label = batch.gradient_label.detach().cpu().numpy()
pred_gradients = outputs["model_outputs"]["pred_gradients"].detach().cpu().numpy()
projected_gradient_difference = (
outputs["projected_gradient_difference"].detach().cpu().numpy()
)
fig, ax = plt.subplots(1, figsize=(15, 5))
scatter_kwargs = dict(marker=".", alpha=0.1, s=1)
x = batch.basis_function_ind.cpu().numpy() + 0.5
ax.scatter(x, gradient_label, **scatter_kwargs, label="True Gradient")
ax.scatter(x, pred_gradients, **scatter_kwargs, label="Predicted Gradient")
ax.scatter(
x,
projected_gradient_difference,
**scatter_kwargs,
label="Projected Difference",
)
handles = [
Patch(color="C0", label="True Gradient"),
Patch(color="C1", label="Predicted Gradient"),
Patch(color="C2", label="Projected Difference"),
]
ax.legend(handles=handles, loc="upper right")
format_basis_func_xaxis(ax, basis_info)
fig.tight_layout()
return fig
[docs]
class LogDistanceEmbeddings(LogMatplotlibToTensorboard):
"""Logs a line plot of the distance embeddings for a range of distances."""
[docs]
def __init__(self, max_distance: float = 5.0, n_distances: int = 1000, **super_kwargs):
"""Plot the distance embeddings for a range of distances.
Args:
max_distance (float): The maximum distance to consider.
n_distances (int): The number of distances to consider.
Returns:
plt.Figure: The plot.
"""
super().__init__(**super_kwargs)
self.max_distance = max_distance
self.n_distances = n_distances
[docs]
def get_figure(
self,
pl_module: MLDFTLitModule,
batch: Any,
outputs: STEP_OUTPUT,
basis_info: BasisInfo,
) -> Figure:
"""Create the figure to be plotted: A line plot of the distance embeddings"""
if hasattr(pl_module.net, "plot_distance_embeddings"):
return pl_module.net.plot_distance_embeddings()
else:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.set_title("Module has no method plot_distance_embeddings.")
return fig