From 3c14b1bec069cf556771f72b3f22617334196bdb Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 17:47:52 -0700 Subject: [PATCH 1/2] support combine_nested on DataTree objects --- xarray/structure/combine.py | 80 +++++++++++++++++++++++++++--------- xarray/tests/test_combine.py | 26 ++++++++++++ 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/xarray/structure/combine.py b/xarray/structure/combine.py index 4ff8354015c..5e07e815537 100644 --- a/xarray/structure/combine.py +++ b/xarray/structure/combine.py @@ -2,13 +2,14 @@ from collections import Counter, defaultdict from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence -from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast +from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast, overload import pandas as pd from xarray.core import dtypes from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset +from xarray.core.datatree import DataTree from xarray.core.utils import iterate_nested from xarray.structure.alignment import AlignmentError from xarray.structure.concat import concat @@ -96,7 +97,7 @@ def _ensure_same_types(series, dim): raise TypeError(error_msg) -def _infer_concat_order_from_coords(datasets): +def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]): concat_dims = [] tile_ids = [() for ds in datasets] @@ -163,10 +164,16 @@ def _infer_concat_order_from_coords(datasets): ] if len(datasets) > 1 and not concat_dims: - raise ValueError( - "Could not find any dimension coordinates to use to " - "order the datasets for concatenation" - ) + if any(isinstance(data, DataTree) for data in datasets): + raise ValueError( + "Did not find any dimension coordinates at root nodes " + "to order the DataTree objects for concatenation" + ) + else: + raise ValueError( + "Could not find any dimension coordinates to use to " + "order the Dataset objects for concatenation" + ) combined_ids = dict(zip(tile_ids, datasets, strict=True)) @@ -224,7 +231,7 @@ def _combine_nd( Parameters ---------- - combined_ids : Dict[Tuple[int, ...]], xarray.Dataset] + combined_ids : Dict[Tuple[int, ...]], xarray.Dataset | xarray.DataTree] Structure containing all datasets to be concatenated with "tile_IDs" as keys, which specify position within the desired final combined result. concat_dims : sequence of str @@ -235,7 +242,7 @@ def _combine_nd( Returns ------- - combined_ds : xarray.Dataset + combined_ds : xarray.Dataset | xarray.DataTree """ example_tile_id = next(iter(combined_ids.keys())) @@ -399,12 +406,39 @@ def _nested_combine( return combined -# Define type for arbitrarily-nested list of lists recursively: -DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]] +# Define types for arbitrarily-nested list of lists +DatasetHyperCube: TypeAlias = Dataset | Sequence["DatasetHyperCube"] +DataTreeHyperCube: TypeAlias = DataTree | Sequence["DataTreeHyperCube"] +@overload def combine_nested( - datasets: DATASET_HYPERCUBE, + datasets: DatasetHyperCube, + concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, + compat: str | CombineKwargDefault = ..., + data_vars: str | CombineKwargDefault = ..., + coords: str | CombineKwargDefault = ..., + fill_value: object = ..., + join: JoinOptions | CombineKwargDefault = ..., + combine_attrs: CombineAttrsOptions = ..., +) -> Dataset: ... + + +@overload +def combine_nested( + datasets: DataTreeHyperCube, + concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, + compat: str | CombineKwargDefault = ..., + data_vars: str | CombineKwargDefault = ..., + coords: str | CombineKwargDefault = ..., + fill_value: object = ..., + join: JoinOptions | CombineKwargDefault = ..., + combine_attrs: CombineAttrsOptions = ..., +) -> DataTree: ... + + +def combine_nested( + datasets: DatasetHyperCube | DataTreeHyperCube, concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, compat: str | CombineKwargDefault = _COMPAT_DEFAULT, data_vars: str | CombineKwargDefault = _DATA_VARS_DEFAULT, @@ -412,7 +446,7 @@ def combine_nested( fill_value: object = dtypes.NA, join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, combine_attrs: CombineAttrsOptions = "drop", -) -> Dataset: +) -> Dataset | DataTree: """ Explicitly combine an N-dimensional grid of datasets into one by using a succession of concat and merge operations along each dimension of the grid. @@ -433,7 +467,7 @@ def combine_nested( Parameters ---------- - datasets : list or nested list of Dataset + datasets : list or nested list of Dataset or DataTree Dataset objects to combine. If concatenation or merging along more than one dimension is desired, then datasets must be supplied in a nested list-of-lists. @@ -527,7 +561,7 @@ def combine_nested( Returns ------- - combined : xarray.Dataset + combined : xarray.Dataset or xarray.DataTree Examples -------- @@ -621,15 +655,19 @@ def combine_nested( concat merge """ - mixed_datasets_and_arrays = any( - isinstance(obj, Dataset) for obj in iterate_nested(datasets) - ) and any( + any_datasets = any(isinstance(obj, Dataset) for obj in iterate_nested(datasets)) + any_unnamed_arrays = any( isinstance(obj, DataArray) and obj.name is None for obj in iterate_nested(datasets) ) - if mixed_datasets_and_arrays: + if any_datasets and any_unnamed_arrays: raise ValueError("Can't combine datasets with unnamed arrays.") + any_datatrees = any(isinstance(obj, DataTree) for obj in iterate_nested(datasets)) + all_datatrees = all(isinstance(obj, DataTree) for obj in iterate_nested(datasets)) + if any_datatrees and not all_datatrees: + raise ValueError("Can't combine a mix of DataTree and non-DataTree objects.") + if isinstance(concat_dim, str | DataArray) or concat_dim is None: concat_dim = [concat_dim] @@ -988,6 +1026,10 @@ def combine_by_coords( Finally, if you attempt to combine a mix of unnamed DataArrays with either named DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation). """ + if any(isinstance(data_object, DataTree) for data_object in data_objects): + raise NotImplementedError( + "combine_by_coords() does not yet support DataTree objects." + ) if not data_objects: return Dataset() @@ -1018,7 +1060,7 @@ def combine_by_coords( # Must be a mix of unnamed dataarrays with either named dataarrays or with datasets # Can't combine these as we wouldn't know whether to merge or concatenate the arrays raise ValueError( - "Can't automatically combine unnamed DataArrays with either named DataArrays or Datasets." + "Can't automatically combine unnamed DataArrays with named DataArrays or Datasets." ) else: # Promote any named DataArrays to single-variable Datasets to simplify combining diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index c7c2a60010f..8dfc3039130 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from itertools import product import numpy as np @@ -8,6 +9,7 @@ from xarray import ( DataArray, Dataset, + DataTree, MergeError, combine_by_coords, combine_nested, @@ -766,6 +768,20 @@ def test_nested_combine_mixed_datasets_arrays(self): ): combine_nested(objs, "x") + def test_nested_combine_mixed_datatrees_and_datasets(self): + objs = [DataTree.from_dict({"foo": 0}), Dataset({"foo": 1})] + with pytest.raises( + ValueError, + match=r"Can't combine a mix of DataTree and non-DataTree objects.", + ): + combine_nested(objs, concat_dim="x") + + def test_datatree(self): + objs = [DataTree.from_dict({"foo": 0}), DataTree.from_dict({"foo": 1})] + expected = DataTree.from_dict({"foo": ("x", [0, 1])}) + actual = combine_nested(objs, concat_dim="x") + assert expected.identical(actual) + class TestCombineDatasetsbyCoords: def test_combine_by_coords(self): @@ -1210,6 +1226,16 @@ def test_combine_by_coords_all_dataarrays_with_the_same_name(self): expected = merge([named_da1, named_da2], compat="no_conflicts", join="outer") assert_identical(expected, actual) + def test_combine_by_coords_datatree(self): + tree = DataTree.from_dict({"/nested/foo": ("x", [10])}, coords={"x": [1]}) + with pytest.raises( + NotImplementedError, + match=re.escape( + "combine_by_coords() does not yet support DataTree objects." + ), + ): + combine_by_coords([tree]) + class TestNewDefaults: def test_concat_along_existing_dim(self): From 28f5c356cc12c2ab6f48d1fd1fbac31020482443 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 13 Oct 2025 19:12:27 -0700 Subject: [PATCH 2/2] make mypy happy --- doc/whats-new.rst | 5 ++- xarray/structure/combine.py | 77 +++++++++++++++++++++++++----------- xarray/tests/test_combine.py | 10 ++--- 3 files changed, 62 insertions(+), 30 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3f59657c0b6..5782ccffda6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,8 +13,9 @@ v2025.10.2 (unreleased) New Features ~~~~~~~~~~~~ -- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree` - objects (:issue:`9790`, :issue:`9778`). +- :py:func:`merge`, :py:func:`concat` and :py:func:`combine_nested` now + support :py:class:`DataTree` objects (:issue:`9790`, :issue:`9778`, + :pull:`10849`). By `Stephan Hoyer `_. Breaking Changes diff --git a/xarray/structure/combine.py b/xarray/structure/combine.py index 5e07e815537..1738fb681a9 100644 --- a/xarray/structure/combine.py +++ b/xarray/structure/combine.py @@ -99,7 +99,7 @@ def _ensure_same_types(series, dim): def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]): concat_dims = [] - tile_ids = [() for ds in datasets] + tile_ids: list[tuple[int, ...]] = [() for ds in datasets] # All datasets have same variables because they've been grouped as such ds0 = datasets[0] @@ -107,17 +107,18 @@ def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]): # Check if dim is a coordinate dimension if dim in ds0: # Need to read coordinate values to do ordering - indexes = [ds._indexes.get(dim) for ds in datasets] - if any(index is None for index in indexes): - error_msg = ( - f"Every dimension requires a corresponding 1D coordinate " - f"and index for inferring concatenation order but the " - f"coordinate '{dim}' has no corresponding index" - ) - raise ValueError(error_msg) - - # TODO (benbovy, flexible indexes): support flexible indexes? - indexes = [index.to_pandas_index() for index in indexes] + indexes: list[pd.Index] = [] + for ds in datasets: + index = ds._indexes.get(dim) + if index is None: + error_msg = ( + f"Every dimension requires a corresponding 1D coordinate " + f"and index for inferring concatenation order but the " + f"coordinate '{dim}' has no corresponding index" + ) + raise ValueError(error_msg) + # TODO (benbovy, flexible indexes): support flexible indexes? + indexes.append(index.to_pandas_index()) # If dimension coordinate values are same on every dataset then # should be leaving this dimension alone (it's just a "bystander") @@ -154,7 +155,7 @@ def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]): rank = series.rank( method="dense", ascending=ascending, numeric_only=False ) - order = rank.astype(int).values - 1 + order = (rank.astype(int).values - 1).tolist() # Append positions along extra dimension to structure which # encodes the multi-dimensional concatenation order @@ -406,15 +407,34 @@ def _nested_combine( return combined -# Define types for arbitrarily-nested list of lists -DatasetHyperCube: TypeAlias = Dataset | Sequence["DatasetHyperCube"] -DataTreeHyperCube: TypeAlias = DataTree | Sequence["DataTreeHyperCube"] +# Define types for arbitrarily-nested list of lists. +# Mypy doesn't seem to handle overloads properly with recursive types, so we +# explicitly expand the first handful of levels of recursion. +DatasetLike: TypeAlias = DataArray | Dataset +DatasetHyperCube: TypeAlias = ( + DatasetLike + | Sequence[DatasetLike] + | Sequence[Sequence[DatasetLike]] + | Sequence[Sequence[Sequence[DatasetLike]]] + | Sequence[Sequence[Sequence[Sequence[DatasetLike]]]] +) +DataTreeHyperCube: TypeAlias = ( + DataTree + | Sequence[DataTree] + | Sequence[Sequence[DataTree]] + | Sequence[Sequence[Sequence[DataTree]]] + | Sequence[Sequence[Sequence[Sequence[DataTree]]]] +) @overload def combine_nested( datasets: DatasetHyperCube, - concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, + concat_dim: str + | DataArray + | list[str] + | Sequence[str | DataArray | pd.Index | None] + | None, compat: str | CombineKwargDefault = ..., data_vars: str | CombineKwargDefault = ..., coords: str | CombineKwargDefault = ..., @@ -427,7 +447,11 @@ def combine_nested( @overload def combine_nested( datasets: DataTreeHyperCube, - concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, + concat_dim: str + | DataArray + | list[str] + | Sequence[str | DataArray | pd.Index | None] + | None, compat: str | CombineKwargDefault = ..., data_vars: str | CombineKwargDefault = ..., coords: str | CombineKwargDefault = ..., @@ -439,7 +463,11 @@ def combine_nested( def combine_nested( datasets: DatasetHyperCube | DataTreeHyperCube, - concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None, + concat_dim: str + | DataArray + | list[str] + | Sequence[str | DataArray | pd.Index | None] + | None, compat: str | CombineKwargDefault = _COMPAT_DEFAULT, data_vars: str | CombineKwargDefault = _DATA_VARS_DEFAULT, coords: str | CombineKwargDefault = _COORDS_DEFAULT, @@ -467,7 +495,7 @@ def combine_nested( Parameters ---------- - datasets : list or nested list of Dataset or DataTree + datasets : list or nested list of Dataset, DataArray or DataTree Dataset objects to combine. If concatenation or merging along more than one dimension is desired, then datasets must be supplied in a nested list-of-lists. @@ -668,13 +696,16 @@ def combine_nested( if any_datatrees and not all_datatrees: raise ValueError("Can't combine a mix of DataTree and non-DataTree objects.") - if isinstance(concat_dim, str | DataArray) or concat_dim is None: - concat_dim = [concat_dim] + concat_dims = ( + [concat_dim] + if isinstance(concat_dim, str | DataArray) or concat_dim is None + else concat_dim + ) # The IDs argument tells _nested_combine that datasets aren't yet sorted return _nested_combine( datasets, - concat_dims=concat_dim, + concat_dims=concat_dims, compat=compat, data_vars=data_vars, coords=coords, diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index 8dfc3039130..2883187e096 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -626,8 +626,8 @@ def test_auto_combine_2d_combine_attrs_kwarg(self): datasets, concat_dim=["dim1", "dim2"], data_vars="all", - combine_attrs=combine_attrs, # type: ignore[arg-type] - ) + combine_attrs=combine_attrs, + ) # type: ignore[call-overload] assert_identical(result, expected) def test_combine_nested_missing_data_new_dim(self): @@ -766,7 +766,7 @@ def test_nested_combine_mixed_datasets_arrays(self): with pytest.raises( ValueError, match=r"Can't combine datasets with unnamed arrays." ): - combine_nested(objs, "x") + combine_nested(objs, "x") # type: ignore[arg-type] def test_nested_combine_mixed_datatrees_and_datasets(self): objs = [DataTree.from_dict({"foo": 0}), Dataset({"foo": 1})] @@ -774,7 +774,7 @@ def test_nested_combine_mixed_datatrees_and_datasets(self): ValueError, match=r"Can't combine a mix of DataTree and non-DataTree objects.", ): - combine_nested(objs, concat_dim="x") + combine_nested(objs, concat_dim="x") # type: ignore[arg-type] def test_datatree(self): objs = [DataTree.from_dict({"foo": 0}), DataTree.from_dict({"foo": 1})] @@ -1234,7 +1234,7 @@ def test_combine_by_coords_datatree(self): "combine_by_coords() does not yet support DataTree objects." ), ): - combine_by_coords([tree]) + combine_by_coords([tree]) # type: ignore[list-item] class TestNewDefaults: