Skip to content
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ New Features
- :py:func:`merge` and :py:func:`concat` now support :py:class:`DataTree`
objects (:issue:`9790`, :issue:`9778`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- :py:class:`DataTree` now supports indexing by lists of paths, similar to
:py:class:`DataTree` (:pull:`10854`).
By `Stephan Hoyer <https://github.com/shoyer>`_.
- The ``h5netcdf`` engine has support for pseudo ``NETCDF4_CLASSIC`` files, meaning variables and attributes are cast to supported types. Note that the saved files won't be recognized as genuine ``NETCDF4_CLASSIC`` files until ``h5netcdf`` adds support with version 1.7.0. (:issue:`10676`, :pull:`10686`).
By `David Huard <https://github.com/huard>`_.

Expand Down
4 changes: 2 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def _decode_variable_name(name):


def _iter_nc_groups(root, parent="/"):
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

parent = NodePath(parent)
parent = TreePath(parent)
yield str(parent)
for path, group in root.groups.items():
gpath = parent / path
Expand Down
10 changes: 5 additions & 5 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def open_groups_as_dict(
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath
from xarray.core.utils import close_on_error

# Keep this message for some versions
Expand All @@ -658,9 +658,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
parent = TreePath("/") / TreePath(group)
else:
parent = NodePath("/")
parent = TreePath("/")

manager = store._manager
groups_dict = {}
Expand All @@ -680,9 +680,9 @@ def open_groups_as_dict(
)

if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

# only warn if phony_dims exist in file
Expand Down
10 changes: 5 additions & 5 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def open_groups_as_dict(
**kwargs,
) -> dict[str, Dataset]:
from xarray.backends.common import _iter_nc_groups
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

filename_or_obj = _normalize_path(filename_or_obj)
store = NetCDF4DataStore.open(
Expand All @@ -864,9 +864,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = NodePath("/") / NodePath(group)
parent = TreePath("/") / TreePath(group)
else:
parent = NodePath("/")
parent = TreePath("/")

manager = store._manager
groups_dict = {}
Expand All @@ -885,9 +885,9 @@ def open_groups_as_dict(
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict
Expand Down
12 changes: 6 additions & 6 deletions xarray/backends/pydap_.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def open_groups_as_dict(
verify=None,
user_charset=None,
) -> dict[str, Dataset]:
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath

filename_or_obj = _normalize_path(filename_or_obj)
store = PydapDataStore.open(
Expand All @@ -344,9 +344,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = str(NodePath("/") / NodePath(group))
parent = str(TreePath("/") / TreePath(group))
else:
parent = str(NodePath("/"))
parent = str(TreePath("/"))

groups_dict = {}
group_names = [parent]
Expand Down Expand Up @@ -384,7 +384,7 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]:

Groups = group_fqn(store.ds)
group_names += [
str(NodePath(path_to_group) / NodePath(group))
str(TreePath(path_to_group) / TreePath(group))
for group, path_to_group in Groups.items()
]
for path_group in group_names:
Expand All @@ -403,9 +403,9 @@ def group_fqn(store, path=None, g_fqn=None) -> dict[str, str]:
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds

return groups_dict
Expand Down
16 changes: 8 additions & 8 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from xarray.backends.store import StoreBackendEntrypoint
from xarray.core import indexing
from xarray.core.treenode import NodePath
from xarray.core.treenode import TreePath
from xarray.core.types import ZarrWriteModes
from xarray.core.utils import (
FrozenDict,
Expand Down Expand Up @@ -1752,9 +1752,9 @@ def open_groups_as_dict(

# Check for a group and make it a parent if it exists
if group:
parent = str(NodePath("/") / NodePath(group))
parent = str(TreePath("/") / TreePath(group))
else:
parent = str(NodePath("/"))
parent = str(TreePath("/"))

stores = ZarrStore.open_store(
filename_or_obj,
Expand Down Expand Up @@ -1785,18 +1785,18 @@ def open_groups_as_dict(
decode_timedelta=decode_timedelta,
)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
group_name = str(TreePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
group_name = str(TreePath(path_group))
groups_dict[group_name] = group_ds
return groups_dict


def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
parent_nodepath = NodePath(parent)
yield str(parent_nodepath)
parent_TreePath = TreePath(parent)
yield str(parent_TreePath)
for path, group in root.groups():
gpath = parent_nodepath / path
gpath = parent_TreePath / path
yield from _iter_zarr_groups(group, parent=str(gpath))


Expand Down
62 changes: 39 additions & 23 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from xarray.core.indexes import Index, Indexes
from xarray.core.options import OPTIONS as XR_OPTS
from xarray.core.options import _get_keep_attrs
from xarray.core.treenode import NamedNode, NodePath, zip_subtrees
from xarray.core.treenode import NamedNode, TreePath, zip_subtrees
from xarray.core.types import Self
from xarray.core.utils import (
Default,
Expand Down Expand Up @@ -114,7 +114,7 @@
# """


T_Path = Union[str, NodePath]
T_Path = Union[str, TreePath]
T = TypeVar("T")
P = ParamSpec("P")

Expand Down Expand Up @@ -188,7 +188,7 @@ def check_alignment(
base_ds = node_ds

for child_name, child in children.items():
child_path = str(NodePath(path) / child_name)
child_path = str(TreePath(path) / child_name)
child_ds = child.to_dataset(inherit=False)
check_alignment(child_path, child_ds, base_ds, child.children)

Expand Down Expand Up @@ -566,7 +566,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
raise KeyError(
f"parent {parent.name} already contains a variable named {name}"
)
path = str(NodePath(parent.path) / name)
path = str(TreePath(parent.path) / name)
node_ds = self.to_dataset(inherit=False)
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherit=True)
check_alignment(path, node_ds, parent_ds, self.children)
Expand Down Expand Up @@ -943,17 +943,38 @@ def get( # type: ignore[override]
else:
return default

def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
def _copy_listed(self, keys: list[str]) -> Self:
"""Get multiple items as a DataTree."""
base = TreePath(self.path)
nodes: dict[str, DataTree | DataArray] = {}
for key in keys:
path = base / key
try:
key2 = str(path.relative_to(base))
except ValueError as e:
raise IndexError(f"cannot subset items from parent nodes: {key}") from e
value = self._get_item(key)
nodes[key2] = value
return self.from_dict(nodes, name=self.name)

@overload
def __getitem__(self, key: list[str]) -> Self: ...

@overload
def __getitem__(self, key: str) -> Self | DataArray: ...

def __getitem__(self, key: str | list[str]) -> Self | DataArray:
"""
Access child nodes, variables, or coordinates stored anywhere in this tree.

Returned object will be either a DataTree or DataArray object depending on whether the key given points to a
child or variable.
Returned object will be either a DataTree or DataArray object depending on
whether the key given points to a child or variable.

Parameters
----------
key : str
Name of variable / child within this node, or unix-like path to variable / child within another node.
Name of variable / child within this node, unix-like path to variable
/ child within another node, or a list of names/paths.

Returns
-------
Expand All @@ -967,14 +988,9 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
elif isinstance(key, str):
# TODO should possibly deal with hashables in general?
# path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
return self._get_item(path)
elif utils.is_list_like(key):
# iterable of variable names
raise NotImplementedError(
"Selecting via tags is deprecated, and selecting multiple items should be "
"implemented via .subset"
)
return self._get_item(key)
elif isinstance(key, list):
return self._copy_listed(key)
else:
raise ValueError(f"Invalid format for key: {key}")

Expand Down Expand Up @@ -1015,7 +1031,7 @@ def __setitem__(
elif isinstance(key, str):
# TODO should possibly deal with hashables in general?
# path-like: a name of a node/variable, or path to a node/variable
path = NodePath(key)
path = TreePath(key)
if isinstance(value, Dataset):
value = DataTree(dataset=value)
return self._set_item(path, value, new_nodes_along_path=True)
Expand Down Expand Up @@ -1341,17 +1357,17 @@ def from_dict(
data_items,
((k, _CoordWrapper(v)) for k, v in coords_items),
)
nodes: dict[NodePath, _CoordWrapper | FromDictDataValue] = {}
nodes: dict[TreePath, _CoordWrapper | FromDictDataValue] = {}
for key, value in flat_data_and_coords:
path = NodePath(key).absolute()
path = TreePath(key).absolute()
if path in nodes:
raise ValueError(
f"multiple entries found corresponding to node {str(path)!r}"
)
nodes[path] = value

# Merge nodes corresponding to DataArrays into Datasets
dataset_args: defaultdict[NodePath, _DatasetArgs] = defaultdict(_DatasetArgs)
dataset_args: defaultdict[TreePath, _DatasetArgs] = defaultdict(_DatasetArgs)
for path in list(nodes):
node = nodes[path]
if node is not None and not isinstance(node, Dataset | DataTree):
Expand All @@ -1378,7 +1394,7 @@ def from_dict(
) from e

# Create the root node
root_data = nodes.pop(NodePath("/"), None)
root_data = nodes.pop(TreePath("/"), None)
if isinstance(root_data, cls):
# use cls so type-checkers understand this method returns Self
obj = root_data.copy()
Expand All @@ -1391,7 +1407,7 @@ def from_dict(
f"or DataTree, got {type(root_data)}"
)

def depth(item: tuple[NodePath, object]) -> int:
def depth(item: tuple[TreePath, object]) -> int:
node_path, _ = item
return len(node_path.parts)

Expand Down Expand Up @@ -1745,7 +1761,7 @@ def match(self, pattern: str) -> DataTree:
matching_nodes = {
path: node.dataset
for path, node in self.subtree_with_keys
if NodePath(node.path).match(pattern)
if TreePath(node.path).match(pattern)
}
return DataTree.from_dict(matching_nodes, name=self.name)

Expand Down
Loading
Loading