Source code for metatrain.utils.data.writers.diskdataset
import zipfile
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
import metatomic.torch as mta
import numpy as np
import torch
from metatensor.torch import TensorMap
from metatomic.torch import ModelCapabilities, System
from .writers import Writer, _split_tensormaps
[docs]
class DiskDatasetWriter(Writer):
"""
Write systems and predictions to a zip file, each system in a separate folder inside
the zip.
:param path: Path to the output zip file.
:param capabilities: Model capabilities.
:param append: If True, open the zip file in append mode.
"""
def __init__(
self,
path: Union[str, Path],
capabilities: Optional[
ModelCapabilities
] = None, # unused, but matches base signature
append: Optional[bool] = False, # if True, open zip in append mode
):
super().__init__(filename=path, capabilities=capabilities, append=append)
mode: Literal["w", "a"] = "a" if append else "w"
self.zip_file = zipfile.ZipFile(path, mode)
self.index = 0
[docs]
def write(self, systems: List[System], predictions: Dict[str, TensorMap]) -> None:
"""
Write a single (system, predictions) into the zip under
a new folder "<index>/".
:param systems: List of systems to write.
:param predictions: Dictionary of TensorMaps with predictions for the systems.
"""
if len(systems) == 1:
# Avoid reindexing samples
split_predictions = [predictions]
else:
split_predictions = _split_tensormaps(
systems, predictions, istart_system=self.index
)
for system, preds in zip(systems, split_predictions, strict=True):
# system
with self.zip_file.open(f"{self.index}/system.mta", "w") as f:
mta.save(f, system.to("cpu").to(torch.float64))
# each target
for target_name, tensor_map in preds.items():
with self.zip_file.open(f"{self.index}/{target_name}.mts", "w") as f:
buf = tensor_map.to("cpu").to(torch.float64)
# metatensor.torch.save_buffer returns a torch.Tensor buffer
buffer = buf.save_buffer()
np.save(f, buffer.numpy())
self.index += 1
[docs]
def finish(self) -> None:
"""
Close the zip file.
"""
self.zip_file.close()