Source code for mldft.ofdft.torch_functionals

"""Implementation of functionals in PyTorch.

The code is based on https://github.com/sail-sg/jax_xc and
https://gitlab.com/libxc/libxc/-/blob/master/src/gga_k_apbe.c
https://gitlab.com/libxc/libxc/-/blob/master/src/maple2c/gga_exc/gga_k_apbe.c
"""

import torch
from loguru import logger
from pyscf import dft, gto

from mldft.utils.coeffs_to_grid import coeffs_to_rho_and_derivatives
from mldft.utils.grids import compute_max_block_size, get_grid_blocks


@torch.jit.script
def cbrt(x: torch.Tensor) -> torch.Tensor:
    """Cube root helper function allowing negative values.

    Args:
        x: input tensor

    Returns:
        torch.Tensor: cube root of x
    """
    return torch.sign(x) * torch.pow(torch.abs(x), 1 / 3)


@torch.jit.script
def _unpolarized_gga_k_apbe(
    rho: torch.Tensor,
    sigma: torch.Tensor,
    kappa: torch.Tensor = torch.tensor(0.804),
    mu: torch.Tensor = torch.tensor(0.23889),
    zeta_threshold: torch.Tensor = torch.tensor(2.220446049250313e-16),
    dens_threshold: torch.Tensor = torch.tensor(1e-15),
) -> torch.Tensor:
    """Calculation for the kinetic energy density of the APBE functional.

    Args:
        rho: density tensor of shape (ngrid)
        sigma: gradient squared tensor of shape (ngrid)
        kappa: parameter kappa of the APBE functional
        mu: parameter mu of the APBE functional
        zeta_threshold: some threshold
        dens_threshold: threshold of the density, return 0 if rho < dens_threshold

    Returns:
        torch.Tensor: kinetic energy density of the APBE functional
    """
    # Some pre-calculations using the zeta_threshold
    device = rho.device
    kappa.to(device)
    mu.to(device)
    zeta_threshold.to(device)
    dens_threshold.to(device)
    if 1.0 <= zeta_threshold:
        t12 = -zeta_threshold + 1.0
    else:
        t12 = torch.tensor(0.0, device=device)
    t13 = 1.0 + t12
    t16 = cbrt(zeta_threshold) ** 2
    t18 = cbrt(t13)
    t19 = t18**2
    if t13 <= zeta_threshold:
        pre_factor = t16 * zeta_threshold
    else:
        pre_factor = t19 * t13

    # Real calculation
    pi = torch.tensor(torch.pi, device=device)
    pre_factor2 = (
        pre_factor * 2.0 * 3.0 / 20.0 * cbrt(torch.tensor(3, device=device)) ** 2 * pi ** (4 / 3)
    )
    mask = rho / 0.2e1 >= dens_threshold
    res = torch.zeros_like(rho)
    res[mask] = pre_factor2 * (
        cbrt(rho[mask]) ** 2
        * (
            1.0
            + kappa
            * (
                1.0
                - kappa
                / (
                    kappa
                    + mu
                    * cbrt(torch.tensor(6, device=device))
                    * sigma[mask]
                    * cbrt(torch.tensor(2, device=device)) ** 2
                    / pi ** (4 / 3)
                    / cbrt(rho[mask]) ** 2
                    / rho[mask] ** 2
                    / 24.0
                )
            )
        )
    )
    return res


@torch.jit.script
def _unpolarized_gga_c_pbe(
    rho: torch.Tensor,
    sigma: torch.Tensor,
    beta: torch.Tensor = torch.tensor(0.06672455060314922),
    gamma: torch.Tensor = torch.tensor(0.031090690869654894),
    BB: torch.Tensor = torch.tensor(1.0),
    zeta_threshold: torch.Tensor = torch.tensor(2.220446049250313e-16),
    dens_threshold: torch.Tensor = torch.tensor(1e-12),
):
    """Calculation for the correlation energy density of the PBE functional."""
    result = torch.zeros_like(rho)
    mask = rho >= dens_threshold
    rho_masked = rho[mask]
    sigma_masked = sigma[mask]
    device = rho.device
    pi = torch.tensor(torch.pi, device=device)
    t1 = cbrt(torch.tensor(3, device=device))
    t3 = cbrt(1.0 / pi)
    t5 = cbrt(torch.tensor(4, device=device))
    t6 = t5**2
    t7 = cbrt(rho_masked)
    t10 = t1 * t3 * t6 / t7
    t13 = torch.sqrt(t10)
    t16 = t10**0.15e1
    t18 = t1**2
    t19 = t3**2
    t21 = t7**2
    t24 = t18 * t19 * t5 / t21
    t30 = torch.log(
        1.0
        + 0.16081979498692535067e2
        / (0.379785e1 * t13 + 0.8969 * t10 + 0.204775 * t16 + 0.123235 * t24)
    )
    t32 = 0.621814e-1 * (0.1e1 + 0.53425e-1 * t10) * t30
    t33 = 0.1e1 <= zeta_threshold
    t34 = cbrt(zeta_threshold)
    if t33:
        t36 = t34 * zeta_threshold
    else:
        t36 = torch.tensor(1.0, device=device)
    t39 = cbrt(torch.tensor(2.0, device=device))
    t54 = torch.log(
        0.1e1
        + 0.29608749977793437516e2
        / (0.51785e1 * t13 + 0.905775 * t10 + 0.1100325 * t16 + 0.1241775 * t24)
    )
    t57 = (
        0.19751673498613801407e-1
        * (0.2e1 * t36 - 0.2e1)
        / (0.2e1 * t39 - 0.2e1)
        * (0.1e1 + 0.278125e-1 * t10)
        * t54
    )
    t58 = t34**2
    t59 = torch.where(t33, t58, 1)
    t60 = t59**2
    t61 = t60 * t59
    t63 = rho_masked**2
    t76 = 0.1e1 / gamma
    t81 = torch.exp(-(-t32 + t57) * t76 / t61)
    t83 = 0.1e1 / (t81 - 0.1e1)
    t85 = sigma_masked**2
    t88 = t63**2
    t91 = t39**2
    t93 = t60**2
    t102 = (
        sigma_masked / t7 / t63 * t39 / t60 * t18 / t3 * t5 / 0.96e2
        + BB * beta * t76 * t83 * t85 / t21 / t88 * t91 / t93 * t1 / t19 * t6 / 0.3072e4
    )
    t112 = torch.log(0.1e1 + beta * t102 * t76 / (beta * t76 * t83 * t102 + 0.1e1))
    result[mask] = gamma * t61 * t112 - t32 + t57
    return result


@torch.jit.script
def _unpolarized_gga_x_pbe(
    rho: torch.Tensor,
    sigma: torch.Tensor,
    kappa: torch.Tensor = torch.tensor(0.804),
    mu: torch.Tensor = torch.tensor(0.2195149727645171),
    zeta_threshold: torch.Tensor = torch.tensor(2.220446049250313e-16),
    dens_threshold: torch.Tensor = torch.tensor(1e-15),
):
    """Calculation for the exchange energy density of the PBE functional."""
    result = torch.zeros_like(rho)
    mask = rho / 2.0 > dens_threshold
    rho_masked = rho[mask]
    sigma_masked = sigma[mask]
    device = rho.device
    pi = torch.tensor(torch.pi, device=device)
    kappa.to(device)
    mu.to(device)
    zeta_threshold.to(device)
    dens_threshold.to(device)

    t3 = cbrt(torch.tensor(3, device=device))
    t4 = cbrt(pi)
    t7 = 0.1e1 <= zeta_threshold
    t8 = zeta_threshold - 0.1e1
    t10 = torch.where(t7, -t8, 0)
    t11 = torch.where(t7, t8, t10)
    t12 = 0.1e1 + t11
    t14 = cbrt(zeta_threshold)
    t16 = cbrt(t12)
    t18 = torch.where(t12 <= zeta_threshold, t14 * zeta_threshold, t16 * t12)
    t19 = cbrt(rho_masked)
    t21 = cbrt(torch.tensor(6.0, device=device))
    t23 = pi**2
    t24 = cbrt(t23)
    t25 = t24**2
    t28 = cbrt(torch.tensor(2, device=device))
    t29 = t28**2
    t31 = rho_masked**2
    t32 = t19**2
    t47 = (
        -0.3e1
        / 0.8e1
        * t3
        / t4
        * t18
        * t19
        * (
            0.1e1
            + kappa
            * (0.1e1 - kappa / (kappa + mu * t21 / t25 * sigma_masked * t29 / t32 / t31 / 0.24e2))
        )
    )
    result[mask] = 0.2e1 * t47
    return result


str_to_torch_functionals = {
    "PBE": [_unpolarized_gga_x_pbe, _unpolarized_gga_c_pbe],
    "APBE": [_unpolarized_gga_k_apbe],
    "GGA_K_APBE": [_unpolarized_gga_k_apbe],
    "GGA_C_PBE": [_unpolarized_gga_c_pbe],
    "GGA_X_PBE": [_unpolarized_gga_x_pbe],
}


[docs] def torch_functional( density_and_gradient: torch.Tensor, functional: str, get_gradient: bool = True, **kwargs ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: r"""Wrapper for the computation of the kinetic energy density of the functional. Before calling the actual computation, the input is converted to torch tensors and the squared gradient is calculated. Args: density_and_gradient: density and density gradient tensor of shape (d, ngrid) functional: functional to use get_gradient: whether to return the gradient of the kinetic energy density kwargs: additional arguments for the functional Returns: tuple of the energy density of the functional divided by the density, gradient of the kinetic energy density of the functional wrt. the density rho on the grid and gradient of the kinetic energy density of the functional wrt. :math:`\sigma = |\nabla \rho|^2` on the grid """ if get_gradient: density_and_gradient = density_and_gradient.requires_grad_(True) rho_torch = density_and_gradient[0] grad_torch = density_and_gradient[1:4] sigma = torch.einsum("ij,ij->j", grad_torch, grad_torch) energy_density = sum( f(rho_torch, sigma, **kwargs) for f in str_to_torch_functionals[functional] ) if not get_gradient: return energy_density, None, None real_energy_density = energy_density * rho_torch energy_density_gradient_rho = torch.autograd.grad( outputs=real_energy_density.sum(), inputs=rho_torch, retain_graph=True, )[0] energy_density_gradient_s = torch.autograd.grad( outputs=real_energy_density.sum(), inputs=sigma, retain_graph=True, )[0] return energy_density, energy_density_gradient_rho, energy_density_gradient_s
[docs] def eval_torch_functionals( coeffs: torch.Tensor, ao: torch.Tensor, grid_weights: torch.Tensor, functionals: list[str], ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: """Computes the density and evaluates the given functionals on the grid. Args: coeffs: coefficients of the basis functions ao: atomic orbitals of the molecule in the basis grid_weights: weights of grid on which to evaluate the functionals, same as used for the ao calculation functionals: list of functionals to evaluate Returns: List of tuples containing the energy and gradient of the given functionals. """ output = {} coeffs.requires_grad = True rho = coeffs_to_rho_and_derivatives(coeffs, ao, max_derivative_order=1) for i, functional in enumerate(functionals): energy_density = torch_functional(rho, functional, get_gradient=False)[0] energy = torch.einsum("i,i,i->", rho[0], energy_density, grid_weights) # If it is the last iteration, we don't need to retain the graph retain_graph = i < len(functionals) - 1 energy_gradient = torch.autograd.grad(energy, coeffs, retain_graph=retain_graph)[0] output[functional] = (energy, energy_gradient) del energy_density return output
[docs] def eval_torch_functionals_blocked( mol: gto.Mole, grid: dft.Grids, coeffs: torch.Tensor, functionals: list[str], pre_computed_aos: torch.Tensor | None = None, max_memory: float = 4000.0, ): """Evaluate torch functionals on the grid in a blocked fashion to reduce memory usage. Used in label generation, where based on memory usage the aos are saved or not. """ max_block_size = compute_max_block_size(max_memory, mol.nao, heuristic_multiplier=8) grid_chunks = get_grid_blocks(grid.size, max_block_size) logger.trace( f"Computing functionals in {len(grid_chunks)} grid blocks of size {max_block_size}" ) for i, (block_start, block_end) in enumerate(grid_chunks): if pre_computed_aos is None: ao_torch = dft.numint.eval_ao(mol, grid.coords[block_start:block_end], deriv=1) ao_torch = torch.as_tensor(ao_torch, dtype=torch.float64) else: ao_torch = pre_computed_aos[:, block_start:block_end] grid_weights = torch.as_tensor(grid.weights[block_start:block_end], dtype=torch.float64) functional_values = eval_torch_functionals(coeffs, ao_torch, grid_weights, functionals) if i == 0: output = functional_values else: # Assuming functional_values is a dict of tuples (value, derivative) for key in functional_values: output[key] = ( output[key][0] + functional_values[key][0], output[key][1] + functional_values[key][1], ) return output
[docs] def eval_torch_functionals_blocked_fast( ao: torch.Tensor, grid_weights: torch.Tensor, coeffs: torch.Tensor, functionals: list[str], max_memory: float = 4000.0, ): """Evaluate torch functionals on precomputed AOs in a blocked fashion to reduce memory usage. Used in density optimization, where the AOs are precomputed to achieve high speed. The estimated total size in MB is 8 * mol.nao * grid.size * 8 / mega_byte where 8 was empirically determined and 8 comes from the size of a double. The maximum block size is then calculated to fit into the given max_memory. Args: ao: Atomic orbitals of the molecule in the basis as a tensor. grid_weights: Weights of the grid points on which to evaluate the functionals. coeffs: Coefficients of the basis functions. functionals: list of functionals to compute. max_memory: Guess of the maximum memory that should be taken by the aos in MB. Total usage might be higher. Defaults to the pyscf default of 4000MB. """ max_block_size = compute_max_block_size(max_memory, ao.shape[-1], heuristic_multiplier=8) grid_size = grid_weights.shape[0] grid_chunks = get_grid_blocks(grid_size, max_block_size) for i, (block_start, block_end) in enumerate(grid_chunks): ao_block = ao[:, block_start:block_end] grid_weights_block = grid_weights[block_start:block_end] functional_values = eval_torch_functionals( coeffs, ao_block, grid_weights_block, functionals ) if i == 0: output = functional_values else: for key in functional_values: output[key] = ( output[key][0] + functional_values[key][0], output[key][1] + functional_values[key][1], ) return output