Source code for mldft.ml.callbacks.mesh_logging

import logging
from functools import partial
from typing import Any, Callable

import lightning as pl
import numpy as np
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from pyscf import dft

from mldft.ml.callbacks.base import OnStepCallbackWithTiming
from mldft.ml.data.components.basis_info import BasisInfo
from mldft.ml.data.components.of_data import Representation
from mldft.ml.models.mldft_module import MLDFTLitModule
from mldft.utils.cube_files import DataCube
from mldft.utils.molecules import build_molecule_ofdata
from mldft.utils.visualize_3d import (
    ATOM_COLORS,
    find_isosurface_value,
    get_sticks_mesh_dict,
)

logger = logging.getLogger(__name__)


[docs] def hex_to_rgb(hex_color: str) -> tuple: """Converts a hex color string to an rgb tuple. Args: hex_color: The hex color string, e.g. ``#FF0000``. Returns: tuple: The rgb tuple, e.g. ``(255, 0, 0)``. """ hex_color = hex_color.lstrip("#") h_len = len(hex_color) return tuple(int(hex_color[i : i + h_len // 3], 16) for i in range(0, h_len, h_len // 3))
[docs] def combine_mesh_dicts(mesh_dicts: list[dict]) -> dict: """Combines multiple mesh dicts into one. Args: mesh_dicts: A list of mesh dicts. Each dict must have exactly the keys ``vertices``, ``faces``, ``colors``. Returns: dict: The combined mesh dict. """ for mesh_dict in mesh_dicts: assert mesh_dict.keys() == { "vertices", "faces", "colors", }, f'Keys of mesh dicts must be "vertices", "faces", "colors". Got {mesh_dicts[0].keys()}' vertex_ind_offsets = np.cumsum( [0] + [mesh_dict["vertices"].shape[1] for mesh_dict in mesh_dicts] ) vertices = np.concatenate([mesh_dict["vertices"] for mesh_dict in mesh_dicts], axis=1) faces = np.concatenate( [ vertex_ind_offset + mesh_dict["faces"] for mesh_dict, vertex_ind_offset in zip(mesh_dicts, vertex_ind_offsets) ], axis=1, ) colors = np.concatenate([mesh_dict["colors"] for mesh_dict in mesh_dicts], axis=1) return dict(vertices=vertices, faces=faces, colors=colors)
[docs] class LogMolecule(OnStepCallbackWithTiming): """Logs the molecule to tensorboard as a mesh."""
[docs] def __init__( self, max_molecule_resolution: int = 10, max_isosurface_resolution: float = 0.2, basis_transformation: str | Callable | None = "auto", log_initial_guess: bool = True, log_gradient: bool = True, log_random_basis_functions: bool = False, **super_kwargs, ) -> None: """ Args: max_molecule_resolution: The maximum "resolution" of the sticks representation of the molecule. max_isosurface_resolution: The maximum grid resolution used to compute the isosurface, in Bohr. basis_transformation: The basis transformation to be applied to the molecule. Can be either ``None``, 'auto', or a callable, e.g. a value of :const:`BASIS_TRANSFORMATIONS`. If 'auto', the basis transformation is chosen based on the label_subdir of the datamodule. ``None`` means no transformation is applied. Defaults to 'auto'. log_initial_guess: Whether to log the initial guess. Defaults to True. log_gradient: Whether to log the gradient. Defaults to True. log_random_basis_functions: Whether to log some random basis functions. Defaults to False. See :meth:`LogMolecule.get_mesh_kwargs` for details. **super_kwargs: Keyword arguments passed to :class:`OnStepCallbackWithTiming`. """ super().__init__(**super_kwargs) self.max_molecule_resolution = max_molecule_resolution self.max_isosurface_resolution = max_isosurface_resolution self.basis_transformation = basis_transformation self.log_initial_guess = log_initial_guess self.log_gradient = log_gradient self.log_random_basis_functions = log_random_basis_functions # Camera and scene configuration. # for documentation, see https://threejs.org/docs self.config_dict = { # use a relatively small fov in order to be able to tell when things are parallel / perpendicular "camera": {"cls": "PerspectiveCamera", "fov": 10}, "lights": [ { "cls": "AmbientLight", "color": "#ffffff", "intensity": 0.75, }, { "cls": "DirectionalLight", "color": "#ffffff", "intensity": 0.75, "position": [0, -1, 2], }, { "cls": "DirectionalLight", "color": "#ffffff", "intensity": 0.5, "position": [0, 1, -2], }, ], "material": { "cls": "MeshStandardMaterial", "roughness": 0.4, "metalness": 0.0, "transparent": True, "opacity": 0.75, }, }
[docs] def get_molecule_dict(self, mol: Any) -> dict: """Get a sticks representation of the molecule as a mesh dict.""" # get the sticks mesh, adapting the resolution to get at most 10000 points resolution = self.max_molecule_resolution sticks_mesh_dict = get_sticks_mesh_dict(mol, resolution=resolution) while resolution > 1 and sticks_mesh_dict["mesh"].n_points > 10000: resolution //= 2 sticks_mesh_dict = get_sticks_mesh_dict(mol, resolution=resolution) mesh = sticks_mesh_dict["mesh"] mesh = mesh.triangulate() assert mesh.is_all_triangles, "Need a mesh with all triangles" vertices = mesh.points[None] # first index is the number of vertices per face, here always 3 faces = mesh.faces.reshape((-1, 4))[None, :, 1:] assert np.allclose( mesh["color_ids"], np.round(mesh["color_ids"]) ), "color_ids must be integers" atom_colors = np.array( [hex_to_rgb(hex_color) for hex_color in ATOM_COLORS.values()], dtype=np.int32 ) face_vertex_colors = atom_colors[None, np.round(mesh["color_ids"]).astype(np.int32)] # the colors are one per each occurrence of the vertex (so multiple if it appears in multiple faces) # we need a single color per vertex, so we take an arbitrary one. # if a vertex is not in any face, it will be black colors = np.zeros((1, vertices.shape[1], 3), dtype=np.int32) for i in range(3): # iterate over the indices of the corners of (all) the triangles colors[:, faces[0, :, i]] = face_vertex_colors return dict( vertices=vertices.copy(), faces=faces.copy(), colors=colors.copy(), )
[docs] def get_mesh_kwargs( self, trainer: pl.Trainer, pl_module: MLDFTLitModule, batch: Any, outputs: STEP_OUTPUT, basis_info: BasisInfo, ) -> dict: """Create a dictionary of meshes to be logged. Args: trainer: The lightning trainer. pl_module: The lightning module. batch: The batch. outputs: The outputs of the lightning module. basis_info: The :class:`BasisInfo` object. Returns: dict: A dict of dicts where each value is a dict with arguments to ``SummaryWriter.add_mesh``. """ result_dict = {} logging_keys = [] sample = batch.to_data_list()[0].detach().cpu() transform = trainer.val_dataloaders.dataset.transforms # construct the molecule mol = build_molecule_ofdata(ofdata=sample, basis=basis_info.basis_dict) try: mol_mesh_dict = self.get_molecule_dict(mol) except Exception as e: # this can happen e.g. if rdkit is unhappy with the molecule logger.error(f"Could not build molecule: {e}") return dict() mol_mesh_dict_with_scf_iter = mol_mesh_dict.copy() mol_mesh_dict_with_scf_iter["scalars"] = dict(scf_iteration=sample.scf_iteration.item()) if self.log_initial_guess: # this works only for the first sample! pred_coeffs = ( outputs["model_outputs"]["pred_diff"][: len(sample.coeffs)].detach().cpu() ) gt_coeffs = sample.coeffs - sample.ground_state_coeffs coeff_diff = pred_coeffs - gt_coeffs result_dict["initial_guess/0_mol"] = mol_mesh_dict_with_scf_iter sample.add_item("initial_guess/1_gt", gt_coeffs, Representation.VECTOR) sample.add_item("initial_guess/2_pred", pred_coeffs, Representation.VECTOR) sample.add_item("initial_guess/3_diff", coeff_diff, Representation.VECTOR) logging_keys.extend( ["initial_guess/1_gt", "initial_guess/2_pred", "initial_guess/3_diff"] ) if self.log_gradient: gradient_label = sample.gradient_label.detach().cpu() pred_gradients = ( outputs["model_outputs"]["pred_gradients"][: len(sample.coeffs)].detach().cpu() ) projected_gradient_difference = ( outputs["projected_gradient_difference"][: len(sample.coeffs)].detach().cpu() ) sample.add_item("gradient/1_gt", gradient_label, Representation.GRADIENT) sample.add_item("gradient/2_pred", pred_gradients, Representation.GRADIENT) sample.add_item( "gradient/3_diff", projected_gradient_difference, Representation.GRADIENT, ) logging_keys.extend(["gradient/1_gt", "gradient/2_pred", "gradient/3_diff"]) result_dict["gradient/0_mol"] = mol_mesh_dict_with_scf_iter if self.log_random_basis_functions: basis_func_to_l = basis_info.l_per_basis_func[sample.basis_function_ind.cpu().numpy()] basis_func_to_m = basis_info.m_per_basis_func[sample.basis_function_ind.cpu().numpy()] logging_keys = [] # plot some random s, l, and d functions, to check that their orientation is correct for i, (l, n) in enumerate(((0, 1), (1, 1), (1, 3), (1, 3), (2, 2), (3, 2), (4, 2))): debug_local_frames_coeffs = np.zeros_like(sample.coeffs) ind = np.random.choice( np.argwhere(np.bitwise_and(basis_func_to_l == l, basis_func_to_m == 0))[ :, 0 ], # only m == 0 # np.argwhere(basis_func_to_l == l)[:, 0], # random m size=n, replace=False, ) debug_local_frames_coeffs[ind] = 1 debug_local_frames_coeffs_torch = torch.as_tensor(debug_local_frames_coeffs) sample.add_item( f"basis_functions/{i}", debug_local_frames_coeffs_torch, Representation.VECTOR, ) logging_keys.append(f"basis_functions/{i}") def get_value(coords: np.ndarray, coeffs: np.ndarray) -> np.ndarray: """Evaluate the value to compute the isosurfaces from on a given array of coords (shape (N, 3)).""" ao = dft.numint.eval_ao(mol, coords, deriv=0) return np.dot(ao, coeffs) sample = transform.invert_basis_transform(sample) for key in logging_keys: coeffs = sample[key].float().numpy() resolution = self.max_isosurface_resolution while True: # 1. compute the value on the grid cube = DataCube.from_function( mol=mol, func=partial(get_value, coeffs=coeffs), resolution=resolution, ) pyvista_grid = cube.to_pyvista() # 2. compute the isosurface quantiles = np.array([0.5]) # hardcoded for now isosurface_value = find_isosurface_value(pyvista_grid["data"], quantiles, p=1) # add negative isosurface values isosurface_values = np.concatenate([-isosurface_value, isosurface_value]) isosurface = pyvista_grid.contour(isosurfaces=isosurface_values) iso_mesh = isosurface.extract_geometry() if iso_mesh.n_points <= 50000: break else: resolution *= 2 if iso_mesh.n_points > 0: iso_mesh_colors = np.empty((iso_mesh.n_points, 3), dtype=np.int32) # assign colors based on isosurface value positive_mask = iso_mesh["data"] > 0 iso_mesh_colors[positive_mask] = np.array([255, 0, 0]) # red iso_mesh_colors[~positive_mask] = np.array([0, 0, 255]) # blue # orient the negative faces correctly (otherwise they are see-through from the outside) iso_faces = iso_mesh.faces.reshape((-1, 4))[:, 1:].copy() negative_face_mask = iso_mesh["data"][iso_faces[:, 0]] < 0 iso_faces[negative_face_mask] = iso_faces[negative_face_mask][:, ::-1] # 3. add the isosurface to the faces, vertices, corners mesh_dict = combine_mesh_dicts( [ mol_mesh_dict, dict( vertices=iso_mesh.points[None], faces=iso_faces[None], colors=iso_mesh_colors[None], ), ] ) mesh_dict["scalars"] = dict(isovalue=float(isosurface_value[0])) else: logger.debug("No isosurface found") mesh_dict = mol_mesh_dict result_dict[f"_{key}"] = mesh_dict return result_dict
[docs] def execute( self, trainer: pl.Trainer, pl_module: MLDFTLitModule, outputs: STEP_OUTPUT, batch: Any, split: str, ) -> None: """Logs a mesh to tensorboard at the end of a training batch, if the timing matches.""" logger.debug(f"Logging training mesh to tensorboard from {self.__class__.__name__}") basis_info = pl_module.basis_info tb_logger = pl_module.tensorboard_logger for key, mesh_dict in self.get_mesh_kwargs( trainer, pl_module, batch, outputs, basis_info ).items(): if "scalars" in mesh_dict: for scalar_name, scalar_value in mesh_dict.pop("scalars").items(): tb_logger.add_scalar( # add "y_" for ordering of tags in tensorboard # (below everything but logging mixins, which start with "z_") tag=f"y_mesh_{split}_{self.name}{key}_{scalar_name}", scalar_value=scalar_value, global_step=trainer.global_step, ) tb_logger.add_mesh( f"{split}_{self.name}{key}", **mesh_dict, config_dict=self.config_dict, global_step=trainer.global_step, )