sample_weighers

This module contains classes for weighing samples in a batch.

Currently, this is used in loss functions.

class GroundStateOnlySampleWeigher[source]

Sample weigher that assigns a weight of 1 to the samples that the ground state and 0 otherwise.

get_weights(batch: OFData) Tensor[source]

Returns the weights for the given batch.

summary_string()[source]

Returns a summary string of the sample weigher, used e.g. as key in the dataset statistics.

class HasEnergyLabelSampleWeigher[source]

Sample weigher that assigns a weight of 1 to the samples with an energy label and 0 otherwise.

get_weights(batch: OFData) Tensor[source]

Returns the weights for the given batch.

summary_string()[source]

Returns a summary string of the sample weigher, used e.g. as key in the dataset statistics.

class InitialGuessOnlySampleWeigher[source]

Sample weigher that assigns a weight of 1 to the samples with scf_iteration == 0 and 0 otherwise.

get_weights(batch: OFData) Tensor[source]

Returns the weights for the given batch.

summary_string()[source]

Returns a summary string of the sample weigher, used e.g. as key in the dataset statistics.

class MinSCFIterationSampleWeigher(min_scf_iteration: int)[source]

Sample weigher that assigns a weight of 1 to the samples with scf_iteration >= min_scf_iteration and 0 otherwise.

__init__(min_scf_iteration: int)[source]

Initializes the sample weigher.

Parameters:

min_scf_iteration – The minimum SCF iteration to assign a weight of 1.

get_weights(batch: OFData) Tensor[source]

Returns the weights for the given batch.

summary_string()[source]

Returns a summary string of the sample weigher, used e.g. as key in the dataset statistics.

class ProductSampleWeigher(*sample_weighers: SampleWeigher)[source]

Sample weigher that combines multiple sample weighers by multiplying their weights.

__init__(*sample_weighers: SampleWeigher)[source]

Initializes the product sample weigher.

Parameters:

sample_weighers – The sample weighers to be combined.

get_weights(batch: OFData) Tensor[source]

Returns the weights for the given batch.

summary_string()[source]

Returns a summary string of the sample weigher, used e.g. as key in the dataset statistics.