Source code for metatrain.utils.scaler.remove
from typing import Callable, Dict, List, Tuple
import torch
from metatensor.torch import TensorMap
from metatomic.torch import System
from .scaler import Scaler
[docs]
def remove_scale(
systems: List[System],
targets: Dict[str, TensorMap],
scaler: torch.nn.Module,
) -> Dict[str, TensorMap]:
"""
Scale all targets to a standard deviation of one.
:param systems: List of systems corresponding to the targets.
:param targets: Dictionary containing the targets to be scaled.
:param scaler: The scaler used to scale the targets.
:return: The scaled targets.
"""
return scaler(systems, targets, remove=True)
[docs]
def get_remove_scale_transform(scaler: Scaler) -> Callable:
"""
Remove the scaling from the targets using the provided scaler.
:param scaler: The scaler used to scale the targets.
:return: A function that removes the scaling from the targets.
"""
def transform(
systems: List[System],
targets: Dict[str, TensorMap],
extra: Dict[str, TensorMap],
) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]:
"""
:param systems: List of systems.
:param targets: Dictionary containing the targets corresponding to the systems.
:param extra: Dictionary containing any extra data.
:return: The systems, updated targets and extra data.
"""
new_targets = remove_scale(systems, targets, scaler)
return systems, new_targets, extra
return transform