diff --git a/poetry.lock b/poetry.lock index 3aca066..e0c7c65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -715,7 +715,7 @@ python-versions = "*" name = "typing-extensions" version = "3.10.0.0" description = "Backported and Experimental Type Hints for Python 3.5+" -category = "dev" +category = "main" optional = false python-versions = "*" @@ -775,7 +775,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "d69e3b7ada6670f9eb624783e459e041ca9f08546234bb696dfb6f3443b0b1f9" +content-hash = "338dfd3d7e34368dd0c5bcddaf6bfcc9611e6288cc7752bc9ae4cd3bff9c787e" [metadata.files] alabaster = [ diff --git a/pyproject.toml b/pyproject.toml index 7567a8d..4b48dae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ documentation = "https://astropenguin.github.io/xarray-accessors/" [tool.poetry.dependencies] python = "^3.7" +typing-extensions = "^3.7" xarray = ">=0.15, <1.0" [tool.poetry.dev-dependencies] diff --git a/tests/test_accessor.py b/tests/test_accessor.py new file mode 100644 index 0000000..15e3356 --- /dev/null +++ b/tests/test_accessor.py @@ -0,0 +1,30 @@ +# third-party packages +import xarray as xr + + +# submodule +from xarray_accessors.accessor import AccessorBase + + +# accessors +class Accessor(AccessorBase): + pass + + +def func(dataarray: xr.DataArray) -> int: + return 1 + + +Accessor.func = func +Accessor.sub.subsub.func = func # type: ignore + + +# test functions +def test_accessor_attrs() -> None: + assert Accessor.func is func + assert Accessor.sub.subsub.func is func # type: ignore + + +def test_accessor_call() -> None: + assert Accessor(xr.DataArray()).func() == 1 # type: ignore + assert Accessor(xr.DataArray()).sub.subsub.func() == 1 # type: ignore diff --git a/xarray_accessors/__init__.py b/xarray_accessors/__init__.py index 86c91ff..2fddb88 100644 --- a/xarray_accessors/__init__.py +++ b/xarray_accessors/__init__.py @@ -1,7 +1,9 @@ -# metadata +# flake8: noqa +# type: ignore __author__ = "Akio Taniguchi" __version__ = "0.1.0" # submodules -from . import utils # noqa +from . import accessor +from . import utils diff --git a/xarray_accessors/accessor.py b/xarray_accessors/accessor.py new file mode 100644 index 0000000..88e49aa --- /dev/null +++ b/xarray_accessors/accessor.py @@ -0,0 +1,96 @@ +# standard library +from functools import wraps +from inspect import isclass +from types import FunctionType, MethodType +from typing import Any, Callable, ClassVar, Dict, List, Type, Union + + +# third-party packages +import xarray as xr + + +# type hints +Accessor = Type["AccessorBase"] +Function = Callable[..., Any] + + +# constants +RESERVED_NAMES = ("_accessed", "_accessors", "_functions") + + +# runtime classes +class AccessorMeta(type): + """Metaclass for DataArray and Dataset accessors.""" + + _accessors: Dict[str, Accessor] #: Nested accessors. + _functions: Dict[str, Function] #: Data functions. + + def __init__(cls, *args: Any, **kwargs: Any) -> None: + """Initialize a class by creating initial attributes.""" + cls._accessors = {} + cls._functions = {} + + def __dir__(cls) -> List[str]: + """Return the union namespace of accessors and functions.""" + return list(set(cls._accessors) | set(cls._functions)) + + def __getattr__(cls, name: str) -> Union[Accessor, Function]: + """Return an accessor class or a function.""" + if name in cls._accessors: + return cls._accessors[name] + + if name in cls._functions: + return cls._functions[name] + + setattr(cls, name, type(name, (AccessorBase,), {})) + return cls._accessors[name] + + def __setattr__(cls, name: str, value: Union[Accessor, Function]) -> None: + """Set an accessor class or a function to the instance.""" + if name in RESERVED_NAMES: + return super().__setattr__(name, value) + + if isclass(value) and issubclass(value, AccessorBase): + return cls._accessors.update({name: value}) + + if isinstance(value, FunctionType): + return cls._functions.update({name: value}) + + raise TypeError("Value must be either an accessor or a function.") + + +class AccessorBase(metaclass=AccessorMeta): + """Base class for DataArray and Dataset accessors.""" + + _accessed: Union[xr.DataArray, xr.Dataset] #: Accessed data. + _accessors: ClassVar[Dict[str, Accessor]] #: Nested accessors. + _functions: ClassVar[Dict[str, Function]] #: Data functions. + + def __init__(self, data: Union[xr.DataArray, xr.Dataset]) -> None: + """Initialize an instance by binding data.""" + super().__setattr__("_accessed", data) + + def __dir__(self) -> List[str]: + """Return the union namespace of accessors and functions.""" + return list(set(self._accessors) | set(self._functions)) + + def __getattr__(self, name: str) -> Union["AccessorBase", MethodType]: + """Return an accessor class instance or a bound function.""" + if name in self._accessors: + return self._accessors[name](self._accessed) + + if name in self._functions: + function = self._functions[name] + + @wraps(function) + def method(self: AccessorBase, *args: Any, **kwargs: Any) -> Any: + return function(self._accessed, *args, **kwargs) + + return MethodType(method, self) + + cname = type(self).__name__ + raise AttributeError(f"{cname!r} object has no attribute {name!r}.") + + def __setattr__(self, name: str, value: Any) -> None: + """Disallow setting a value to the instance.""" + raise AttributeError("Cannot set a value to the instance.")