import warnings
from typing import IO
import numpy as np
import pyscf
from pyscf import gto
from pyscf.tools import cubegen
[docs]
class DataCube:
"""A data cube, containing a field on a cartesian grid, spanning a molecule."""
[docs]
def __init__(self, mol: pyscf.gto.Mole, box: np.ndarray, origin: np.ndarray, data: np.ndarray):
"""Initialize a DataCube.
Args:
mol: The pyscf molecule.
box: The vectors spanning the box, defining x y and z directions.
origin: The origin of the box.
data: The data, shape (nx, ny, nz).
"""
self.mol = mol
self.box = box
self.origin = origin
self.data = data
[docs]
def __str__(self):
return f"DataCube(mol={self.mol}, box={self.box}, origin={self.origin}, data.shape={self.data.shape})"
[docs]
@classmethod
def from_file(cls, filename: str | IO, is_tiling_unit_cell=False) -> "DataCube":
"""Read a cube file. Adapted from `pyscf.tools.cubegen.Cube.read`.
For details on the format see
https://h5cube-spec.readthedocs.io/en/latest/cubeformat.html#cubeformat-dset-ids
Args:
filename: The filename.
is_tiling_unit_cell: Whether to use an asymmetric mesh for tiling unit cells.
Returns:
A DataCube object.
"""
with open(filename) as f:
return cls.from_fileobject(f, is_tiling_unit_cell)
[docs]
@classmethod
def from_fileobject(cls, f: IO, is_tiling_unit_cell=False):
"""Read a cube file. Adapted from `pyscf.tools.cubegen.Cube.read`.
For details on the format see
https://h5cube-spec.readthedocs.io/en/latest/cubeformat.html#cubeformat-dset-ids
Args:
f: The file object.
is_tiling_unit_cell: Whether to use an asymmetric mesh for tiling unit cells.
Returns:
A DataCube object.
"""
f.readline()
f.readline()
data = f.readline().split()
natm = int(data[0])
has_dset_ids = natm < 0 # this indicates another line after molecule description
natm = np.abs(natm)
box_origin = np.array([float(x) for x in data[1:]])
def parse_nx(data):
d = data.split()
nx = int(d[0])
x_vec = np.array([float(x) for x in d[1:]]) * nx
if is_tiling_unit_cell:
# Use an asymmetric mesh for tiling unit cells
xs = np.linspace(0, 1, nx, endpoint=False)
else:
# Use endpoint=True to get a symmetric mesh
# see also the discussion https://github.com/sunqm/pyscf/issues/154
xs = np.linspace(0, 1, nx, endpoint=True)
return x_vec, nx, xs
# get box dimensions and resolution
box = np.zeros((3, 3))
box[0], nx, xs = parse_nx(f.readline())
box[1], ny, ys = parse_nx(f.readline())
box[2], nz, zs = parse_nx(f.readline())
# construct the molecule
atoms = []
for ia in range(natm):
d = f.readline().split()
atoms.append([int(d[0]), [float(x) for x in d[2:]]])
mol = gto.M(atom=atoms, unit="Bohr")
if has_dset_ids:
warnings.warn("Cube file has DSET_IDS. Parsing these is not implemented, ignoring.")
f.readline()
# read and reshape cube data
data = f.read()
cube_data = np.array([float(x) for x in data.split()])
assert nx * ny * nz == len(cube_data), f"{nx*ny*nz=} != {len(cube_data)}"
cube_data = cube_data.reshape([nx, ny, nz])
return cls(
mol=mol,
box=box,
origin=box_origin,
data=cube_data,
)
[docs]
@classmethod
def from_function(
cls,
mol: pyscf.gto.Mole,
func: callable,
resolution: float = cubegen.RESOLUTION,
margin: float = cubegen.BOX_MARGIN,
block_size: int = 16384,
):
"""Create a DataCube by evaluating a given function on a cartesian grid spanning a
molecule.
Args:
mol: The pyscf molecule.
func: The function to evaluate on the grid. Must take an array of positions of shape (n, 3) as input.
resolution: The resolution of the grid in Bohr.
margin: The margin of the box.
block_size: The blocksize for the evaluation. Defaults to 16384.
"""
cc = cubegen.Cube(mol, nx=None, ny=None, nz=None, resolution=resolution, margin=margin)
coords = cc.get_coords()
ngrids = cc.get_ngrids()
block_size = ngrids if block_size is None else block_size
data = np.empty(ngrids)
for ip0, ip1 in pyscf.lib.prange(0, ngrids, block_size):
data[ip0:ip1] = func(coords[ip0:ip1])
data = data.reshape(cc.nx, cc.ny, cc.nz)
# off by one bs
box = np.array([b * s / (s - 1) for b, s in zip(cc.box, data.shape)])
return cls(mol, box, cc.boxorig, data)
[docs]
def to_pyvista(self):
"""Convert to a pyvista grid.
Args:
self: The DataCube.
Returns:
The pyvista grid.
"""
import pyvista as pv
assert np.max(np.abs(self.box[~np.eye(3, dtype=bool)])) == 0, "Box must be diagonal."
grid = pv.ImageData(
dimensions=self.data.shape,
origin=self.origin,
spacing=[self.box[i, i] / self.data.shape[i] for i in range(3)],
)
grid.point_data["data"] = self.data.ravel(order="F")
return grid
# def to_data_string(self, field, fname, comment=None):
#
# result = StringIO()
# cubegen.Cube.write()
#