Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ documentation = "https://astropenguin.github.io/xarray-accessors/"

[tool.poetry.dependencies]
python = "^3.7"
typing-extensions = "^3.7"
xarray = ">=0.15, <1.0"

[tool.poetry.dev-dependencies]
Expand Down
30 changes: 30 additions & 0 deletions tests/test_accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# third-party packages
import xarray as xr


# submodule
from xarray_accessors.accessor import AccessorBase


# accessors
class Accessor(AccessorBase):
pass


def func(dataarray: xr.DataArray) -> int:
return 1


Accessor.func = func
Accessor.sub.subsub.func = func # type: ignore


# test functions
def test_accessor_attrs() -> None:
assert Accessor.func is func
assert Accessor.sub.subsub.func is func # type: ignore


def test_accessor_call() -> None:
assert Accessor(xr.DataArray()).func() == 1 # type: ignore
assert Accessor(xr.DataArray()).sub.subsub.func() == 1 # type: ignore
6 changes: 4 additions & 2 deletions xarray_accessors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# metadata
# flake8: noqa
# type: ignore
__author__ = "Akio Taniguchi"
__version__ = "0.1.0"


# submodules
from . import utils # noqa
from . import accessor
from . import utils
96 changes: 96 additions & 0 deletions xarray_accessors/accessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# standard library
from functools import wraps
from inspect import isclass
from types import FunctionType, MethodType
from typing import Any, Callable, ClassVar, Dict, List, Type, Union


# third-party packages
import xarray as xr


# type hints
Accessor = Type["AccessorBase"]
Function = Callable[..., Any]


# constants
RESERVED_NAMES = ("_accessed", "_accessors", "_functions")


# runtime classes
class AccessorMeta(type):
"""Metaclass for DataArray and Dataset accessors."""

_accessors: Dict[str, Accessor] #: Nested accessors.
_functions: Dict[str, Function] #: Data functions.

def __init__(cls, *args: Any, **kwargs: Any) -> None:
"""Initialize a class by creating initial attributes."""
cls._accessors = {}
cls._functions = {}

def __dir__(cls) -> List[str]:
"""Return the union namespace of accessors and functions."""
return list(set(cls._accessors) | set(cls._functions))

def __getattr__(cls, name: str) -> Union[Accessor, Function]:
"""Return an accessor class or a function."""
if name in cls._accessors:
return cls._accessors[name]

if name in cls._functions:
return cls._functions[name]

setattr(cls, name, type(name, (AccessorBase,), {}))
return cls._accessors[name]

def __setattr__(cls, name: str, value: Union[Accessor, Function]) -> None:
"""Set an accessor class or a function to the instance."""
if name in RESERVED_NAMES:
return super().__setattr__(name, value)

if isclass(value) and issubclass(value, AccessorBase):
return cls._accessors.update({name: value})

if isinstance(value, FunctionType):
return cls._functions.update({name: value})

raise TypeError("Value must be either an accessor or a function.")


class AccessorBase(metaclass=AccessorMeta):
"""Base class for DataArray and Dataset accessors."""

_accessed: Union[xr.DataArray, xr.Dataset] #: Accessed data.
_accessors: ClassVar[Dict[str, Accessor]] #: Nested accessors.
_functions: ClassVar[Dict[str, Function]] #: Data functions.

def __init__(self, data: Union[xr.DataArray, xr.Dataset]) -> None:
"""Initialize an instance by binding data."""
super().__setattr__("_accessed", data)

def __dir__(self) -> List[str]:
"""Return the union namespace of accessors and functions."""
return list(set(self._accessors) | set(self._functions))

def __getattr__(self, name: str) -> Union["AccessorBase", MethodType]:
"""Return an accessor class instance or a bound function."""
if name in self._accessors:
return self._accessors[name](self._accessed)

if name in self._functions:
function = self._functions[name]

@wraps(function)
def method(self: AccessorBase, *args: Any, **kwargs: Any) -> Any:
return function(self._accessed, *args, **kwargs)

return MethodType(method, self)

cname = type(self).__name__
raise AttributeError(f"{cname!r} object has no attribute {name!r}.")

def __setattr__(self, name: str, value: Any) -> None:
"""Disallow setting a value to the instance."""
raise AttributeError("Cannot set a value to the instance.")