"""MAE of energy, MAE for gradient, MAE for proj_minao."""
import torch
from torch_geometric.data import Batch
from torch_geometric.nn import global_add_pool
from torchmetrics import Metric
from mldft.ml.data.components.of_data import OFData
from mldft.ml.models.components.sample_weighers import (
ConstantSampleWeigher,
HasEnergyLabelSampleWeigher,
SampleWeigher,
)
[docs]
class PerSampleAbsoluteErrorMetric(Metric):
"""Base class for metrics that calculate the absolute error per sample."""
AVERAGING_MODES = ["per molecule", "per electron"]
DEFAULT_SAMPLE_WEIGHER = None
[docs]
def __init__(self, mode="per molecule", sample_weigher: SampleWeigher | None = "default"):
"""
Args:
mode: The averaging mode, either "per molecule" or "per electron".
sample_weigher (SampleWeigher, optional): Sample weigher to be used.
Defaults to :attr:`DEFAULT_SAMPLE_WEIGHER`.
"""
super().__init__()
assert mode in self.AVERAGING_MODES, f"mode must be one of {self.AVERAGING_MODES}"
self.mode = mode
if sample_weigher == "default":
assert (
self.DEFAULT_SAMPLE_WEIGHER is not None
), f"DEFAULT_SAMPLE_WEIGHER is not set for class {self.__class__.__name__}"
sample_weigher = self.DEFAULT_SAMPLE_WEIGHER
self.sample_weigher = (
sample_weigher if sample_weigher is not None else ConstantSampleWeigher()
)
self.add_state("weighted_molecule_count", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_absolute_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state(
"sum_per_electron_absolute_error", default=torch.tensor(0.0), dist_reduce_fx="sum"
)
[docs]
def update_with_errors(self, batch: Batch, errors: torch.Tensor):
"""Updates the hidden states for a batch and corresponding errors.
Args:
batch: The OFData object, used for updating the molecule and electron count.
errors: The errors that should be accumulated. The shape should either be (batch_size,) or (n_basis,), in
the latter case the errors per basis function are summed up to get the error per molecule.
"""
assert errors.ndim == 1, f"{errors.shape}"
assert isinstance(batch, OFData)
errors = errors.abs()
if errors.shape == (batch.n_basis,):
# we got an error per basis function, have to sum it up to get the error per molecule
errors = global_add_pool(errors, batch.coeffs_batch)
assert errors.shape[0] == batch.batch_size, f"{errors.size} != {batch.batch_size}"
sample_weights = self.sample_weigher.get_weights(batch)
self.sum_absolute_error += (errors * sample_weights).sum()
self.weighted_molecule_count += sample_weights.sum()
# this assumes neutral molecules, as this is technically the number of protons
electrons_per_molecule = global_add_pool(batch.atomic_numbers, batch.atomic_numbers_batch)
self.sum_per_electron_absolute_error += (
errors * sample_weights / electrons_per_molecule
).sum()
[docs]
def compute_per_molecule(self):
"""Calculates the mean error per molecule."""
return self.sum_absolute_error / self.weighted_molecule_count
[docs]
def compute_per_electron(self):
"""Calculates the mean error per electron."""
return self.sum_per_electron_absolute_error / self.weighted_molecule_count
[docs]
def compute(self):
"""Calculates the mean error, based on the averaging mode."""
if self.mode == "per molecule":
return self.compute_per_molecule()
elif self.mode == "per electron":
return self.compute_per_electron()
else:
raise ValueError(f"Invalid mode: {self.mode}")
[docs]
class MAEGradient(PerSampleAbsoluteErrorMetric):
r"""Mean Absolute Error - Metric of the gradient of the kinetic energy averaged over the number of molecules.
For the default ``mode="per molecule"``, the error is averaged as
.. math::
\text{MAE_Gradint} = \frac{1}{\text{n_molecules}}\sum_d\sum_k ||
\left( \textbf{I}-\frac{\textbf{w}^{(d)}{\textbf{w}^{(d)}}^T}{{\textbf{w}^{(d)}}^T
\textbf{w}^{(d)}}\right)\left(\nabla_\textbf{p} T_{S,\Theta}(\textbf{p}^{(d,k)},\mathcal{M^{(d)}})-
\nabla_\textbf{p} T_S^{(d,k)}\right)||.
For ``mode="per electron"``, the error is averaged as
.. math::
\text{WMAE_Gradint} = \frac{1}{\text{n_molecules}}\sum_d \frac{1}{\text{n_electrons}^{(d)}}\sum_k ||
\left( \textbf{I}-\frac{\textbf{w}^{(d)}{\textbf{w}^{(d)}}^T}{{\textbf{w}^{(d)}}^T
\textbf{w}^{(d)}}\right)\left(\nabla_\textbf{p} T_{S,\Theta}(\textbf{p}^{(d,k)},\mathcal{M^{(d)}})
-\nabla_\textbf{p} T_S^{(d,k)}\right)||.
"""
DEFAULT_SAMPLE_WEIGHER = (
HasEnergyLabelSampleWeigher()
) # has gradient label <=> has energy label
[docs]
def update(
self,
batch: Batch,
pred_energy: torch.Tensor,
projected_gradient_difference: torch.Tensor,
pred_diff: torch.Tensor,
):
"""Computes the error and updates the hidden states."""
error = projected_gradient_difference
self.update_with_errors(batch, error)
[docs]
class MAEEnergy(PerSampleAbsoluteErrorMetric):
r"""Metric for the mean absolute error of the kinetic energy. When ``mode="per molecule"``, the
error is averaged as.
.. math::
\text{MAE_Energy} = \frac{1}{\text{n_molecules}}\sum_d\sum_k | T_{S,\Theta}(\textbf{p}^{(d,k)},\mathcal{M^{(d)}})-T_S^{(d,k)}|.
"""
DEFAULT_SAMPLE_WEIGHER = HasEnergyLabelSampleWeigher()
[docs]
def update(
self,
batch: Batch,
pred_energy: torch.Tensor,
projected_gradient_difference: torch.Tensor,
pred_diff: torch.Tensor,
):
error = pred_energy - batch.energy_label
self.update_with_errors(batch, error)
[docs]
class MAEInitialGuess(PerSampleAbsoluteErrorMetric):
"""MAE of the initial guess delta coefficients."""
DEFAULT_SAMPLE_WEIGHER = ConstantSampleWeigher()
[docs]
def update(
self,
batch: Batch,
pred_energy: torch.Tensor,
projected_gradient_difference: torch.Tensor,
pred_diff: torch.Tensor,
):
"""Computes the error and updates the hidden states."""
error = pred_diff - (batch.coeffs - batch.ground_state_coeffs)
self.update_with_errors(batch, error)