"""This module contains classes for weighing samples in a batch.
Currently, this is used in loss functions.
"""
import torch
from torch import Tensor
from torch_geometric.nn import global_max_pool
from mldft.ml.data.components.of_data import OFData
class SampleWeigher:
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch.
Args:
batch: The batch containing the data.
Returns:
weights: The weights for the samples in the given batch.
"""
raise NotImplementedError
def __mul__(self, other):
return ProductSampleWeigher(self, other)
def __str__(self):
"""Returns a string representation of the sample weigher."""
return f"{self.__class__.__name__}({self.summary_string()})"
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
raise NotImplementedError
[docs]
class ProductSampleWeigher(SampleWeigher):
"""Sample weigher that combines multiple sample weighers by multiplying their weights."""
[docs]
def __init__(self, *sample_weighers: SampleWeigher):
"""Initializes the product sample weigher.
Args:
sample_weighers: The sample weighers to be combined.
"""
super().__init__()
self.sample_weighers = sample_weighers
[docs]
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
weights = torch.ones_like(batch.energy_label)
for weigher in self.sample_weighers:
weights *= weigher.get_weights(batch)
return weights
[docs]
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return "_times_".join(str(weigher) for weigher in self.sample_weighers)
class ConstantSampleWeigher(SampleWeigher):
def __init__(self, weight: float = 1):
"""Initializes the constant sample weigher.
Args:
weight: The constant weight to be used for all samples.
"""
super().__init__()
self.weight = weight
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
return (torch.ones_like(batch.energy_label) * self.weight).float()
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return f"constant_{self.weight}" if self.weight != 1 else "constant"
[docs]
class InitialGuessOnlySampleWeigher(SampleWeigher):
"""Sample weigher that assigns a weight of 1 to the samples with ``scf_iteration == 0`` and 0
otherwise."""
[docs]
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
return (batch.scf_iteration == 0).float()
[docs]
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return "initial_guess_only"
[docs]
class MinSCFIterationSampleWeigher(SampleWeigher):
"""Sample weigher that assigns a weight of 1 to the samples with ``scf_iteration >=
min_scf_iteration`` and 0 otherwise."""
[docs]
def __init__(self, min_scf_iteration: int):
"""Initializes the sample weigher.
Args:
min_scf_iteration: The minimum SCF iteration to assign a weight of 1.
"""
super().__init__()
self.min_scf_iteration = min_scf_iteration
[docs]
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
return (batch.scf_iteration >= self.min_scf_iteration).float()
[docs]
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return f"min_scf_iteration_{self.min_scf_iteration}"
[docs]
class HasEnergyLabelSampleWeigher(SampleWeigher):
"""Sample weigher that assigns a weight of 1 to the samples with an energy label and 0
otherwise."""
[docs]
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
return batch.has_energy_label.float()
[docs]
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return "has_energy_label"
[docs]
class GroundStateOnlySampleWeigher(SampleWeigher):
"""Sample weigher that assigns a weight of 1 to the samples that the ground state and 0
otherwise."""
[docs]
def get_weights(self, batch: OFData) -> Tensor:
"""Returns the weights for the given batch."""
return 1 - global_max_pool(
(~batch.coeffs.eq(batch.ground_state_coeffs)).float(), batch.coeffs_batch
)
[docs]
def summary_string(self):
"""Returns a summary string of the sample weigher, used e.g. as key in the dataset
statistics."""
return "ground_state_only"