Splitting Equistore TensorMaps#

Module for splitting lists of TensorMap objects into multiple TensorMap objects along a given axis.

equisolve.utils.split_data.split_data(tensors: List[TensorMap] | TensorMap, axis: str, names: List[str] | str, n_groups: int, group_sizes: List[int] | List[float] | None = None, seed: int | None = None) Tuple[List[List[TensorMap]], List[Labels]][source]#

Splits a list of TensorMap objects into multiple TensorMap objects along a given axis.

For either the “samples” or “properties” axis, the unique indices for the specified metadata name are found. If seed is set, the indices are shuffled. Then, they are divided into n_groups, where the sizes of the groups are specified by the group_sizes argument.

These grouped indices are then used to split the list of input tensors. The split tensors, along with the grouped labels, are returned. The tensors are returned as a list of list of TensorMap objects.

Each list in the returned list of list corresponds to the split :py:class`TensorMap` at the same position in the input tensors list. Each nested list contains TensorMap objects that share no common indices for the specified axis and names. However, the metadata on all other axes (including the keys) will be equivalent.

The passed list of TensorMap objects in tensors must have the same set of unique indices for the specified axis and names. For instance, if passing an input and output tensor for splitting (i.e. as used in supervised machine learning), the output tensor must have structure indices 0 -> 10 if the input tensor does.

Parameters:
  • tensors – input list of TensorMap objects, each of which will be split into n_groups new TensorMap objects.

  • axis – a str equal to either “samples” or “properties”. This is the axis along which the input TensorMap objects will be split.

  • names – a list of str indicating the samples/properties names by which the tensors will be split.

  • n_groups – an int indicating how many new TensorMap objects each of the tensors passed in tensors will be split into. If group_sizes is none (default), n_groups is used to split the data into n evenly sized groups according to the unique metadata for the specified axis and names, to the nearest integer.

  • group_sizes – an ordered list of float the group sizes to split each input TensorMap into. A list of int will be interpreted as an indication of the absolute group sizes, whereas a list of float as indicating the relative sizes. For the former case, the sum of this list must be <= the total number of unique indices present in the input tensors for the chosen axis and names. In the latter, the sum of this list must be <= 1.

  • seed – an int that seeds the numpy random number generator. Used to control shuffling of the unique indices, which dictate the data that ends up in each of the split output tensors. If None (default), no shuffling of the indices occurs. If a int, shuffling is executed but with a random seed set to this value.

Return split_tensors:

list of list of TensorMap. The i th element in the list contains n_groups TensorMap objects corresponding to the split ith TensorMap of the input list tensors.

Return grouped_labels:

list of Labels corresponding to the unique indices according to the specified axis and names that are present in each of the returned groups of TensorMap. The length of this list is n_groups.

Examples#

Split a TensorMap tensor into 2 new TensorMaps along the “samples” axis for the “structure” metadata. Without specifying group_sizes, the data will be split equally by structure index. If the number of unique strutcure indices present in the input data is not exactly divisible by n_groups, the group sizes will be made to the nearest int. Without specifying seed, no shuffling of the structure indices will occur and they will be grouped in lexigraphical order. For instance, if the input tensor has structure indices 0 -> 9 (inclusive), the first new tensor will contain only structure indices 0 -> 4 (inc.) and the second will contain only 5 -> 9 (inc).

from equisolve.utils import split_data

[[new_tensor_1, new_tensor_2]], grouped_labels = split_data(
    tensors=tensor,
    axis="samples",
    names=["structure"],
    n_groups=2,
)

Split 2 tensors corresponding to input and output data into train and test data, with a relative 80:20 ratio. If both input and output tensors contain structure indices 0 -> 9 (inclusive), the in_train and out_train tensors will contain structure indices 0 -> 7 (inc.) and the in_test and out_test tensors will contain structure indices 8 -> 9 (inc.). As we want to specify relative group sizes, we will pass group_sizes as a list of float. Specifying the seed will shuffle the structure indices before the groups are made.

from equisolve.utils import split_data

[[in_train, in_test], [out_train, out_test]], grouped_labels = split_data(
    tensors=[input, output],
    axis="samples",
    names=["structure"],
    n_groups=2,                  # for train-test split
    group_sizes=[0.8, 0.2],  # relative, a 80% 20% train-test split
    seed=100,
)

Split 2 tensors corresponding to input and output data into train, test, and validation data. If input and output tensors have the same 10 structure indices, we can split such that the train, test, and val tensors have 7, 2, and 1 structures in each, respectively. We want to specify absolute group sizes, so will pass a list of int. Specifying the seed will shuffle the structure indices before they are grouped.

import metatensor
from equisolve.utils import split_data

# Find the unique structure indices in the input tensor
unique_structure_indices = metatensor.unique_metadata(
    tensor=input, axis="samples", names=["structure"],
)
# They run from 0 -> 10 (inclusive)
unique_structure_indices
>>> Labels(
    [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (8,), (9,)],
    dtype=[('structure', '<i4')]
)
# Verify that the output has the same unique structure indices
assert unique_structure_indices == metatensor.unique_metadata(
    tensor=output, axis="samples", names=["structure"],
)
>>> True

# Split the data by structure index, with an abolute split of 7, 2, 1
# for the train, test, and validation tensors, respectively
(
    [
        [in_train, in_test, in_val],
        [out_train, out_test, out_val]
    ]
), grouped_labels = split_data(
    tensors=[input, output],
    axis="samples",
    names=["structure"],
    n_groups=3,  # for train-test-validation
    group_sizes=[7, 2, 1],  # absolute; 7, 2, 1 for train, test, val
    seed=100,
)
# Inspect the grouped structure indices
grouped_labels
>>> [
    Labels(
        [(3,), (7,), (1,), (8,), (0,), (9,), (2,)],
        dtype=[('structure', '<i4')]
    ),
    Labels([(4,), (6,)], dtype=[('structure', '<i4')]),
    Labels([(5,)], dtype=[('structure', '<i4')]),
]