Source code for mldft.utils.visualize_3d

from functools import partial

import numpy as np
import pyscf
import pyvista as pv
from matplotlib.pyplot import rcParams
from pyscf import dft

from mldft.utils.conversions import pyscf_to_rdkit
from mldft.utils.cube_files import DataCube


[docs] def find_isosurface_value( cube_array: np.ndarray, quantile: float | np.ndarray = 0.9, p: int = 2 ) -> float | np.ndarray: """Find an isosurface value for a cube array, such that the isosurface contains a given fraction of the total mass. The mass is computed as the sum of the absolute values of the cube array raised to the power p. Args: cube_array: The cube array. quantile: The fraction (or array of fractions) of the total mass to be contained in the isosurface. p: The power to raise the cube array to. Use p=1 for electron density and p=2 for orbitals. Returns: The isosurface value. """ cube_array = np.abs(cube_array).flatten() mass = cube_array**p total_mass = mass.sum() mass_sorted = np.sort(mass)[::-1] ind = np.searchsorted(np.cumsum(mass_sorted), total_mass * quantile) isosurface_value = mass_sorted[ind] ** (1 / p) return isosurface_value
[docs] def _fix_cube_data(cube_data: str) -> str: """Fix a cube file string produced by pyscf. Undoes a pyscf bug by adding spaces where needed: e.g. replace incorrect lines like "-2.52899E-114-1.81539E-116-9.75640E-119-3.92559E-121-1.18255E-123-2.66702E-126" with "-2.52899E-114 -1.81539E-116 -9.75640E-119 -3.92559E-121 -1.18255E-123 -2.66702E-126". Args: cube_data: The cube data to fix. Returns: The fixed cube data. """ cube_data = cube_data.replace("-", " -") cube_data = cube_data.replace("E -", "E-") return cube_data
[docs] def _eval_orbital(coords: np.ndarray, mol: pyscf.gto.Mole, coeff: np.ndarray) -> np.ndarray: """Evaluate an orbital at the given coordinates. Args: coords: The coordinates to evaluate the orbital at, shape (n, 3). mol: The molecule. Returns: The orbital at the given coordinates, shape (n). """ ao = mol.eval_gto("GTOval", coords) orb_on_grid = np.einsum("...i,i", ao, coeff) return orb_on_grid
[docs] def visualize_orbital( mol: pyscf.gto.Mole, coeff: np.ndarray, resolution: float = 0.3, margin: float = 3.0, **plot_orbital_kwargs, ): """Visualize an orbital together with the molecule it belongs to. Args: mol: The molecule. coeff: The orbital coefficients. resolution: The resolution of the cube file. margin: The margin of the cube file. plot_orbital_kwargs: Keyword arguments to pass to plot_orbital. Returns: The pyvista plotter. """ cube = DataCube.from_function( mol=mol, func=partial(_eval_orbital, coeff=coeff, mol=mol), resolution=resolution, margin=margin, ) return plot_orbital(cube, **plot_orbital_kwargs)
[docs] def _eval_density(coords: np.ndarray, mol: pyscf.gto.Mole, dm: np.ndarray) -> np.ndarray: """Evaluate the electron density at the given coordinates. Args: coords: The coordinates to evaluate the density at, shape (n, 3). mol: The molecule. dm: The density matrix. Returns: The electron density at the given coordinates, shape (n). """ ao = mol.eval_gto("GTOval", coords) rho = dft.numint.eval_rho(mol, ao, dm) return rho
[docs] def visualize_density( mol: pyscf.gto.Mole, density_matrix: np.ndarray, resolution: float = 0.3, margin: float = 3.0, **plot_density_kwargs, ): """Visualize the electron density with the corresponding molecule. Args: mol: The molecule. density_matrix: The density matrix in the atomic orbital basis. resolution: The resolution of the cube file. margin: The margin of the cube file. plot_density_kwargs: Keyword arguments to pass to plot_density. Returns: The pyvista plotter. """ cube = DataCube.from_function( mol=mol, func=partial(_eval_density, dm=density_matrix, mol=mol), resolution=resolution, margin=margin, ) return plot_density(cube, **plot_density_kwargs)
ATOM_COLORS = { "C": "#c8c8c8", "H": "#ffffff", "N": "#8f8fff", "S": "#ffc832", "O": "#f00000", "F": "#ffff00", "P": "#ffa500", "K": "#42f4ee", "G": "#3f3f3f", } ATOM_IDS = {key: i for i, key in enumerate(ATOM_COLORS.keys())}
[docs] def get_sticks_mesh_dict( mol: pyscf.gto.Mole, bond_radius: float = 0.2, atom_radius: float = None, resolution=20, ) -> dict: """Create a 'sticks' representation for a pyscf molecule. By choosing a larger `atom_radius`, one can make this look like balls-and-sticks. Args: mol: The molecule. bond_radius: The radius of the cylinders representing the bonds. atom_radius: The radius of the spheres representing the atoms. If None, use `bond_radius`. resolution: The resolution of the cylinders and spheres. Returns: A dictionary with keyword arguments to pass to pyvista.Plotter.add_mesh. """ rdkit_mol = pyscf_to_rdkit(mol) atom_positions = ( rdkit_mol.GetConformer().GetPositions() * 1.8897259886 ) # rdkit gives values in Angstrom atom_radius = bond_radius if atom_radius is None else atom_radius mesh_elements = [] for i, atom in enumerate(rdkit_mol.GetAtoms()): mesh_elements_atom = [] pos_i = atom_positions[i] color_id = ATOM_IDS.get(atom.GetSymbol(), ATOM_IDS["G"]) sphere = pv.Sphere( center=pos_i, radius=atom_radius, phi_resolution=resolution, theta_resolution=resolution, ) sphere["color_ids"] = np.ones(sphere.n_cells) * color_id mesh_elements_atom.append(sphere) bonds = atom.GetBonds() for bond in bonds: inds = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() j = inds[0] if inds[1] == i else inds[1] # select the index of the other atom pos_j = atom_positions[j] cylinder = pv.Cylinder( center=(pos_i * 3 + pos_j) / 4, # one quarter of the way from i to j direction=pos_j - pos_i, radius=bond_radius, height=np.linalg.norm(pos_i - pos_j) / 2, capping=False, resolution=resolution, ) cylinder["color_ids"] = np.ones(cylinder.n_cells) * color_id mesh_elements_atom.append(cylinder) mesh_elements.extend(mesh_elements_atom) merged_mesh = pv.MultiBlock(mesh_elements).combine().extract_surface() add_mesh_kwargs = dict( mesh=merged_mesh, smooth_shading=True, diffuse=0.5, specular=0.5, ambient=0.5, clim=(0, len(ATOM_COLORS)), cmap=list(ATOM_COLORS.values()), show_scalar_bar=False, ) return add_mesh_kwargs
# colors of the three axes in the local frame visualization AXES_COLORS = ["#ff0000", "#00ff00", "#0000ff"]
[docs] def get_local_frames_mesh_dict( origins: np.ndarray, bases: np.ndarray, scale: float = 1, axes_radius_scale: float = 0.03, cone_scale: float = 2.0, cone_aspect_ratio: float = 2, resolution: int = 20, ) -> dict: """Create a mesh dict for a set of local frames, given by their origin and basis vectors. Args: origins: The origins of the local frames, shape (n, 3). bases: The basis vectors of the local frames, shape (n, 3=n_vectors, 3=n_axes). scale: The scale of the local frames. axes_radius_scale: The radius of the cylinders representing the axes. cone_scale: The scale of the cones representing the axes. cone_aspect_ratio: The aspect ratio of the cones representing the axes. resolution: The resolution of the cylinders and cones. Returns: A dictionary with keyword arguments to pass to :meth:`pyvista.Plotter.add_mesh`. """ assert ( bases.shape[2] == 3 and 0 < bases.shape[1] <= 3 ), f"Invalid shape of bases: {bases.shape}. Must be (n, 1|2|3, 3)." assert origins.shape[1] == 3, f"Invalid shape of origins: {origins.shape}. Must be (n, 3)." assert origins.shape[0] == bases.shape[0], ( f"Incompatible shapes of origins and bases: {origins.shape} and {bases.shape}. " f"Must be (n, 3) and (n, 1|2|3, 3)." ) mesh_elements = [] axes_radius = scale * axes_radius_scale cone_radius = axes_radius * cone_scale cone_height = cone_radius * cone_aspect_ratio * 2 for origin, bases in zip(origins, bases): for color_id, basis_vector in enumerate(bases): cylinder_height = scale * np.linalg.norm(basis_vector, axis=-1) - cone_height if cylinder_height < 0: cylinder_height = 0 basis_vector_length = np.linalg.norm(basis_vector) basis_vector = basis_vector / basis_vector_length dest = origin + basis_vector * cylinder_height cylinder = pv.Cylinder( center=(origin + dest) / 2, # one quarter of the way from i to j direction=dest - origin, radius=axes_radius, height=cylinder_height, resolution=resolution, ) cylinder["color_ids"] = np.ones(cylinder.n_points) * color_id mesh_elements.append(cylinder) # add a cone at the tip of the arrow this_cone_height = min(cone_height, basis_vector_length) cone = pv.Cone( center=origin + basis_vector * (cylinder_height + this_cone_height / 2), direction=basis_vector, height=this_cone_height, radius=cone_radius, resolution=resolution, ) cone = cone.extract_geometry() cone["color_ids"] = np.ones(cone.n_points) * color_id mesh_elements.append(cone) merged_mesh = pv.MultiBlock(mesh_elements).combine().extract_surface() add_mesh_kwargs = dict( mesh=merged_mesh, smooth_shading=True, diffuse=0.5, specular=0.5, ambient=0.5, clim=(0, len(AXES_COLORS) - 1), cmap=AXES_COLORS, show_scalar_bar=False, ) return add_mesh_kwargs
[docs] def plot_molecule( mol: pyscf.gto.Mole, plotter: pv.Plotter = None, figsize: tuple[int, int] = None, title: str = None, ) -> pv.Plotter: """Plot molecule. Args: mol: The pyscf molecule. plotter: A pyvista plotter to use. If None, a new plotter is created. figsize: The figure size in inches. title: The title of the plot. Returns: The pyvista plotter. """ if plotter is not None: pl = plotter else: plotter_kwargs = dict(image_scale=2) if figsize is not None: dpi = int(rcParams["figure.dpi"]) plotter_kwargs["window_size"] = [s * dpi for s in figsize] pl = pv.Plotter(**plotter_kwargs) pl.add_mesh(**get_sticks_mesh_dict(mol)) if title is not None: pl.add_title(title) pl.camera_position = "iso" if plotter is None: # if no plotter was passed, show the plot return pl.show() else: # otherwise, return the plotter, e.g. to save the image next return pl
[docs] def plot_orbital( cube: DataCube, mode: str = "auto", plot_molecule: bool = True, isosurface_quantile: float = None, plotter: pv.Plotter = None, figsize: tuple[int, int] = None, title: str = None, ) -> pv.Plotter: """Plot an electron orbital using pyvista. By default, the orbital is plotted as a volume. Args: cube: The `DataCube`. mode: The mode to use for plotting. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'. plot_molecule: Whether to plot the molecule. isosurface_quantile: The quantile of the total mass to be contained in the isosurface, for mode 'isosurface'. Defaults to 0.9. plotter: A pyvista plotter to use. If None, a new plotter is created. figsize: The figure size in inches. title: The title of the plot. Returns: The pyvista plotter. """ if mode == "auto": if isosurface_quantile is None: mode = "volume" else: mode = "isosurface" assert mode in [ "volume", "isosurface", "nested_isosurfaces", ], f"Invalid mode: {mode}. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'." if plotter is not None: pl = plotter else: plotter_kwargs = dict(image_scale=2) if figsize is not None: dpi = int(rcParams["figure.dpi"]) plotter_kwargs["window_size"] = [s * dpi for s in figsize] pl = pv.Plotter(**plotter_kwargs) orbital = cube.to_pyvista() if mode == "volume": neg_mask = orbital["data"] < 0 rgba = np.zeros((orbital.n_points, 4), np.uint8) rgba[neg_mask, 0] = 170 rgba[~neg_mask, 2] = 170 # normalize opacity, such that 0.1 quantile is fully opaque opac = np.abs(orbital["data"]) # ** 2 opac /= find_isosurface_value(opac, 0.25, p=1) opac = np.clip(opac, 0, 1) rgba[:, -1] = opac * 255 orbital["plot_scalars"] = rgba vol = pl.add_volume( orbital, opacity_unit_distance=10, scalars="plot_scalars", ) vol.prop.interpolation_type = "linear" if plot_molecule: mol = cube.mol pl.add_mesh(**get_sticks_mesh_dict(mol)) if mode in ["isosurface", "nested_isosurfaces"]: if isinstance(isosurface_quantile, float): assert ( mode == "isosurface" ), f'Invalid mode: {mode}. Must be "isosurface" for a single isosurface value' quantiles = np.array([isosurface_quantile]) elif mode == "nested_isosurfaces": quantiles = np.linspace(0.1, 0.99, 10) else: isosurface_quantile = 0.9 quantiles = np.array([isosurface_quantile]) isosurface_values = find_isosurface_value(orbital["data"], quantiles, p=2) # add negative isosurface values isosurface_values = np.concatenate([-isosurface_values, isosurface_values]) quantiles = np.concatenate([-quantiles, quantiles]) isosurface = orbital.contour(isosurfaces=isosurface_values) iso_mesh = isosurface.extract_geometry() iso_mesh["quantile"] = np.zeros(iso_mesh.n_points) iso_mesh["opacity"] = np.zeros(iso_mesh.n_points) for quantile, isosurface_value in zip(quantiles, isosurface_values): mask = iso_mesh["data"] == isosurface_value # squaring looks good with 'seismic' colormap iso_mesh["quantile"][mask] = quantile iso_mesh["opacity"][mask] = 1 - np.abs(quantile) if mode == "nested_isosurfaces": # add a set of nested, transparent isosurfaces iso_mesh["quantile"] = ( -np.sign(iso_mesh["quantile"]) * (1 - np.abs(iso_mesh["quantile"])) ** 2 ) pl.add_mesh( iso_mesh, scalars="quantile", clim=(-1, 1), cmap="seismic", opacity="opacity", smooth_shading=True, ambient=0.5, diffuse=0.5, # specular=0.2, show_scalar_bar=False, ) else: # add a single, opaque and shiny isosurface pl.add_mesh( iso_mesh, opacity=1, scalars="quantile", clim=(-isosurface_quantile, isosurface_quantile), cmap="bwr_r", smooth_shading=True, ambient=0.3, diffuse=0.7, specular=0.5, show_scalar_bar=False, ) if title is not None: pl.add_title(title) pl.camera_position = "iso" if plotter is None: # if no plotter was passed, show the plot return pl.show() else: # otherwise, return the plotter, e.g. to save the image next return pl
[docs] def plot_density( cube: DataCube = None, mode: str = "auto", plot_molecule: bool = True, isosurface_quantile: float = None, isosurface_opacity: float = 0.4, plotter: pv.Plotter = None, figsize: tuple[int, int] = None, title: str = None, cmap: str = None, ): """Plot an electron density using pyvista. By default, the orbital is plotted as a volume. Args: cube: The `DataCube`. mode: The mode to use for plotting. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'. plot_molecule: Whether to plot the molecule. isosurface_quantile: For mode 'isosurface', the quantile of the total mass to be contained in the isosurface. Defaults to 0.9. isosurface_opacity: For mode 'isosurface', the opacity of the isosurface. plotter: A pyvista plotter to use. If None, a new plotter is created. figsize: The figure size in inches. title: The title of the plot. cmap: The colormap to use. Returns: The pyvista plotter. """ if mode == "auto": if isosurface_quantile is None: mode = "volume" else: mode = "isosurface" assert mode in [ "volume", "isosurface", "nested_isosurfaces", ], f"Invalid mode: {mode}. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'." if plotter is not None: pl = plotter else: plotter_kwargs = dict(image_scale=2) if figsize is not None: dpi = int(rcParams["figure.dpi"]) plotter_kwargs["window_size"] = [s * dpi for s in figsize] pl = pv.Plotter(**plotter_kwargs) density = cube.to_pyvista() if mode == "volume": raise NotImplementedError if plot_molecule: mol = cube.mol pl.add_mesh(**get_sticks_mesh_dict(mol)) if mode in ["isosurface", "nested_isosurfaces"]: if isinstance(isosurface_quantile, float): assert ( mode == "isosurface" ), f'Invalid mode: {mode}. Must be "isosurface" for a single isosurface value' quantiles = np.array([isosurface_quantile]) elif mode == "nested_isosurfaces": n_surfaces = 20 quantiles = np.linspace(0, 1, n_surfaces + 2)[1:-1] else: isosurface_quantile = 0.9 quantiles = np.array([isosurface_quantile]) isosurface_values = find_isosurface_value(density["data"], quantiles, p=1) isosurface = density.contour(isosurfaces=isosurface_values) iso_mesh = isosurface.extract_geometry() iso_mesh["quantile"] = np.zeros(iso_mesh.n_points) # iso_mesh["opacity"] = np.zeros(iso_mesh.n_points) # for quantile, isosurface_value in zip(quantiles, isosurface_values): # mask = iso_mesh["data"] == isosurface_value # iso_mesh["quantile"][mask] = quantile # iso_mesh["opacity"][mask] = 1 - np.abs(quantile) iso_mesh["opacity"] = np.ones(iso_mesh.n_points) * 0.05 if mode == "nested_isosurfaces": # add a set of nested, transparent isosurfaces iso_mesh["quantile"] = -np.sign(iso_mesh["quantile"]) * ( 1 - np.abs(iso_mesh["quantile"]) ) pl.add_mesh( iso_mesh, # scalars="quantile", # clim=(0, 1), cmap=cmap, opacity=0.04, # "opacity", smooth_shading=True, ambient=0.5, diffuse=0.5, # specular=0.2, show_scalar_bar=False, ) else: # add a single, transparent and shiny isosurface pl.add_mesh( iso_mesh, opacity=isosurface_opacity, color="white", smooth_shading=True, ambient=0.4, diffuse=0.7, specular=0.3, show_scalar_bar=False, ) if title is not None: pl.add_title(title) pl.camera_position = "iso" if plotter is None: # if no plotter was passed, show the plot return pl.show() else: # otherwise, return the plotter, e.g. to save the image next return pl