Source code for metatrain.utils.data.writers.writers

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Union

import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import ModelCapabilities, System


[docs] class Writer(ABC): def __init__( self, filename: Union[str, Path], capabilities: Optional[ModelCapabilities] = None, append: Optional[bool] = None, ): self.filename = filename self.capabilities = capabilities self.append = append
[docs] @abstractmethod def write(self, systems: List[System], predictions: Dict[str, TensorMap]) -> None: """ Write a single system and its predictions. :param systems: List of systems to write. :param predictions: Dictionary of TensorMaps with predictions for the systems. """ ...
[docs] @abstractmethod def finish(self) -> None: """Called after all writes. Optional to override.""" ...
def _split_tensormaps( systems: List[System], batch_predictions: Dict[str, TensorMap], istart_system: Optional[int] = 0, ) -> List[Dict[str, TensorMap]]: """ Split a TensorMap into multiple TensorMaps, one for each key. :param systems: List of systems in the batch. :param batch_predictions: Dictionary of TensorMaps with batch predictions. :param istart_system: Starting index for the systems. :return: List of dictionaries, each containing TensorMaps for a single system. """ device = next(iter(batch_predictions.values()))[0].values.device split_selection = [ Labels("system", torch.tensor([[i]], device=device)) for i in range(len(systems)) ] batch_predictions_split = { key: mts.split(tensormap, "samples", split_selection) for key, tensormap in batch_predictions.items() } out_tensormaps: List[Dict[str, TensorMap]] = [] for i in range(len(systems)): # build a per-sample dict tensormaps: Dict[str, TensorMap] = {} for k in batch_predictions_split.keys(): new_blocks: List[TensorBlock] = [] for block in batch_predictions_split[k][i]: new_block = TensorBlock( samples=Labels( block.samples.names, block.samples.values + istart_system * torch.eye( block.samples.values.size(-1), device=block.samples.values.device, dtype=block.samples.values.dtype, )[0], assume_unique=True, ), components=block.components, properties=block.properties, values=block.values, ) for gradient_name, gradient_block in block.gradients(): new_block.add_gradient( gradient_name, TensorBlock( samples=Labels( gradient_block.samples.names, gradient_block.samples.values + istart_system * torch.eye( gradient_block.samples.values.size(-1), device=gradient_block.samples.values.device, dtype=gradient_block.samples.values.dtype, )[0], assume_unique=True, ), components=gradient_block.components, properties=gradient_block.properties, values=gradient_block.values, ), ) new_blocks.append(new_block) tensormaps[k] = TensorMap( keys=batch_predictions_split[k][i].keys, blocks=new_blocks, ) out_tensormaps.append(tensormaps) return out_tensormaps