Source code for metatrain.utils.transfer

from typing import Dict, List, Optional, Tuple

import torch
from metatensor.torch import TensorMap
from metatomic.torch import System

from . import torch_jit_script_unless_coverage


[docs] @torch_jit_script_unless_coverage def batch_to( systems: List[System], targets: Dict[str, TensorMap], extra_data: Optional[Dict[str, TensorMap]] = None, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ) -> Tuple[List[System], Dict[str, TensorMap], Optional[Dict[str, TensorMap]]]: """ Changes the systems and targets to the specified floating point data type. :param systems: List of systems. :param targets: Dictionary of targets. :param extra_data: Optional dictionary of extra data. :param dtype: Desired floating point data type. :param device: Device to move the data to. :return: The systems, targets, and extra data with moved to the specified device and with the desired data type. """ # non-blocking transfers can cause bugs in other cases non_blocking = (device.type == "cuda") if (device is not None) else False systems = [ system.to(dtype=dtype, device=device, non_blocking=non_blocking) for system in systems ] targets = { key: value.to(dtype=dtype, device=device, non_blocking=non_blocking) for key, value in targets.items() } if extra_data is not None: new_dtypes: List[Optional[int]] = [] for key in extra_data.keys(): if key.endswith("_mask"): # masks should always be boolean new_dtypes.append(torch.bool) else: new_dtypes.append(dtype) extra_data = { key: value.to(dtype=_dtype, device=device, non_blocking=non_blocking) for (key, value), _dtype in zip(extra_data.items(), new_dtypes, strict=True) } return systems, targets, extra_data