dataset_statistics
Module to compute and store statistics for the AtomRef and DimensionWiseRescaling modules.
- class DatasetStatistics(path: str | Path, create_store: bool = False)[source]
Class to store statistics as needed by
AtomRefandDimensionWiseRescaling. Also seemldft.ml.compute_dataset_statisticsfor how to compute them.The fields are computed per atom type and concatenated. They can be split using
split_field_by_atom_type(), with a correspondingBasisInfoobject.- __init__(path: str | Path, create_store: bool = False) None[source]
Initialize the dataset statistics.
- Parameters:
path – Path to the .zarr file.
create_store – Whether to create the store if it does not exist (If it exists, it will not be overwritten).
- __repr__()[source]
Returns a string representation of the dataset statistics, with all keys and shapes.
- static from_dataset(path: str | Path, dataset: OFDataset | OFLoader, basis_info, **kwargs) DatasetStatistics[source]
Compute statistics from a dataset, given its loader.
- Parameters:
path – Path to save the dataset statistics to.
dataset – The dataset to compute the statistics from, or its loader. If the dataset is given directly, a loader with batch size 128 and 8 workers is used.
basis_info – Basis info object for the dataset.
**kwargs – Additional keyword arguments to pass to the
DatasetStatisticsFitter.
- load_statistic(weigher_key: str, statistic_key: str) ndarray[source]
Get an attribute from the statistics.
- Parameters:
weigher_key – The key of the weigher, e.g. ‘has_energy_label’.
statistic_key – The statistic to get, e.g. ‘coeffs/mean’.
- Returns:
The requested statistic as a numpy array
- save_statistic(weigher_key: str, statistic_key: str, data: ndarray, overwrite: bool = False) None[source]
Save a statistic to the zarr store. If the statistic is already present, a check is made: If the data is the same, nothing is done. If the data is different, either a warning or an error is raised, depending on the value of the overwrite flag.
- Parameters:
weigher_key – The key of the weigher, e.g. ‘coeffs/mean’.
statistic_key – The attribute to save, e.g. ‘mean’.
data – The data to save.
overwrite – Whether to overwrite an existing statistic.
- save_to(new_path: str | Path, overwrite: bool = False) None[source]
Save the statistics to a new path.
If statistics already exist at the new path, they will be added using
save_statistic().
- class DatasetStatisticsFitter(basis_info: BasisInfo, atom_ref_fit_sample_weighers: list[SampleWeigher], initial_guess_sample_weighers: list[SampleWeigher], with_bias: bool = True, min_atoms_per_type: int = 10, n_batches: int | None = None)[source]
Computes / fits the statistics needed for AtomRef and DimensionwiseRescaling for a dataset.
- __init__(basis_info: BasisInfo, atom_ref_fit_sample_weighers: list[SampleWeigher], initial_guess_sample_weighers: list[SampleWeigher], with_bias: bool = True, min_atoms_per_type: int = 10, n_batches: int | None = None)[source]
Initialize the DatasetStatisticsFitter.
- Parameters:
basis_info – Basis info object for the dataset.
atom_ref_fit_sample_weighers – Sample weighers to use for fitting the linear model for AtomRef.
initial_guess_sample_weighers – Sample weighers to use for initial guess statistics.
with_bias – Whether to include a T_global bias in the fit.
min_atoms_per_type – Minimum number of atoms per atom type in the dataset.
n_batches – Number of batches to use for fitting the statistics. If None, all batches are used. Useful for debugging.
- fit(path: str | Path, loader: OFLoader, device: str | DeviceObjType = 'auto') DatasetStatistics[source]
Calculate dataset statistics from a data loader and fit a linear model to the target energies.
- Parameters:
path – Path to save the dataset statistics to.
loader – The data loader to fit the statistics to.
device – Device to use for the statistics. If “auto”, cuda is used if available.
- Returns:
The fitted dataset statistics.
- Return type:
- static parse_field_info(field_info: str | tuple[str, callable], sample: OFData) tuple[str, Tensor][source]
Parse the field info to get the field name and data from the sample.
- Parameters:
field_info – The field info to parse. Can be the name of the attribute of the sample, or a tuple with the field name and a callable to get the field data from the sample.
sample – The sample to get the data from.
- Returns:
The field name and the data from the sample.
- Return type:
tuple[str, torch.Tensor]
- static parse_field_name(field_info: str | tuple[str, callable]) str[source]
Parse the field info to get the field name.
- Parameters:
field_info – The field info to parse. Can be the name of the attribute of the sample, or a tuple with the field name and a callable to get the field data from the sample.
- Returns:
The field name.
- Return type:
str
- sample_weights_to_atom_weights(sample: OFData, sample_mask: Tensor) Tensor[source]
Convert a sample mask to an atom mask.
- Parameters:
sample – The sample to get the atom mask from.
sample_mask – The mask for the samples.
- Returns:
The mask for the atoms corresponding to the samples.
- Return type:
atom_mask
- sample_weights_to_basis_function_weights(sample: OFData, sample_mask: Tensor) Tensor[source]
Convert a sample mask to a basis function mask.
- Parameters:
sample – The sample to get the basis function mask from.
sample_mask – The mask for the samples.
- Returns:
The mask for the basis functions corresponding to the samples.
- Return type:
basis_function_mask