diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 13c46eef..3578cbcf 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -24,11 +24,11 @@ jobs: - name: Setup Python version uses: actions/setup-python@v5 with: - python-version: '3.13' + python-version: '3.12' - name: Install dependencies run: | - pip install -r scripts/github_actions/requirements.txt + pip install black==24.10.0 - name: Run black on Python files run: | @@ -38,22 +38,21 @@ jobs: run: | black --include '\.pyi$' --check --verbose . - # typing: - # runs-on: ubuntu-latest - # steps: - # - name: Check out code - # uses: actions/checkout@v4 - - # - name: Setup Python version - # uses: actions/setup-python@v5 - # with: - # python-version: '3.13' - - # - name: Install dependencies - # run: | - # pip install -r scripts/github_actions/requirements.txt - # pip install -r dev-requirements.txt - - # - name: Run mypy - # run: | - # mypy --check quasardb + typing: + runs-on: ubuntu-22.04 + steps: + - name: Check out code + uses: actions/checkout@v4 + + - name: Setup Python version + uses: actions/setup-python@v5 + with: + python-version: '3.7' + + - name: Install dependencies + run: | + pip install -r dev-requirements.txt + + - name: Run mypy + run: | + mypy --check quasardb diff --git a/.gitignore b/.gitignore index 1804c981..08014c09 100644 --- a/.gitignore +++ b/.gitignore @@ -110,9 +110,9 @@ celerybeat-schedule *.sage.py # Environments -.env -.venv -venv +.env*/ +.venv*/ +venv*/ # Spyder project settings .spyderproject diff --git a/dev-requirements.txt b/dev-requirements.txt index f7c0385b..d42d61dd 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -37,8 +37,9 @@ setuptools-git == 1.2 # Linting black==24.10.0; python_version >= '3.9' -black == 23.3.0; python_version < '3.9' +black==23.3.0; python_version < '3.9' # Stubs mypy pybind11-stubgen +pandas-stubs diff --git a/pyproject.toml b/pyproject.toml index db220b9c..6bd09571 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,5 +25,6 @@ xfail_strict = true filterwarnings = [] testpaths = ["tests"] -# [tool.mypy] -# python_version = "3.9" \ No newline at end of file +[tool.mypy] +python_version = "3.7" +disallow_untyped_defs = true \ No newline at end of file diff --git a/quasardb/extensions/__init__.py b/quasardb/extensions/__init__.py index ddea07d8..2c3f29a3 100644 --- a/quasardb/extensions/__init__.py +++ b/quasardb/extensions/__init__.py @@ -1,8 +1,9 @@ -from .writer import extend_writer +from typing import Any, List +from .writer import extend_writer -__all__ = [] +__all__: List[Any] = [] -def extend_module(m): - m.Writer = extend_writer(m.Writer) +def extend_module(m: Any) -> None: + extend_writer(m.Writer) diff --git a/quasardb/extensions/writer.py b/quasardb/extensions/writer.py index 11d98270..ecf24d3c 100644 --- a/quasardb/extensions/writer.py +++ b/quasardb/extensions/writer.py @@ -1,12 +1,15 @@ import copy -import quasardb +from typing import Any, Callable, Dict, List, Optional + import numpy as np import numpy.ma as ma -__all__ = [] +import quasardb + +__all__: List[Any] = [] -def _ensure_ctype(self, idx, ctype): +def _ensure_ctype(self: Any, idx: int, ctype: quasardb.ColumnType) -> None: assert "table" in self._legacy_state infos = self._legacy_state["table"].list_columns() cinfo = infos[idx] @@ -24,7 +27,7 @@ def _ensure_ctype(self, idx, ctype): raise quasardb.IncompatibleTypeError() -def _legacy_next_row(self, table): +def _legacy_next_row(self: Any, table: Any) -> Dict[str, Any]: if "pending" not in self._legacy_state: self._legacy_state["pending"] = [] @@ -37,37 +40,37 @@ def _legacy_next_row(self, table): return self._legacy_state["pending"][-1] -def _legacy_current_row(self): +def _legacy_current_row(self: Any) -> Dict[str, Any]: return self._legacy_state["pending"][-1] -def _legacy_start_row(self, table, x): +def _legacy_start_row(self: Any, table: Any, x: np.datetime64) -> None: row = _legacy_next_row(self, table) assert "$timestamp" not in row row["$timestamp"] = x -def _legacy_set_double(self, idx, x): +def _legacy_set_double(self: Any, idx: int, x: float) -> None: _ensure_ctype(self, idx, quasardb.ColumnType.Double) assert isinstance(x, float) assert idx not in _legacy_current_row(self)["by_index"] _legacy_current_row(self)["by_index"][idx] = x -def _legacy_set_int64(self, idx, x): +def _legacy_set_int64(self: Any, idx: int, x: int) -> None: _ensure_ctype(self, idx, quasardb.ColumnType.Int64) assert isinstance(x, int) assert idx not in _legacy_current_row(self)["by_index"] _legacy_current_row(self)["by_index"][idx] = x -def _legacy_set_timestamp(self, idx, x): +def _legacy_set_timestamp(self: Any, idx: int, x: np.datetime64) -> None: _ensure_ctype(self, idx, quasardb.ColumnType.Timestamp) assert idx not in _legacy_current_row(self)["by_index"] _legacy_current_row(self)["by_index"][idx] = x -def _legacy_set_string(self, idx, x): +def _legacy_set_string(self: Any, idx: int, x: str) -> None: _ensure_ctype(self, idx, quasardb.ColumnType.String) assert isinstance(x, str) assert idx not in _legacy_current_row(self)["by_index"] @@ -75,7 +78,7 @@ def _legacy_set_string(self, idx, x): _legacy_current_row(self)["by_index"][idx] = x -def _legacy_set_blob(self, idx, x): +def _legacy_set_blob(self: Any, idx: int, x: bytes) -> None: _ensure_ctype(self, idx, quasardb.ColumnType.Blob) assert isinstance(x, bytes) assert idx not in _legacy_current_row(self)["by_index"] @@ -83,10 +86,10 @@ def _legacy_set_blob(self, idx, x): _legacy_current_row(self)["by_index"][idx] = x -def _legacy_push(self): +def _legacy_push(self: Any) -> Optional[quasardb.WriterData]: if "pending" not in self._legacy_state: # Extremely likely default case, no "old" rows - return + return None assert "table" in self._legacy_state table = self._legacy_state["table"] @@ -109,7 +112,7 @@ def _legacy_push(self): all_idx = set(ctype_by_idx.keys()) # Prepare data structure - pivoted = {"$timestamp": [], "by_index": {}} + pivoted: Dict[str, Any] = {"$timestamp": [], "by_index": {}} for i in all_idx: pivoted["by_index"][i] = [] @@ -140,7 +143,6 @@ def _legacy_push(self): mask = [x is None for x in xs] - xs_ = [] if all(mask): xs_ = ma.masked_all(len(xs), dtype=dtype) else: @@ -159,9 +161,11 @@ def _legacy_push(self): return push_data -def _wrap_fn(old_fn, replace_fn): +def _wrap_fn( + old_fn: Callable[..., Any], replace_fn: Callable[..., Optional[quasardb.WriterData]] +) -> Callable[..., Any]: - def wrapped(self, *args, **kwargs): + def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any: data = replace_fn(self) if data: return old_fn(self, data, *args, **kwargs) @@ -171,7 +175,7 @@ def wrapped(self, *args, **kwargs): return wrapped -def extend_writer(x): +def extend_writer(x: Any) -> None: """ Extends the writer with the "old", batch inserter API. This is purely a backwards compatibility layer, and we want to avoid having to maintain that diff --git a/quasardb/firehose.py b/quasardb/firehose.py index d4801c1b..f831116a 100644 --- a/quasardb/firehose.py +++ b/quasardb/firehose.py @@ -1,22 +1,27 @@ -import time -import quasardb import logging +import time +from typing import Any, Dict, Iterator, List, Optional, Tuple + import numpy as np +from quasardb import Cluster + FIREHOSE_TABLE = "$qdb.firehose" POLL_INTERVAL = 0.1 logger = logging.getLogger("quasardb.firehose") -def _init(): +def _init() -> Dict[str, Any]: """ Initialize our internal state. """ return {"last": None, "seen": set()} -def _get_transactions_since(conn, table_name, last): +def _get_transactions_since( + conn: Cluster, table_name: str, last: Optional[Dict[str, Any]] +) -> List[Dict[str, Any]]: """ Retrieve all transactions since a certain timestamp. `last` is expected to be a dict firehose row with at least a $timestamp attached. @@ -33,7 +38,9 @@ def _get_transactions_since(conn, table_name, last): return conn.query(q) -def _get_transaction_data(conn, table_name, begin, end): +def _get_transaction_data( + conn: Cluster, table_name: str, begin: str, end: str +) -> List[Dict[str, Any]]: """ Gets all data from a certain table. """ @@ -41,7 +48,9 @@ def _get_transaction_data(conn, table_name, begin, end): return conn.query(q) -def _get_next(conn, table_name, state): +def _get_next( + conn: Cluster, table_name: str, state: Dict[str, Any] +) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: # Our flow to retrieve new data is as follows: # 1. Based on the state's last processed transaction, retrieve all transactions @@ -52,7 +61,7 @@ def _get_next(conn, table_name, state): txs = _get_transactions_since(conn, table_name, state["last"]) - xs = list() + xs: List[Dict[str, Any]] = [] for tx in txs: txid = tx["transaction_id"] @@ -83,7 +92,7 @@ def _get_next(conn, table_name, state): return (state, xs) -def subscribe(conn, table_name): +def subscribe(conn: Cluster, table_name: str) -> Iterator[Dict[str, Any]]: state = _init() while True: diff --git a/quasardb/numpy/__init__.py b/quasardb/numpy/__init__.py index bbecdfab..eeb20616 100644 --- a/quasardb/numpy/__init__.py +++ b/quasardb/numpy/__init__.py @@ -26,15 +26,17 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +from __future__ import annotations import logging import time import warnings -from typing import Dict, List, Optional, Tuple, Union -from numpy.typing import DTypeLike +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union import quasardb import quasardb.table_cache as table_cache +from quasardb.quasardb import Table, Writer +from quasardb.typing import DType, MaskedArrayAny, NDArrayAny, NDArrayTime logger = logging.getLogger("quasardb.numpy") @@ -64,7 +66,13 @@ class IncompatibleDtypeError(TypeError): Exception raised when a provided dtype is not the expected dtype. """ - def __init__(self, cname=None, ctype=None, expected=None, provided=None) -> None: + def __init__( + self, + cname: Optional[str] = None, + ctype: Optional[quasardb.ColumnType] = None, + expected: Optional[List[DType]] = None, + provided: Optional[DType] = None, + ): self.cname = cname self.ctype = ctype self.expected = expected @@ -82,7 +90,7 @@ class IncompatibleDtypeErrors(TypeError): Wraps multiple dtype errors """ - def __init__(self, xs) -> None: + def __init__(self, xs: List[IncompatibleDtypeError]): self.xs = xs super().__init__(self.msg()) @@ -95,7 +103,7 @@ class InvalidDataCardinalityError(ValueError): Raised when the provided data arrays doesn't match the table's columns. """ - def __init__(self, data, cinfos) -> None: + def __init__(self, data: List[Any], cinfos: List[Any]) -> None: self.data = data self.cinfos = cinfos super().__init__(self.msg()) @@ -109,7 +117,7 @@ def msg(self) -> str: # Based on QuasarDB column types, which dtype do we accept? # First entry will always be the 'preferred' dtype, other ones # those that we can natively convert in native code. -_ctype_to_dtype = { +_ctype_to_dtype: Dict[quasardb.ColumnType, List[DType]] = { quasardb.ColumnType.String: [np.dtype("U")], quasardb.ColumnType.Symbol: [np.dtype("U")], quasardb.ColumnType.Int64: [np.dtype("i8"), np.dtype("i4"), np.dtype("i2")], @@ -119,7 +127,7 @@ def msg(self) -> str: } -def _best_dtype_for_ctype(ctype: quasardb.ColumnType): +def _best_dtype_for_ctype(ctype: quasardb.ColumnType) -> DType: """ Returns the 'best' DType for a certain column type. For example, for blobs, even though we accept py::bytes, prefer bytestrings (as they are faster to read in c++). @@ -132,9 +140,11 @@ def _best_dtype_for_ctype(ctype: quasardb.ColumnType): def _coerce_dtype( - dtype: Union[DTypeLike, Dict[str, DTypeLike], List[DTypeLike]], - columns: List[Tuple[str, quasardb.ColumnInfo]], -) -> List[DTypeLike]: + dtype: Optional[ + Union[DType, Dict[str, Optional[DType]], Sequence[Optional[DType]]] + ], + columns: List[Tuple[str, quasardb.ColumnType]], +) -> List[Optional[DType]]: if dtype is None: dtype = [None] * len(columns) @@ -143,7 +153,7 @@ def _coerce_dtype( if type(dtype) is dict: # Conveniently look up column index by label - offsets = {} + offsets: Dict[str, int] = {} for i in range(len(columns)): (cname, _) = columns[i] offsets[cname] = i @@ -152,7 +162,7 @@ def _coerce_dtype( # the relative offset within the table. # # Any columns not provided will have a 'None' dtype. - dtype_ = [None] * len(columns) + dtype_: List[Optional[DType]] = [None] * len(columns) for k, dt in dtype.items(): if not k in offsets: @@ -175,17 +185,17 @@ def _coerce_dtype( if len(dtype) is not len(columns): raise ValueError( - "Expected exactly one dtype for each column, but %d dtypes were provided for %d columns".format( + "Expected exactly one dtype for each column, but {} dtypes were provided for {} columns".format( len(dtype), len(columns) ) ) - return dtype + return list(dtype) def _add_desired_dtypes( - dtype: List[DTypeLike], columns: List[Tuple[str, quasardb.ColumnInfo]] -) -> List[DTypeLike]: + dtype: List[Optional[DType]], columns: List[Tuple[str, quasardb.ColumnType]] +) -> List[Optional[DType]]: """ When infer_types=True, this function sets the 'desired' dtype for each of the columns. `dtype` is expected to be the output of `_coerce_dtype`, that is, a list-like with an @@ -209,7 +219,7 @@ def _add_desired_dtypes( return dtype -def _is_all_masked(xs): +def _is_all_masked(xs: Any) -> bool: if ma.isMA(xs): return ma.size(xs) == ma.count_masked(xs) @@ -239,7 +249,7 @@ def _is_all_masked(xs): return all(x is None for x in xs) -def dtypes_equal(lhs, rhs): +def dtypes_equal(lhs: DType, rhs: DType) -> bool: if lhs.kind == "U" or lhs.kind == "S": # Unicode and string data has variable length encoding, which means their itemsize # can be anything. @@ -250,7 +260,7 @@ def dtypes_equal(lhs, rhs): return lhs == rhs -def _dtype_found(needle, haystack): +def _dtype_found(needle: DType, haystack: List[DType]) -> bool: """ Returns True if one of the dtypes in `haystack` matches that of `needle`. """ @@ -261,19 +271,17 @@ def _dtype_found(needle, haystack): return False -def _validate_dtypes(data, columns): +def _validate_dtypes( + data: List[Any], columns: List[Tuple[str, quasardb.ColumnType]] +) -> None: errors = list() - def _error_to_msg(e): - (cname, ctype, provided_dtype, expected_dtype) = e - return - for data_, (cname, ctype) in zip(data, columns): expected_ = _ctype_to_dtype[ctype] logger.debug("data_.dtype = %s, expected_ = %s", data_.dtype, expected_) - if _dtype_found(data_.dtype, expected_) == False: + if not _dtype_found(data_.dtype, expected_): errors.append( IncompatibleDtypeError( cname=cname, ctype=ctype, provided=data_.dtype, expected=expected_ @@ -284,11 +292,15 @@ def _error_to_msg(e): raise IncompatibleDtypeErrors(errors) -def _coerce_deduplicate(deduplicate, deduplication_mode, columns): +def _coerce_deduplicate( + deduplicate: Union[bool, str, List[str]], + deduplication_mode: str, + columns: List[Tuple[str, quasardb.ColumnType]], +) -> Union[bool, List[str]]: """ Throws an error when 'deduplicate' options are incorrect. """ - cnames = [cname for (cname, ctype) in columns] + cnames = [cname for (cname, _) in columns] if deduplication_mode not in ["drop", "upsert"]: raise RuntimeError( @@ -308,7 +320,7 @@ def _coerce_deduplicate(deduplicate, deduplication_mode, columns): if not isinstance(deduplicate, list): raise TypeError( "drop_duplicates should be either a bool or a list, got: " - + type(deduplicate) + + str(type(deduplicate)) ) for column_name in deduplicate: @@ -322,7 +334,7 @@ def _coerce_deduplicate(deduplicate, deduplication_mode, columns): return deduplicate -def _clean_nulls(xs, dtype): +def _clean_nulls(xs: MaskedArrayAny, dtype: DType) -> MaskedArrayAny: """ Numpy's masked arrays have a downside that in case they're not able to convert a (masked!) value to the desired dtype, they raise an error. So, for example, if I have a masked array of objects that @@ -343,7 +355,7 @@ def _clean_nulls(xs, dtype): if xs.dtype is not np.dtype("object"): return xs - fill_value = None + fill_value: Any = None if dtype == np.float64 or dtype == np.float32 or dtype == np.float16: fill_value = float("nan") elif dtype == np.int64 or dtype == np.int32 or dtype == np.int16: @@ -357,7 +369,9 @@ def _clean_nulls(xs, dtype): return ma.array(xs_, mask=mask) -def _coerce_data(data, dtype): +def _coerce_data( + data: List[MaskedArrayAny], dtype: List[Optional[DType]] +) -> List[MaskedArrayAny]: """ Coerces each numpy array of `data` to the dtype present in `dtype`. """ @@ -386,7 +400,7 @@ def _coerce_data(data, dtype): logger.debug("data of data[%d] after: %s", i, data_) try: - data[i] = data_.astype(dtype_) + data[i] = ma.masked_array(data_.astype(dtype_)) except TypeError as err: # One 'bug' is that, if everything is masked, the underlying data type can be # pretty much anything. @@ -418,7 +432,9 @@ def _coerce_data(data, dtype): return data -def _probe_length(xs): +def _probe_length( + xs: Union[Dict[Any, NDArrayAny], Iterable[NDArrayAny]] +) -> Optional[int]: """ Returns the length of the first non-null array in `xs`, or None if all arrays are null. @@ -433,7 +449,10 @@ def _probe_length(xs): return None -def _ensure_list(xs, cinfos): +def _ensure_list( + xs: Union[List[Any], Dict[Any, Any], NDArrayAny], + cinfos: List[Tuple[str, quasardb.ColumnType]], +) -> List[Any]: """ If input data is a dict, ensures it's converted to a list with the correct offsets. @@ -479,7 +498,9 @@ def _ensure_list(xs, cinfos): return ret -def _coerce_retries(retries) -> quasardb.RetryOptions: +def _coerce_retries( + retries: Optional[Union[int, quasardb.RetryOptions]] +) -> quasardb.RetryOptions: if retries is None: return quasardb.RetryOptions() elif isinstance(retries, int): @@ -489,13 +510,17 @@ def _coerce_retries(retries) -> quasardb.RetryOptions: else: raise TypeError( "retries should either be an integer or quasardb.RetryOptions, got: " - + type(retries) + + str(type(retries)) ) def _kwarg_deprecation_warning( - old_kwarg, old_value, new_kwargs, new_values, stacklevel -): + old_kwarg: str, + old_value: Any, + new_kwargs: List[str], + new_values: List[Any], + stacklevel: int, +) -> None: new_declaration = ", ".join( f"{new_kwarg}={new_value}" for new_kwarg, new_value in zip(new_kwargs, new_values) @@ -508,23 +533,25 @@ def _kwarg_deprecation_warning( ) -def _type_check(var, var_name, target_type, raise_error=True, allow_none=True): +def _type_check( + var: Any, + var_name: str, + target_type: Type, + raise_error: bool = True, + allow_none: bool = True, +) -> bool: if allow_none and var is None: return True if not isinstance(var, target_type): if raise_error: - raise quasardb.quasardb.InvalidArgumentError( + raise quasardb.InvalidArgumentError( f"Invalid '{var_name}' type, expected: {target_type}, got: {type(var)}" ) return False return True -def ensure_ma(xs, dtype=None): - if isinstance(dtype, list): - assert isinstance(xs, list) == True - return [ensure_ma(xs_, dtype_) for (xs_, dtype_) in zip(xs, dtype)] - +def _ensure_ma(xs: Any, dtype: Optional[DType] = None) -> MaskedArrayAny: # Don't bother if we're already a masked array if ma.isMA(xs): return xs @@ -545,14 +572,26 @@ def ensure_ma(xs, dtype=None): return ma.masked_invalid(xs, copy=False) -def read_array(table=None, column=None, ranges=None): +def ensure_ma( + xs: Any, dtype: Optional[Union[DType, List[Optional[DType]]]] = None +) -> Union[List[MaskedArrayAny], MaskedArrayAny]: + if isinstance(dtype, list): + assert isinstance(xs, list) == True + return [_ensure_ma(xs_, dtype_) for (xs_, dtype_) in zip(xs, dtype)] + + return _ensure_ma(xs, dtype) + + +def read_array( + table: Optional[Table] = None, column: Optional[str] = None, ranges: Any = None +) -> Tuple[NDArrayTime, MaskedArrayAny]: if table is None: raise RuntimeError("A table is required.") if column is None: raise RuntimeError("A column is required.") - kwargs = {"column": column} + kwargs: Dict[str, Any] = {"column": column} if ranges is not None: kwargs["ranges"] = ranges @@ -573,8 +612,13 @@ def read_array(table=None, column=None, ranges=None): def write_array( - data=None, index=None, table=None, column=None, dtype=None, infer_types=True -): + data: Any = None, + index: Optional[NDArrayTime] = None, + table: Optional[Table] = None, + column: Optional[str] = None, + dtype: Optional[DType] = None, + infer_types: bool = True, +) -> None: """ Write a Numpy array to a single column. @@ -625,16 +669,16 @@ def write_array( # write_arrays(). cinfos = [(column, ctype)] - dtype_ = [dtype] + dtype_: List[Optional[DType]] = [dtype] - dtype = _coerce_dtype(dtype_, cinfos) + dtype_ = _coerce_dtype(dtype_, cinfos) if infer_types is True: - dtype = _add_desired_dtypes(dtype, cinfos) + dtype_ = _add_desired_dtypes(dtype_, cinfos) # data_ = an array of [data] data_ = [data] - data_ = _coerce_data(data_, dtype) + data_ = _coerce_data(data_, dtype_) _validate_dtypes(data_, cinfos) # No functions that assume array-of-data anymore, let's put it back @@ -662,27 +706,29 @@ def write_array( def write_arrays( - data, - cluster, - table=None, + data: Any, + cluster: quasardb.Cluster, + table: Optional[Union[str, Table]] = None, *, - dtype=None, - index=None, + dtype: Optional[ + Union[DType, Dict[str, Optional[DType]], List[Optional[DType]]] + ] = None, + index: Optional[NDArrayTime] = None, # TODO: Set the default push_mode after removing _async, fast and truncate - push_mode=None, - _async=False, - fast=False, - truncate=False, - truncate_range=None, - deduplicate=False, - deduplication_mode="drop", - infer_types=True, - writer=None, - write_through=True, - retries=3, + push_mode: Optional[quasardb.WriterPushMode] = None, + _async: bool = False, + fast: bool = False, + truncate: Union[bool, Tuple[Any, ...]] = False, + truncate_range: Optional[Tuple[Any, ...]] = None, + deduplicate: Union[bool, str, List[str]] = False, + deduplication_mode: str = "drop", + infer_types: bool = True, + writer: Optional[Writer] = None, + write_through: bool = True, + retries: Union[int, quasardb.RetryOptions] = 3, # We accept additional kwargs that will be passed through the writer.push() methods - **kwargs, -): + **kwargs: Any, +) -> List[Table]: """ Write multiple aligned numpy arrays to a table. @@ -841,7 +887,7 @@ def write_arrays( if kwarg_value: if push_mode and push_mode != mode: - raise quasardb.quasardb.InvalidArgumentError( + raise quasardb.InvalidArgumentError( f"Found '{kwarg}' in kwargs, but push mode is already set to {push_mode}" ) push_mode = mode @@ -857,20 +903,20 @@ def write_arrays( if writer is None: writer = cluster.writer() - ret = [] + ret: List[Table] = [] n_rows = 0 push_data = quasardb.WriterData() - for table, data_ in data: - # Acquire reference to table if string is provided - if isinstance(table, str): - table = table_cache.lookup(table, cluster) + for table_, data_ in data: + # Acquire reference to table_ if string is provided + if isinstance(table_, str): + table_ = table_cache.lookup(table_, cluster) - cinfos = [(x.name, x.type) for x in table.list_columns()] - dtype = _coerce_dtype(dtype, cinfos) + cinfos = [(x.name, x.type) for x in table_.list_columns()] + dtype_ = _coerce_dtype(dtype, cinfos) - assert type(dtype) is list - assert len(dtype) is len(cinfos) + assert type(dtype_) is list + assert len(dtype_) is len(cinfos) if index is None and isinstance(data_, dict) and "$timestamp" in data_: # Create shallow copy of `data_` so that we don't modify the reference, i.e. @@ -880,6 +926,11 @@ def write_arrays( # side-effects. data_ = data_.copy() index_ = data_.pop("$timestamp") + + if ma.isMA(index_): + # Index might be a masked array + index_ = index_.data + assert "$timestamp" not in data_ elif index is not None: index_ = index @@ -889,15 +940,16 @@ def write_arrays( assert index_ is not None if infer_types is True: - dtype = _add_desired_dtypes(dtype, cinfos) + dtype_ = _add_desired_dtypes(dtype_, cinfos) data_ = _ensure_list(data_, cinfos) if len(data_) != len(cinfos): raise InvalidDataCardinalityError(data_, cinfos) - data_ = ensure_ma(data_, dtype=dtype) - data_ = _coerce_data(data_, dtype) + data_ = ensure_ma(data_, dtype=dtype_) + assert isinstance(data_, list) + data_ = _coerce_data(data_, dtype_) # Just some additional friendly information about incorrect dtypes, we'd # prefer to have this information thrown from Python instead of native @@ -912,10 +964,10 @@ def write_arrays( for i in range(len(data_)): assert len(data_[i]) == len(index_) - push_data.append(table, index_, data_) + push_data.append(table_, index_, data_) n_rows += len(index_) - ret.append(table) + ret.append(table_) retries = _coerce_retries(retries) @@ -943,9 +995,13 @@ def write_arrays( return ret -def _xform_query_results(xs, index, dict): +def _xform_query_results( + xs: Sequence[Tuple[str, MaskedArrayAny]], + index: Optional[Union[str, int]], + dict: bool, +) -> Tuple[NDArrayAny, Union[Dict[str, MaskedArrayAny], List[MaskedArrayAny]]]: if len(xs) == 0: - return (np.array([], np.dtype("datetime64[ns]")), np.array([])) + return (np.array([], np.dtype("datetime64[ns]")), {} if dict else []) n = None for x in xs: @@ -959,7 +1015,11 @@ def _xform_query_results(xs, index, dict): if index is None: # Generate a range, put it in the front of the result list, # recurse and tell the function to use that index. - xs_ = [("$index", np.arange(n))] + xs + assert isinstance(n, int) + xs_: Sequence[Tuple[str, MaskedArrayAny]] = [ + ("$index", ma.masked_array(np.arange(n))) + ] + list(xs) + return _xform_query_results(xs_, "$index", dict) if isinstance(index, str): @@ -983,10 +1043,11 @@ def _xform_query_results(xs, index, dict): ) ) + assert isinstance(xs, list) idx = xs[index][1] del xs[index] - # Our index *may* be a masked array, but if it is, there should be no + # Our index *must* be a masked array, and there should be no # masked items: we cannot not have an index for a certain row. if ma.isMA(idx): if ma.count_masked(idx) > 0: @@ -997,17 +1058,18 @@ def _xform_query_results(xs, index, dict): assert isinstance(idx.data, np.ndarray) idx = idx.data - xs_ = None - - if dict is True: - xs_ = {x[0]: x[1] for x in xs} + if dict: + return idx, {x[0]: x[1] for x in xs} else: - xs_ = [x[1] for x in xs] - - return (idx, xs_) + return idx, [x[1] for x in xs] -def query(cluster, query, index=None, dict=False): +def query( + cluster: quasardb.Cluster, + query: str, + index: Optional[Union[str, int]] = None, + dict: bool = False, +) -> Tuple[NDArrayAny, Union[Dict[str, MaskedArrayAny], List[MaskedArrayAny]]]: """ Execute a query and return the results as numpy arrays. The shape of the return value is always: @@ -1039,7 +1101,6 @@ def query(cluster, query, index=None, dict=False): """ - m = {} xs = cluster.query_numpy(query) return _xform_query_results(xs, index, dict) diff --git a/quasardb/pandas/__init__.py b/quasardb/pandas/__init__.py index 224d0a9f..b2a7bc49 100644 --- a/quasardb/pandas/__init__.py +++ b/quasardb/pandas/__init__.py @@ -26,16 +26,18 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # +from __future__ import annotations import logging import warnings -from datetime import datetime -from functools import partial +from datetime import timedelta +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import quasardb -import quasardb.table_cache as table_cache import quasardb.numpy as qdbnp - +import quasardb.table_cache as table_cache +from quasardb.quasardb import Cluster, Table, Writer +from quasardb.typing import DType, MaskedArrayAny, Range, RangeSet logger = logging.getLogger("quasardb.pandas") @@ -54,7 +56,7 @@ class PandasRequired(ImportError): import numpy.ma as ma import pandas as pd from pandas.core.api import DataFrame, Series - from pandas.core.base import PandasObject + from pandas.core.base import PandasObject # type: ignore[attr-defined] except ImportError: raise PandasRequired("The pandas library is required to handle pandas data formats") @@ -63,7 +65,7 @@ class PandasRequired(ImportError): # Constant mapping of numpy dtype to QuasarDB column type # TODO(leon): support this natively in qdb C api ? we have everything we need # to understand dtypes. -_dtype_map = { +_dtype_map: Dict[Any, quasardb.ColumnType] = { np.dtype("int64"): quasardb.ColumnType.Int64, np.dtype("int32"): quasardb.ColumnType.Int64, np.dtype("float64"): quasardb.ColumnType.Double, @@ -84,8 +86,13 @@ class PandasRequired(ImportError): "datetime64": quasardb.ColumnType.Timestamp, } +# Type hint for TableLike parameter +TableLike = Union[str, Table] + -def read_series(table, col_name, ranges=None): +def read_series( + table: Table, col_name: str, ranges: Optional[RangeSet] = None +) -> pd.Series: """ Read a Pandas Timeseries from a single column. @@ -110,7 +117,7 @@ def read_series(table, col_name, ranges=None): quasardb.ColumnType.Symbol: table.string_get_ranges, } - kwargs = {"column": col_name} + kwargs: Dict[str, Any] = {"column": col_name} if ranges is not None: kwargs["ranges"] = ranges @@ -124,10 +131,16 @@ def read_series(table, col_name, ranges=None): res = (read_with[t])(**kwargs) - return Series(res[1], index=res[0]) + return pd.Series(res[1], index=res[0]) -def write_series(series, table, col_name, infer_types=True, dtype=None): +def write_series( + series: pd.Series, + table: Table, + col_name: str, + infer_types: bool = True, + dtype: Optional[DType] = None, +) -> None: """ Writes a Pandas Timeseries to a single column. @@ -177,12 +190,12 @@ def write_series(series, table, col_name, infer_types=True, dtype=None): def query( - cluster: quasardb.Cluster, + cluster: Cluster, query: str, - index: str = None, + index: Optional[str] = None, blobs: bool = False, numpy: bool = True, -): +) -> pd.DataFrame: """ Execute *query* and return the result as a pandas DataFrame. @@ -229,13 +242,13 @@ def query( def stream_dataframes( - conn: quasardb.Cluster, - tables: list, + conn: Cluster, + tables: List[TableLike], *, - batch_size: int = 2**16, - column_names: list = None, - ranges: list = None, -): + batch_size: Optional[int] = 2**16, + column_names: Optional[List[str]] = None, + ranges: Optional[RangeSet] = None, +) -> Iterator[pd.DataFrame]: """ Read a Pandas Dataframe from a QuasarDB Timeseries table. Returns a generator with dataframes of size `batch_size`, which is useful when traversing a large dataset which does not fit into memory. @@ -267,7 +280,7 @@ def stream_dataframes( """ # Sanitize batch_size - if batch_size == None: + if batch_size is None: batch_size = 2**16 elif not isinstance(batch_size, int): raise TypeError( @@ -276,7 +289,7 @@ def stream_dataframes( ) ) - kwargs = {"batch_size": batch_size} + kwargs: Dict[str, Any] = {"batch_size": batch_size} if column_names: kwargs["column_names"] = column_names @@ -298,22 +311,43 @@ def stream_dataframes( yield df -def stream_dataframe(conn: quasardb.Cluster, table, **kwargs): +def stream_dataframe( + conn: Cluster, + table: TableLike, + *, + batch_size: Optional[int] = 2**16, + column_names: Optional[List[str]] = None, + ranges: Optional[RangeSet] = None, +) -> Iterator[pd.DataFrame]: """ Read a single table and return a stream of dataframes. This is a convenience function that wraps around `stream_dataframes`. """ - kwargs["tables"] = [table] - # For backwards compatibility, we drop the `$table` column returned: this is not strictly # necessary, but it also is somewhat reasonable to drop it when we're reading from a single # table, which is the case here. clean_df_fn = lambda df: df.drop(columns=["$table"]) - return (clean_df_fn(df) for df in stream_dataframes(conn, **kwargs)) + return ( + clean_df_fn(df) + for df in stream_dataframes( + conn, + [table], + batch_size=batch_size, + column_names=column_names, + ranges=ranges, + ) + ) -def read_dataframe(conn: quasardb.Cluster, table, **kwargs): +def read_dataframe( + conn: Cluster, + table: TableLike, + *, + batch_size: Optional[int] = 2**16, + column_names: Optional[List[str]] = None, + ranges: Optional[RangeSet] = None, +) -> pd.DataFrame: """ Read a Pandas Dataframe from a QuasarDB Timeseries table. Wraps around stream_dataframes(), and returns everything as a single dataframe. batch_size is always explicitly set to 0. @@ -331,29 +365,27 @@ def read_dataframe(conn: quasardb.Cluster, table, **kwargs): """ - if ( - "batch_size" in kwargs - and kwargs["batch_size"] != 0 - and kwargs["batch_size"] != None - ): - logger.warn( + if batch_size is not None and batch_size != 0: + logger.warning( "Providing a batch size with read_dataframe is unsupported, overriding batch_size to 65536." ) - logger.warn( + logger.warning( "If you wish to traverse the data in smaller batches, please use: stream_dataframe()." ) - kwargs["batch_size"] = 2**16 + batch_size = 2**16 # Note that this is *lazy*, dfs is a generator, not a list -- as such, dataframes will be # fetched on-demand, which means that an error could occur in the middle of processing # dataframes. - dfs = stream_dataframe(conn, table, **kwargs) + dfs = stream_dataframe( + conn, table, batch_size=batch_size, column_names=column_names, ranges=ranges + ) # if result of stream_dataframe is empty this could result in ValueError on pd.concat() # as stream_dataframe is a generator there is no easy way to check for this condition without evaluation # the most simple way is to catch the ValueError and return an empty DataFrame try: - return pd.concat(dfs, copy=False) + return pd.concat(dfs, copy=False) # type: ignore[call-overload] except ValueError as e: logger.error( "Error while concatenating dataframes. This can happen if result set is empty. Returning empty dataframe. Error: %s", @@ -362,7 +394,9 @@ def read_dataframe(conn: quasardb.Cluster, table, **kwargs): return pd.DataFrame() -def _extract_columns(df, cinfos): +def _extract_columns( + df: pd.DataFrame, cinfos: List[Tuple[str, quasardb.ColumnType]] +) -> Dict[str, MaskedArrayAny]: """ Converts dataframe to a number of numpy arrays, one for each column. @@ -370,13 +404,12 @@ def _extract_columns(df, cinfos): If a table column is not present in the dataframe, it it have a None entry. If a dataframe column is not present in the table, it will be ommitted. """ - ret = {} + ret: Dict[str, MaskedArrayAny] = {} # Grab all columns from the DataFrame in the order of table columns, # put None if not present in df. for i in range(len(cinfos)): - (cname, ctype) = cinfos[i] - xs = None + (cname, _) = cinfos[i] if cname in df.columns: arr = df[cname].array @@ -385,7 +418,32 @@ def _extract_columns(df, cinfos): return ret -def write_dataframes(dfs, cluster, *, create=False, shard_size=None, **kwargs): +def write_dataframes( + dfs: Union[ + Dict[TableLike, pd.DataFrame], + List[tuple[TableLike, pd.DataFrame]], + ], + cluster: quasardb.Cluster, + *, + create: bool = False, + shard_size: Optional[timedelta] = None, + # numpy.write_arrays passthrough options + dtype: Optional[ + Union[DType, Dict[str, Optional[DType]], List[Optional[DType]]] + ] = None, + push_mode: Optional[quasardb.WriterPushMode] = None, + _async: bool = False, + fast: bool = False, + truncate: Union[bool, Range] = False, + truncate_range: Optional[Range] = None, + deduplicate: Union[bool, str, List[str]] = False, + deduplication_mode: str = "drop", + infer_types: bool = True, + writer: Optional[Writer] = None, + write_through: bool = True, + retries: Union[int, quasardb.RetryOptions] = 3, + **kwargs: Any, +) -> List[Table]: """ Store dataframes into a table. Any additional parameters not documented here are passed to numpy.write_arrays(). Please consult the pydoc of that function @@ -410,7 +468,7 @@ def write_dataframes(dfs, cluster, *, create=False, shard_size=None, **kwargs): # If dfs is a dict, we convert it to a list of tuples. if isinstance(dfs, dict): - dfs = dfs.items() + dfs = list(dfs.items()) if shard_size is not None and create == False: raise ValueError("Invalid argument: shard size provided while create is False") @@ -436,7 +494,7 @@ def write_dataframes(dfs, cluster, *, create=False, shard_size=None, **kwargs): cinfos = [(x.name, x.type) for x in table.list_columns()] if not df.index.is_monotonic_increasing: - logger.warn( + logger.warning( "dataframe index is unsorted, resorting dataframe based on index" ) df = df.sort_index().reindex() @@ -447,36 +505,141 @@ def write_dataframes(dfs, cluster, *, create=False, shard_size=None, **kwargs): # is sparse, most notably forcing sparse integer arrays to floating points. data = _extract_columns(df, cinfos) - data["$timestamp"] = df.index.to_numpy(copy=False, dtype="datetime64[ns]") + data["$timestamp"] = ma.masked_array( + df.index.to_numpy(copy=False, dtype="datetime64[ns]") + ) # We cast to masked_array to enforce typing compliance data_by_table.append((table, data)) kwargs["deprecation_stacklevel"] = kwargs.get("deprecation_stacklevel", 1) + 1 - return qdbnp.write_arrays(data_by_table, cluster, table=None, index=None, **kwargs) + return qdbnp.write_arrays( + data_by_table, + cluster, + table=None, + index=None, + dtype=dtype, + push_mode=push_mode, + _async=_async, + fast=fast, + truncate=truncate, + truncate_range=truncate_range, + deduplicate=deduplicate, + deduplication_mode=deduplication_mode, + infer_types=infer_types, + writer=writer, + write_through=write_through, + retries=retries, + **kwargs, + ) -def write_dataframe(df, cluster, table, **kwargs): +def write_dataframe( + df: pd.DataFrame, + cluster: quasardb.Cluster, + table: TableLike, + *, + create: bool = False, + shard_size: Optional[timedelta] = None, + # numpy.write_arrays passthrough options + dtype: Optional[ + Union[DType, Dict[str, Optional[DType]], List[Optional[DType]]] + ] = None, + push_mode: Optional[quasardb.WriterPushMode] = None, + _async: bool = False, + fast: bool = False, + truncate: Union[bool, Range] = False, + truncate_range: Optional[Range] = None, + deduplicate: Union[bool, str, List[str]] = False, + deduplication_mode: str = "drop", + infer_types: bool = True, + writer: Optional[Writer] = None, + write_through: bool = True, + retries: Union[int, quasardb.RetryOptions] = 3, + **kwargs: Any, +) -> List[Table]: """ Store a single dataframe into a table. Takes the same arguments as `write_dataframes`, except only a single df/table combination. """ kwargs["deprecation_stacklevel"] = kwargs.get("deprecation_stacklevel", 1) + 1 - write_dataframes([(table, df)], cluster, **kwargs) + return write_dataframes( + [(table, df)], + cluster, + create=create, + shard_size=shard_size, + dtype=dtype, + push_mode=push_mode, + _async=_async, + fast=fast, + truncate=truncate, + truncate_range=truncate_range, + deduplicate=deduplicate, + deduplication_mode=deduplication_mode, + infer_types=infer_types, + writer=writer, + write_through=write_through, + retries=retries, + **kwargs, + ) -def write_pinned_dataframe(*args, **kwargs): +def write_pinned_dataframe( + df: pd.DataFrame, + cluster: quasardb.Cluster, + table: TableLike, + *, + create: bool = False, + shard_size: Optional[timedelta] = None, + # numpy.write_arrays passthrough options + dtype: Optional[ + Union[DType, Dict[str, Optional[DType]], List[Optional[DType]]] + ] = None, + push_mode: Optional[quasardb.WriterPushMode] = None, + _async: bool = False, + fast: bool = False, + truncate: Union[bool, Range] = False, + truncate_range: Optional[Range] = None, + deduplicate: Union[bool, str, List[str]] = False, + deduplication_mode: str = "drop", + infer_types: bool = True, + writer: Optional[Writer] = None, + write_through: bool = True, + retries: Union[int, quasardb.RetryOptions] = 3, + **kwargs: Any, +) -> List[Table]: """ Legacy wrapper around write_dataframe() """ - logger.warn( + logger.warning( "write_pinned_dataframe is deprecated and will be removed in a future release." ) - logger.warn("Please use write_dataframe directly instead") + logger.warning("Please use write_dataframe directly instead") kwargs["deprecation_stacklevel"] = 2 - return write_dataframe(*args, **kwargs) + return write_dataframe( + df, + cluster, + table, + create=create, + shard_size=shard_size, + dtype=dtype, + push_mode=push_mode, + _async=_async, + fast=fast, + truncate=truncate, + truncate_range=truncate_range, + deduplicate=deduplicate, + deduplication_mode=deduplication_mode, + infer_types=infer_types, + writer=writer, + write_through=write_through, + retries=retries, + **kwargs, + ) -def _create_table_from_df(df, table, shard_size=None): +def _create_table_from_df( + df: pd.DataFrame, table: Table, shard_size: Optional[timedelta] = None +) -> Table: cols = list() dtypes = _get_inferred_dtypes(df) @@ -498,14 +661,14 @@ def _create_table_from_df(df, table, shard_size=None): table.create(cols) else: table.create(cols, shard_size) - except quasardb.quasardb.AliasAlreadyExistsError: + except quasardb.AliasAlreadyExistsError: # TODO(leon): warn? how? pass return table -def _dtype_to_column_type(dt, inferred): +def _dtype_to_column_type(dt: Any, inferred: Any) -> quasardb.ColumnType: res = _dtype_map.get(inferred, None) if res is None: res = _dtype_map.get(dt, None) @@ -516,8 +679,8 @@ def _dtype_to_column_type(dt, inferred): return res -def _get_inferred_dtypes(df): - dtypes = dict() +def _get_inferred_dtypes(df: pd.DataFrame) -> Dict[str, str]: + dtypes = {} for i in range(len(df.columns)): c = df.columns[i] dt = pd.api.types.infer_dtype(df[c].values) @@ -526,7 +689,7 @@ def _get_inferred_dtypes(df): return dtypes -def _get_inferred_dtypes_indexed(df): +def _get_inferred_dtypes_indexed(df: pd.DataFrame) -> List[str]: dtypes = _get_inferred_dtypes(df) # Performance improvement: avoid a expensive dict lookups by indexing # the column types by relative offset within the df. diff --git a/quasardb/pool.py b/quasardb/pool.py index d7994c17..59b2c189 100644 --- a/quasardb/pool.py +++ b/quasardb/pool.py @@ -1,22 +1,30 @@ +from __future__ import annotations + +import functools import logging -import quasardb import threading -import functools import weakref +from types import TracebackType +from typing import Any, Callable, Optional, Type, TypeVar, Union + +# import quasardb +from quasardb import Cluster logger = logging.getLogger("quasardb.pool") +T = TypeVar("T") + -def _create_conn(**kwargs): - return quasardb.Cluster(**kwargs) +def _create_conn(**kwargs: Any) -> Cluster: + return Cluster(**kwargs) -class SessionWrapper(object): - def __init__(self, pool, conn): +class SessionWrapper: + def __init__(self, pool: Pool, conn: Cluster): self._conn = conn self._pool = pool - def __getattr__(self, attr): + def __getattr__(self, attr: Any) -> Any: # This hack copies all the quasardb.Cluster() properties, functions and # whatnot, and pretends that this class is actually a quasardb.Cluster. # @@ -41,14 +49,19 @@ def __getattr__(self, attr): return getattr(self._conn, attr) - def __enter__(self): + def __enter__(self: T) -> T: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self._pool.release(self._conn) -class Pool(object): +class Pool: """ A connection pool. This class should not be initialized directly, but rather the subclass `SingletonPool` should be initialized. @@ -88,8 +101,10 @@ def my_qdb_connection_create(): """ - def __init__(self, get_conn=None, **kwargs): - self._all_connections = [] + def __init__( + self, get_conn: Optional[Callable[..., Cluster]] = None, **kwargs: Any + ): + self._all_connections: list[SessionWrapper] = [] if get_conn is None: get_conn = functools.partial(_create_conn, **kwargs) @@ -99,16 +114,21 @@ def __init__(self, get_conn=None, **kwargs): self._get_conn = get_conn - def __enter__(self): + def __enter__(self: T) -> T: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: self.close() - def _create_conn(self): + def _create_conn(self) -> SessionWrapper: return SessionWrapper(self, self._get_conn()) - def close(self): + def close(self) -> None: """ Close this connection pool, and all associated connections. This function is automatically invoked when used in a with-block or when using the global @@ -118,10 +138,10 @@ def close(self): logger.debug("closing connection {}".format(conn)) conn.close() - def _do_connect(self): + def _do_connect(self) -> SessionWrapper: raise NotImplementedError - def connect(self) -> quasardb.Cluster: + def connect(self) -> SessionWrapper: """ Acquire a new connection from the pool. Returned connection must either be explicitly released using `pool.release()`, or should be wrapped in a @@ -134,10 +154,10 @@ def connect(self) -> quasardb.Cluster: logger.info("Acquiring connection from pool") return self._do_connect() - def _do_release(self): + def _do_release(self, conn: Cluster) -> None: raise NotImplementedError - def release(self, conn: quasardb.Cluster): + def release(self, conn: Cluster) -> None: """ Put a connection back into the pool """ @@ -175,11 +195,11 @@ class SingletonPool(Pool): ``` """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): Pool.__init__(self, **kwargs) self._conn = threading.local() - def _do_connect(self): + def _do_connect(self) -> SessionWrapper: try: c = self._conn.current() if c: @@ -193,7 +213,7 @@ def _do_connect(self): return c - def _do_release(self, conn): + def _do_release(self, conn: Cluster) -> None: # Thread-local connections do not have to be 'released'. pass @@ -201,7 +221,7 @@ def _do_release(self, conn): __instance = None -def initialize(*args, **kwargs): +def initialize(*args: Any, **kwargs: Any) -> None: """ Initialize a new global SingletonPool. Forwards all arguments to the constructor of `SingletonPool()`. @@ -244,7 +264,12 @@ def instance() -> SingletonPool: return __instance -def _inject_conn_arg(conn, arg, args, kwargs): +def _inject_conn_arg( + conn: SessionWrapper, + arg: Union[str, int], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Decorator utility function. Takes the argument provided to the decorator that configures how we should inject the pool into the args to the callback @@ -256,9 +281,9 @@ def _inject_conn_arg(conn, arg, args, kwargs): # # Because positional args are always a tuple, and tuples don't have an # easy 'insert into position' function, we just cast to and from a list. - args = list(args) - args.insert(arg, conn) - args = tuple(args) + args_list = list(args) + args_list.insert(arg, conn) + args = tuple(args_list) else: assert isinstance(arg, str) == True # If not a number, we assume it's a kwarg, which makes things easier @@ -267,7 +292,9 @@ def _inject_conn_arg(conn, arg, args, kwargs): return (args, kwargs) -def with_conn(_fn=None, *, arg=0): +def with_conn( + _fn: Optional[Callable[..., Any]] = None, *, arg: Union[str, int] = 0 +) -> Callable[..., Any]: """ Decorator function that handles connection assignment, release and invocation for you. Should be used in conjuction with the global singleton accessor, see also: `initialize()`. @@ -295,8 +322,8 @@ def myfunction(arg1, arg2, conn=None): ``` """ - def inner(fn): - def wrapper(*args, **kwargs): + def inner(fn: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: pool = instance() with pool.connect() as conn: diff --git a/quasardb/quasardb/_batch_inserter.pyi b/quasardb/quasardb/_batch_inserter.pyi index f1fb6a23..fdb1a89f 100644 --- a/quasardb/quasardb/_batch_inserter.pyi +++ b/quasardb/quasardb/_batch_inserter.pyi @@ -1,3 +1,5 @@ +from typing import Any + class TimeSeriesBatch: def push(self) -> None: """ @@ -14,7 +16,7 @@ class TimeSeriesBatch: Fast, in-place batch push that is efficient when doing lots of small, incremental pushes. """ - def push_truncate(self, **kwargs) -> None: + def push_truncate(self, **kwargs: Any) -> None: """ Before inserting data, truncates any existing data. This is useful when you want your insertions to be idempotent, e.g. in case of a retry. """ @@ -23,8 +25,8 @@ class TimeSeriesBatch: def set_double(self, index: int, double: float) -> None: ... def set_int64(self, index: int, int64: int) -> None: ... def set_string(self, index: int, string: str) -> None: ... - def set_timestamp(self, index: int, timestamp: object) -> None: ... - def start_row(self, ts: object) -> None: + def set_timestamp(self, index: int, timestamp: Any) -> None: ... + def start_row(self, ts: Any) -> None: """ Calling this function marks the beginning of processing a new row. """ diff --git a/quasardb/quasardb/_cluster.pyi b/quasardb/quasardb/_cluster.pyi index 095b3a41..f39d54a3 100644 --- a/quasardb/quasardb/_cluster.pyi +++ b/quasardb/quasardb/_cluster.pyi @@ -1,9 +1,10 @@ from __future__ import annotations import datetime +from types import TracebackType +from typing import Any, Optional, Type -import numpy as np - +from ..typing import MaskedArrayAny, RangeSet from ._batch_column import BatchColumnInfo from ._batch_inserter import TimeSeriesBatch from ._blob import Blob @@ -27,7 +28,12 @@ class Cluster: """ def __enter__(self) -> Cluster: ... - def __exit__(self, exc_type: object, exc_value: object, exc_tb: object) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... def __init__( self, uri: str, @@ -39,7 +45,7 @@ class Cluster: timeout: datetime.timedelta = datetime.timedelta(minutes=1), do_version_check: bool = False, enable_encryption: bool = False, - compression_mode: Options.Compression = ..., # balanced + compression_mode: Options.Compression = Options.Compression.Balanced, client_max_parallelism: int = 0, ) -> None: ... def blob(self, alias: str) -> Blob: ... @@ -55,9 +61,9 @@ class Cluster: def integer(self, alias: str) -> Integer: ... def is_open(self) -> bool: ... def node(self, uri: str) -> Node: ... - def node_config(self, uri: str) -> dict[str, object]: ... - def node_status(self, uri: str) -> dict[str, object]: ... - def node_topology(self, uri: str) -> dict[str, object]: ... + def node_config(self, uri: str) -> dict[str, Any]: ... + def node_status(self, uri: str) -> dict[str, Any]: ... + def node_topology(self, uri: str) -> dict[str, Any]: ... def options(self) -> Options: ... def perf(self) -> Perf: ... def pinned_writer(self) -> Writer: ... @@ -68,20 +74,20 @@ class Cluster: def purge_cache(self, timeout: datetime.timedelta) -> None: ... def query( self, query: str, blobs: bool | list[str] = False - ) -> list[dict[str, object]]: ... + ) -> list[dict[str, Any]]: ... def query_continuous_full( self, query: str, pace: datetime.timedelta, blobs: bool | list[str] = False ) -> QueryContinuous: ... def query_continuous_new_values( self, query: str, pace: datetime.timedelta, blobs: bool | list[str] = False ) -> QueryContinuous: ... - def query_numpy(self, query: str) -> list[tuple[str, np.ma.MaskedArray]]: ... + def query_numpy(self, query: str) -> list[tuple[str, MaskedArrayAny]]: ... def reader( self, table_names: list[str], column_names: list[str] = [], batch_size: int = 0, - ranges: list[tuple] = [], + ranges: RangeSet = [], ) -> Reader: ... def string(self, alias: str) -> String: ... def suffix_count(self, suffix: str) -> int: ... diff --git a/quasardb/quasardb/_continuous.pyi b/quasardb/quasardb/_continuous.pyi index fe4c026b..6f9f7e87 100644 --- a/quasardb/quasardb/_continuous.pyi +++ b/quasardb/quasardb/_continuous.pyi @@ -1,12 +1,14 @@ from __future__ import annotations +from typing import Any + # import datetime class QueryContinuous: def __iter__(self) -> QueryContinuous: ... - def __next__(self) -> list[dict[str, object]]: ... - def probe_results(self) -> list[dict[str, object]]: ... - def results(self) -> list[dict[str, object]]: ... + def __next__(self) -> list[dict[str, Any]]: ... + def probe_results(self) -> list[dict[str, Any]]: ... + def results(self) -> list[dict[str, Any]]: ... # def run( # self, # mode: qdb_query_continuous_mode_type_t, diff --git a/quasardb/quasardb/_entry.pyi b/quasardb/quasardb/_entry.pyi index c5c80710..d1a11188 100644 --- a/quasardb/quasardb/_entry.pyi +++ b/quasardb/quasardb/_entry.pyi @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +from typing import Any class Entry: class Metadata: ... @@ -15,27 +16,27 @@ class Entry: Stream: Entry.Type # value = Timeseries: Entry.Type # value = __members__: dict[str, Entry.Type] - def __and__(self, other: object) -> object: ... - def __eq__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def __and__(self, other: Any) -> Any: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... - def __gt__(self, other: object) -> bool: ... + def __gt__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... - def __invert__(self) -> object: ... - def __le__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __or__(self, other: object) -> object: ... - def __rand__(self, other: object) -> object: ... + def __invert__(self) -> Any: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __or__(self, other: Any) -> Any: ... + def __rand__(self, other: Any) -> Any: ... def __repr__(self) -> str: ... - def __ror__(self, other: object) -> object: ... - def __rxor__(self, other: object) -> object: ... + def __ror__(self, other: Any) -> Any: ... + def __rxor__(self, other: Any) -> Any: ... def __setstate__(self, state: int) -> None: ... def __str__(self) -> str: ... - def __xor__(self, other: object) -> object: ... + def __xor__(self, other: Any) -> Any: ... @property def name(self) -> str: ... @property diff --git a/quasardb/quasardb/_options.pyi b/quasardb/quasardb/_options.pyi index 764dc90b..fa6d7eaf 100644 --- a/quasardb/quasardb/_options.pyi +++ b/quasardb/quasardb/_options.pyi @@ -1,6 +1,7 @@ from __future__ import annotations import datetime +from typing import Any class Options: class Compression: @@ -8,27 +9,27 @@ class Options: Best: Options.Compression # value = Balanced: Options.Compression # value = __members__: dict[str, Options.Compression] - def __and__(self, other: object) -> object: ... - def __eq__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def __and__(self, other: Any) -> Any: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... - def __gt__(self, other: object) -> bool: ... + def __gt__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... - def __invert__(self) -> object: ... - def __le__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __or__(self, other: object) -> object: ... - def __rand__(self, other: object) -> object: ... + def __invert__(self) -> Any: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __or__(self, other: Any) -> Any: ... + def __rand__(self, other: Any) -> Any: ... def __repr__(self) -> str: ... - def __ror__(self, other: object) -> object: ... - def __rxor__(self, other: object) -> object: ... + def __ror__(self, other: Any) -> Any: ... + def __rxor__(self, other: Any) -> Any: ... def __setstate__(self, state: int) -> None: ... def __str__(self) -> str: ... - def __xor__(self, other: object) -> object: ... + def __xor__(self, other: Any) -> Any: ... @property def name(self) -> str: ... @property @@ -38,27 +39,27 @@ class Options: Disabled: Options.Encryption # value = AES256GCM: Options.Encryption # value = __members__: dict[str, Options.Encryption] - def __and__(self, other: object) -> object: ... - def __eq__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def __and__(self, other: Any) -> Any: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... - def __gt__(self, other: object) -> bool: ... + def __gt__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... - def __invert__(self) -> object: ... - def __le__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __or__(self, other: object) -> object: ... - def __rand__(self, other: object) -> object: ... + def __invert__(self) -> Any: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __or__(self, other: Any) -> Any: ... + def __rand__(self, other: Any) -> Any: ... def __repr__(self) -> str: ... - def __ror__(self, other: object) -> object: ... - def __rxor__(self, other: object) -> object: ... + def __ror__(self, other: Any) -> Any: ... + def __rxor__(self, other: Any) -> Any: ... def __setstate__(self, state: int) -> None: ... def __str__(self) -> str: ... - def __xor__(self, other: object) -> object: ... + def __xor__(self, other: Any) -> Any: ... @property def name(self) -> str: ... @property diff --git a/quasardb/quasardb/_perf.pyi b/quasardb/quasardb/_perf.pyi index dd26ec93..7a3af57e 100644 --- a/quasardb/quasardb/_perf.pyi +++ b/quasardb/quasardb/_perf.pyi @@ -1,5 +1,7 @@ +from typing import Any + class Perf: def clear(self) -> None: ... def disable(self) -> None: ... def enable(self) -> None: ... - def get(self, flame: bool = False, outfile: str = "") -> object: ... + def get(self, flame: bool = False, outfile: str = "") -> Any: ... diff --git a/quasardb/quasardb/_reader.pyi b/quasardb/quasardb/_reader.pyi index fd0ad7e8..c52d6512 100644 --- a/quasardb/quasardb/_reader.pyi +++ b/quasardb/quasardb/_reader.pyi @@ -1,9 +1,15 @@ from __future__ import annotations -import typing +from types import TracebackType +from typing import Any, Iterator, Optional, Type class Reader: def __enter__(self) -> Reader: ... - def __exit__(self, exc_type: object, exc_value: object, exc_tb: object) -> None: ... - def __iter__(self) -> typing.Iterator[dict[str, object]]: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... + def __iter__(self) -> Iterator[dict[str, Any]]: ... def get_batch_size(self) -> int: ... diff --git a/quasardb/quasardb/_table.pyi b/quasardb/quasardb/_table.pyi index ddd5831f..00c792c4 100644 --- a/quasardb/quasardb/_table.pyi +++ b/quasardb/quasardb/_table.pyi @@ -2,10 +2,10 @@ from __future__ import annotations import datetime import typing - -import numpy +from typing import Any, Optional, Union from quasardb.quasardb._reader import Reader +from quasardb.typing import MaskedArrayAny, NDArrayAny, NDArrayTime, RangeSet from ._entry import Entry @@ -18,27 +18,27 @@ class ColumnType: String: ColumnType # value = Symbol: ColumnType # value = __members__: dict[str, ColumnType] - def __and__(self, other: object) -> object: ... - def __eq__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def __and__(self, other: Any) -> Any: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... - def __gt__(self, other: object) -> bool: ... + def __gt__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... - def __invert__(self) -> object: ... - def __le__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __or__(self, other: object) -> object: ... - def __rand__(self, other: object) -> object: ... + def __invert__(self) -> Any: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __or__(self, other: Any) -> Any: ... + def __rand__(self, other: Any) -> Any: ... def __repr__(self) -> str: ... - def __ror__(self, other: object) -> object: ... - def __rxor__(self, other: object) -> object: ... + def __ror__(self, other: Any) -> Any: ... + def __rxor__(self, other: Any) -> Any: ... def __setstate__(self, state: int) -> None: ... def __str__(self) -> str: ... - def __xor__(self, other: object) -> object: ... + def __xor__(self, other: Any) -> Any: ... @property def name(self) -> str: ... @property @@ -69,10 +69,13 @@ class IndexedColumnInfo: class Table(Entry): def __repr__(self) -> str: ... def blob_get_ranges( - self, column: str, ranges: object = None - ) -> tuple[numpy.ndarray, numpy.ma.MaskedArray]: ... + self, column: str, ranges: Optional[RangeSet] = None + ) -> tuple[NDArrayTime, MaskedArrayAny]: ... def blob_insert( - self, column: str, timestamps: numpy.ndarray, values: numpy.ma.MaskedArray + self, + column: str, + timestamps: NDArrayTime, + values: Union[MaskedArrayAny, NDArrayAny], ) -> None: ... def column_id_by_index(self, index: int) -> str: ... def column_index_by_id(self, alias: str) -> int: ... @@ -86,40 +89,52 @@ class Table(Entry): ttl: datetime.timedelta = ..., ) -> None: ... def double_get_ranges( - self, column: str, ranges: object = None - ) -> tuple[numpy.ndarray, numpy.ma.MaskedArray]: ... + self, column: str, ranges: Optional[RangeSet] = None + ) -> tuple[NDArrayTime, MaskedArrayAny]: ... def double_insert( - self, column: str, timestamps: numpy.ndarray, values: numpy.ma.MaskedArray + self, + column: str, + timestamps: NDArrayTime, + values: Union[MaskedArrayAny, NDArrayAny], ) -> None: ... - def erase_ranges(self, column: str, ranges: object) -> int: ... + def erase_ranges(self, column: str, ranges: RangeSet) -> int: ... def get_shard_size(self) -> datetime.timedelta: ... def get_ttl(self) -> datetime.timedelta: ... def has_ttl(self) -> bool: ... def insert_columns(self, columns: list[ColumnInfo]) -> None: ... def int64_get_ranges( - self, column: str, ranges: object = None - ) -> tuple[numpy.ndarray, numpy.ma.MaskedArray]: ... + self, column: str, ranges: Optional[RangeSet] = None + ) -> tuple[NDArrayTime, MaskedArrayAny]: ... def int64_insert( - self, column: str, timestamps: numpy.ndarray, values: numpy.ma.MaskedArray + self, + column: str, + timestamps: NDArrayTime, + values: Union[MaskedArrayAny, NDArrayAny], ) -> None: ... def list_columns(self) -> list[ColumnInfo]: ... def reader( self, column_names: list[str] = [], batch_size: int = 0, - ranges: list[tuple] = [], + ranges: RangeSet = [], ) -> Reader: ... def retrieve_metadata(self) -> None: ... def string_get_ranges( - self, column: str, ranges: object = None - ) -> tuple[numpy.ndarray, numpy.ma.MaskedArray]: ... + self, column: str, ranges: Optional[RangeSet] = None + ) -> tuple[NDArrayTime, MaskedArrayAny]: ... def string_insert( - self, column: str, timestamps: numpy.ndarray, values: numpy.ma.MaskedArray + self, + column: str, + timestamps: NDArrayTime, + values: Union[MaskedArrayAny, NDArrayAny], ) -> None: ... - def subscribe(self, conn: object) -> object: ... + def subscribe(self, conn: Any) -> Any: ... def timestamp_get_ranges( - self, column: str, ranges: object = None - ) -> tuple[numpy.ndarray, numpy.ma.MaskedArray]: ... + self, column: str, ranges: Optional[RangeSet] = None + ) -> tuple[NDArrayTime, MaskedArrayAny]: ... def timestamp_insert( - self, column: str, timestamps: numpy.ndarray, values: numpy.ma.MaskedArray + self, + column: str, + timestamps: NDArrayTime, + values: Union[MaskedArrayAny, NDArrayAny], ) -> None: ... diff --git a/quasardb/quasardb/_writer.pyi b/quasardb/quasardb/_writer.pyi index ef7bc960..01abbbb2 100644 --- a/quasardb/quasardb/_writer.pyi +++ b/quasardb/quasardb/_writer.pyi @@ -1,11 +1,15 @@ from __future__ import annotations +from typing import Any, Iterable + +from quasardb.typing import Range + from ._table import Table class WriterData: def __init__(self) -> None: ... def append( - self, table: Table, index: list[object], column_data: list[list[object]] + self, table: Table, index: Iterable[Any], column_data: Iterable[Any] ) -> None: ... def empty(self) -> bool: ... @@ -15,27 +19,27 @@ class WriterPushMode: Fast: WriterPushMode # value = Async: WriterPushMode # value = __members__: dict[str, WriterPushMode] - def __and__(self, other: object) -> object: ... - def __eq__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def __and__(self, other: Any) -> Any: ... + def __eq__(self, other: Any) -> bool: ... + def __ge__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... - def __gt__(self, other: object) -> bool: ... + def __gt__(self, other: Any) -> bool: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... - def __invert__(self) -> object: ... - def __le__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __or__(self, other: object) -> object: ... - def __rand__(self, other: object) -> object: ... + def __invert__(self) -> Any: ... + def __le__(self, other: Any) -> bool: ... + def __lt__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __or__(self, other: Any) -> Any: ... + def __rand__(self, other: Any) -> Any: ... def __repr__(self) -> str: ... - def __ror__(self, other: object) -> object: ... - def __rxor__(self, other: object) -> object: ... + def __ror__(self, other: Any) -> Any: ... + def __rxor__(self, other: Any) -> Any: ... def __setstate__(self, state: int) -> None: ... def __str__(self) -> str: ... - def __xor__(self, other: object) -> object: ... + def __xor__(self, other: Any) -> Any: ... @property def name(self) -> str: ... @property @@ -50,19 +54,18 @@ class Writer: deduplication_mode: str, deduplicate: str, retries: int, - range: tuple[object, ...], - **kwargs, + range: Range, + **kwargs: Any, ) -> None: ... def push_fast( self, data: WriterData, write_through: bool, - push_mode: WriterPushMode, deduplication_mode: str, deduplicate: str, retries: int, - range: tuple[object, ...], - **kwargs, + range: Range, + **kwargs: Any, ) -> None: """Deprecated: Use `writer.push()` instead.""" @@ -70,12 +73,11 @@ class Writer: self, data: WriterData, write_through: bool, - push_mode: WriterPushMode, deduplication_mode: str, deduplicate: str, retries: int, - range: tuple[object, ...], - **kwargs, + range: Range, + **kwargs: Any, ) -> None: """Deprecated: Use `writer.push()` instead.""" @@ -83,29 +85,28 @@ class Writer: self, data: WriterData, write_through: bool, - push_mode: WriterPushMode, deduplication_mode: str, deduplicate: str, retries: int, - range: tuple[object, ...], - **kwargs, + range: Range, + **kwargs: Any, ) -> None: """Deprecated: Use `writer.push()` instead3.""" - def start_row(self, table: object, x: object) -> None: + def start_row(self, table: Any, x: Any) -> None: """Legacy function""" - def set_double(self, idx: object, value: object) -> object: + def set_double(self, idx: Any, value: Any) -> Any: """Legacy function""" - def set_int64(self, idx: object, value: object) -> object: + def set_int64(self, idx: Any, value: Any) -> Any: """Legacy function""" - def set_string(self, idx: object, value: object) -> object: + def set_string(self, idx: Any, value: Any) -> Any: """Legacy function""" - def set_blob(self, idx: object, value: object) -> object: + def set_blob(self, idx: Any, value: Any) -> Any: """Legacy function""" - def set_timestamp(self, idx: object, value: object) -> object: + def set_timestamp(self, idx: Any, value: Any) -> Any: """Legacy function""" diff --git a/quasardb/quasardb/metrics/__init__.pyi b/quasardb/quasardb/metrics/__init__.pyi index 1dede248..ad28bf5d 100644 --- a/quasardb/quasardb/metrics/__init__.pyi +++ b/quasardb/quasardb/metrics/__init__.pyi @@ -4,6 +4,9 @@ Keep track of low-level performance metrics from __future__ import annotations +from types import TracebackType +from typing import Optional, Type + __all__ = ["Measure", "clear", "totals"] class Measure: @@ -12,7 +15,12 @@ class Measure: """ def __enter__(self) -> Measure: ... - def __exit__(self, exc_type: object, exc_value: object, exc_tb: object) -> None: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... def __init__(self) -> None: ... def get(self) -> dict[str, int]: ... diff --git a/quasardb/stats.py b/quasardb/stats.py index c07dac53..32dc90d1 100644 --- a/quasardb/stats.py +++ b/quasardb/stats.py @@ -1,10 +1,12 @@ -import re - -import quasardb import logging +import re from collections import defaultdict from datetime import datetime from enum import Enum +from typing import Any, DefaultDict, Dict, List, TypeVar, Union + +import quasardb +from quasardb.quasardb import Cluster, Node logger = logging.getLogger("quasardb.stats") @@ -17,11 +19,11 @@ user_clean_pattern = re.compile(r"\.uid_\d+") -def is_user_stat(s): +def is_user_stat(s: str) -> bool: return user_pattern.match(s) is not None -def is_cumulative_stat(s): +def is_cumulative_stat(s: str) -> bool: # NOTE(leon): It's quite difficult to express in Python that you don't want any # regex to _end_ with uid_[0-9]+, because Python's regex engine doesn't support # variable width look-behind. @@ -34,7 +36,7 @@ def is_cumulative_stat(s): return user_pattern.match(s) is None -def by_node(conn): +def by_node(conn: Cluster) -> Dict[str, Dict[str, Any]]: """ Returns statistic grouped by node URI. @@ -45,7 +47,7 @@ def by_node(conn): return {x: of_node(conn.node(x)) for x in conn.endpoints()} -def of_node(dconn): +def of_node(dconn: Node) -> Dict[str, Any]: """ Returns statistic for a single node. @@ -61,7 +63,10 @@ def of_node(dconn): idx = _index_keys(dconn, ks) raw = {k: _get_stat_value(dconn, k) for k in ks} - ret = {"by_uid": _by_uid(raw, idx), "cumulative": _cumulative(raw, idx)} + ret: Dict[str, Any] = { + "by_uid": _by_uid(raw, idx), + "cumulative": _cumulative(raw, idx), + } check_duration = datetime.now() - start @@ -87,7 +92,7 @@ def of_node(dconn): ) -def stat_type(stat_id): +def stat_type(stat_id: str) -> None: """ Returns the statistic type by a stat id. Returns one of: @@ -109,7 +114,7 @@ def stat_type(stat_id): return None -def _get_all_keys(dconn, n=1024): +def _get_all_keys(dconn: Node, n: int = 1024) -> List[str]: """ Returns all keys from a single node. @@ -171,15 +176,17 @@ class Unit(Enum): "seconds": Unit.SECONDS, } +T = TypeVar("T", Type, Unit) + -def _lookup_enum(dconn, k, m): +def _lookup_enum(dconn: Node, k: str, m: Dict[str, T]) -> T: """ Utility function to avoid code duplication: automatically looks up a key's value from QuasarDB and looks it up in provided dict. """ - x = dconn.blob(k).get() - x = _clean_blob(x) + _x = dconn.blob(k).get() + x = _clean_blob(_x) if x not in m: raise Exception(f"Unrecognized unit/type {x} from key {k}") @@ -187,7 +194,7 @@ def _lookup_enum(dconn, k, m): return m[x] -def _lookup_type(dconn, k): +def _lookup_type(dconn: Node, k: str) -> Type: """ Looks up and parses/validates the metric type. """ @@ -196,7 +203,7 @@ def _lookup_type(dconn, k): return _lookup_enum(dconn, k, _type_string_to_enum) -def _lookup_unit(dconn, k): +def _lookup_unit(dconn: Node, k: str) -> Unit: """ Looks up and parses/validates the metric type. """ @@ -205,7 +212,7 @@ def _lookup_unit(dconn, k): return _lookup_enum(dconn, k, _unit_string_to_enum) -def _index_keys(dconn, ks): +def _index_keys(dconn: Node, ks: List[str]) -> DefaultDict[str, Dict[str, Any]]: """ Takes all statistics keys that are retrieved, and "indexes" them in such a way that we end up with a dict of all statistic keys, their type and their unit. @@ -241,13 +248,17 @@ def _index_keys(dconn, ks): # In which case we'll store `requests.out_bytes` as the statistic type, and look up the type # and unit for those metrics and add a placeholder value. - ret = defaultdict(lambda: {"value": None, "type": None, "unit": None}) + ret: DefaultDict[str, Dict[str, Any]] = defaultdict( + lambda: {"value": None, "type": None, "unit": None} + ) for k in ks: # Remove any 'uid_[0-9]+' part from the string k_ = user_clean_pattern.sub("", k) matches = total_pattern.match(k_) + if matches is None: + continue parts = matches.groups()[0].rsplit(".", 1) metric_id = parts[0] @@ -267,7 +278,7 @@ def _index_keys(dconn, ks): return ret -def _clean_blob(x): +def _clean_blob(x: bytes) -> str: """ Utility function that decodes a blob as an UTF-8 string, as the direct node C API does not yet support 'string' types and as such all statistics are stored as blobs. @@ -278,7 +289,7 @@ def _clean_blob(x): return "".join(c for c in x_ if ord(c) != 0) -def _get_stat_value(dconn, k): +def _get_stat_value(dconn: Node, k: str) -> Union[int, str]: # Ugly, but works: try to retrieve as integer, if not an int, retrieve as # blob # @@ -288,16 +299,19 @@ def _get_stat_value(dconn, k): return dconn.integer(k).get() # Older versions of qdb api returned 'alias not found' - except quasardb.quasardb.AliasNotFoundError: + except quasardb.AliasNotFoundError: return _clean_blob(dconn.blob(k).get()) # Since ~ 3.14.2, it returns 'Incompatible Type' - except quasardb.quasardb.IncompatibleTypeError: + except quasardb.IncompatibleTypeError: return _clean_blob(dconn.blob(k).get()) -def _by_uid(stats, idx): - xs = {} +def _by_uid( + stats: Dict[str, Union[int, str]], idx: DefaultDict[str, Dict[str, Any]] +) -> Dict[int, Dict[str, Dict[str, Any]]]: + xs: Dict[int, Dict[str, Dict[str, Any]]] = {} + for k, v in stats.items(): matches = user_pattern.match(k) if is_user_stat(k) and matches: @@ -329,8 +343,10 @@ def _by_uid(stats, idx): return xs -def _cumulative(stats, idx): - xs = {} +def _cumulative( + stats: Dict[str, Union[int, str]], idx: DefaultDict[str, Dict[str, Any]] +) -> Dict[str, Dict[str, Any]]: + xs: Dict[str, Dict[str, Any]] = {} for k, v in stats.items(): matches = total_pattern.match(k) diff --git a/quasardb/table_cache.py b/quasardb/table_cache.py index 2867b6c9..b9a37718 100644 --- a/quasardb/table_cache.py +++ b/quasardb/table_cache.py @@ -1,11 +1,15 @@ import logging +from typing import Dict, Optional + +from quasardb.quasardb import Cluster, Table logger = logging.getLogger("quasardb.table_cache") -_cache = {} +_cache: Dict[str, Table] = {} -def clear(): +def clear() -> None: + global _cache logger.info("Clearing table cache") _cache = {} @@ -17,7 +21,7 @@ def exists(table_name: str) -> bool: return table_name in _cache -def store(table, table_name=None, force_retrieve_metadata=True): +def store(table: Table, table_name: Optional[str] = None) -> Table: """ Stores a table into the cache. Ensures metadata is retrieved. This is useful if you want to retrieve all table metadata at the beginning of a process, to avoid doing expensive @@ -29,7 +33,7 @@ def store(table, table_name=None, force_retrieve_metadata=True): table_name = table.get_name() if exists(table_name): - logger.warn("Table already in cache, overwriting: %s", table_name) + logger.warning("Table already in cache, overwriting: %s", table_name) logger.debug("Caching table %s", table_name) _cache[table_name] = table @@ -39,7 +43,7 @@ def store(table, table_name=None, force_retrieve_metadata=True): return table -def lookup(table_name: str, conn, force_retrieve_metadata=True): +def lookup(table_name: str, conn: Cluster) -> Table: """ Retrieves table from _cache if already exists. If it does not exist, looks up the table from `conn` and puts it in the cache. @@ -53,4 +57,4 @@ def lookup(table_name: str, conn, force_retrieve_metadata=True): logger.debug("table %s not yet found, looking up", table_name) table = conn.table(table_name) - return store(table, table_name, force_retrieve_metadata=force_retrieve_metadata) + return store(table, table_name) diff --git a/quasardb/typing.py b/quasardb/typing.py new file mode 100644 index 00000000..ad5012f1 --- /dev/null +++ b/quasardb/typing.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import Iterable, Tuple + +import numpy as np + +# Numpy + +# # Modern typing (numpy >= 1.22.0, python >= 3.9) +# DType = np.dtype[Any] +# NDArrayAny = np.ndarray[Any, np.dtype[Any]] +# NDArrayTime = np.ndarray[Any, np.dtype[np.datetime64]] +# MaskedArrayAny = np.ma.MaskedArray[Any, Any] + +# Legacy fallback (numpy ~ 1.20.3, python 3.7) +DType = np.dtype +NDArrayAny = np.ndarray +NDArrayTime = np.ndarray +MaskedArrayAny = np.ma.MaskedArray + +# Qdb expressions +Range = Tuple[np.datetime64, np.datetime64] +RangeSet = Iterable[Range] diff --git a/scripts/github_actions/requirements.txt b/scripts/github_actions/requirements.txt deleted file mode 100644 index a5276e58..00000000 --- a/scripts/github_actions/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -black[colorama]==24.10.0 -slack_sdk==3.34.0