Skip to content

Commit 66b55a4

Browse files
committed
Add some type hints.
1 parent 1593e47 commit 66b55a4

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

cf_xarray/accessor.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import functools
22
import inspect
33
from collections import ChainMap
4-
from typing import Any, List, Optional, Set, Union
4+
from typing import Callable, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
55

66
import xarray as xr
77
from xarray import DataArray, Dataset
@@ -27,7 +27,7 @@
2727
# Define the criteria for coordinate matches
2828
# Copied from metpy
2929
# Internally we only use X, Y, Z, T
30-
coordinate_criteria: dict = {
30+
coordinate_criteria: MutableMapping[str, MutableMapping[str, Tuple]] = {
3131
"standard_name": {
3232
"T": ("time",),
3333
"time": ("time",),
@@ -90,15 +90,28 @@
9090
# },
9191
}
9292

93+
# "vertical" is just an alias for "Z"
9394
coordinate_criteria["standard_name"]["vertical"] = coordinate_criteria["standard_name"][
9495
"Z"
9596
]
97+
# "long_name" and "standard_name" criteria are the same. For convenience.
9698
coordinate_criteria["long_name"] = coordinate_criteria["standard_name"]
9799

100+
# Type for Mapper functions
101+
Mapper = Callable[
102+
[Union[xr.DataArray, xr.Dataset], str, bool, str],
103+
Union[Optional[str], List[Optional[str]], DataArray], # this sucks
104+
]
105+
98106

99-
def _get_axis_coord_single(var, key, *args):
107+
def _get_axis_coord_single(
108+
var: Union[xr.DataArray, xr.Dataset],
109+
key: str,
110+
error: bool = True,
111+
default: str = None,
112+
) -> Optional[str]:
100113
""" Helper method for when we really want only one result per key. """
101-
results = _get_axis_coord(var, key, *args)
114+
results = _get_axis_coord(var, key, error, default)
102115
if len(results) > 1:
103116
raise ValueError(
104117
f"Multiple results for {key!r} found: {results!r}. Is this valid CF? Please open an issue."
@@ -111,7 +124,7 @@ def _get_axis_coord(
111124
var: Union[xr.DataArray, xr.Dataset],
112125
key: str,
113126
error: bool = True,
114-
default: Optional[str] = None,
127+
default: str = None,
115128
) -> List[Optional[str]]:
116129
"""
117130
Translate from axis or coord name to variable name
@@ -176,13 +189,15 @@ def _get_axis_coord(
176189

177190

178191
def _get_measure_variable(
179-
da: xr.DataArray, key: str, error: bool = True, default: Any = None
192+
da: xr.DataArray, key: str, error: bool = True, default: str = None
180193
) -> DataArray:
181194
""" tiny wrapper since xarray does not support providing str for weights."""
182195
return da[_get_measure(da, key, error, default)]
183196

184197

185-
def _get_measure(da: xr.DataArray, key: str, error: bool = True, default: Any = None):
198+
def _get_measure(
199+
da: xr.DataArray, key: str, error: bool = True, default: str = None
200+
) -> Optional[str]:
186201
"""
187202
Interprets 'cell_measures'.
188203
"""
@@ -213,18 +228,21 @@ def _get_measure(da: xr.DataArray, key: str, error: bool = True, default: Any =
213228
return measures[key]
214229

215230

216-
_DEFAULT_KEY_MAPPERS: dict = dict.fromkeys(
217-
("dim", "coord", "group"), _get_axis_coord_single
218-
)
219-
_DEFAULT_KEY_MAPPERS["weights"] = _get_measure_variable
231+
#: Default mappers for common keys.
232+
_DEFAULT_KEY_MAPPERS: Mapping[str, Mapper] = {
233+
"dim": _get_axis_coord_single,
234+
"coord": _get_axis_coord_single,
235+
"group": _get_axis_coord_single,
236+
"weights": _get_measure_variable, # type: ignore
237+
}
220238

221239

222240
def _getattr(
223241
obj: Union[DataArray, Dataset],
224242
attr: str,
225243
accessor: "CFAccessor",
226-
key_mappers: dict,
227-
wrap_classes=False,
244+
key_mappers: Mapping[str, Mapper],
245+
wrap_classes: bool = False,
228246
):
229247
"""
230248
Common getattr functionality.
@@ -235,11 +253,13 @@ def _getattr(
235253
obj : DataArray, Dataset
236254
attr : Name of attribute in obj that will be shadowed.
237255
accessor : High level accessor object: CFAccessor
256+
key_mappers : dict
257+
dict(key_name: mapper)
238258
wrap_classes: bool
239259
Should we wrap the return value with _CFWrappedClass?
240260
Only True for the high level CFAccessor.
241261
Facilitates code reuse for _CFWrappedClass and _CFWrapppedPlotMethods
242-
For both of thos, wrap_classes is False.
262+
For both of those, wrap_classes is False.
243263
"""
244264
func = getattr(obj, attr)
245265

@@ -258,10 +278,10 @@ def wrapper(*args, **kwargs):
258278
class _CFWrappedClass:
259279
def __init__(self, towrap, accessor: "CFAccessor"):
260280
"""
281+
This class is used to wrap any class in _WRAPPED_CLASSES.
261282
262283
Parameters
263284
----------
264-
obj : DataArray, Dataset
265285
towrap : Resample, GroupBy, Coarsen, Rolling, Weighted
266286
Instance of xarray class that is being wrapped.
267287
accessor : CFAccessor
@@ -282,6 +302,10 @@ def __getattr__(self, attr):
282302

283303

284304
class _CFWrappedPlotMethods:
305+
"""
306+
This class wraps DataArray.plot
307+
"""
308+
285309
def __init__(self, obj, accessor):
286310
self._obj = obj
287311
self.accessor = accessor

0 commit comments

Comments
 (0)