Source code for metatrain.utils.transfer
from typing import Dict, List, Optional, Tuple
import torch
from metatensor.torch import TensorMap
from metatomic.torch import System
from . import torch_jit_script_unless_coverage
[docs]
@torch_jit_script_unless_coverage
def batch_to(
systems: List[System],
targets: Dict[str, TensorMap],
extra_data: Optional[Dict[str, TensorMap]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> Tuple[List[System], Dict[str, TensorMap], Optional[Dict[str, TensorMap]]]:
"""
Changes the systems and targets to the specified floating point data type.
:param systems: List of systems.
:param targets: Dictionary of targets.
:param extra_data: Optional dictionary of extra data.
:param dtype: Desired floating point data type.
:param device: Device to move the data to.
:return: The systems, targets, and extra data with moved to the specified device
and with the desired data type.
"""
# non-blocking transfers can cause bugs in other cases
non_blocking = (device.type == "cuda") if (device is not None) else False
systems = [
system.to(dtype=dtype, device=device, non_blocking=non_blocking)
for system in systems
]
targets = {
key: value.to(dtype=dtype, device=device, non_blocking=non_blocking)
for key, value in targets.items()
}
if extra_data is not None:
new_dtypes: List[Optional[int]] = []
for key in extra_data.keys():
if key.endswith("_mask"): # masks should always be boolean
new_dtypes.append(torch.bool)
else:
new_dtypes.append(dtype)
extra_data = {
key: value.to(dtype=_dtype, device=device, non_blocking=non_blocking)
for (key, value), _dtype in zip(extra_data.items(), new_dtypes, strict=True)
}
return systems, targets, extra_data