from functools import partial
import numpy as np
import pyscf
import pyvista as pv
from matplotlib.pyplot import rcParams
from pyscf import dft
from mldft.utils.conversions import pyscf_to_rdkit
from mldft.utils.cube_files import DataCube
[docs]
def find_isosurface_value(
    cube_array: np.ndarray, quantile: float | np.ndarray = 0.9, p: int = 2
) -> float | np.ndarray:
    """Find an isosurface value for a cube array, such that the isosurface contains a given
    fraction of the total mass. The mass is computed as the sum of the absolute values of the cube
    array raised to the power p.
    Args:
        cube_array: The cube array.
        quantile: The fraction (or array of fractions) of the total mass to be contained in the isosurface.
        p: The power to raise the cube array to. Use p=1 for electron density and p=2 for orbitals.
    Returns:
        The isosurface value.
    """
    cube_array = np.abs(cube_array).flatten()
    mass = cube_array**p
    total_mass = mass.sum()
    mass_sorted = np.sort(mass)[::-1]
    ind = np.searchsorted(np.cumsum(mass_sorted), total_mass * quantile)
    isosurface_value = mass_sorted[ind] ** (1 / p)
    return isosurface_value 
[docs]
def _fix_cube_data(cube_data: str) -> str:
    """Fix a cube file string produced by pyscf. Undoes a pyscf bug by adding spaces where needed:
    e.g. replace incorrect lines like
    "-2.52899E-114-1.81539E-116-9.75640E-119-3.92559E-121-1.18255E-123-2.66702E-126" with
    "-2.52899E-114 -1.81539E-116 -9.75640E-119 -3.92559E-121 -1.18255E-123 -2.66702E-126".
    Args:
        cube_data: The cube data to fix.
    Returns:
        The fixed cube data.
    """
    cube_data = cube_data.replace("-", " -")
    cube_data = cube_data.replace("E -", "E-")
    return cube_data 
[docs]
def _eval_orbital(coords: np.ndarray, mol: pyscf.gto.Mole, coeff: np.ndarray) -> np.ndarray:
    """Evaluate an orbital at the given coordinates.
    Args:
        coords: The coordinates to evaluate the orbital at, shape (n, 3).
        mol: The molecule.
    Returns:
        The orbital at the given coordinates, shape (n).
    """
    ao = mol.eval_gto("GTOval", coords)
    orb_on_grid = np.einsum("...i,i", ao, coeff)
    return orb_on_grid 
[docs]
def visualize_orbital(
    mol: pyscf.gto.Mole,
    coeff: np.ndarray,
    resolution: float = 0.3,
    margin: float = 3.0,
    **plot_orbital_kwargs,
):
    """Visualize an orbital together with the molecule it belongs to.
    Args:
        mol: The molecule.
        coeff: The orbital coefficients.
        resolution: The resolution of the cube file.
        margin: The margin of the cube file.
        plot_orbital_kwargs: Keyword arguments to pass to plot_orbital.
    Returns:
        The pyvista plotter.
    """
    cube = DataCube.from_function(
        mol=mol,
        func=partial(_eval_orbital, coeff=coeff, mol=mol),
        resolution=resolution,
        margin=margin,
    )
    return plot_orbital(cube, **plot_orbital_kwargs) 
[docs]
def _eval_density(coords: np.ndarray, mol: pyscf.gto.Mole, dm: np.ndarray) -> np.ndarray:
    """Evaluate the electron density at the given coordinates.
    Args:
        coords: The coordinates to evaluate the density at, shape (n, 3).
        mol: The molecule.
        dm: The density matrix.
    Returns:
        The electron density at the given coordinates, shape (n).
    """
    ao = mol.eval_gto("GTOval", coords)
    rho = dft.numint.eval_rho(mol, ao, dm)
    return rho 
[docs]
def visualize_density(
    mol: pyscf.gto.Mole,
    density_matrix: np.ndarray,
    resolution: float = 0.3,
    margin: float = 3.0,
    **plot_density_kwargs,
):
    """Visualize the electron density with the corresponding molecule.
    Args:
        mol: The molecule.
        density_matrix: The density matrix in the atomic orbital basis.
        resolution: The resolution of the cube file.
        margin: The margin of the cube file.
        plot_density_kwargs: Keyword arguments to pass to plot_density.
    Returns:
        The pyvista plotter.
    """
    cube = DataCube.from_function(
        mol=mol,
        func=partial(_eval_density, dm=density_matrix, mol=mol),
        resolution=resolution,
        margin=margin,
    )
    return plot_density(cube, **plot_density_kwargs) 
ATOM_COLORS = {
    "C": "#c8c8c8",
    "H": "#ffffff",
    "N": "#8f8fff",
    "S": "#ffc832",
    "O": "#f00000",
    "F": "#ffff00",
    "P": "#ffa500",
    "K": "#42f4ee",
    "G": "#3f3f3f",
}
ATOM_IDS = {key: i for i, key in enumerate(ATOM_COLORS.keys())}
[docs]
def get_sticks_mesh_dict(
    mol: pyscf.gto.Mole,
    bond_radius: float = 0.2,
    atom_radius: float = None,
    resolution=20,
) -> dict:
    """Create a 'sticks' representation for a pyscf molecule. By choosing a larger `atom_radius`,
    one can make this look like balls-and-sticks.
    Args:
        mol: The molecule.
        bond_radius: The radius of the cylinders representing the bonds.
        atom_radius: The radius of the spheres representing the atoms. If None, use `bond_radius`.
        resolution: The resolution of the cylinders and spheres.
    Returns:
        A dictionary with keyword arguments to pass to pyvista.Plotter.add_mesh.
    """
    rdkit_mol = pyscf_to_rdkit(mol)
    atom_positions = (
        rdkit_mol.GetConformer().GetPositions() * 1.8897259886
    )  # rdkit gives values in Angstrom
    atom_radius = bond_radius if atom_radius is None else atom_radius
    mesh_elements = []
    for i, atom in enumerate(rdkit_mol.GetAtoms()):
        mesh_elements_atom = []
        pos_i = atom_positions[i]
        color_id = ATOM_IDS.get(atom.GetSymbol(), ATOM_IDS["G"])
        sphere = pv.Sphere(
            center=pos_i,
            radius=atom_radius,
            phi_resolution=resolution,
            theta_resolution=resolution,
        )
        sphere["color_ids"] = np.ones(sphere.n_cells) * color_id
        mesh_elements_atom.append(sphere)
        bonds = atom.GetBonds()
        for bond in bonds:
            inds = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            j = inds[0] if inds[1] == i else inds[1]  # select the index of the other atom
            pos_j = atom_positions[j]
            cylinder = pv.Cylinder(
                center=(pos_i * 3 + pos_j) / 4,  # one quarter of the way from i to j
                direction=pos_j - pos_i,
                radius=bond_radius,
                height=np.linalg.norm(pos_i - pos_j) / 2,
                capping=False,
                resolution=resolution,
            )
            cylinder["color_ids"] = np.ones(cylinder.n_cells) * color_id
            mesh_elements_atom.append(cylinder)
        mesh_elements.extend(mesh_elements_atom)
    merged_mesh = pv.MultiBlock(mesh_elements).combine().extract_surface()
    add_mesh_kwargs = dict(
        mesh=merged_mesh,
        smooth_shading=True,
        diffuse=0.5,
        specular=0.5,
        ambient=0.5,
        clim=(0, len(ATOM_COLORS)),
        cmap=list(ATOM_COLORS.values()),
        show_scalar_bar=False,
    )
    return add_mesh_kwargs 
# colors of the three axes in the local frame visualization
AXES_COLORS = ["#ff0000", "#00ff00", "#0000ff"]
[docs]
def get_local_frames_mesh_dict(
    origins: np.ndarray,
    bases: np.ndarray,
    scale: float = 1,
    axes_radius_scale: float = 0.03,
    cone_scale: float = 2.0,
    cone_aspect_ratio: float = 2,
    resolution: int = 20,
) -> dict:
    """Create a mesh dict for a set of local frames, given by their origin and basis vectors.
    Args:
        origins: The origins of the local frames, shape (n, 3).
        bases: The basis vectors of the local frames, shape (n, 3=n_vectors, 3=n_axes).
        scale: The scale of the local frames.
        axes_radius_scale: The radius of the cylinders representing the axes.
        cone_scale: The scale of the cones representing the axes.
        cone_aspect_ratio: The aspect ratio of the cones representing the axes.
        resolution: The resolution of the cylinders and cones.
    Returns:
        A dictionary with keyword arguments to pass to :meth:`pyvista.Plotter.add_mesh`.
    """
    assert (
        bases.shape[2] == 3 and 0 < bases.shape[1] <= 3
    ), f"Invalid shape of bases: {bases.shape}. Must be (n, 1|2|3, 3)."
    assert origins.shape[1] == 3, f"Invalid shape of origins: {origins.shape}. Must be (n, 3)."
    assert origins.shape[0] == bases.shape[0], (
        f"Incompatible shapes of origins and bases: {origins.shape} and {bases.shape}. "
        f"Must be (n, 3) and (n, 1|2|3, 3)."
    )
    mesh_elements = []
    axes_radius = scale * axes_radius_scale
    cone_radius = axes_radius * cone_scale
    cone_height = cone_radius * cone_aspect_ratio * 2
    for origin, bases in zip(origins, bases):
        for color_id, basis_vector in enumerate(bases):
            cylinder_height = scale * np.linalg.norm(basis_vector, axis=-1) - cone_height
            if cylinder_height < 0:
                cylinder_height = 0
            basis_vector_length = np.linalg.norm(basis_vector)
            basis_vector = basis_vector / basis_vector_length
            dest = origin + basis_vector * cylinder_height
            cylinder = pv.Cylinder(
                center=(origin + dest) / 2,  # one quarter of the way from i to j
                direction=dest - origin,
                radius=axes_radius,
                height=cylinder_height,
                resolution=resolution,
            )
            cylinder["color_ids"] = np.ones(cylinder.n_points) * color_id
            mesh_elements.append(cylinder)
            # add a cone at the tip of the arrow
            this_cone_height = min(cone_height, basis_vector_length)
            cone = pv.Cone(
                center=origin + basis_vector * (cylinder_height + this_cone_height / 2),
                direction=basis_vector,
                height=this_cone_height,
                radius=cone_radius,
                resolution=resolution,
            )
            cone = cone.extract_geometry()
            cone["color_ids"] = np.ones(cone.n_points) * color_id
            mesh_elements.append(cone)
    merged_mesh = pv.MultiBlock(mesh_elements).combine().extract_surface()
    add_mesh_kwargs = dict(
        mesh=merged_mesh,
        smooth_shading=True,
        diffuse=0.5,
        specular=0.5,
        ambient=0.5,
        clim=(0, len(AXES_COLORS) - 1),
        cmap=AXES_COLORS,
        show_scalar_bar=False,
    )
    return add_mesh_kwargs 
[docs]
def plot_molecule(
    mol: pyscf.gto.Mole,
    plotter: pv.Plotter = None,
    figsize: tuple[int, int] = None,
    title: str = None,
) -> pv.Plotter:
    """Plot molecule.
    Args:
        mol: The pyscf molecule.
        plotter: A pyvista plotter to use. If None, a new plotter is created.
        figsize: The figure size in inches.
        title: The title of the plot.
    Returns:
        The pyvista plotter.
    """
    if plotter is not None:
        pl = plotter
    else:
        plotter_kwargs = dict(image_scale=2)
        if figsize is not None:
            dpi = int(rcParams["figure.dpi"])
            plotter_kwargs["window_size"] = [s * dpi for s in figsize]
        pl = pv.Plotter(**plotter_kwargs)
    pl.add_mesh(**get_sticks_mesh_dict(mol))
    if title is not None:
        pl.add_title(title)
    pl.camera_position = "iso"
    if plotter is None:  # if no plotter was passed, show the plot
        return pl.show()
    else:  # otherwise, return the plotter, e.g. to save the image next
        return pl 
[docs]
def plot_orbital(
    cube: DataCube,
    mode: str = "auto",
    plot_molecule: bool = True,
    isosurface_quantile: float = None,
    plotter: pv.Plotter = None,
    figsize: tuple[int, int] = None,
    title: str = None,
) -> pv.Plotter:
    """Plot an electron orbital using pyvista. By default, the orbital is plotted as a volume.
    Args:
        cube: The `DataCube`.
        mode: The mode to use for plotting. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'.
        plot_molecule: Whether to plot the molecule.
        isosurface_quantile: The quantile of the total mass to be contained in the isosurface, for mode 'isosurface'.
            Defaults to 0.9.
        plotter: A pyvista plotter to use. If None, a new plotter is created.
        figsize: The figure size in inches.
        title: The title of the plot.
    Returns:
        The pyvista plotter.
    """
    if mode == "auto":
        if isosurface_quantile is None:
            mode = "volume"
        else:
            mode = "isosurface"
    assert mode in [
        "volume",
        "isosurface",
        "nested_isosurfaces",
    ], f"Invalid mode: {mode}. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'."
    if plotter is not None:
        pl = plotter
    else:
        plotter_kwargs = dict(image_scale=2)
        if figsize is not None:
            dpi = int(rcParams["figure.dpi"])
            plotter_kwargs["window_size"] = [s * dpi for s in figsize]
        pl = pv.Plotter(**plotter_kwargs)
    orbital = cube.to_pyvista()
    if mode == "volume":
        neg_mask = orbital["data"] < 0
        rgba = np.zeros((orbital.n_points, 4), np.uint8)
        rgba[neg_mask, 0] = 170
        rgba[~neg_mask, 2] = 170
        # normalize opacity, such that 0.1 quantile is fully opaque
        opac = np.abs(orbital["data"])  # ** 2
        opac /= find_isosurface_value(opac, 0.25, p=1)
        opac = np.clip(opac, 0, 1)
        rgba[:, -1] = opac * 255
        orbital["plot_scalars"] = rgba
        vol = pl.add_volume(
            orbital,
            opacity_unit_distance=10,
            scalars="plot_scalars",
        )
        vol.prop.interpolation_type = "linear"
    if plot_molecule:
        mol = cube.mol
        pl.add_mesh(**get_sticks_mesh_dict(mol))
    if mode in ["isosurface", "nested_isosurfaces"]:
        if isinstance(isosurface_quantile, float):
            assert (
                mode == "isosurface"
            ), f'Invalid mode: {mode}. Must be "isosurface" for a single isosurface value'
            quantiles = np.array([isosurface_quantile])
        elif mode == "nested_isosurfaces":
            quantiles = np.linspace(0.1, 0.99, 10)
        else:
            isosurface_quantile = 0.9
            quantiles = np.array([isosurface_quantile])
        isosurface_values = find_isosurface_value(orbital["data"], quantiles, p=2)
        # add negative isosurface values
        isosurface_values = np.concatenate([-isosurface_values, isosurface_values])
        quantiles = np.concatenate([-quantiles, quantiles])
        isosurface = orbital.contour(isosurfaces=isosurface_values)
        iso_mesh = isosurface.extract_geometry()
        iso_mesh["quantile"] = np.zeros(iso_mesh.n_points)
        iso_mesh["opacity"] = np.zeros(iso_mesh.n_points)
        for quantile, isosurface_value in zip(quantiles, isosurface_values):
            mask = iso_mesh["data"] == isosurface_value
            # squaring looks good with 'seismic' colormap
            iso_mesh["quantile"][mask] = quantile
            iso_mesh["opacity"][mask] = 1 - np.abs(quantile)
        if mode == "nested_isosurfaces":
            # add a set of nested, transparent isosurfaces
            iso_mesh["quantile"] = (
                -np.sign(iso_mesh["quantile"]) * (1 - np.abs(iso_mesh["quantile"])) ** 2
            )
            pl.add_mesh(
                iso_mesh,
                scalars="quantile",
                clim=(-1, 1),
                cmap="seismic",
                opacity="opacity",
                smooth_shading=True,
                ambient=0.5,
                diffuse=0.5,  # specular=0.2,
                show_scalar_bar=False,
            )
        else:
            # add a single, opaque and shiny isosurface
            pl.add_mesh(
                iso_mesh,
                opacity=1,
                scalars="quantile",
                clim=(-isosurface_quantile, isosurface_quantile),
                cmap="bwr_r",
                smooth_shading=True,
                ambient=0.3,
                diffuse=0.7,
                specular=0.5,
                show_scalar_bar=False,
            )
    if title is not None:
        pl.add_title(title)
    pl.camera_position = "iso"
    if plotter is None:  # if no plotter was passed, show the plot
        return pl.show()
    else:  # otherwise, return the plotter, e.g. to save the image next
        return pl 
[docs]
def plot_density(
    cube: DataCube = None,
    mode: str = "auto",
    plot_molecule: bool = True,
    isosurface_quantile: float = None,
    isosurface_opacity: float = 0.4,
    plotter: pv.Plotter = None,
    figsize: tuple[int, int] = None,
    title: str = None,
    cmap: str = None,
):
    """Plot an electron density using pyvista. By default, the orbital is plotted as a volume.
    Args:
        cube: The `DataCube`.
        mode: The mode to use for plotting. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'.
        plot_molecule: Whether to plot the molecule.
        isosurface_quantile: For mode 'isosurface', the quantile of the total mass to be contained in the isosurface.
            Defaults to 0.9.
        isosurface_opacity: For mode 'isosurface', the opacity of the isosurface.
        plotter: A pyvista plotter to use. If None, a new plotter is created.
        figsize: The figure size in inches.
        title: The title of the plot.
        cmap: The colormap to use.
    Returns:
        The pyvista plotter.
    """
    if mode == "auto":
        if isosurface_quantile is None:
            mode = "volume"
        else:
            mode = "isosurface"
    assert mode in [
        "volume",
        "isosurface",
        "nested_isosurfaces",
    ], f"Invalid mode: {mode}. Must be one of 'volume', 'isosurface', 'nested_isosurfaces'."
    if plotter is not None:
        pl = plotter
    else:
        plotter_kwargs = dict(image_scale=2)
        if figsize is not None:
            dpi = int(rcParams["figure.dpi"])
            plotter_kwargs["window_size"] = [s * dpi for s in figsize]
        pl = pv.Plotter(**plotter_kwargs)
    density = cube.to_pyvista()
    if mode == "volume":
        raise NotImplementedError
    if plot_molecule:
        mol = cube.mol
        pl.add_mesh(**get_sticks_mesh_dict(mol))
    if mode in ["isosurface", "nested_isosurfaces"]:
        if isinstance(isosurface_quantile, float):
            assert (
                mode == "isosurface"
            ), f'Invalid mode: {mode}. Must be "isosurface" for a single isosurface value'
            quantiles = np.array([isosurface_quantile])
        elif mode == "nested_isosurfaces":
            n_surfaces = 20
            quantiles = np.linspace(0, 1, n_surfaces + 2)[1:-1]
        else:
            isosurface_quantile = 0.9
            quantiles = np.array([isosurface_quantile])
        isosurface_values = find_isosurface_value(density["data"], quantiles, p=1)
        isosurface = density.contour(isosurfaces=isosurface_values)
        iso_mesh = isosurface.extract_geometry()
        iso_mesh["quantile"] = np.zeros(iso_mesh.n_points)
        # iso_mesh["opacity"] = np.zeros(iso_mesh.n_points)
        # for quantile, isosurface_value in zip(quantiles, isosurface_values):
        #     mask = iso_mesh["data"] == isosurface_value
        #     iso_mesh["quantile"][mask] = quantile
        #     iso_mesh["opacity"][mask] = 1 - np.abs(quantile)
        iso_mesh["opacity"] = np.ones(iso_mesh.n_points) * 0.05
        if mode == "nested_isosurfaces":
            # add a set of nested, transparent isosurfaces
            iso_mesh["quantile"] = -np.sign(iso_mesh["quantile"]) * (
                1 - np.abs(iso_mesh["quantile"])
            )
            pl.add_mesh(
                iso_mesh,
                # scalars="quantile",
                # clim=(0, 1),
                cmap=cmap,
                opacity=0.04,  # "opacity",
                smooth_shading=True,
                ambient=0.5,
                diffuse=0.5,  # specular=0.2,
                show_scalar_bar=False,
            )
        else:
            # add a single, transparent and shiny isosurface
            pl.add_mesh(
                iso_mesh,
                opacity=isosurface_opacity,
                color="white",
                smooth_shading=True,
                ambient=0.4,
                diffuse=0.7,
                specular=0.3,
                show_scalar_bar=False,
            )
    if title is not None:
        pl.add_title(title)
    pl.camera_position = "iso"
    if plotter is None:  # if no plotter was passed, show the plot
        return pl.show()
    else:  # otherwise, return the plotter, e.g. to save the image next
        return pl