# This file is using parts from https://github.com/atomicarchitects/equiformer_v2 licensed under the MIT License.
import os
from typing import Optional
import e3nn.o3._wigner as o3_wigner
import torch
from e3nn.o3 import Irreps
from torch import Tensor
# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L10
# _Jd is a list of tensors of shape (2l+1, 2l+1)
_Jd = torch.load(os.path.join(os.path.dirname(__file__), "Jd.pt"), weights_only=True)
# Borrowed from e3nn @ 0.4.0:
# https://github.com/e3nn/e3nn/blob/0.4.0/e3nn/o3/_wigner.py#L37
#
# In 0.5.0, e3nn shifted to torch.matrix_exp which is significantly slower:
# https://github.com/e3nn/e3nn/blob/0.5.0/e3nn/o3/_wigner.py#L92
def wigner_D(l, alpha, beta, gamma):
if not l < len(_Jd):
raise NotImplementedError(
f"wigner D maximum l implemented is {len(_Jd) - 1}, send us an email to ask for more"
)
alpha, beta, gamma = torch.broadcast_tensors(alpha, beta, gamma)
J = _Jd[l].to(dtype=alpha.dtype, device=alpha.device)
Xa = _z_rot_mat(alpha, l)
Xb = _z_rot_mat(beta, l)
Xc = _z_rot_mat(gamma, l)
return Xa @ J @ Xb @ J @ Xc
def _z_rot_mat(angle, l):
shape, device, dtype = angle.shape, angle.device, angle.dtype
M = angle.new_zeros((*shape, 2 * l + 1, 2 * l + 1))
inds = torch.arange(0, 2 * l + 1, 1, device=device)
reversed_inds = torch.arange(2 * l, -1, -1, device=device)
frequencies = torch.arange(l, -l - 1, -1, dtype=dtype, device=device)
M[..., inds, reversed_inds] = torch.sin(frequencies * angle[..., None])
M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
return M
# monkey patch the wigner_D function from e3nn to the faster version
o3_wigner.wigner_D = wigner_D
[docs]
def pyscf_to_e3nn_matrix(irreps: Irreps, dtype: torch.dtype = torch.float64) -> Tensor:
"""Calculates the transformation matrix which can be used to transform a coefficient vector in
the pyscf convention to a coefficient vector in the e3nn convention.
For details on the conventions, compare https://docs.e3nn.org/en/stable/guide/change_of_basis.html and
https://pyscf.org/user/gto.html#ordering-of-basis-functions. The difference lies in the orientation
of the spherical harmonics for l>1.
A coefficient vector can be transformed by multiplying it with this matrix **on the left**.
.. warning:
As of now, the equivariance of the vector in the e3nn convention under e3nn transformations has not been
checked explicitly
Args:
irreps: The irreps of the coefficient vector in the pyscf convention.
dtype: The dtype of the matrix.
Returns:
torch.Tensor: The transformation matrix. Shape (irreps.dim, irreps.dim).
"""
# pyscf uses a different convention than e3nn for l>1 (the wikipedia convention).
# To alleviate the difference, we transform the l>1 vectors into the e3nn convention first.
# To keep the l=1 vectors fixed, we replace them with scalars for this transformation:
# Construct an irreps object where l=1 vectors are replaced by three l=0 scalars
irreps_l_greater_1 = []
for irrep in irreps:
if irrep.ir.l == 1:
irreps_l_greater_1.append(f"{3 * irrep.mul}x0e")
else:
irreps_l_greater_1.append(str(irrep))
irreps_l_greater_1 = Irreps("+".join(irreps_l_greater_1))
change_of_coord = torch.tensor(
[
# this specifies the change of basis yzx -> xyz
[0.0, 0.0, 1.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
],
dtype=dtype,
)
return irreps_l_greater_1.D_from_matrix(change_of_coord)
[docs]
def pyscf_to_e3nn_local_frames_matrix(basis: Tensor, irreps: Irreps) -> Tensor:
"""Calculates transformation matrix which can be used to transform a coefficient vector in the
pyscf convention to a coefficient vector in local frames in the e3nn convention.
A coefficient vector can be transformed by multiplying it with this matrix **on the left**.
Args:
basis: The basis vectors of the local frame. Shape (3, 3).
irreps: The irreps of the coefficient vector in the pyscf convention.
Returns:
torch.Tensor: The transformation matrix. Shape (irreps.dim, irreps.dim).
"""
# Ignore the four lines below, we do not need to do this here, as we refactored the local basis module:
# Now, the second axis points towards the closest neighbor, alleviating the need for this permutation.
# # Permute the basis to the order zyx.
# # The reason is that e3nn uses y (the second axis) as the distinguished one, and we want this to correspond to the
# # first axis in the local basis, which points towards the closest neighbor.
# # basis = basis[[2, 0, 1], :]
# Calculate the transformation matrix for transforming from the pyscf to the e3nn convention
pyscf_to_e3nn_mat = pyscf_to_e3nn_matrix(irreps, dtype=basis.dtype)
# Calculate the rotation matrix from the basis.
rot_mat = get_rotation_matrix_from_basis(basis)
# Calculate the transformation matrix for transforming to the local frames inside the e3nn convention.
wigner_D_mat = irreps.D_from_matrix(rot_mat)
return wigner_D_mat @ pyscf_to_e3nn_mat
[docs]
def get_rotation_matrix_from_basis(basis: Tensor) -> Tensor:
"""Returns the rotation matrix from the basis.
Args:
basis (Tensor): The basis of the local frame.
Returns:
Tensor: The rotation matrix to the new basis from the standard basis. Assumed to be multiplied as
'vector @ rot_mat.T' or 'torch.matmul(rot_mat, vector)'.
"""
# the rotation matrix is the just the basis matrix
return basis
[docs]
def local_frames_from_rel_positions(
diff_vec_1: Tensor, diff_vec_2: Tensor, diff_vec_3: Optional[Tensor] = None
) -> Tensor:
"""Returns the local frames of the atoms in the molecule.
In case the vectors are parallel, the second vector is changed to a random vector.
.. warning:
Different to how it is described in [M-OFDFT]_, the axis pointing towards the
closest neighbor is the second axis. This is because the second axis is the
distinguished axis in e3nn.
In case of only wanting to put 2 pos in, make the third vector a 0 vector. Then the third axis is just chosen
s.t. the local frame is right-handed.
Args:
diff_vec_1 (Tensor): The relative position of the first neighbor atom. Shape (..., 3)
diff_vec_2 (Tensor): The relative position of the second neighbor atom. Shape (..., 3)
diff_vec_3 (Optional[Tensor]): The relative position of the third neighbor atom. Shape (..., 3). If given, the basis system will be chosen such that the third axis is oriented towards the third neighbor.
Returns:
torch.Tensor: The local frames of the atoms in the molecule.
Notes:
In case the vectors are parallel, the second vector is changed to a random vector.
"""
# calculate the basis vectors of the local frame with the cross products
lengths = torch.norm(diff_vec_1, dim=-1)
assert not torch.any(lengths == 0), f"zero distance, {lengths=}"
out_shape = list(diff_vec_1.shape) + [3]
# unsqueeze first two vectors to make batched operations possible
basis_x = torch.div(diff_vec_1, torch.norm(diff_vec_1, dim=-1).unsqueeze(-1)).view(-1, 3)
diff_vec_2 = diff_vec_2.view(-1, 3)
assert not torch.any(torch.isnan(basis_x)), f"NaN in basis_x: {basis_x=}"
# calculate the z basis vector with the cross product between x and the second direction
# and normalize it
cross_z = torch.cross(basis_x, diff_vec_2, dim=-1)
# NOTE: if the vectors are parallel the norm of the cross product will be zero
# thus the second vector then needs to be changed. For the time being we will simply
# use a random second vector.
# do this until we have a non zero cross product, likely single integration ...
# but to be safe ...
while torch.any(torch.norm(cross_z, dim=-1) == 0):
# change second vector
diff_vec_2 = torch.where(
torch.norm(cross_z, dim=-1)[:, None] == 0,
torch.randn_like(diff_vec_2),
diff_vec_2,
)
cross_z = torch.cross(basis_x, diff_vec_2, dim=-1)
basis_z = cross_z / torch.norm(cross_z, dim=-1).unsqueeze(-1)
# calculate the y basis vector with the cross product between z and x
# no need to normalize because x and z are normalized
basis_y = torch.cross(basis_z, basis_x, dim=-1)
# if the third vector is given, we can use it to orient the z axis in its direction
if diff_vec_3 is not None:
diff_vec_3 = diff_vec_3.view(-1, 3)
basis_z = torch.where(
torch.einsum("ij,ij->i", basis_z, diff_vec_3)[:, None] < 0, -basis_z, basis_z
)
# stack the vectors into one tensor and return it
# return torch.stack((basis_x, basis_y, basis_z), dim=-2)
return torch.stack((basis_z, basis_x, basis_y), dim=-2).view(out_shape)
[docs]
def local_frames_from_positions(
pos: Tensor,
neighbor1_pos: Tensor,
neighbor2_pos: Tensor,
neighbor3_pos: Optional[Tensor] = None,
) -> Tensor:
"""Returns the local frames of the atoms in the molecule.
Warning: This function does not use the third neighbor position.
Args:
pos (Tensor): The position of the atom. Shape (..., 3)
neighbor1_pos (Tensor): The position of the first neighbor atom. Shape (..., 3)
neighbor2_pos (Tensor): The position of the second neighbor atom. Shape (..., 3)
neighbor3_pos (Optional[Tensor]): The position of the third neighbor atom. Shape (..., 3). If given, the basis system will be chosen such that the third axis is oriented towards the third neighbor.
Returns:
torch.Tensor: The local frames of the atoms in the molecule.
"""
# calculate vectors which point from the start to the neighbors
diff_vec_1 = neighbor1_pos - pos
diff_vec_2 = neighbor2_pos - pos
diff_vec_3 = neighbor3_pos - pos if neighbor3_pos is not None else None
# calculate basis and return it
return local_frames_from_rel_positions(diff_vec_1, diff_vec_2, diff_vec_3)