dataset_statistics

Module to compute and store statistics for the AtomRef and DimensionWiseRescaling modules.

class DatasetStatistics(path: str | Path, create_store: bool = False)[source]

Class to store statistics as needed by AtomRef and DimensionWiseRescaling. Also see mldft.ml.compute_dataset_statistics for how to compute them.

The fields are computed per atom type and concatenated. They can be split using split_field_by_atom_type(), with a corresponding BasisInfo object.

__init__(path: str | Path, create_store: bool = False) None[source]

Initialize the dataset statistics.

Parameters:
  • path – Path to the .zarr file.

  • create_store – Whether to create the store if it does not exist (If it exists, it will not be overwritten).

__repr__()[source]

Returns a string representation of the dataset statistics, with all keys and shapes.

static from_dataset(path: str | Path, dataset: OFDataset | OFLoader, basis_info, **kwargs) DatasetStatistics[source]

Compute statistics from a dataset, given its loader.

Parameters:
  • path – Path to save the dataset statistics to.

  • dataset – The dataset to compute the statistics from, or its loader. If the dataset is given directly, a loader with batch size 128 and 8 workers is used.

  • basis_info – Basis info object for the dataset.

  • **kwargs – Additional keyword arguments to pass to the DatasetStatisticsFitter.

keys_recursive()[source]

Get all keys in the zarr store recursively.

load_statistic(weigher_key: str, statistic_key: str) ndarray[source]

Get an attribute from the statistics.

Parameters:
  • weigher_key – The key of the weigher, e.g. ‘has_energy_label’.

  • statistic_key – The statistic to get, e.g. ‘coeffs/mean’.

Returns:

The requested statistic as a numpy array

save_statistic(weigher_key: str, statistic_key: str, data: ndarray, overwrite: bool = False) None[source]

Save a statistic to the zarr store. If the statistic is already present, a check is made: If the data is the same, nothing is done. If the data is different, either a warning or an error is raised, depending on the value of the overwrite flag.

Parameters:
  • weigher_key – The key of the weigher, e.g. ‘coeffs/mean’.

  • statistic_key – The attribute to save, e.g. ‘mean’.

  • data – The data to save.

  • overwrite – Whether to overwrite an existing statistic.

save_to(new_path: str | Path, overwrite: bool = False) None[source]

Save the statistics to a new path.

If statistics already exist at the new path, they will be added using save_statistic().

class DatasetStatisticsFitter(basis_info: BasisInfo, atom_ref_fit_sample_weighers: list[SampleWeigher], initial_guess_sample_weighers: list[SampleWeigher], with_bias: bool = True, min_atoms_per_type: int = 10, n_batches: int | None = None)[source]

Computes / fits the statistics needed for AtomRef and DimensionwiseRescaling for a dataset.

__init__(basis_info: BasisInfo, atom_ref_fit_sample_weighers: list[SampleWeigher], initial_guess_sample_weighers: list[SampleWeigher], with_bias: bool = True, min_atoms_per_type: int = 10, n_batches: int | None = None)[source]

Initialize the DatasetStatisticsFitter.

Parameters:
  • basis_info – Basis info object for the dataset.

  • atom_ref_fit_sample_weighers – Sample weighers to use for fitting the linear model for AtomRef.

  • initial_guess_sample_weighers – Sample weighers to use for initial guess statistics.

  • with_bias – Whether to include a T_global bias in the fit.

  • min_atoms_per_type – Minimum number of atoms per atom type in the dataset.

  • n_batches – Number of batches to use for fitting the statistics. If None, all batches are used. Useful for debugging.

fit(path: str | Path, loader: OFLoader, device: str | DeviceObjType = 'auto') DatasetStatistics[source]

Calculate dataset statistics from a data loader and fit a linear model to the target energies.

Parameters:
  • path – Path to save the dataset statistics to.

  • loader – The data loader to fit the statistics to.

  • device – Device to use for the statistics. If “auto”, cuda is used if available.

Returns:

The fitted dataset statistics.

Return type:

DatasetStatistics

static parse_field_info(field_info: str | tuple[str, callable], sample: OFData) tuple[str, Tensor][source]

Parse the field info to get the field name and data from the sample.

Parameters:
  • field_info – The field info to parse. Can be the name of the attribute of the sample, or a tuple with the field name and a callable to get the field data from the sample.

  • sample – The sample to get the data from.

Returns:

The field name and the data from the sample.

Return type:

tuple[str, torch.Tensor]

static parse_field_name(field_info: str | tuple[str, callable]) str[source]

Parse the field info to get the field name.

Parameters:

field_info – The field info to parse. Can be the name of the attribute of the sample, or a tuple with the field name and a callable to get the field data from the sample.

Returns:

The field name.

Return type:

str

sample_weights_to_atom_weights(sample: OFData, sample_mask: Tensor) Tensor[source]

Convert a sample mask to an atom mask.

Parameters:
  • sample – The sample to get the atom mask from.

  • sample_mask – The mask for the samples.

Returns:

The mask for the atoms corresponding to the samples.

Return type:

atom_mask

sample_weights_to_basis_function_weights(sample: OFData, sample_mask: Tensor) Tensor[source]

Convert a sample mask to a basis function mask.

Parameters:
  • sample – The sample to get the basis function mask from.

  • sample_mask – The mask for the samples.

Returns:

The mask for the basis functions corresponding to the samples.

Return type:

basis_function_mask