Source code for metatrain.utils.data.writers.metatensor
from pathlib import Path
from typing import Dict, List, Optional, Union
import metatensor.torch as mts
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatomic.torch import ModelCapabilities, System
from .writers import Writer
[docs]
class MetatensorWriter(Writer):
"""
Write systems and predictions to Metatensor files (.mts).
:param filename: Base filename for the output files. Each target will be saved in a
separate file with the target name appended.
:param capabilities: Model capabilities.
:param append: Whether to append to existing files, unused here but kept for
compatibility with the base class.
"""
def __init__(
self,
filename: Union[str, Path],
capabilities: Optional[ModelCapabilities] = None,
append: Optional[bool] = False, # unused, but matches base signature
) -> None:
super().__init__(filename, capabilities, append)
self._systems: List[System] = []
self._preds: List[Dict[str, TensorMap]] = []
[docs]
def write(self, systems: List[System], predictions: Dict[str, TensorMap]) -> None:
"""
Accumulate systems and predictions to write them all at once in ``finish``.
:param systems: List of systems to write.
:param predictions: Dictionary of TensorMaps with predictions for the systems.
"""
# just accumulate
self._systems.extend(systems)
self._preds.append(predictions)
[docs]
def finish(self) -> None:
"""
Write all accumulated systems and predictions to Metatensor files.
"""
# concatenate per-sample TensorMaps into full ones
predictions = _concatenate_tensormaps(self._preds)
# write out .mts files (writes one file per target)
filename_base = Path(self.filename).stem
for prediction_name, prediction_tmap in predictions.items():
mts.save(
filename_base + "_" + prediction_name + ".mts",
prediction_tmap.to("cpu").to(torch.float64),
)
def _concatenate_tensormaps(
tensormap_dict_list: List[Dict[str, TensorMap]],
) -> Dict[str, TensorMap]:
# Concatenating TensorMaps is tricky, because the model does not know the
# "number" of the system it is predicting. For example, if a model predicts
# 3 batches of 4 atoms each, the system labels will be [0, 1, 2, 3],
# [0, 1, 2, 3], [0, 1, 2, 3] for the three batches, respectively. Due
# to this, the join operation would not achieve the desired result
# ([0, 1, 2, ..., 11, 12]). Here, we fix this by renaming the system labels.
system_counter = 0
n_systems = 0
tensormaps_shifted_systems = []
for tensormap_dict in tensormap_dict_list:
tensormap_dict_shifted = {}
for name, tensormap in tensormap_dict.items():
new_keys = []
new_blocks = []
for key, block in tensormap.items():
new_key = key
where_system = block.samples.names.index("system")
n_systems = torch.max(block.samples.column("system")) + 1
new_samples_values = block.samples.values
new_samples_values[:, where_system] += system_counter
new_block = TensorBlock(
values=block.values,
samples=Labels(
block.samples.names,
values=new_samples_values,
assume_unique=True,
),
components=block.components,
properties=block.properties,
)
for gradient_name, gradient_block in block.gradients():
new_block.add_gradient(
gradient_name,
gradient_block,
)
new_keys.append(new_key)
new_blocks.append(new_block)
tensormap_dict_shifted[name] = TensorMap(
keys=Labels(
names=tensormap.keys.names,
values=torch.stack([new_key.values for new_key in new_keys]),
),
blocks=new_blocks,
)
tensormaps_shifted_systems.append(tensormap_dict_shifted)
system_counter += n_systems
return {
target: mts.join(
[pred[target] for pred in tensormaps_shifted_systems],
axis="samples",
remove_tensor_name=True,
)
for target in tensormaps_shifted_systems[0].keys()
}