Scaler model¶
- class metatrain.utils.scaler.scaler.Scaler(hypers: Dict, dataset_info: DatasetInfo)[source]¶
Bases:
Module
Placeholder docs.
- Parameters:
hypers (Dict) – Hyperparameters for the scaler. Should be an empty dictionary.
dataset_info (DatasetInfo) – Information about the dataset used to initialize the scaler.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- outputs: Dict[str, ModelOutput]¶
- train_model(datasets: List[Dataset | Subset], additive_models: List[Module], batch_size: int, is_distributed: bool, fixed_weights: Dict[str, float | Dict[int, float]] | None = None) None [source]¶
Placeholder docs.
- Parameters:
datasets (List[Dataset | Subset]) – List of datasets to use for training the scaler.
additive_models (List[Module]) – List of additive models to remove from the targets before accumulating the quantities needed for fitting the scales.
batch_size (int) – Batch size to use for the dataloader.
is_distributed (bool) – Whether to use distributed sampling or not.
fixed_weights (Dict[str, float | Dict[int, float]] | None) – Optional dict of fixed weights to apply to the scales of each target. The keys of the dict are the target names, and the values are either a single float value to be applied to all atomic types, or a dict mapping atomic type (int) to weight (float). If not provided, all scales will be computed based on the accumulated quantities.
- Return type:
None
- restart(dataset_info: DatasetInfo) Scaler [source]¶
Restart the model with a new dataset info.
- Parameters:
dataset_info (DatasetInfo) – New dataset information to be used.
- Returns:
The restarted Scaler.
- Return type:
- forward(systems: List[System], outputs: Dict[str, TensorMap], remove: bool = False) Dict[str, TensorMap] [source]¶
Scales the outputs based on the stored standard deviations.
- Parameters:
- Returns:
A dictionary with the scaled outputs.
- Raises:
ValueError – If no scales have been computed or if outputs keys contain unsupported keys.
- Return type:
- supported_outputs() Dict[str, ModelOutput] [source]¶
- Return type: