Scaler model

class metatrain.utils.scaler.scaler.Scaler(hypers: Dict, dataset_info: DatasetInfo)[source]

Bases: Module

Placeholder docs.

Parameters:
  • hypers (Dict) – Hyperparameters for the scaler. Should be an empty dictionary.

  • dataset_info (DatasetInfo) – Information about the dataset used to initialize the scaler.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

outputs: Dict[str, ModelOutput]
train_model(datasets: List[Dataset | Subset], additive_models: List[Module], batch_size: int, is_distributed: bool, fixed_weights: Dict[str, float | Dict[int, float]] | None = None) None[source]

Placeholder docs.

Parameters:
  • datasets (List[Dataset | Subset]) – List of datasets to use for training the scaler.

  • additive_models (List[Module]) – List of additive models to remove from the targets before accumulating the quantities needed for fitting the scales.

  • batch_size (int) – Batch size to use for the dataloader.

  • is_distributed (bool) – Whether to use distributed sampling or not.

  • fixed_weights (Dict[str, float | Dict[int, float]] | None) – Optional dict of fixed weights to apply to the scales of each target. The keys of the dict are the target names, and the values are either a single float value to be applied to all atomic types, or a dict mapping atomic type (int) to weight (float). If not provided, all scales will be computed based on the accumulated quantities.

Return type:

None

restart(dataset_info: DatasetInfo) Scaler[source]

Restart the model with a new dataset info.

Parameters:

dataset_info (DatasetInfo) – New dataset information to be used.

Returns:

The restarted Scaler.

Return type:

Scaler

forward(systems: List[System], outputs: Dict[str, TensorMap], remove: bool = False) Dict[str, TensorMap][source]

Scales the outputs based on the stored standard deviations.

Parameters:
  • systems (List[System]) – List of systems for which the outputs were computed.

  • outputs (Dict[str, TensorMap]) – Dictionary containing the output TensorMaps.

  • remove (bool) – If True, removes the scaling (i.e., divides by the scales). If False, applies the scaling (i.e., multiplies by the scales).

Returns:

A dictionary with the scaled outputs.

Raises:

ValueError – If no scales have been computed or if outputs keys contain unsupported keys.

Return type:

Dict[str, TensorMap]

supported_outputs() Dict[str, ModelOutput][source]
Return type:

Dict[str, ModelOutput]

scales_to(device: device, dtype: dtype) None[source]
Parameters:
Return type:

None

sync_tensor_maps() None[source]
Return type:

None