Source code for mldft.utils.plotting.axes
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pyscf
from matplotlib import transforms
from numpy.typing import ArrayLike
from pyscf.lib.parameters import ANGULAR
from mldft.ml.data.components.basis_info import BasisInfo
[docs]
def format_basis_func_xaxis(
ax: plt.Axes,
basis_info: BasisInfo,
atom_types: None | ArrayLike = None,
) -> None:
"""Formats the x-axis of ``ax`` in-place such that the basis function indices are shown, and
the atom types are indicated below.
Args:
ax: The axes object to format.
basis_info: The basis info object.
atom_types: The atomic numbers of the atoms to show on the x-axis.
If None, each element in the basis is shown once.
"""
if atom_types is None: # show atom types as they appear in the basis
atom_symbols = [
pyscf.data.elements.ELEMENTS[atomic_number]
for atomic_number in basis_info.atomic_numbers
]
atom_types = basis_info.atomic_numbers
basis_dim_per_atom = basis_info.basis_dim_per_atom
l_per_shell = basis_info.l_per_shell
shell_boundaries = basis_info.shell_to_first_basis_func
else: # show only the specified atom types (possibly multiple times)
atom_types = np.asarray(
[
atomic_number
for atomic_number in atom_types
if atomic_number in basis_info.atomic_numbers
],
dtype=np.int8,
)
atom_symbols = [
pyscf.data.elements.ELEMENTS[atomic_number] for atomic_number in atom_types
]
# indices of the atoms in the basis info object
atom_ind = np.array(
[basis_info.atomic_number_to_atom_index[atomic_number] for atomic_number in atom_types]
)
basis_dim_per_atom = basis_info.basis_dim_per_atom[atom_ind]
l_per_shell = np.concatenate(
np.asarray(
np.split(basis_info.l_per_shell, np.cumsum(basis_info.n_shells_per_atom)[:-1]),
dtype=object,
)[atom_ind]
) # sorry
shell_boundaries = np.concatenate([[0], np.cumsum(2 * l_per_shell + 1)])[:-1]
n_basis = 0
atom_indeces = []
irreps_per_atom = []
for atomic_number in atom_types:
atom_indx = basis_info.atomic_number_to_atom_index[atomic_number]
atom_indeces.append(atom_indx)
n_basis += basis_info.basis_dim_per_atom[atom_indx]
irreps_per_atom.append(basis_info.irreps_per_atom[atom_indx])
ax.set_xlim(0, n_basis)
shell_sizes = [] # number of coeffs per shell, i.e. always 2l+1
l_sizes = [] # number of coeffs per l, e.g. 30 if there are 10 p functions
l_symbols = []
for irrep in irreps_per_atom:
ls, counts = np.unique(irrep.ls, return_counts=True)
shell_sizes.append(
np.concatenate([np.full(count, 2 * L + 1) for L, count in zip(ls, counts)])
)
l_sizes.append((2 * ls + 1) * counts)
l_symbols.extend([ANGULAR[l] for l in ls])
xtick_positions = np.cumsum(np.concatenate([[0]] + shell_sizes))
ax.set_xticks(
xtick_positions,
labels=[
None,
]
* len(xtick_positions),
minor=True,
)
l_change_positions = np.cumsum(np.concatenate([[0]] + l_sizes))
ax.set_xticks(
l_change_positions,
labels=[None] * len(l_change_positions),
minor=False,
horizontalalignment="center",
)
# Change the size of the x-tick marks
# Convert from points to figure units (for consistent size regardless of plot size)
ax_height = ax.get_window_extent().transformed(plt.gcf().dpi_scale_trans.inverted()).height
points_to_figure_units = plt.gcf().dpi_scale_trans.inverted().transform([1, 0])[0]
ax.tick_params(axis="x", which="major", length=1700 * points_to_figure_units)
l_center_positions = (l_change_positions[1:] + l_change_positions[:-1]) / 2
# add text below the x-axis to indicate the l values using their symbols
trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
for center, symbol in zip(l_center_positions, l_symbols):
ax.text(
center,
-12 * points_to_figure_units / ax_height,
symbol,
ha="center",
va="center",
transform=trans,
fontsize="small",
)
# Second X-axis
ax2 = ax.twiny()
ax2.spines["bottom"].set_position(("axes", -24 * points_to_figure_units / ax_height))
ax2.tick_params("both", length=0, width=0, which="minor")
ax2.tick_params("both", direction="in", length=1700 * points_to_figure_units, which="major")
ax2.xaxis.set_ticks_position("bottom")
ax2.xaxis.set_label_position("bottom")
atom_boundaries = np.concatenate([[0], np.cumsum(basis_dim_per_atom)])
# vertical lines at shell changes
shell_boundaries = shell_boundaries[l_per_shell > 0] # exclude s shells
for boundary in shell_boundaries[1:-1]:
ax.axvline(boundary, c="k", lw=0.75, ls="--", alpha=0.1)
# vertical lines at l changes
l_boundaries = np.concatenate([[0], l_change_positions])
for boundary in l_boundaries[1:-1]:
ax.axvline(boundary, c="k", lw=0.75, ls="--", alpha=0.3)
# vertical lines at atom boundaries
for boundary in atom_boundaries[1:-1]:
ax.axvline(boundary, c="k", lw=0.75, ls="-")
ax2.set_xticks(l_change_positions)
ax2.xaxis.set_major_formatter(ticker.NullFormatter())
ax2.xaxis.set_minor_locator(
ticker.FixedLocator((atom_boundaries[1:] + atom_boundaries[:-1]) / 2)
)
ax2.xaxis.set_minor_formatter(ticker.FixedFormatter(atom_symbols))
ax2.set_xlabel("basis function index")