diff --git a/.gitignore b/.gitignore index 40e54f1a..4f430847 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Compiled python modules. *.pyc /.idea/ +/.spyproject/ /docs/build/ /docs/source/autoapi/ /docs/source/db_mapping_schema.rst diff --git a/README_dev.md b/README_dev.md new file mode 100644 index 00000000..60ab394b --- /dev/null +++ b/README_dev.md @@ -0,0 +1,24 @@ +# Developing Data Transition + +## Testing `alembic` Migration + +1. Edit `./spinedb_api/alembic.ini`, point `sqlalchemy.url` to a (copy of a) SQLite test database. ⚠️ Its data will be altered by the migration script. +1. Edit `./spinedb_api/alembic/versions/a973ab537da2_reencode_parameter_values.py` and temporarily change + ```python + new_value = transition_data(old_value) + ``` + to + ```python + new_value = b'prepend_me ' + old_value + ``` +1. Within the `./spinedb_api` folder, execute + ```bash + alembic upgrade head + ``` +1. Open your SQLite test database in a database editor and check for changed `paramater_value`s. + +## Developing the Data Transition Module + +1. Edit `./spinedb_api/compat/data_transition.py` for development. +1. In a Python REPL, call its function `transition_data(old_json_bytes)` and check for correct output of our test cases. +1. Once this works, revert the changes of `./spinedb_api/alembic/versions/a973ab537da2_reencode_parameter_values.py` and test the above `alembic` migration again. diff --git a/pyproject.toml b/pyproject.toml index c066f4eb..31e5a5ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "chardet >=4.0.0", "PyMySQL[rsa] >=1.0.2", "psycopg2-binary", - "pyarrow >= 19.0", + "pyarrow >= 20.0", + "pydantic >= 2", "pandas >= 2.2.3", ] diff --git a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py index b9d43160..ed8fe9ec 100644 --- a/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py +++ b/spinedb_api/alembic/versions/8b0eff478bcb_add_active_by_default_to_entity_class.py @@ -9,7 +9,7 @@ from alembic import op import sqlalchemy as sa import sqlalchemy.orm -from spinedb_api.compatibility import convert_tool_feature_method_to_active_by_default +from spinedb_api.compat.compatibility_transformations import convert_tool_feature_method_to_active_by_default # revision identifiers, used by Alembic. revision = "8b0eff478bcb" diff --git a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py index e11dfd32..9830f6fd 100644 --- a/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py +++ b/spinedb_api/alembic/versions/989fccf80441_replace_values_with_reference_to_list_.py @@ -12,7 +12,8 @@ from sqlalchemy.sql.expression import bindparam from spinedb_api.exception import SpineIntegrityError from spinedb_api.helpers import group_concat -from spinedb_api.parameter_value import ParameterValueFormatError, dump_db_value, from_database +from spinedb_api.incomplete_values import dump_db_value +from spinedb_api.parameter_value import ParameterValueFormatError, from_database # revision identifiers, used by Alembic. revision = "989fccf80441" diff --git a/spinedb_api/alembic/versions/a973ab537da2_reencode_parameter_values.py b/spinedb_api/alembic/versions/a973ab537da2_reencode_parameter_values.py new file mode 100644 index 00000000..446acda6 --- /dev/null +++ b/spinedb_api/alembic/versions/a973ab537da2_reencode_parameter_values.py @@ -0,0 +1,74 @@ +"""reencode parameter_values + +Revision ID: a973ab537da2 +Revises: e9f2c2330cf8 +Create Date: 2025-05-21 12:49:16.861670 + +""" + +from typing import Any, Optional, SupportsFloat +from alembic import op +import sqlalchemy as sa +from spinedb_api.compat.converters import parse_duration +from spinedb_api.parameter_value import DateTime, Duration, from_dict, to_database +from spinedb_api.value_support import load_db_value + +# revision identifiers, used by Alembic. +revision = "a973ab537da2" +down_revision = "e9f2c2330cf8" +branch_labels = None +depends_on = None + + +TYPES_TO_REENCODE = {"duration", "date_time", "time_pattern", "time_series", "array", "map"} + + +def upgrade(): + conn = op.get_bind() + metadata = sa.MetaData() + metadata.reflect(bind=conn) + _upgrade_table_types(metadata.tables["parameter_definition"], "default_value", "default_type", conn) + _upgrade_table_types(metadata.tables["parameter_value"], "value", "type", conn) + _upgrade_table_types(metadata.tables["list_value"], "value", "type", conn) + + +def downgrade(): + pass + + +def _upgrade_table_types(table, value_label, type_label, connection): + value_column = getattr(table.c, value_label) + type_column = getattr(table.c, type_label) + update_statement = ( + table.update() + .where(table.c.id == sa.bindparam("id")) + .values( + { + "id": sa.bindparam("id"), + value_label: sa.bindparam(value_label), + } + ) + ) + batch_data = [] + for id_, type_, old_blob in connection.execute( + sa.select(table.c.id, type_column, value_column).where(type_column.in_(TYPES_TO_REENCODE)) + ): + legacy_value = _from_database_legacy(old_blob, type_) + new_blob, _ = to_database(legacy_value) + batch_data.append({"id": id_, value_label: new_blob}) + if len(batch_data) == 100: + connection.execute(update_statement, batch_data) + batch_data.clear() + if batch_data: + connection.execute(update_statement, batch_data) + + +def _from_database_legacy(value: bytes, type_: Optional[str]) -> Optional[Any]: + parsed = load_db_value(value) + if isinstance(parsed, dict): + return from_dict(parsed, type_) + if type_ == DateTime.TYPE: + return DateTime(parsed) + if type_ == Duration.TYPE: + return Duration(parse_duration(parsed)) + raise RuntimeError(f"migration for {type_} missing") diff --git a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py index 8e9edafd..a40b57a4 100644 --- a/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py +++ b/spinedb_api/alembic/versions/ca7a13da8ff6_add_type_info_for_scalars.py @@ -9,7 +9,7 @@ import json from alembic import op import sqlalchemy as sa -from spinedb_api.parameter_value import type_for_scalar +from spinedb_api.incomplete_values import type_for_scalar # revision identifiers, used by Alembic. revision = "ca7a13da8ff6" @@ -40,7 +40,10 @@ def _update_scalar_type_info(table, value_label, type_label, connection): parsed_value = json.loads(value) if parsed_value is None: continue - value_type = type_for_scalar(parsed_value) + if isinstance(parsed_value, dict): + value_type = parsed_value["type"] + else: + value_type = type_for_scalar(parsed_value) connection.execute(update_statement.where(table.c.id == id_), {type_label: value_type}) diff --git a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py index 0e1d5947..f62c17e3 100644 --- a/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py +++ b/spinedb_api/alembic/versions/ce9faa82ed59_create_entity_alternative_table.py @@ -8,7 +8,7 @@ from alembic import op import sqlalchemy as sa -from spinedb_api.compatibility import convert_tool_feature_method_to_entity_alternative +from spinedb_api.compat.compatibility_transformations import convert_tool_feature_method_to_entity_alternative # revision identifiers, used by Alembic. revision = "ce9faa82ed59" diff --git a/spinedb_api/arrow_value.py b/spinedb_api/arrow_value.py index 7dae65e9..7d80cc5a 100644 --- a/spinedb_api/arrow_value.py +++ b/spinedb_api/arrow_value.py @@ -18,283 +18,93 @@ This is highly experimental API. """ -from collections import defaultdict -from collections.abc import Callable, Iterable import datetime -from typing import Any, Optional, SupportsFloat, Union +import json +from typing import Any, SupportsFloat, TypeAlias from dateutil import relativedelta -import numpy import pyarrow -from .parameter_value import ( - NUMPY_DATETIME_DTYPE, - TIME_SERIES_DEFAULT_RESOLUTION, - TIME_SERIES_DEFAULT_START, - ParameterValueFormatError, - duration_to_relativedelta, - load_db_value, -) +from .compat.converters import parse_duration +from .exception import SpineDBAPIError +from .helpers import time_period_format_specification, time_series_metadata +from .models import AllArrays, ArrayAsDict, Metadata, dict_to_array +from .value_support import load_db_value, to_union_array, validate_time_period -_DATA_TYPE_TO_ARROW_TYPE = { - "date_time": pyarrow.timestamp("s"), - "duration": pyarrow.duration("us"), - "float": pyarrow.float64(), - "str": pyarrow.string(), - "null": pyarrow.null(), -} +Value: TypeAlias = float | str | bool | datetime.datetime | relativedelta.relativedelta | pyarrow.RecordBatch | None -_ARROW_TYPE_TO_DATA_TYPE = dict(zip(_DATA_TYPE_TO_ARROW_TYPE.values(), _DATA_TYPE_TO_ARROW_TYPE.keys())) -_DATA_CONVERTER = { - "date_time": lambda data: numpy.array(data, dtype="datetime64[s]"), -} - - -def from_database(db_value: bytes, value_type: str) -> Any: +def from_database(db_value: bytes, value_type: str) -> Value: """Parses a database value.""" if db_value is None: return None - loaded = load_db_value(db_value, value_type) - if isinstance(loaded, dict): - return from_dict(loaded, value_type) + loaded = load_db_value(db_value) + if isinstance(loaded, list) and len(loaded) > 0 and isinstance(loaded[0], dict): + return to_record_batch(loaded) + if value_type == "duration": + return parse_duration(loaded) + if value_type == "date_time": + return datetime.datetime.fromisoformat(loaded) if isinstance(loaded, SupportsFloat) and not isinstance(loaded, bool): return float(loaded) return loaded -def from_dict(loaded_value: dict, value_type: str) -> pyarrow.RecordBatch: - """Converts a value dict to parsed value.""" - if value_type == "array": - data_type = loaded_value.get("value_type", "float") - data = loaded_value["data"] - if data_type in _DATA_CONVERTER: - data = _DATA_CONVERTER[data_type](data) - arrow_type = _DATA_TYPE_TO_ARROW_TYPE[data_type] - y_array = pyarrow.array(data, type=arrow_type) - x_array = pyarrow.array(range(0, len(y_array)), type=pyarrow.int64()) - return pyarrow.RecordBatch.from_arrays([x_array, y_array], names=[loaded_value.get("index_name", "i"), "value"]) - if value_type == "map": - return crawled_to_record_batch(crawl_map_uneven, loaded_value) - if value_type == "time_series": - return crawled_to_record_batch(crawl_time_series, loaded_value) - raise NotImplementedError(f"unknown value type {value_type}") - - -def to_database(parsed_value: Any) -> tuple[bytes, str]: - """Converts parsed value into database value.""" - raise NotImplementedError() - - -def type_of_loaded(loaded_value: Any) -> str: - """Infer the type of loaded value.""" - if isinstance(loaded_value, dict): - return loaded_value["type"] - elif isinstance(loaded_value, str): - return "str" - elif isinstance(loaded_value, bool): - return "bool" - elif isinstance(loaded_value, SupportsFloat): - return "float" - elif isinstance(loaded_value, datetime.datetime): - return "date_time" - elif loaded_value is None: - return "null" - raise RuntimeError(f"unknown type") - +def with_column_as_time_period(record_batch: pyarrow.RecordBatch, column: int | str) -> pyarrow.RecordBatch: + """Creates a shallow copy of record_batch with additional metadata marking a column's data type as time_period. -CrawlTuple = tuple[list, list, list, dict[str, dict[str, str]], int] + Also, validates that the column contains strings compatible with the time period specification. + """ + for period in record_batch.column(column): + validate_time_period(period.as_py()) + return with_field_metadata(time_period_format_specification(), record_batch, column) -def crawled_to_record_batch( - crawl: Callable[[dict, Optional[list[tuple[str, Any]]], Optional[list[str]]], CrawlTuple], loaded_value: dict +def with_column_as_time_stamps( + record_batch: pyarrow.RecordBatch, column: int | str, ignore_year: bool, repeat: bool ) -> pyarrow.RecordBatch: - typed_xs, ys, index_names, index_metadata, depth = crawl(loaded_value) - if not ys: - return pyarrow.RecordBatch.from_arrays( - [ - pyarrow.array([], _DATA_TYPE_TO_ARROW_TYPE[loaded_value["index_type"]]), - pyarrow.array([], pyarrow.null()), - ], - names=index_names + ["value"], - ) - x_arrays = [] - for i in range(depth): - x_arrays.append(build_x_array(typed_xs, i)) - arrays = x_arrays + [build_y_array(ys)] - array_names = index_names + ["value"] - return pyarrow.RecordBatch.from_arrays(arrays, schema=make_schema(arrays, array_names, index_metadata)) - - -def make_schema( - arrays: Iterable[pyarrow.Array], array_names: Iterable[str], array_metadata: dict[str, dict[str, str]] -) -> pyarrow.Schema: - fields = [] - for array, name in zip(arrays, array_names): - fields.append(pyarrow.field(name, array.type, metadata=array_metadata.get(name))) - return pyarrow.schema(fields) - - -def crawl_map_uneven( - loaded_value: dict, root_index: Optional[list[tuple[str, Any]]] = None, root_index_names: Optional[list[str]] = None -) -> CrawlTuple: - if root_index is None: - root_index = [] - root_index_names = [] - depth = len(root_index) + 1 - typed_xs = [] - ys = [] - max_nested_depth = 0 - index_names = root_index_names + [loaded_value.get("index_name", f"col_{depth}")] - index_metadata = {} - deepest_nested_index_names = [] - index_type = loaded_value["index_type"] - data = loaded_value["data"] - if isinstance(data, dict): - data = data.items() - for x, y in data: - index = root_index + [(index_type, x)] - if isinstance(y, dict): - y_is_scalar = False - y_type = y["type"] - if y_type == "date_time": - y = datetime.datetime.fromisoformat(y["data"]) - y_is_scalar = True - if not y_is_scalar: - if y_type == "map": - crawl_nested = crawl_map_uneven - elif y_type == "time_series": - crawl_nested = crawl_time_series - else: - raise RuntimeError(f"unknown nested type {y_type}") - nested_xs, nested_ys, nested_index_names, nested_index_metadata, nested_depth = crawl_nested( - y, index, index_names - ) - typed_xs += nested_xs - ys += nested_ys - deepest_nested_index_names = collect_nested_index_names(nested_index_names, deepest_nested_index_names) - index_metadata.update(nested_index_metadata) - max_nested_depth = max(max_nested_depth, nested_depth) - continue - typed_xs.append(index) - ys.append(y) - index_names = index_names if not deepest_nested_index_names else deepest_nested_index_names - return typed_xs, ys, index_names, index_metadata, depth if max_nested_depth == 0 else max_nested_depth - - -def crawl_time_series( - loaded_value: dict, root_index: Optional[list[tuple[str, Any]]] = None, root_index_names: Optional[list[str]] = None -) -> CrawlTuple: - if root_index is None: - root_index = [] - root_index_names = [] - typed_xs = [] - ys = [] - data = loaded_value["data"] - index_name = loaded_value.get("index_name", "t") - if isinstance(data, list) and data and not isinstance(data[0], list): - loaded_index = loaded_value.get("index", {}) - start = numpy.datetime64(loaded_index.get("start", TIME_SERIES_DEFAULT_START)) - resolution = loaded_index.get("resolution", TIME_SERIES_DEFAULT_RESOLUTION) - data = zip(time_stamps(start, resolution, len(data)), data) - for x, y in data: - index = root_index + [("date_time", x)] - typed_xs.append(index) - ys.append(y) - ignore_year = loaded_index.get("ignore_year", False) - repeat = loaded_index.get("repeat", False) - else: - if isinstance(data, dict): - data = data.items() - for x, y in data: - index = root_index + [("date_time", datetime.datetime.fromisoformat(x))] - typed_xs.append(index) - ys.append(y) - ignore_year = False - repeat = False - metadata = { - index_name: { - "ignore_year": "true" if ignore_year else "false", - "repeat": "true" if repeat else "false", - } - } - index_names = root_index_names + [index_name] - return typed_xs, ys, index_names, metadata, len(root_index) + 1 + if not pyarrow.types.is_timestamp(record_batch.column(column).type): + raise SpineDBAPIError("column is not time stamp column") + return with_field_metadata(time_series_metadata(ignore_year, repeat), record_batch, column) -def time_series_resolution(resolution: Union[str, list[str]]) -> list[relativedelta]: - """Parses time series resolution string.""" - if isinstance(resolution, str): - resolution = [duration_to_relativedelta(resolution)] - else: - resolution = list(map(duration_to_relativedelta, resolution)) - if not resolution: - raise ParameterValueFormatError("Resolution cannot be empty or zero.") - return resolution - - -def time_stamps(start, resolution, count): - resolution_as_deltas = time_series_resolution(resolution) - cycle_count = -(-count // len(resolution_as_deltas)) - deltas = [start.tolist()] + (cycle_count * resolution_as_deltas)[: count - 1] - np_deltas = numpy.array(deltas) - return np_deltas.cumsum().astype(NUMPY_DATETIME_DTYPE) - - -def collect_nested_index_names(index_names1, index_names2): - if len(index_names1) > len(index_names2): - longer = index_names1 - else: - longer = index_names2 - for name1, name2 in zip(index_names1, index_names2): - if name1 != name2: - raise RuntimeError(f"index name mismatch") - return longer - - -def build_x_array(uneven_data, i): - by_type = defaultdict(list) - types_and_offsets = [] - for row in uneven_data: - try: - data_type, x = row[i] - except IndexError: - x = None - data_type = "null" - x_list = by_type[data_type] - x_list.append(x) - types_and_offsets.append((data_type, len(x_list) - 1)) - return union_array(by_type, types_and_offsets) - - -def build_y_array(y_list): - by_type = defaultdict(list) - types_and_offsets = [] - for y in y_list: - data_type = type_of_loaded(y) - y_list = by_type[data_type] - y_list.append(y) - types_and_offsets.append((data_type, len(y_list) - 1)) - return union_array(by_type, types_and_offsets) - - -def union_array(by_type, types_and_offsets): - if len(by_type) == 1: - data_type, data = next(iter(by_type.items())) - if data_type in _DATA_CONVERTER: - data = _DATA_CONVERTER[data_type](data) - return pyarrow.array(data, type=_DATA_TYPE_TO_ARROW_TYPE[data_type]) - arrays = [] - for type_, ys in by_type.items(): - if type_ in _DATA_CONVERTER: - ys = _DATA_CONVERTER[type_](ys) - arrow_type = _DATA_TYPE_TO_ARROW_TYPE[type_] - array = pyarrow.array(ys, type=arrow_type) - arrays.append(array) - type_index = {y_type: i for i, y_type in enumerate(by_type)} - type_ids = [] - value_offsets = [] - for type_, offset in types_and_offsets: - type_ids.append(type_index[type_]) - value_offsets.append(offset) - types = pyarrow.array(type_ids, type=pyarrow.int8()) - offsets = pyarrow.array(value_offsets, type=pyarrow.int32()) - return pyarrow.UnionArray.from_dense(types, offsets, arrays, field_names=list(by_type)) +def with_field_metadata( + metadata: Metadata | dict[str, Any], record_batch: pyarrow.RecordBatch, column: int | str +) -> pyarrow.RecordBatch: + column_i = column if isinstance(column, int) else record_batch.column_names.index(column) + new_fields = [] + for i in range(record_batch.num_columns): + field = record_batch.field(i) + if i == column_i: + field = field.with_metadata({key: json.dumps(value) for key, value in metadata.items()}) + new_fields.append(field) + return pyarrow.record_batch(record_batch.columns, schema=pyarrow.schema(new_fields)) + + +def load_field_metadata(field: pyarrow.Field) -> dict[str, Any] | None: + metadata = field.metadata + if metadata is None: + return None + return {key.decode(): json.loads(value) for key, value in metadata.items()} + + +def to_record_batch(loaded_value: list[ArrayAsDict]) -> pyarrow.RecordBatch: + columns = list(map(dict_to_array, loaded_value)) + arrow_columns = {column.name: to_arrow(column) for column in columns} + record_batch = pyarrow.record_batch(arrow_columns) + for column in columns: + if column.metadata: + record_batch = with_field_metadata(column.metadata, record_batch, column.name) + return record_batch + + +def to_arrow(column: AllArrays) -> pyarrow.Array: + match column.type: + case "array" | "array_index": + return pyarrow.array(column.values) + case "dict_encoded_array" | "dict_encoded_index": + return pyarrow.DictionaryArray.from_arrays(column.indices, column.values) + case "run_end_array" | "run_end_index": + return pyarrow.RunEndEncodedArray.from_arrays(column.run_end, column.values) + case "any_array": + return to_union_array(column.values) + case _: + raise NotImplementedError(f"{column.type}: column type") diff --git a/spinedb_api/compatibility.py b/spinedb_api/compat/compatibility_transformations.py similarity index 100% rename from spinedb_api/compatibility.py rename to spinedb_api/compat/compatibility_transformations.py diff --git a/spinedb_api/compat/converters.py b/spinedb_api/compat/converters.py new file mode 100644 index 00000000..d3ff4588 --- /dev/null +++ b/spinedb_api/compat/converters.py @@ -0,0 +1,152 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +from datetime import timedelta +import re + +from dateutil.relativedelta import relativedelta +import pandas as pd +import pyarrow as pa + +# Regex pattern to identify a number encoded as a string +freq = r"([0-9]+)" +# Regex patterns that matches partial duration strings +DATE_PAT = re.compile(r"".join(rf"({freq}{unit})?" for unit in "YMD")) +TIME_PAT = re.compile(r"".join(rf"({freq}{unit})?" for unit in "HMS")) +WEEK_PAT = re.compile(rf"{freq}W") + + +def parse_duration(value: str) -> relativedelta: + """Parse a ISO 8601 duration format string to a `relativedelta`.""" + value = value.lstrip("P") + if m0 := WEEK_PAT.match(value): + weeks = m0.groups()[0] + return relativedelta(weeks=int(weeks)) + + # unpack to variable number of args to handle absence of timestamp + date, *_time = value.split("T") + time = _time[0] if _time else "" + delta = relativedelta() + + def parse_num(token: str) -> int: + return int(token) if token else 0 + + if m1 := DATE_PAT.match(date): + years = parse_num(m1.groups()[1]) + months = parse_num(m1.groups()[3]) + days = parse_num(m1.groups()[5]) + delta += relativedelta(years=years, months=months, days=days) + + if m2 := TIME_PAT.match(time): + hours = parse_num(m2.groups()[1]) + minutes = parse_num(m2.groups()[3]) + seconds = parse_num(m2.groups()[5]) + delta += relativedelta(hours=hours, minutes=minutes, seconds=seconds) + + return delta + + +def _normalise_delta(years=0, months=0, days=0, hours=0, minutes=0, seconds=0, microseconds=0, nanoseconds=0) -> dict: + microseconds += nanoseconds // 1_000 + + seconds += microseconds // 1_000_000 + + minutes += seconds // 60 + seconds = seconds % 60 + + hours += minutes // 60 + minutes = minutes % 60 + + days += hours // 24 + hours = hours % 24 + + years += months // 12 + months = months % 12 + + units = ("years", "months", "days", "hours", "minutes", "seconds") + values = (years, months, days, hours, minutes, seconds) + res = {unit: value for unit, value in zip(units, values) if value > 0} + return res + + +def _delta_as_dict(delta: relativedelta | pd.DateOffset | timedelta | pa.MonthDayNano) -> dict: + match delta: + case pa.MonthDayNano(): + return _normalise_delta(months=delta.months, days=delta.days, nanoseconds=delta.nanoseconds) + case timedelta(): + return _normalise_delta(days=delta.days, seconds=delta.seconds, microseconds=delta.microseconds) + case relativedelta() | pd.DateOffset(): + return {k: v for k, v in vars(delta).items() if not k.startswith("_") and k.endswith("s") and v} + case _: + raise TypeError(f"{delta}: unknown type {type(delta)}") + + +def to_relativedelta(offset: str | pd.DateOffset | timedelta | pa.MonthDayNano) -> relativedelta: + """Convert various compatible time offset formats to `relativedelta`. + + Compatible formats: + - JSON string in "duration" format + - `pandas.DateOffset` + - `datetime.timedelta` + + Everyone should use this instead of trying to convert themselves. + + """ + match offset: + case str(): + return parse_duration(offset) + case _: + return relativedelta(**_delta_as_dict(offset)) + + +def to_dateoffset(delta: relativedelta) -> pd.DateOffset: + """Convert `relativedelta` to `pandas.DateOffset`.""" + return pd.DateOffset(**_delta_as_dict(delta)) + + +_duration_abbrevs = { + "years": "Y", + "months": "M", + "days": "D", + "sentinel": "T", + "hours": "H", + "minutes": "M", + "seconds": "S", +} + + +_ZERO_DURATION = "P0D" + + +def to_duration(delta: relativedelta | pd.DateOffset | timedelta | pa.MonthDayNano) -> str: + """Convert various compatible time offset objects to JSON string + in "duration" format. + + Compatible formats: + - `relativedelta` + - `pandas.DateOffset` + - `datetime.timedelta` + + Use this for any kind of serialisation. + + """ + kwargs = _delta_as_dict(delta) + duration = "P" + for unit, abbrev in _duration_abbrevs.items(): + match unit, kwargs.get(unit): + case "sentinel", _: + duration += abbrev + case _, None: + pass + case _, num: + duration += f"{num}{abbrev}" + duration = duration.rstrip("T") + return duration if duration != "P" else _ZERO_DURATION diff --git a/spinedb_api/compat/data_transition.py b/spinedb_api/compat/data_transition.py new file mode 100644 index 00000000..b0b0d0ac --- /dev/null +++ b/spinedb_api/compat/data_transition.py @@ -0,0 +1,330 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +"""Reencode old map type JSON to record arrays or dictionary columns""" + +from collections import defaultdict +import json +import re +from typing import Any, Callable, Iterable, TypeAlias +from warnings import warn +from dateutil.relativedelta import relativedelta +import numpy as np +import pandas as pd +from pydantic import RootModel +from ..models import Table, TimePeriod +from .encode import convert_records_to_columns, to_table + +# Regex pattern to indentify numerical sequences encoded as string +SEQ_PAT = re.compile(r"^(t|p)([0-9]+)$") +# Regex pattern to identify a number encoded as a string +FREQ_PAT = re.compile("^[0-9]+$") +# Regex pattern to duration strings +DUR_PAT = re.compile(r"([0-9]+) *(Y|M|W|D|h|min|s)") + + +def _normalise_freq(freq: int | str): + """Normalise integer/string to frequency. + + The frequency value is as understood by `pandas.Timedelta`. Note + that ambiguous values such as month or year are still retained + with the intention to handle later in the pipeline. + + """ + if isinstance(freq, int): + return str(freq) + "min" + if FREQ_PAT.match(freq): + # If frequency is an integer, the implied unit is "minutes" + return freq + "min" + # not very robust yet + return ( + freq.replace("years", "Y") + .replace("year", "Y") + .replace("months", "M") + .replace("month", "M") + .replace("weeks", "W") + .replace("week", "W") + .replace("days", "D") + .replace("day", "D") + .replace("hours", "h") + .replace("hour", "h") + .replace("minutes", "min") + .replace("minute", "min") + .replace("seconds", "s") + .replace("second", "s") + ) + + +_to_numpy_time_units = { + "Y": "Y", + "M": "M", + "W": "W", + "D": "D", + "h": "h", + "min": "m", + "s": "s", +} + + +def _low_res_datetime(start: str, freq: str, periods: int) -> pd.DatetimeIndex: + """Create pd.DatetimeIndex with lower time resolution. + + The default resolution of pd.date_time is [ns], which puts + boundaries on allowed start- and end-dates due to limited storage + capacity. Choosing a resolution of [s] instead opens up that range + considerably. + + "For nanosecond resolution, the time span that can be represented + using a 64-bit integer is limited to approximately 584 years." - + https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations + + You can check the available ranges with `pd.Timestamp.min` and + `pd.Timestamp.max`. + + """ + if re_match := DUR_PAT.match(_normalise_freq(freq)): + number_str, unit = re_match.groups() + else: + raise ValueError(f"invalid frequency: {freq!r}") + + start_date_np = np.datetime64(start, "s") + freq_np = np.timedelta64(int(number_str), _to_numpy_time_units[unit]) + freq_pd = pd.Timedelta(freq_np) + + date_array = np.arange(start_date_np, start_date_np + periods * freq_np, freq_np) + date_array_with_frequency = pd.DatetimeIndex(date_array, freq=freq_pd, dtype="datetime64[s]") + + return date_array_with_frequency + + +def _to_relativedelta(val: str) -> relativedelta: + if (m := DUR_PAT.match(val)) is None: + raise ValueError(f"{val}: bad duration value") + num_str, freq = m.groups() + num = int(num_str) + match freq: + case "Y": + return relativedelta(years=num) + case "M": + return relativedelta(months=num) + case "W": + return relativedelta(weeks=num) + case "D": + return relativedelta(days=num) + case "h": + return relativedelta(hours=num) + case "min": + return relativedelta(minutes=num) + case "s": + return relativedelta(seconds=num) + case _: + # should not get here + raise ValueError(f"{val}: unknown duration") + + +def _atoi(val: str) -> int | str: + """Convert string to number if it matches `t0001` or `p2001`. + + If a match is found, also override the name to "time" or "period" + respectively. + + """ + if m := SEQ_PAT.match(val): + return int(m.group(2)) + else: + return val + + +_FmtIdx: TypeAlias = Callable[[str, str | Any], dict[str, Any]] + + +def _formatter(index_type: str) -> _FmtIdx: + """Get a function that formats the values of a name value pair. + + The name is the column name. The function returned depends on the + `index_type`. An unknown `index_type` returns a noop formatter, + but it also issues a warning. A noop formatter can be requested + explicitly by passing the type "noop"; no warning is issued in + this case. + + Index types: + ============ + + - "date_time" :: converts value to `datetime` + + - "duration" :: converts string to `relativedelta`; this + allows for ambiguous units like month or year. + + - "str" :: convert the value to integer if it matches `t0001` or + `p2002`, and the name to "time" and "period" respectively; + without a match it is a noop. + + - "float" | "noop" :: noop + + - fallback :: noop with a warning + + """ + match index_type: + case "date_time" | "datetime": + return lambda name, key: {name: pd.Timestamp(key)} + case "duration": + return lambda name, key: {name: _to_relativedelta(_normalise_freq(key))} + case "str": + # don't use lambda, can't add type hints + def _atoi_dict(name: str, val: str) -> dict[str, int | str]: + return {name: _atoi(val)} + + return _atoi_dict + case "float" | "noop": + return lambda name, key: {name: key} + case "time_pattern" | "timepattern" | "time-pattern": + return lambda name, key: {name: TimePeriod(key)} + case _: # fallback to noop w/ a warning + warn(f"{index_type}: unknown type, fallback to noop formatter") + return lambda name, key: {name: key} + + +def make_records( + json_doc: dict | int | float | str, + idx_lvls: dict, + res: list[dict], + *, + lvlname_base: str = "col_", +) -> list[dict]: + """Parse parameter value into a list of records + + Spine db stores parameter_value as JSON. After the JSON blob has + been decoded to a Python dict, this function can transform it into + a list of records (dict) like a table. These records can then be + consumed by Pandas to create a dataframe. + + The parsing logic works recursively by traversing depth first. + Each call incrementally accumulates a cell/level of a record in + the `idx_lvls` dictionary, once the traversal reaches a leaf node, + the final record is appended to the list `res`. The final result + is also returned by the function, allowing for composition. + + If at any level, the index level name is missing, a default base + name can be provided by setting a default `lvlname_base`. The + level name is derived by concatenating the base name with depth + level. + + """ + lvlname = lvlname_base + str(len(idx_lvls)) + + # NOTE: The private functions below are closures, defined early in + # the function such that they have the original arguments to + # `make_records` available to them, but nothing more. They either + # help with some computation, raise a warning, or are helpers to + # append to the result. + _msg_assert = "for the type checker: rest of the function expects `json_doc` to be a dict" + + def _uniquify_index_name(default: str) -> str: + assert isinstance(json_doc, dict), _msg_assert + index_name = json_doc.get("index_name", default) + return index_name + f"{len(idx_lvls)}" if index_name in idx_lvls else index_name + + def _from_pairs(data: Iterable[Iterable], fmt: _FmtIdx): + index_name = _uniquify_index_name(lvlname) + for key, val in data: + _lvls = {**idx_lvls, **fmt(index_name, key)} + make_records(val, _lvls, res, lvlname_base=lvlname_base) + + def _deprecated(var: str, val: Any): + assert isinstance(json_doc, dict), _msg_assert + index_name = json_doc.get("index_name", lvlname) + msg = f"{index_name}: {var}={val} is deprecated, handle in model, defaulting to time index from 0001-01-01." + warn(msg, DeprecationWarning) + + def _time_index(idx: dict, length: int): + start = idx.get("start", "0001-01-01T00:00:00") + resolution = idx.get("resolution", "1h") + return _low_res_datetime(start=start, freq=resolution, periods=length) + + def _append_arr(arr: Iterable, fmt: _FmtIdx): + index_name = _uniquify_index_name("i") + for value in arr: + res.append({**idx_lvls, **fmt(index_name, value)}) + + match json_doc: + # maps + case {"data": dict() as data, "type": "map"}: + # NOTE: is "index_type" mandatory? In case it's not, we + # check for it separately, and fallback in a way that + # raises a warning but doesn't crash; same for the + # 2-column array variant below. + index_type = json_doc.get("index_type", "undefined-index_type-in-map") + _from_pairs(data.items(), _formatter(index_type)) + case {"data": dict() as data, "index_type": index_type}: + # NOTE: relies on other types not having "index_type"; + # same for the 2-column array variant below. + _from_pairs(data.items(), _formatter(index_type)) + case {"data": [[_, _], *_] as data, "type": "map"}: + index_type = json_doc.get("index_type", "undefined-index_type-in-map") + _from_pairs(data, _formatter(index_type)) + case {"data": [[_, _], *_] as data, "index_type": index_type}: + _from_pairs(data, _formatter(index_type)) + # time series + case {"data": dict() as data, "type": "time_series"}: + _from_pairs(data.items(), _formatter("date_time")) + case {"data": [[str(), float() | int()], *_] as data, "type": "time_series"}: + _from_pairs(data, _formatter("date_time")) + case { + "data": [float() | int(), *_] as data, + "type": "time_series", + "index": dict() as idx, + }: + match idx: + case {"ignore_year": ignore_year}: + _deprecated("ignore_year", ignore_year) + case {"repeat": repeat}: + _deprecated("repeat", repeat) + + index = _time_index(idx, len(data)) + _from_pairs(zip(index, data), _formatter("noop")) + case {"type": "time_series", "data": [float() | int(), *_] as data}: + msg = "array-like 'time_series' without time-stamps, relies on 'ignore_year' and 'repeat' implicitly" + warn(msg, DeprecationWarning) + updated = {**json_doc, "index": {"ignore_year": True, "repeat": True}} + make_records(updated, idx_lvls, res, lvlname_base=lvlname_base) + # time_pattern + case {"type": "time_pattern", "data": dict() as data}: + _from_pairs(data.items(), _formatter("time_pattern")) + # arrays + case { + "type": "array", + "value_type": value_type, + "data": [str() | float() | int(), *_] as data, + }: + _append_arr(data, _formatter(value_type)) + case {"type": "array", "data": [float() | int(), *_] as data}: + _append_arr(data, _formatter("float")) + # date_time | duration + case { + "type": "date_time" | "duration" as data_t, + "data": str() | int() as data, + }: + _fmt = _formatter(data_t) + res.append({**idx_lvls, **_fmt(data_t, data)}) + # values + case int() | float() | str() | bool() as data: + _fmt = _formatter("noop") + res.append({**idx_lvls, **_fmt("value", data)}) + case _: + raise ValueError(f"match not found: {json_doc}") + return res + + +def transition_data(old_json_bytes: bytes) -> bytes: + records = make_records(json.loads(old_json_bytes), {}, []) + columns = convert_records_to_columns(records) + table = to_table(columns) + return RootModel[Table](table).model_dump_json().encode() diff --git a/spinedb_api/compat/encode.py b/spinedb_api/compat/encode.py new file mode 100644 index 00000000..d6b9b34c --- /dev/null +++ b/spinedb_api/compat/encode.py @@ -0,0 +1,107 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +"""Encode Python sequences into Array types supported by JSON blobs""" + +from datetime import datetime +import enum +from itertools import chain +from types import NoneType +from typing import Any, Sequence, TypeVar +from dateutil.relativedelta import relativedelta +import pandas as pd +from ..models import ( + AnyArray, + Array, + ArrayIndex, + DictEncodedArray, + DictEncodedIndex, + RunEndArray, + RunEndIndex, + Table, + type_map, +) + + +def convert_records_to_columns(recs: list[dict[str, Any]]) -> dict[str, list]: + nrows = len(recs) + columns: dict[str, list] = {k: [None] * nrows for k in chain.from_iterable(recs)} + for i, rec in enumerate(recs): + for col in rec: + columns[col][i] = rec[col] + return columns + + +_sentinel = enum.Enum("_sentinel", "value") +SENTINEL = _sentinel.value + +re_t = TypeVar("re_t", RunEndArray, RunEndIndex) + + +def re_encode(name: str, vals: list, array_t: type[re_t]) -> re_t: + last = SENTINEL + values, run_end = [], [] + for idx, val in enumerate(vals, start=1): + if last != val: + values.append(val) + run_end.append(idx) + else: + run_end[-1] = idx + last = val + return array_t(name=name, values=values, run_end=run_end) + + +de_t = TypeVar("de_t", DictEncodedArray, DictEncodedIndex) + + +def de_encode(name: str, value_type: str, vals: list, array_t: type[de_t]) -> de_t: + # not using list(set(...)) to preserve order + values = list(dict.fromkeys(vals)) + indices = list(map(values.index, vals)) + return array_t(name=name, value_type=value_type, values=values, indices=indices) + + +def is_any_w_none(arr: Sequence) -> tuple[bool, bool]: + all_types = set(map(type, arr)) + has_none = NoneType in all_types + return len(all_types - {NoneType}) > 1, has_none + + +def to_array(name: str, col: list): + any_type, has_none = is_any_w_none(col) + if any_type: + return AnyArray(name=name, values=col) + + match name, col, has_none: + case "value", list(), _: + return Array(name=name, value_type=type_map[type(col[0])], values=col) + case _, [float(), *_], _: + return Array(name=name, value_type="float", values=col) + case _, [bool(), *_], _: + return Array(name=name, value_type="bool", values=col) + case _, [int(), *_], True: + return Array(name=name, value_type="int", values=col) + case _, [int(), *_], False: + return ArrayIndex(name=name, value_type="int", values=col) + case _, [pd.Timestamp() | datetime(), *_], False: + return ArrayIndex(name=name, value_type="date_time", values=col) + case _, [relativedelta(), *_], False: + return ArrayIndex(name=name, value_type="duration", values=col) + case _, [str(), *_], True: + return de_encode(name, "str", col, DictEncodedArray) + case _, [str(), *_], False: + return de_encode(name, "str", col, DictEncodedIndex) + case _, _, _: + raise NotImplementedError(f"{name}: unknown column type {type(col[0])} ({has_none=})") + + +def to_table(columns: dict[str, list]) -> Table: + return [to_array(name, col) for name, col in columns.items()] diff --git a/spinedb_api/db_mapping.py b/spinedb_api/db_mapping.py index 5017d6fb..53c52f3a 100644 --- a/spinedb_api/db_mapping.py +++ b/spinedb_api/db_mapping.py @@ -32,7 +32,7 @@ from sqlalchemy.exc import ArgumentError, DatabaseError, DBAPIError from sqlalchemy.orm import Session from sqlalchemy.pool import NullPool, StaticPool -from .compatibility import CompatibilityTransformations, compatibility_transformations +from .compat.compatibility_transformations import CompatibilityTransformations, compatibility_transformations from .db_mapping_base import DatabaseMappingBase, MappedItemBase, MappedTable, PublicItem from .db_mapping_commit_mixin import DatabaseMappingCommitMixin from .db_mapping_query_mixin import DatabaseMappingQueryMixin @@ -534,7 +534,7 @@ def item(self, mapped_table: MappedTable, **kwargs) -> PublicItem: def get_or_add_by_type(self, item_type: str, **kwargs) -> PublicItem: return self.get_or_add(self.mapped_table(item_type), **kwargs) - + def get_or_add(self, mapped_table: MappedTable, **kwargs) -> PublicItem: try: return self.item(mapped_table, **kwargs) diff --git a/spinedb_api/db_mapping_helpers.py b/spinedb_api/db_mapping_helpers.py index 0f2c7daf..f7bcde3c 100644 --- a/spinedb_api/db_mapping_helpers.py +++ b/spinedb_api/db_mapping_helpers.py @@ -13,9 +13,10 @@ This module defines functions, classes and other utilities that may be useful with :class:`.db_mapping.DatabaseMapping`. """ -from spinedb_api.db_mapping_base import PublicItem -from spinedb_api.mapped_items import ParameterDefinitionItem -from spinedb_api.parameter_value import UNPARSED_NULL_VALUE, Map, from_database_to_dimension_count, type_for_value +from .db_mapping_base import PublicItem +from .incomplete_values import from_database_to_dimension_count +from .mapped_items import ParameterDefinitionItem +from .parameter_value import UNPARSED_NULL_VALUE, Map, type_and_rank_for_value # Here goes stuff that depends on `database_mapping`, `mapped_items` etc. # and thus cannot go to `helpers` due to circular imports. @@ -74,4 +75,4 @@ def is_parameter_type_valid(parameter_types, database_value, value, value_type): return True rank = from_database_to_dimension_count(database_value, value_type) return any(rank == type_and_rank[1] for type_and_rank in parameter_types if type_and_rank[0] == Map.TYPE) - return any(type_for_value(value) == type_and_rank for type_and_rank in parameter_types) + return any(type_and_rank_for_value(value) == type_and_rank for type_and_rank in parameter_types) diff --git a/spinedb_api/export_mapping/export_mapping.py b/spinedb_api/export_mapping/export_mapping.py index 3037ca3b..36087bfa 100644 --- a/spinedb_api/export_mapping/export_mapping.py +++ b/spinedb_api/export_mapping/export_mapping.py @@ -21,12 +21,12 @@ from sqlalchemy.orm import Query from sqlalchemy.sql.expression import CacheKey from .. import DatabaseMapping +from ..incomplete_values import from_database_to_dimension_count from ..mapping import Mapping, Position, is_pivoted, is_regular, unflatten from ..parameter_value import ( IndexedValue, convert_containers_to_maps, from_database, - from_database_to_dimension_count, from_database_to_single_value, type_for_scalar, ) diff --git a/spinedb_api/helpers.py b/spinedb_api/helpers.py index a9703837..46e1fe3e 100644 --- a/spinedb_api/helpers.py +++ b/spinedb_api/helpers.py @@ -17,7 +17,7 @@ import json from operator import itemgetter import os -from typing import Any +from typing import Any, Literal import warnings from alembic.config import Config from alembic.environment import EnvironmentContext @@ -57,6 +57,7 @@ from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import FunctionElement, bindparam, cast from sqlalchemy.sql.selectable import SelectBase +from typing_extensions import TypedDict from .exception import SpineDBAPIError, SpineDBVersionError SUPPORTED_DIALECTS = { @@ -936,6 +937,23 @@ def string_to_bool(string: str) -> bool: raise ValueError(string) +class FormatMetadata(TypedDict): + format: Literal["time_period"] + + +def time_period_format_specification() -> FormatMetadata: + return {"format": "time_period"} + + +class TimeSeriesMetadata(TypedDict): + ignore_year: bool + repeat: bool + + +def time_series_metadata(ignore_year: bool, repeat: bool) -> TimeSeriesMetadata: + return {"ignore_year": ignore_year, "repeat": repeat} + + @enum.unique class DisplayStatus(enum.Enum): """Custom enum for entity class display status.""" diff --git a/spinedb_api/import_mapping/generator.py b/spinedb_api/import_mapping/generator.py index cd6dc8ff..16f27d89 100644 --- a/spinedb_api/import_mapping/generator.py +++ b/spinedb_api/import_mapping/generator.py @@ -16,10 +16,13 @@ """ from collections.abc import Callable from copy import deepcopy +import json from operator import itemgetter from typing import Any, Optional +from .. import arrow_value from ..exception import ParameterValueFormatError from ..helpers import string_to_bool +from ..incomplete_values import split_value_and_type from ..mapping import Position, is_pivoted from ..parameter_value import ( Array, @@ -28,10 +31,10 @@ TimeSeriesVariableResolution, convert_leaf_maps_to_specialized_containers, from_database, - split_value_and_type, ) from .import_mapping import ImportMapping, check_validity from .import_mapping_compat import import_mapping_from_dict +from .type_conversion import JSONObject _NO_VALUE = object() @@ -386,11 +389,11 @@ def _make_value(row, value_pos): if "data" not in value: return _NO_VALUE return _parameter_value_from_dict(value) - if isinstance(value, str): + if isinstance(value, JSONObject): try: - return from_database(*split_value_and_type(value)) - except ParameterValueFormatError: - pass + return arrow_value.from_database(*split_value_and_type(value.json_string)) + except (json.JSONDecodeError, TypeError, KeyError, ParameterValueFormatError): + value = value.json_string return value diff --git a/spinedb_api/import_mapping/type_conversion.py b/spinedb_api/import_mapping/type_conversion.py index daf48742..0029d141 100644 --- a/spinedb_api/import_mapping/type_conversion.py +++ b/spinedb_api/import_mapping/type_conversion.py @@ -11,7 +11,7 @@ ###################################################################################################################### """ Type conversion functions. """ - +import json import re from spinedb_api.helpers import string_to_bool from spinedb_api.parameter_value import DateTime, Duration, ParameterValueFormatError @@ -27,7 +27,8 @@ def value_to_convert_spec(value): "float": FloatConvertSpec, "string": StringConvertSpec, "boolean": BooleanConvertSpec, - }.get(value) + "json": JSONConvertSpec, + }[value] return spec() if isinstance(value, dict): start_datetime = DateTime(value.get("start_datetime")) @@ -78,7 +79,16 @@ class BooleanConvertSpec(ConvertSpec): RETURN_TYPE = bool def __call__(self, value): - return self.RETURN_TYPE(string_to_bool(str(value))) + return string_to_bool(str(value)) + +class JSONObject: + def __init__(self, string: str): + self.json_string = string + + +class JSONConvertSpec(ConvertSpec): + DISPLAY_NAME = "JSON" + RETURN_TYPE = JSONObject class IntegerSequenceDateTimeConvertSpec(ConvertSpec): diff --git a/spinedb_api/incomplete_values.py b/spinedb_api/incomplete_values.py new file mode 100644 index 00000000..3005bf90 --- /dev/null +++ b/spinedb_api/incomplete_values.py @@ -0,0 +1,94 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +"""This module contains utilities that deal with value blobs or JSON representations.""" + +import json +from typing import Optional +from .parameter_value import RANK_1_TYPES, TABLE_TYPE, Map, from_dict, to_database, type_for_scalar +from .value_support import JSONValue, load_db_value + + +def dump_db_value(parsed_value: JSONValue) -> tuple[bytes, str]: + """ + Unparses a JSON object into a binary blob and type string. + + If the given object is a dict, extracts the "type" property from it. + + :meta private: + + Args: + parsed_value: A JSON object, typically obtained by calling :func:`load_db_value`. + + Returns: + database representation (value and type). + """ + if isinstance(parsed_value, dict): + value_type = parsed_value.pop("type") + value = from_dict(parsed_value, value_type) + return to_database(value) + if isinstance(parsed_value, list): + value_type = TABLE_TYPE + else: + value_type = type_for_scalar(parsed_value) + db_value = json.dumps(parsed_value).encode("UTF8") + return db_value, value_type + + +def from_database_to_dimension_count(database_value: bytes, value_type: Optional[str]) -> int: + """ + Counts the dimensions in a database representation of a parameter value (value and type). + + :meta private: + + Args: + database_value: the database value + value_type: the value type + + Returns: + number of dimensions + """ + if value_type in RANK_1_TYPES: + return 1 + if value_type == Map.TYPE or value_type == TABLE_TYPE: + parsed = load_db_value(database_value) + return len(parsed) - 1 + return 0 + + +def join_value_and_type(db_value: bytes, db_type: Optional[str]) -> str: + """Joins value blob and type into list and dumps it into JSON string. + + Args: + db_value: database value + db_type: value type + + Returns: + JSON string. + """ + return json.dumps([db_value.decode(), db_type]) + + +def split_value_and_type(value_and_type: str) -> tuple[bytes, str]: + """Splits the given JSON string into value blob and type. + + Args: + value_and_type: a string joining value and type, as obtained by calling :func:`join_value_and_type`. + + Returns: + value blob and type. + """ + parsed = json.loads(value_and_type) + if isinstance(parsed, dict): + # legacy + value_dict = json.loads(value_and_type) + return to_database(from_dict(value_dict, value_dict["type"])) + return parsed[0].encode(), parsed[1] diff --git a/spinedb_api/mapped_items.py b/spinedb_api/mapped_items.py index f9489426..e4106097 100644 --- a/spinedb_api/mapped_items.py +++ b/spinedb_api/mapped_items.py @@ -10,7 +10,7 @@ # this program. If not, see . ###################################################################################################################### from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Iterable, Iterator from contextlib import suppress import inspect from operator import itemgetter @@ -668,14 +668,11 @@ def __getitem__(self, key): return self._arrow_value return super().__getitem__(key) - def merge(self, other): - merged, updated_fields = super().merge(other) - if not merged: - return merged, updated_fields - if self.value_key in merged: + def update(self, other): + super().update(other) + if self.value_key in other: self._parsed_value = None self._arrow_value = None - return merged, updated_fields def _strip_equal_fields(self, other): undefined = object() @@ -1116,7 +1113,7 @@ class ListValueItem(ParsedValueBase): "parsed_value": {"type": ParameterValue, "value": "The value.", "optional": True}, "index": {"type": int, "value": "The value index.", "optional": True}, } - unique_keys = (("parameter_value_list_name", "value_and_type"), ("parameter_value_list_name", "index")) + unique_keys = (("parameter_value_list_name", "parsed_value", "type"), ("parameter_value_list_name", "index")) required_key_combinations = ( ("parameter_value_list_name", "parameter_value_list_id"), ( @@ -1136,7 +1133,7 @@ class ListValueItem(ParsedValueBase): def __getitem__(self, key): if key == "value_and_type": - return (self["value"], self["type"]) + return (super().__getitem__("value"), super().__getitem__("key")) return super().__getitem__(key) diff --git a/spinedb_api/models.py b/spinedb_api/models.py new file mode 100755 index 00000000..4312169a --- /dev/null +++ b/spinedb_api/models.py @@ -0,0 +1,392 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### + +"""Write JSON schema for JSON blob in SpineDB""" + +from datetime import datetime, timedelta +from types import NoneType +from typing import Annotated, Literal, TypeAlias, TypedDict +from dateutil.relativedelta import relativedelta +import numpy as np +import pandas as pd +from pydantic import ( + BeforeValidator, + Field, + PlainSerializer, + PlainValidator, + RootModel, + TypeAdapter, + WithJsonSchema, + model_validator, +) +from pydantic.dataclasses import dataclass +from typing_extensions import NotRequired, Self +from .compat.converters import to_duration, to_relativedelta +from .helpers import FormatMetadata, TimeSeriesMetadata + + +def from_timestamp(ts: str | pd.Timestamp | datetime) -> datetime: + match ts: + # NOTE: subtype of datetime, has to be before + case pd.Timestamp(): + return ts.to_pydatetime() + case datetime(): + return ts + case str(): + return datetime.fromisoformat(ts) + case _: + raise ValueError(f"{ts}: could not coerce to `datetime`") + + +def validate_relativedelta(value: str | pd.DateOffset | timedelta | relativedelta) -> relativedelta: + match value: + case relativedelta(): + return value + case str() | pd.DateOffset() | timedelta(): + return to_relativedelta(value) + case _: + raise ValueError(f"{value}: cannot coerce `{type(value)}` to `relativedelta`") + + +# types +class TimePeriod(str): + """Wrapper type necessary for data migration. + + This is necessary to discriminate from regular strings during + during DB migration. In the future if the migration script + doesn't need to be supported, this type can be removed, and the + `TimePeriod_` annotation below can just use `str`. Something like + this: + + .. sourcecode:: python + + TimePeriod_: TypeAlias = Annotated[ + str, + WithJsonSchema( + {"type": "string", "format": "time_period"}, + mode="serialization" + ), + ] + + """ + + def __init__(self, value) -> None: + if not isinstance(value, str): + raise ValueError(f"{type(value)}: non-string values cannot be a TimePeriod") + super().__init__() + + +# annotations for validation +Datetime: TypeAlias = Annotated[datetime, BeforeValidator(from_timestamp)] +RelativeDelta: TypeAlias = Annotated[ + relativedelta, + PlainValidator(validate_relativedelta), + PlainSerializer(to_duration, when_used="json"), + WithJsonSchema({"type": "string", "format": "duration"}, mode="serialization"), +] +TimePeriod_: TypeAlias = Annotated[ + TimePeriod, + PlainValidator(TimePeriod), + PlainSerializer(str), + WithJsonSchema({"type": "string", "format": "time_period"}, mode="serialization"), +] +Metadata = FormatMetadata | TimeSeriesMetadata + +# non-nullable arrays +Floats: TypeAlias = list[float] +Integers: TypeAlias = list[int] +Strings: TypeAlias = list[str] +Booleans: TypeAlias = list[bool] +Datetimes: TypeAlias = list[Datetime] +Durations: TypeAlias = list[RelativeDelta] +TimePeriods_: TypeAlias = list[TimePeriod_] + +# nullable variant of arrays +NullableIntegers: TypeAlias = list[int | None] +NullableFloats: TypeAlias = list[float | None] +NullableStrings: TypeAlias = list[str | None] +NullableBooleans: TypeAlias = list[bool | None] +NullableDatetimes: TypeAlias = list[Datetime | None] +NullableDurations: TypeAlias = list[RelativeDelta | None] +NullableTimePeriods_: TypeAlias = list[TimePeriod_ | None] + +# sets of types used to define array schemas below +NullableTypes: TypeAlias = ( + NullableIntegers + | NullableFloats + | NullableStrings + | NullableBooleans + | NullableDatetimes + | NullableDurations + | NullableTimePeriods_ +) + +# names of types used in the schema +NullTypeName: TypeAlias = Literal["null"] +TypeNames: TypeAlias = Literal["int", "float", "str", "bool", "date_time", "duration", "time_period"] + +typename_map: dict[NullTypeName | TypeNames, type] = { + "str": str, + "int": int, + "float": float, + "bool": bool, + "date_time": Datetime, + "duration": RelativeDelta, + "time_period": TimePeriod, + "null": NoneType, +} +type_map: dict[type, TypeNames] = { + str: "str", + int: "int", + np.int8: "int", + np.int16: "int", + np.int32: "int", + np.int64: "int", + float: "float", + np.float16: "float", + np.float32: "float", + np.float64: "float", + # np.float128: "float", # not available on macos + bool: "bool", + np.bool: "bool", + datetime: "date_time", + pd.Timestamp: "date_time", + timedelta: "duration", + pd.Timedelta: "duration", + relativedelta: "duration", + pd.DateOffset: "duration", +} + + +class TypeAdapterMixin: + @model_validator(mode="after") + def convert_to_final_type(self) -> Self: + value_type = getattr(self, "value_type") + values = getattr(self, "values") + if value_type in ("date_time", "duration", "time_period"): + adapter = TypeAdapter(typename_map[value_type]) + for i in range(len(values)): + value = values[i] + values[i] = adapter.validate_python(values[i]) if value is not None else None + return self + + +@dataclass(frozen=True) +class RunLengthIndex(TypeAdapterMixin): + """Run length encoded array + + NOTE: this is not supported by PyArrow, if we use it, we will have + to convert to a supported format. + + """ + + name: str + run_len: Integers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["run_length_index"] = "run_length_index" + + +@dataclass(frozen=True) +class RunLengthArray(TypeAdapterMixin): + """Run length encoded array + + NOTE: this is not supported by PyArrow, if we use it, we will have + to convert to a supported format. + + """ + + name: str + run_len: Integers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["run_length_array"] = "run_length_array" + + +@dataclass(frozen=True) +class RunEndIndex(TypeAdapterMixin): + """Run end encoded array""" + + name: str + run_end: Integers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["run_end_index"] = "run_end_index" + + +@dataclass(frozen=True) +class RunEndArray(TypeAdapterMixin): + """Run end encoded array""" + + name: str + run_end: Integers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["run_end_array"] = "run_end_array" + + +@dataclass(frozen=True) +class DictEncodedIndex(TypeAdapterMixin): + """Dictionary encoded array""" + + name: str + indices: Integers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["dict_encoded_index"] = "dict_encoded_index" + + +@dataclass(frozen=True) +class DictEncodedArray(TypeAdapterMixin): + """Dictionary encoded array""" + + name: str + indices: NullableIntegers + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["dict_encoded_array"] = "dict_encoded_array" + + +@dataclass(frozen=True) +class ArrayIndex(TypeAdapterMixin): + """Any array that is an index, e.g. a sequence, timestamps, labels""" + + name: str + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["array_index"] = "array_index" + + @model_validator(mode="after") + def convert_to_final_type(self) -> Self: + if self.value_type in ("date_time", "duration", "time_period"): + adapter = TypeAdapter(typename_map[self.value_type]) + for i in range(len(self.values)): + value = self.values[i] + self.values[i] = adapter.validate_python(value) if value is not None else None + return self + + +@dataclass(frozen=True) +class Array(TypeAdapterMixin): + """Array""" + + name: str + values: NullableTypes + value_type: TypeNames + metadata: Metadata | None = None + type: Literal["array"] = "array" + + +# NOTE: anyarray excludes "time_period" +AnyType: TypeAlias = str | int | float | bool | RelativeDelta | Datetime +NullableAnyTypes: TypeAlias = list[AnyType | None] +AnyTypeNames: TypeAlias = Literal["int", "float", "str", "bool", "date_time", "duration"] + + +@dataclass(frozen=True) +class AnyArray: + """Array with mixed types""" + + name: str + values: NullableAnyTypes + value_types: list[AnyTypeNames | NullTypeName] + metadata: Metadata | None = None + type: Literal["any_array"] = "any_array" + + @model_validator(mode="after") + def convert_to_final_type(self) -> Self: + if len(self.values) != len(self.value_types): + raise ValueError("mismatching values and value_types") + + for i in range(len(self.values)): + val = self.values[i] + typ = self.value_types[i] + self.values[i] = TypeAdapter(typename_map[typ]).validate_python(val) + return self + + +# NOTE: To add run-length encoding to the schema, add it to the +# following type union following which, we need to implement a +# converter to a compatible pyarrow array type +AllArrays: TypeAlias = RunEndIndex | DictEncodedIndex | ArrayIndex | RunEndArray | DictEncodedArray | Array | AnyArray +Table: TypeAlias = list[Annotated[AllArrays, Field(discriminator="type")]] + + +def from_json(json_str: str, type_: type[Table | AllArrays] = Table): + """Generic wrapper for JSON parsing.""" + return TypeAdapter(type_).validate_json(json_str) + + +def from_dict(value: dict, type_: type[Table | AllArrays] = Table): + """Generic wrapper for converting from a dictionary.""" + return TypeAdapter(type_).validate_python(value) + + +def to_json(obj: Table | AllArrays) -> str: + """Generic wrapper to serialise to JSON.""" + # FIXME: check why the equivalent: TypeAdapter(obj).dump_json() isn't working + return RootModel[type(obj)](obj).model_dump_json() + + +class ArrayAsDict(TypedDict): + name: str + type: str + values: list + value_type: NotRequired[TypeNames] + value_types: NotRequired[list[AnyTypeNames | NullTypeName]] + metadata: NotRequired[Metadata | None] + indices: NotRequired[list] + run_end: NotRequired[Integers] + run_len: NotRequired[Integers] + + +def dict_to_array(data: ArrayAsDict) -> AllArrays: + """Wrapper to read structured dictionary as an array.""" + match data["type"]: + case "array": + type_ = Array + case "array_index": + type_ = ArrayIndex + case "dict_encoded_array": + type_ = DictEncodedArray + case "dict_encoded_index": + type_ = DictEncodedIndex + case "run_end_array": + type_ = RunEndArray + case "run_end_index": + type_ = RunEndIndex + case "any_array": + type_ = AnyArray + case _: + raise ValueError(f"{data['type']}: unknown array type") + + return TypeAdapter(type_).validate_python(data) + + +if __name__ == "__main__": + from argparse import ArgumentParser + import json + from pathlib import Path + + parser = ArgumentParser(__doc__) + parser.add_argument("json_file", help="Path of JSON schema file to write") + opts = parser.parse_args() + + schema = TypeAdapter(Table).json_schema(mode="serialization") + Path(opts.json_file).write_text(json.dumps(schema)) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 1bfccaa6..51b403e5 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -84,17 +84,33 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence from copy import copy +from dataclasses import dataclass from datetime import datetime from itertools import takewhile import json from json.decoder import JSONDecodeError import re -from typing import Any, Literal, Optional, SupportsFloat, Type, TypeAlias, Union +from typing import Any, ClassVar, Counter, Literal, Optional, SupportsFloat, Type, TypeAlias, Union import dateutil.parser from dateutil.relativedelta import relativedelta import numpy as np import numpy.typing as nptyping -from .exception import ParameterValueFormatError +import pyarrow +from pyarrow import compute +from .arrow_value import ( + load_field_metadata, + to_record_batch, + with_column_as_time_period, + with_column_as_time_stamps, + with_field_metadata, +) +from .compat.converters import parse_duration, to_duration +from .exception import ParameterValueFormatError, SpineDBAPIError +from .helpers import TimeSeriesMetadata, time_series_metadata +from .models import NullTypeName, TypeNames +from .value_support import JSONValue, load_db_value, to_union_array, validate_time_period + +TABLE_TYPE: Literal["table"] = "table" # Defaulting to seconds precision in numpy. NUMPY_DATETIME_DTYPE = "datetime64[s]" @@ -106,6 +122,7 @@ # Default unit if resolution is given as a number instead of a string. _TIME_SERIES_PLAIN_INDEX_UNIT = "m" FLOAT_VALUE_TYPE = "float" +INT_VALUE_TYPE = "int" BOOLEAN_VALUE_TYPE = "bool" STRING_VALUE_TYPE = "str" @@ -116,7 +133,7 @@ ] -def from_database(value: bytes, type_: Optional[str]) -> Optional[Value]: +def from_database(value: bytes, type_: Optional[str]) -> Optional[Value | ParameterValue]: """ Converts a parameter value from the DB into a Python object. @@ -127,17 +144,30 @@ def from_database(value: bytes, type_: Optional[str]) -> Optional[Value]: Returns: A Python object representing the value. """ - parsed = load_db_value(value, type_) + parsed = load_db_value(value) if isinstance(parsed, dict): - return from_dict(parsed) - if isinstance(parsed, bool): - return parsed - if isinstance(parsed, SupportsFloat): + return from_dict(parsed, type_) + if type_ == TABLE_TYPE: + return to_record_batch(parsed) + if type_ == DateTime.TYPE: + return DateTime(parsed) + if type_ == Duration.TYPE: + return Duration(parse_duration(parsed)) + if type_ == Array.TYPE: + return Array.from_arrow(to_record_batch(parsed)) + if type_ == TimePattern.TYPE: + return TimePattern.from_arrow(to_record_batch(parsed)) + if type_ == TimeSeries.TYPE: + return TimeSeriesVariableResolution.from_arrow(to_record_batch(parsed)) + if type_ == Map.TYPE: + return Map.from_arrow(to_record_batch(parsed)) + if isinstance(parsed, int) and type_ == FLOAT_VALUE_TYPE: + # json.dumps() writes floats without decimals as ints. return float(parsed) return parsed -def to_database(parsed_value: Optional[Value]) -> tuple[bytes, Optional[str]]: +def to_database(parsed_value: Optional[Value | ParameterValue | dict]) -> tuple[bytes, Optional[str]]: """ Converts a Python object representing a parameter value into its DB representation. @@ -147,10 +177,23 @@ def to_database(parsed_value: Optional[Value]) -> tuple[bytes, Optional[str]]: Returns: The value as a binary blob and its type string. """ - if hasattr(parsed_value, "to_database"): - return parsed_value.to_database() + if isinstance(parsed_value, Duration): + return json.dumps(to_duration(parsed_value.value)).encode("UTF8"), parsed_value.TYPE + if isinstance(parsed_value, relativedelta): + return json.dumps(to_duration(parsed_value)).encode("UTF8"), Duration.TYPE + if isinstance(parsed_value, DateTime): + return json.dumps(parsed_value.value.isoformat()).encode("UTF8"), parsed_value.TYPE + if isinstance(parsed_value, datetime): + return json.dumps(parsed_value.isoformat()).encode("UTF8"), DateTime.TYPE + if isinstance(parsed_value, IndexedValue): + return json.dumps(to_list(parsed_value.as_arrow())).encode("UTF8"), parsed_value.TYPE + if isinstance(parsed_value, pyarrow.RecordBatch): + return json.dumps(to_list(parsed_value)).encode(), TABLE_TYPE + if isinstance(parsed_value, dict): + db_type = parsed_value.pop("type") + else: + db_type = type_for_scalar(parsed_value) db_value = json.dumps(parsed_value).encode("UTF8") - db_type = type_for_scalar(parsed_value) return db_value, db_type @@ -227,54 +270,6 @@ def relativedelta_to_duration(delta: relativedelta) -> str: return "0h" -JSONValue = Union[bool, float, str, dict] - - -def load_db_value(db_value: bytes, type_: Optional[str]) -> Optional[JSONValue]: - """ - Parses a binary blob into a JSON object. - - If the result is a dict, adds the "type" property to it. - - :meta private: - - Args: - db_value: The binary blob. - type_: The value type. - - Returns: - The parsed parameter value. - """ - if db_value is None: - return None - try: - parsed = json.loads(db_value) - except JSONDecodeError as err: - raise ParameterValueFormatError(f"Could not decode the value: {err}") from err - if isinstance(parsed, dict): - parsed["type"] = type_ - return parsed - - -def dump_db_value(parsed_value: JSONValue) -> tuple[bytes, str]: - """ - Unparses a JSON object into a binary blob and type string. - - If the given object is a dict, extracts the "type" property from it. - - :meta private: - - Args: - parsed_value: A JSON object, typically obtained by calling :func:`load_db_value`. - - Returns: - database representation (value and type). - """ - value_type = parsed_value["type"] if isinstance(parsed_value, dict) else type_for_scalar(parsed_value) - db_value = json.dumps(parsed_value).encode("UTF8") - return db_value, value_type - - def from_database_to_single_value(database_value: bytes, value_type: Optional[str]) -> Union[str, Optional[Value]]: """ Same as :func:`from_database`, but in the case of indexed types returns just the type as a string. @@ -293,43 +288,19 @@ def from_database_to_single_value(database_value: bytes, value_type: Optional[st return value_type -def from_database_to_dimension_count(database_value: bytes, value_type: Optional[str]) -> int: - """ - Counts the dimensions in a database representation of a parameter value (value and type). - - :meta private: - - Args: - database_value: the database value - value_type: the value type - - Returns: - number of dimensions - """ - if value_type in RANK_1_TYPES: - return 1 - if value_type == Map.TYPE: - parsed = load_db_value(database_value, value_type) - if "rank" in parsed: - return parsed["rank"] - map_value = from_dict(parsed) - return map_dimensions(map_value) - return 0 - - -def from_dict(value: dict) -> Optional[Value]: +def from_dict(value: dict, value_type: str) -> Optional[Value]: """ Converts a dictionary representation of a parameter value into an encoded parameter value. :meta private: Args: - value: the value dictionary including the "type" key. + value: the value dictionary. + value_type: value's type Returns: the encoded parameter value. """ - value_type = value["type"] try: if value_type == DateTime.TYPE: return _datetime_from_database(value["data"]) @@ -696,7 +667,7 @@ def _map_values_from_database(values_in_db: Iterable[Optional[Union[bool, float, return [] values = [] for value_in_db in values_in_db: - value = from_dict(value_in_db) if isinstance(value_in_db, dict) else value_in_db + value = from_dict(value_in_db, value_in_db["type"]) if isinstance(value_in_db, dict) else value_in_db if isinstance(value, int): value = float(value) elif value is not None and not isinstance(value, (float, bool, Duration, IndexedValue, str, DateTime)): @@ -761,6 +732,15 @@ def to_database(self) -> tuple[bytes, str]: """ return json.dumps(self.to_dict()).encode("UTF8"), self.TYPE + def as_arrow(self) -> Any: + """Returns an Apache Arrow compatible representation of the value.""" + raise NotImplementedError() + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch | datetime | relativedelta) -> ParameterValue: + """Converts an Arrow value to legacy parameter value.""" + return cls(arrow_value) + class DateTime(ParameterValue): """A parameter value of type 'date_time'. A point in time.""" @@ -768,7 +748,7 @@ class DateTime(ParameterValue): VALUE_TYPE = "single value" TYPE = "date_time" - def __init__(self, value: Optional[Union[str, DateTime, datetime]] = None): + def __init__(self, value: Optional[Union[str, DateTime, datetime, pyarrow.TimestampScalar]] = None): """ Args: The `date_time` value. @@ -785,6 +765,8 @@ def __init__(self, value: Optional[Union[str, DateTime, datetime]] = None): raise ParameterValueFormatError(f'Could not parse datetime from "{value}"') from error elif isinstance(value, DateTime): value = copy(value._value) + elif isinstance(value, pyarrow.TimestampScalar): + value = value.as_py() elif not isinstance(value, datetime): raise ParameterValueFormatError(f'"{type(value).__name__}" cannot be converted to DateTime.') self._value = value @@ -819,6 +801,9 @@ def to_dict(self) -> dict: def value(self) -> datetime: return self._value + def as_arrow(self) -> datetime: + return self._value + class Duration(ParameterValue): """ @@ -828,7 +813,7 @@ class Duration(ParameterValue): VALUE_TYPE = "single value" TYPE = "duration" - def __init__(self, value: Optional[Union[str, relativedelta, Duration]] = None): + def __init__(self, value: Optional[Union[str, relativedelta, Duration, pyarrow.MonthDayNanoIntervalScalar]] = None): """ Args: value: the `duration` value. @@ -841,6 +826,9 @@ def __init__(self, value: Optional[Union[str, relativedelta, Duration]] = None): value = copy(value._value) elif isinstance(value, relativedelta): value = value.normalized() + elif isinstance(value, pyarrow.MonthDayNanoIntervalScalar): + months, days, nanoseconds = value.as_py() + value = relativedelta(months=months, days=days, microseconds=nanoseconds // 1000) else: raise ParameterValueFormatError(f'Could not parse duration from "{value}"') self._value = value @@ -870,8 +858,11 @@ def to_dict(self) -> dict: def value(self) -> relativedelta: return self._value + def as_arrow(self) -> relativedelta: + return self._value + -ScalarValue = Union[bool, float, str, DateTime, Duration] +ScalarValue = Union[bool, float, str, datetime, relativedelta] class _Indexes(np.ndarray): @@ -1025,7 +1016,7 @@ def _merge(value, other): return self -Value = Union[ScalarValue, IndexedValue] +Value = Union[ScalarValue, pyarrow.RecordBatch] class Array(IndexedValue): @@ -1035,7 +1026,7 @@ class Array(IndexedValue): TYPE = "array" DEFAULT_INDEX_NAME = "i" - def __init__(self, values: Sequence[ScalarValue], value_type: Optional[Type] = None, index_name: str = ""): + def __init__(self, values: Sequence[ArrayValue], value_type: Optional[Type] = None, index_name: str = ""): """ Args: values: the array values. @@ -1063,9 +1054,17 @@ def __eq__(self, other): if not isinstance(other, Array): return NotImplemented try: - return np.array_equal(self._values, other._values, equal_nan=True) and self.index_name == other.index_name + return ( + (self._values or self._value_type == other._value_type) + and np.array_equal(self._values, other._values, equal_nan=True) + and self.index_name == other.index_name + ) except TypeError: - return np.array_equal(self._values, other._values) and self.index_name == other.index_name + return ( + (self._values or self._value_type == other._value_type) + and np.array_equal(self._values, other._values) + and self.index_name == other.index_name + ) def to_dict(self): try: @@ -1086,6 +1085,30 @@ def to_dict(self): value_dict["index_name"] = self.index_name return value_dict + def as_arrow(self) -> pyarrow.RecordBatch: + """Returns an Apache Arrow compatible representation of the value.""" + indexes = pyarrow.array(range(1, len(self._values) + 1), type=pyarrow.int64()) + array = pyarrow.array( + self._values if not issubclass(self._value_type, ParameterValue) else [x.as_arrow() for x in self._values] + ) + return pyarrow.record_batch({self.index_name: indexes, "value": array}) + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch) -> Array: + value_column_index = arrow_value.schema.get_field_index("value") + value_column = arrow_value[value_column_index] + index_name = arrow_value.column_names[0] + if len(value_column) == 0: + return cls([], value_type=_arrow_type_to_python(value_column.type), index_name=index_name) + if pyarrow.types.is_interval(arrow_value.schema.types[value_column_index]): + return cls(list(Duration(d) for d in value_column), index_name=index_name) + if pyarrow.types.is_timestamp(arrow_value.schema.types[value_column_index]): + return cls(list(DateTime(t) for t in value_column), index_name=index_name) + return cls(value_column.to_pylist(), index_name=index_name) + + +ArrayValue: TypeAlias = float | bool | str | Duration | DateTime + class _TimePatternIndexes(_Indexes): """An array of *checked* time pattern indexes.""" @@ -1104,35 +1127,7 @@ def _check_index(union_str: str) -> None: if not union_str: # We accept empty strings so we can add empty rows in the parameter value editor UI return - union_dlm = "," - intersection_dlm = ";" - range_dlm = "-" - regexp = r"(Y|M|D|WD|h|m|s)" - for intersection_str in union_str.split(union_dlm): - for interval_str in intersection_str.split(intersection_dlm): - m = re.match(regexp, interval_str) - if m is None: - raise ParameterValueFormatError( - f"Invalid interval {interval_str}, it should start with either Y, M, D, WD, h, m, or s." - ) - key = m.group(0) - lower_upper_str = interval_str[len(key) :] - lower_upper = lower_upper_str.split(range_dlm) - if len(lower_upper) != 2: - raise ParameterValueFormatError( - f"Invalid interval bounds {lower_upper_str}, it should be two integers separated by dash (-)." - ) - lower_str, upper_str = lower_upper - try: - lower = int(lower_str) - except Exception as error: - raise ParameterValueFormatError(f"Invalid lower bound {lower_str}, must be an integer.") from error - try: - upper = int(upper_str) - except Exception as error: - raise ParameterValueFormatError(f"Invalid upper bound {upper_str}, must be an integer.") from error - if lower > upper: - raise ParameterValueFormatError(f"Lower bound {lower} can't be higher than upper bound {upper}.") + validate_time_period(union_str) def __array_finalize__(self, obj): """Checks indexes when building the array.""" @@ -1192,6 +1187,31 @@ def to_dict(self): value_dict["index_name"] = self.index_name return value_dict + def as_arrow(self) -> pyarrow.RecordBatch: + values = pyarrow.array(self._values) + indexes = pyarrow.array(self._indexes) + record_batch = pyarrow.record_batch({self.index_name: indexes, "value": values}) + return with_column_as_time_period(record_batch, 0) + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch) -> TimePattern: + value_column_index = arrow_value.schema.get_field_index("value") + for column_i in range(arrow_value.num_columns - 1): + try: + column_metadata = load_field_metadata(arrow_value.field(column_i)) + except json.decoder.JSONDecodeError: + continue + if column_metadata is not None and column_metadata.get("format") == "time_period": + index_column_index = column_i + break + else: + raise RuntimeError("value's field metadata doesn't indicate that it contains a time period column") + return cls( + arrow_value[index_column_index].to_pylist(), + arrow_value[value_column_index], + index_name=arrow_value.column_names[index_column_index], + ) + class TimeSeries(IndexedValue): """Base for all classes representing 'time_series' parameter values.""" @@ -1270,7 +1290,9 @@ class TimeSeriesFixedResolution(TimeSeries): other than having getters for their values. """ - _memoized_indexes: dict[tuple[np.datetime64, tuple[relativedelta, ...], int], nptyping.NDArray[np.datetime64]] = {} + _memoized_indexes: ClassVar[ + dict[tuple[np.datetime64, tuple[relativedelta, ...], int], nptyping.NDArray[np.datetime64]] + ] = {} def __init__( self, @@ -1405,6 +1427,76 @@ def set_value(self, index: np.datetime64, value: float) -> None: if pos is not None: self._values[pos] = value + def as_arrow(self) -> pyarrow.RecordBatch: + values = pyarrow.array(self._values) + indexes = pyarrow.array(self.indexes, type=pyarrow.timestamp(NUMPY_DATETIME64_UNIT)) + record_batch = pyarrow.record_batch({self.index_name: indexes, "value": values}) + return with_column_as_time_stamps(record_batch, 0, ignore_year=self._ignore_year, repeat=self._repeat) + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch) -> TimeSeriesFixedResolution: + time_stamp_column_name, ignore_year, repeat = _find_column_with_time_series_metadata(arrow_value) + time_stamp_column = arrow_value[time_stamp_column_name] + intervals = compute.month_day_nano_interval_between( + time_stamp_column.slice(length=len(time_stamp_column) - 1), time_stamp_column.slice(offset=1) + ) + resolution = _resolve_resolution(intervals) + return cls( + time_stamp_column[0].as_py(), + resolution, + arrow_value["value"], + ignore_year=ignore_year, + repeat=repeat, + index_name=time_stamp_column_name, + ) + + +def _resolve_resolution(intervals: pyarrow.MonthDayNanoIntervalArray) -> relativedelta | list[relativedelta]: + if intervals.value_counts() == 1: + months, days, nanoseconds = intervals[0].as_py() + return relativedelta(months=months, days=days, microseconds=nanoseconds // 1000) + interval_length = len(intervals) + pattern_length = 2 + pattern = intervals.slice(0, pattern_length) + pattern_found = False + while not pattern_found: + sub_pattern_mismatch = False + for interval_i in range(pattern_length, interval_length, pattern_length): + for pattern_i in range(pattern_length): + sub_interval_i = interval_i + pattern_i + if sub_interval_i == interval_length: + break + if intervals[interval_i + pattern_i] != pattern[pattern_i]: + sub_pattern_mismatch = True + break + if sub_pattern_mismatch: + pattern_length += 1 + pattern = intervals.slice(0, pattern_length) + break + else: + pattern_found = True + resolution = [] + for interval in pattern: + months, days, nanoseconds = interval.as_py() + resolution.append(relativedelta(months=months, days=days, microseconds=nanoseconds // 1000)) + return resolution + + +def _find_column_with_time_series_metadata(arrow_value: pyarrow.RecordBatch) -> tuple[str, bool, bool]: + for column_i in range(arrow_value.num_columns - 1): + try: + column_metadata = load_field_metadata(arrow_value.field(column_i)) + except json.decoder.JSONDecodeError: + continue + if column_metadata and "ignore_year" in column_metadata and "repeat" in column_metadata: + time_stamp_column_name = arrow_value.column_names[column_i] + ignore_year = column_metadata["ignore_year"] + repeat = column_metadata["repeat"] + break + else: + raise RuntimeError("value's field metadata doesn't indicate that it contains a time stamp column") + return time_stamp_column_name, ignore_year, repeat + class TimeSeriesVariableResolution(TimeSeries): """A parameter value of type 'time_series'. @@ -1468,6 +1560,23 @@ def to_dict(self): value_dict["index_name"] = self.index_name return value_dict + def as_arrow(self) -> pyarrow.RecordBatch: + values = pyarrow.array(self._values) + indexes = pyarrow.array(self._indexes, type=pyarrow.timestamp(NUMPY_DATETIME64_UNIT)) + record_batch = pyarrow.record_batch({self.index_name: indexes, "value": values}) + return with_column_as_time_stamps(record_batch, 0, ignore_year=self._ignore_year, repeat=self._repeat) + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch) -> TimeSeriesVariableResolution: + time_stamp_column_name, ignore_year, repeat = _find_column_with_time_series_metadata(arrow_value) + return cls( + arrow_value[time_stamp_column_name].to_pylist(), + arrow_value["value"], + ignore_year=ignore_year, + repeat=repeat, + index_name=time_stamp_column_name, + ) + class Map(IndexedValue): """A parameter value of type 'map'. A mapping from key to value, where the values can be other instances @@ -1481,7 +1590,7 @@ class Map(IndexedValue): def __init__( self, indexes: Sequence[MapIndex], - values: Sequence[Value], + values: Sequence[MapValue], index_type: Optional[Type] = None, index_name: str = "", ): @@ -1538,9 +1647,116 @@ def to_dict(self): value_dict["index_name"] = self.index_name return value_dict + def as_arrow(self) -> pyarrow.RecordBatch: + return map_as_arrow(self) + + @classmethod + def from_arrow(cls, arrow_value: pyarrow.RecordBatch) -> Map: + if arrow_value.num_rows == 0: + return Map( + [], + [], + index_type=_arrow_type_to_python(arrow_value.column(0).type), + index_name=arrow_value.column_names[0], + ) + map_value = _map_from_arrow(arrow_value) + for column_i in range(arrow_value.num_columns - 1): + try: + metadata = load_field_metadata(arrow_value.field(column_i)) + except JSONDecodeError: + continue + if metadata and "ignore_year" in metadata and "repeat" in metadata: + map_value = convert_leaf_maps_to_specialized_containers(map_value, time_series_kwargs=metadata) + break + for column_i in range(arrow_value.num_columns - 1): + if pyarrow.types.is_integer(arrow_value.column(column_i).type): + map_value = convert_leaf_maps_to_specialized_containers(map_value) + break + return map_value + + +def _map_from_arrow(arrow_value: pyarrow.RecordBatch) -> Map: + values_by_index_path = [] + value_column = [_python_value_as_spine(x.as_py()) for x in arrow_value.column(arrow_value.num_columns - 1)] + for row_i in range(arrow_value.num_rows): + index_path = [] + for column_i in range(arrow_value.num_columns - 1): + y = arrow_value.column(column_i)[row_i] + if y.is_valid: + index_path.append((arrow_value.column_names[column_i], _python_value_as_spine(y.as_py()))) + values_by_index_path.append((index_path, value_column[row_i])) + root_dict = {} + for index_path, value in values_by_index_path: + root_index = index_path[0] + root_dict.setdefault(root_index, []).append(value if len(index_path) == 1 else (index_path[1:], value)) + collected_dict = _crawl_map_dict_to_collect_all_dimensions(root_dict) + return _map_dict_to_map(collected_dict) + + +def _crawl_map_dict_to_collect_all_dimensions(root_dict: dict) -> dict: + root_dict = _collect_nested_dimensions(root_dict) + for index, values in root_dict.items(): + for i, value in enumerate(values): + if isinstance(value, dict): + new_value = _crawl_map_dict_to_collect_all_dimensions(value) + values[i] = new_value + return root_dict + + +def _collect_nested_dimensions(root_dict: dict) -> dict: + new_root_dict = {} + for index, values in root_dict.items(): + for value in values: + if not isinstance(value, tuple): + new_root_dict.setdefault(index, []).append(value) + continue + index_path, value = value + nested_index = index_path[0] + existing_values = new_root_dict.setdefault(index, []) + for existing_value in existing_values: + if isinstance(existing_value, dict): + existing_index = next(iter(existing_value)) + if existing_index[0] == nested_index[0]: + nested_dict = existing_value + break + else: + nested_dict = {} + existing_values.append(nested_dict) + nested_dict.setdefault(nested_index, []).append(value if len(index_path) == 1 else (index_path[1:], value)) + return new_root_dict + + +def _map_dict_to_map(map_dict: dict) -> Map: + map_indexes = [] + map_values = [] + index_name = None + for (current_index_name, index), values in map_dict.items(): + if index_name is None: + index_name = current_index_name + if current_index_name != index_name: + raise RuntimeError("logic error: index name must not change") + for value in values: + map_indexes.append(index) + if isinstance(value, dict): + map_values.append(_map_dict_to_map(value)) + else: + map_values.append(value) + return Map(map_indexes, map_values, index_name=index_name) + + +def _python_value_as_spine(x: Any) -> Any: + match x: + case pyarrow.lib.MonthDayNano(): + return Duration(relativedelta(months=x.months, days=x.days, microseconds=x.nanoseconds // 1000)) + case datetime(): + return DateTime(x) + case _: + return x -MapIndex = Union[float, str, DateTime, Duration] -MapIndexType = Union[Type[float], Type[str], Type[DateTime], Type[Duration]] + +MapValue: TypeAlias = float | str | bool | Duration | IndexedValue +MapIndex: TypeAlias = Union[float, str, DateTime, Duration] +MapIndexType: TypeAlias = Union[Type[float], Type[str], Type[DateTime], Type[Duration]] _MAP_INDEX_TYPES = { STRING_VALUE_TYPE: str, DateTime.TYPE: DateTime, @@ -1569,7 +1785,9 @@ def map_dimensions(map_: Map) -> int: return 1 + nested -def convert_leaf_maps_to_specialized_containers(map_: Map) -> IndexedValue: +def convert_leaf_maps_to_specialized_containers( + map_: Map, time_series_kwargs: Optional[dict[str, Any]] = None +) -> IndexedValue: """ Converts leafs to specialized containers. @@ -1582,17 +1800,21 @@ def convert_leaf_maps_to_specialized_containers(map_: Map) -> IndexedValue: Args: map_: a map to process. + time_series_kwargs: additional keyword arguments passed to time series constructor Returns: a new map with leaves converted. """ - converted_container = _try_convert_to_container(map_) + converted_container = _try_convert_to_time_series(map_, time_series_kwargs) + if converted_container is not None: + return converted_container + converted_container = _try_convert_to_array(map_) if converted_container is not None: return converted_container new_values = [] for _, value in zip(map_.indexes, map_.values): if isinstance(value, Map): - converted = convert_leaf_maps_to_specialized_containers(value) + converted = convert_leaf_maps_to_specialized_containers(value, time_series_kwargs) new_values.append(converted) else: new_values.append(value) @@ -1685,12 +1907,15 @@ def convert_map_to_dict(map_: Map) -> dict: return d -def _try_convert_to_container(map_: Map) -> Optional[TimeSeriesVariableResolution]: +def _try_convert_to_time_series( + map_: Map, time_series_kwargs: Optional[dict[str, Any]] = None +) -> Optional[TimeSeriesVariableResolution]: """ - Tries to convert a map to corresponding specialized container. + Tries to convert a map to time series. Args: map_: a map to convert + time_series_kwargs: keyword arguments passed to time series constructor Returns: converted Map or None if the map couldn't be converted @@ -1704,7 +1929,175 @@ def _try_convert_to_container(map_: Map) -> Optional[TimeSeriesVariableResolutio return None stamps.append(index) values.append(value) - return TimeSeriesVariableResolution(stamps, values, False, False, index_name=map_.index_name) + if time_series_kwargs is None: + time_series_kwargs = {"ignore_year": False, "repeat": False} + return TimeSeriesVariableResolution(stamps, values, index_name=map_.index_name, **time_series_kwargs) + + +def _try_convert_to_array(map_: Map) -> Array | None: + """ + Tries to convert a map to array. + + Args: + map_: a map to convert + + Returns: + converted Map or None if the map couldn't be converted + """ + if not map_: + return None + values = [] + for i, (index, value) in enumerate(zip(map_.indexes, map_.values)): + if index != i: + return None + values.append(value) + return Array(values, index_name=map_.index_name) + + +def map_as_arrow(map_: Map) -> pyarrow.RecordBatch: + header, index_rows, value_column, metadata_by_column_index = _map_as_table_for_arrow(map_) + _uniquefy_header_names(header) + return _table_to_record_batch(header, index_rows, value_column, metadata_by_column_index) + + +@dataclass +class _MapHeader: + name: str + type: Type + + +def _map_as_table_for_arrow( + map_: Map, + header: Optional[list[_MapHeader]] = None, + index_rows: Optional[list[list]] = None, + base_index: Optional[list] = None, + value_column: Optional[list] = None, + metadata_by_column_index: Optional[dict[int, dict[str, Any]]] = None, +) -> tuple[list[_MapHeader], list[list], list, dict[int, dict[str, Any]]]: + if value_column is None: + value_column = [] + if header is None: + header = [] + if base_index is None: + base_index = [] + depth = len(base_index) + if depth == len(header): + header.append(_MapHeader(map_.index_name, map_.index_type)) + elif map_.index_type != header[depth].type: + raise SpineDBAPIError("different index types at the same depth are not supported") + if index_rows is None: + index_rows = [] + if metadata_by_column_index is None: + metadata_by_column_index = {} + for index, value in zip(map_.indexes, map_.values): + index_row = base_index + ([index.as_arrow()] if isinstance(index, ParameterValue) else [index]) + if isinstance(value, Map): + _map_as_table_for_arrow(value, header, index_rows, index_row, value_column, metadata_by_column_index) + elif isinstance(value, IndexedValue): + _unroll_nested_indexed_value(value, header, index_rows, index_row, value_column, metadata_by_column_index) + else: + value_column.append(value) + index_rows.append(index_row) + return header, index_rows, value_column, metadata_by_column_index + + +def _unroll_nested_indexed_value( + value: IndexedValue, + header: list[_MapHeader], + index_rows: list[list], + base_index: list, + value_column: list, + metadata_by_column_index: dict[int, dict[str, Any] | TimeSeriesMetadata], +) -> None: + depth = len(base_index) + if depth == len(header): + match value: + case Array(): + index_type = int + case TimeSeries(): + index_type = np.datetime64 + case _: + raise SpineDBAPIError("unsupported indexed value type") + header.append(_MapHeader(value.index_name, index_type)) + index_i = depth + else: + for i, column_header in enumerate(header): + if value.index_name == column_header.name: + index_i = i + break + else: + index_i = depth + if index_i > len(base_index): + base_index = base_index + (index_i - len(base_index)) * [None] + depth = index_i + for x, y in zip(value.indexes, value.values): + value_column.append(y) + index = x.as_arrow() if isinstance(x, ParameterValue) else x + index_row = [base_index[i] if i != index_i else index for i in range(depth + 1)] + index_rows.append(index_row) + if isinstance(value, TimeSeries) and index_i not in metadata_by_column_index: + metadata_by_column_index[index_i] = time_series_metadata(value.ignore_year, value.repeat) + + +def _table_to_record_batch( + header: list[_MapHeader], index_rows: list[list], value_column: list, metadata: dict[int, dict[str, Any]] +) -> pyarrow.RecordBatch: + max_depth = len(header) + index_columns = [[] for _ in range(max_depth)] + for index_row in index_rows: + index_row = index_row + (max_depth - len(index_row)) * [None] + for column, index in zip(index_columns, index_row): + column.append(index) + arrays = {} + for i, (h, column) in enumerate(zip(header, index_columns)): + arrow_type = _python_type_to_arrow(h.type) + if i == len(index_columns) - 1: + arrays[h.name] = pyarrow.array(column, type=arrow_type) + else: + run_values = [column[0]] + run_ends = [1] + for x in column[1:]: + if x == run_values[-1]: + run_ends[-1] = run_ends[-1] + 1 + else: + run_values.append(x) + run_ends.append(run_ends[-1] + 1) + array_type = pyarrow.run_end_encoded(pyarrow.int64(), arrow_type) + arrays[h.name] = pyarrow.RunEndEncodedArray.from_arrays(run_ends, run_values, type=array_type) + arrays["value"] = _build_value_array(value_column) + record_batch = pyarrow.record_batch(arrays) + if metadata: + for column_i, field_metadata in metadata.items(): + record_batch = with_field_metadata(field_metadata, record_batch, column_i) + return record_batch + + +def _uniquefy_header_names(header: list[_MapHeader]): + for i, h in enumerate(header): + if h.name == Map.DEFAULT_INDEX_NAME: + h.name = f"col_{i + 1}" + names = Counter() + for h in header: + names[h.name] += 1 + for name, count in names.items(): + if count != 1: + clashing_headers = [h for h in header if h.name == name] + for i, h in clashing_headers: + h.name = h.name + f" ({i + 1})" + + +def _build_value_array(value_column: list) -> pyarrow.Array: + arrow_values = [] + if not value_column: + return pyarrow.array([], type=pyarrow.float64()) + first_type = type(value_column[0]) + if all(isinstance(y, first_type) for y in value_column[1:]): + return pyarrow.array(value_column) + for y in value_column: + if isinstance(y, ParameterValue): + y = y.as_arrow() + arrow_values.append(y) + return to_union_array(arrow_values) # Value types that are supported by spinedb_api @@ -1721,7 +2114,7 @@ def _try_convert_to_container(map_: Map) -> Optional[TimeSeriesVariableResolutio } RANK_1_TYPES: set[str] = {Array.TYPE, TimePattern.TYPE, TimeSeries.TYPE} -NON_ZERO_RANK_TYPES: set[str] = RANK_1_TYPES | {Map.TYPE} +NON_ZERO_RANK_TYPES: set[str] = RANK_1_TYPES | {Map.TYPE, TABLE_TYPE} def type_and_rank_to_fancy_type(value_type: str, rank: int) -> str: @@ -1738,46 +2131,6 @@ def fancy_type_to_type_and_rank(fancy_type: str) -> tuple[str, int]: return fancy_type, 0 -def join_value_and_type(db_value: bytes, db_type: Optional[str]) -> str: - """Joins database value and type into a string. - The resulting string is a JSON string. - In case of complex types (duration, date_time, time_series, time_pattern, array, map), - the type is just added as top-level key. - - :meta private: - - Args: - db_value: database value - db_type: value type - - Returns: - parameter value as JSON with an additional ``type`` field. - """ - try: - parsed = load_db_value(db_value, db_type) - except ParameterValueFormatError: - parsed = None - return json.dumps(parsed) - - -def split_value_and_type(value_and_type: str) -> tuple[bytes, str]: - """Splits the given string into value and type. - - :meta private: - - Args: - value_and_type: a string joining value and type, as obtained by calling :func:`join_value_and_type`. - - Returns: - database value and type. - """ - try: - parsed = json.loads(value_and_type) - except (TypeError, json.JSONDecodeError): - parsed = value_and_type - return dump_db_value(parsed) - - def deep_copy_value(value: Optional[Value]) -> Optional[Value]: """Copies a value. The operation is deep meaning that nested Maps will be copied as well. @@ -1829,7 +2182,7 @@ def deep_copy_map(value: Map) -> Map: return Map(xs, ys, index_type=value.index_type, index_name=value.index_name) -def type_for_value(value: Value) -> tuple[str, int]: +def type_and_rank_for_value(value: Value) -> tuple[str, int]: """Declares value's database type and rank. Args: @@ -1838,6 +2191,8 @@ def type_for_value(value: Value) -> tuple[str, int]: Returns: type and rank """ + if isinstance(value, pyarrow.RecordBatch): + return TABLE_TYPE, value.num_columns - 1 if isinstance(value, Map): return Map.TYPE, map_dimensions(value) if isinstance(value, ParameterValue): @@ -1847,7 +2202,7 @@ def type_for_value(value: Value) -> tuple[str, int]: return type_for_scalar(value), 0 -def type_for_scalar(parsed_value: JSONValue) -> Optional[str]: +def type_for_scalar(parsed_value: bool | float | int | str | datetime | relativedelta | None) -> Optional[str]: """Declares scalar value's database type. Args: @@ -1858,15 +2213,222 @@ def type_for_scalar(parsed_value: JSONValue) -> Optional[str]: """ if parsed_value is None: return None - if isinstance(parsed_value, dict): - return parsed_value["type"] if isinstance(parsed_value, bool): return BOOLEAN_VALUE_TYPE - if isinstance(parsed_value, SupportsFloat): + if isinstance(parsed_value, float): return FLOAT_VALUE_TYPE + if isinstance(parsed_value, int): + return INT_VALUE_TYPE if isinstance(parsed_value, str): return STRING_VALUE_TYPE + if isinstance(parsed_value, datetime): + return DateTime.TYPE + if isinstance(parsed_value, relativedelta): + return Duration.TYPE raise ParameterValueFormatError(f"Values of type {type(parsed_value).__name__} not supported.") UNPARSED_NULL_VALUE: bytes = to_database(None)[0] + + +def _arrow_type_to_python(arrow_type: pyarrow.DataType) -> Type: + if ( + pyarrow.types.is_floating(arrow_type) + or pyarrow.types.is_integer(arrow_type) + or pyarrow.types.is_null(arrow_type) + ): + return float + if pyarrow.types.is_string(arrow_type): + return str + if pyarrow.types.is_boolean(arrow_type): + return bool + if pyarrow.types.is_interval(arrow_type): + return relativedelta + raise RuntimeError(f"unknown Arrow type {arrow_type}") + + +def _python_type_to_arrow(python_type: Type) -> pyarrow.DataType: + if python_type is float: + return pyarrow.float64() + if python_type is int: + return pyarrow.int64() + if python_type is str: + return pyarrow.string() + if python_type is bool: + return pyarrow.bool_() + if python_type is DateTime or python_type is np.datetime64: + return pyarrow.timestamp(NUMPY_DATETIME64_UNIT) + if python_type is Duration: + return pyarrow.month_day_nano_interval() + raise RuntimeError(f"unknown data type {python_type}") + + +def to_list(loaded_value: pyarrow.RecordBatch) -> list[dict]: + arrays = [] + for i_column, (name, column) in enumerate(zip(loaded_value.column_names, loaded_value.columns)): + is_value_column = i_column == loaded_value.num_columns - 1 + base_data = { + "name": name, + } + try: + metadata = load_field_metadata(loaded_value.field(i_column)) + except JSONDecodeError: + pass + else: + if metadata is not None: + base_data["metadata"] = metadata + match column: + case pyarrow.RunEndEncodedArray(): + arrays.append( + { + **base_data, + "type": "run_end_array" if is_value_column else "run_end_index", + "run_end": column.run_ends.to_pylist(), + "values": _array_values_to_list(column.values, column.type.value_type), + "value_type": _arrow_data_type_to_value_type(column.type.value_type), + } + ) + case pyarrow.DictionaryArray(): + arrays.append( + { + **base_data, + "type": "dict_encoded_array" if is_value_column else "dict_encoded_index", + "indices": column.indices.to_pylist(), + "values": _array_values_to_list(column.dictionary, column.type.value_type), + "value_type": _arrow_data_type_to_value_type(column.type.value_type), + } + ) + case pyarrow.UnionArray(): + if not is_value_column: + raise SpineDBAPIError("union array cannot be index") + value_list, type_list = _union_array_values_to_list(column) + arrays.append( + { + **base_data, + "type": "any_array", + "values": value_list, + "value_types": type_list, + } + ) + case pyarrow.TimestampArray(): + arrays.append( + { + **base_data, + "type": "array" if is_value_column else "array_index", + "values": [t.isoformat() if t is not None else t for t in column.to_pylist()], + "value_type": "date_time", + } + ) + case pyarrow.MonthDayNanoIntervalArray(): + arrays.append( + { + **base_data, + "type": "array" if is_value_column else "array_index", + "values": [ + _month_day_nano_interval_to_iso_duration(dt) if dt.is_valid else None for dt in column + ], + "value_type": "duration", + } + ) + case _: + arrays.append( + { + **base_data, + "type": "array" if is_value_column else "array_index", + "values": column.to_pylist(), + "value_type": _array_value_type(column, is_value_column), + } + ) + return arrays + + +def _array_values_to_list(values: pyarrow.Array, value_type: pyarrow.DataType) -> list: + if pyarrow.types.is_timestamp(value_type): + return [t.isoformat() if t is not None else t for t in values.to_pylist()] + if pyarrow.types.is_interval(value_type): + return [_month_day_nano_interval_to_iso_duration(dt) if dt.is_valid else None for dt in values] + return values.to_pylist() + + +def _arrow_data_type_to_value_type(data_type: pyarrow.DataType) -> TypeNames | NullTypeName: + if pyarrow.types.is_floating(data_type): + return "float" + if pyarrow.types.is_integer(data_type): + return "int" + if pyarrow.types.is_string(data_type): + return "str" + if pyarrow.types.is_timestamp(data_type): + return "date_time" + if pyarrow.types.is_interval(data_type): + return "duration" + if pyarrow.types.is_boolean(data_type): + return "bool" + if pyarrow.types.is_null(data_type): + return "null" + raise SpineDBAPIError(f"unknown Arrow data type {data_type}") + + +def _union_array_values_to_list(column: pyarrow.UnionArray) -> tuple[list, list[TypeNames | NullTypeName]]: + values = [] + types = [] + for i, x in enumerate(column): + types.append(_arrow_data_type_to_value_type(x.type[x.type_code].type)) + match x.value: + case pyarrow.MonthDayNanoIntervalScalar(): + values.append(_month_day_nano_interval_to_iso_duration(x)) + case _: + values.append(x.as_py()) + return values, types + + +_ZERO_DURATION = "P0D" + + +def _month_day_nano_interval_to_iso_duration(dt: pyarrow.MonthDayNanoIntervalScalar) -> str: + duration = "P" + months, days, nanoseconds = dt.as_py() + years = months // 12 + if years: + duration = duration + f"{years}Y" + months -= years * 12 + if months: + duration = duration + f"{months}M" + if days: + duration = duration + f"{days}D" + if not nanoseconds: + return duration if duration != "P" else _ZERO_DURATION + duration = duration + "T" + seconds = nanoseconds // 1000000000 + hours = seconds // 3600 + if hours: + duration = duration + f"{hours}H" + seconds -= hours * 3600 + minutes = seconds // 60 + if minutes: + duration = duration + f"{minutes}M" + seconds -= minutes * 60 + if seconds: + duration += f"{seconds}S" + return duration if duration != "PT" else _ZERO_DURATION + + +def _array_value_type(column: pyarrow.Array, is_value_column: bool) -> str: + match column: + case pyarrow.FloatingPointArray(): + return "float" + case pyarrow.IntegerArray(): + return "int" + case pyarrow.StringArray() | pyarrow.LargeStringArray(): + return "str" + case pyarrow.BooleanArray(): + if not is_value_column: + raise SpineDBAPIError("boolean array cannot be index") + return "bool" + case pyarrow.MonthDayNanoIntervalArray(): + if is_value_column: + raise SpineDBAPIError("duration array cannot be value") + return "duration" + case pyarrow.NullArray(): + return "float" + case _: + raise SpineDBAPIError(f"unsupported column type {type(column).__name__}") diff --git a/spinedb_api/spine_db_client.py b/spinedb_api/spine_db_client.py index 3d9e4107..86c9aac6 100644 --- a/spinedb_api/spine_db_client.py +++ b/spinedb_api/spine_db_client.py @@ -19,7 +19,7 @@ from sqlalchemy.engine.url import URL from .server_client_helpers import ReceiveAllMixing, decode, encode -client_version = 8 +client_version = 9 class SpineDBClient(ReceiveAllMixing): diff --git a/spinedb_api/spine_db_server.py b/spinedb_api/spine_db_server.py index 6e3ddd0f..fa14f16a 100644 --- a/spinedb_api/spine_db_server.py +++ b/spinedb_api/spine_db_server.py @@ -115,11 +115,11 @@ def _import_entity_class(server_url, class_name): from .filters.scenario_filter import scenario_filter_config from .filters.tools import apply_filter_stack, clear_filter_configs from .import_functions import import_data -from .parameter_value import dump_db_value +from .incomplete_values import dump_db_value from .server_client_helpers import ReceiveAllMixing, decode, encode from .spine_db_client import SpineDBClient -_current_server_version = 8 +_current_server_version = 9 class OrderingDict(TypedDict): @@ -211,7 +211,7 @@ def _do_work(self): for server_address in list(self._servers): self._shutdown_server(server_address) - def _start_server(self, db_url, upgrade, memory, ordering): + def _start_server(self, db_url, ordering, upgrade, memory): host = "127.0.0.1" commit_lock = self._get_commit_lock(db_url) while True: @@ -293,8 +293,10 @@ def _run_request_on_manager(request, server_manager_queue, *args, **kwargs): return output_queue.get() -def start_spine_db_server(server_manager_queue, db_url, upgrade: bool = False, memory: bool = False, ordering=None): - return _run_request_on_manager("start_server", server_manager_queue, db_url, upgrade, memory, ordering) +def start_spine_db_server( + server_manager_queue, db_url, ordering: OrderingDict, upgrade: bool = False, memory: bool = False +): + return _run_request_on_manager("start_server", server_manager_queue, db_url, ordering, upgrade, memory) def shutdown_spine_db_server(server_manager_queue, server_address): @@ -673,7 +675,7 @@ def closing_spine_db_server( "precursors": set(), "part_count": 0, } - server_address = start_spine_db_server(server_manager_queue, db_url, memory=memory, ordering=ordering) + server_address = start_spine_db_server(server_manager_queue, db_url, ordering, upgrade=upgrade, memory=memory) host, port = server_address try: yield urlunsplit(("http", f"{host}:{port}", "", "", "")) diff --git a/spinedb_api/spine_io/exporters/excel.py b/spinedb_api/spine_io/exporters/excel.py index fa888fd7..0dbc7e0a 100644 --- a/spinedb_api/spine_io/exporters/excel.py +++ b/spinedb_api/spine_io/exporters/excel.py @@ -34,7 +34,7 @@ ScenarioMapping, ) from spinedb_api.export_mapping.group_functions import GroupOneOrNone -from ...parameter_value import from_database_to_dimension_count +from ...incomplete_values import from_database_to_dimension_count from .excel_writer import ExcelWriter from .writer import write diff --git a/spinedb_api/value_support.py b/spinedb_api/value_support.py new file mode 100644 index 00000000..a3caff15 --- /dev/null +++ b/spinedb_api/value_support.py @@ -0,0 +1,103 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +from collections import defaultdict +from collections.abc import Sequence +import json +import re +from typing import Any, Literal, Optional, TypeAlias +import pyarrow +from .exception import ParameterValueFormatError + +JSONValue = bool | float | str | list | dict + +_INTERVAL_REGEXP = re.compile(r"(Y|M|D|WD|h|m|s)") + + +def load_db_value(db_value: bytes) -> Optional[JSONValue]: + """ + Parses a binary blob into a JSON object. + + If the result is a dict, adds the "type" property to it. + + :meta private: + + Args: + db_value: The binary blob. + + Returns: + The parsed parameter value. + """ + if db_value is None: + return None + try: + parsed = json.loads(db_value) + except json.JSONDecodeError as err: + raise ParameterValueFormatError(f"Could not decode the value: {err}") from err + return parsed + + +def validate_time_period(time_period: str) -> None: + """ + Checks if a time period has the right format. + + Args: + time_period: The time period to check. Generally assumed to be a union of interval intersections. + + Raises: + ParameterValueFormatError: If the given string doesn't comply with time period spec. + """ + union_dlm = "," + intersection_dlm = ";" + range_dlm = "-" + for intersection_str in time_period.split(union_dlm): + for interval_str in intersection_str.split(intersection_dlm): + m = _INTERVAL_REGEXP.match(interval_str) + if m is None: + raise ParameterValueFormatError( + f"Invalid interval {interval_str}, it should start with either Y, M, D, WD, h, m, or s." + ) + key = m.group(0) + lower_upper_str = interval_str[len(key) :] + lower_upper = lower_upper_str.split(range_dlm) + if len(lower_upper) != 2: + raise ParameterValueFormatError( + f"Invalid interval bounds {lower_upper_str}, it should be two integers separated by dash (-)." + ) + lower_str, upper_str = lower_upper + try: + lower = int(lower_str) + except Exception as error: + raise ParameterValueFormatError(f"Invalid lower bound {lower_str}, must be an integer.") from error + try: + upper = int(upper_str) + except Exception as error: + raise ParameterValueFormatError(f"Invalid upper bound {upper_str}, must be an integer.") from error + if lower > upper: + raise ParameterValueFormatError(f"Lower bound {lower} can't be higher than upper bound {upper}.") + + +def to_union_array(arr: Sequence[Any | None]): + type_map = defaultdict(list) + offsets = [] + for item in arr: + item_t = type(item) + offsets.append(len(type_map[item_t])) + type_map[item_t].append(item) + + _types = list(type_map) + types = pyarrow.array((_types.index(type(i)) for i in arr), type=pyarrow.int8()) + uarr = pyarrow.UnionArray.from_dense( + types, + pyarrow.array(offsets, type=pyarrow.int32()), + list(map(pyarrow.array, type_map.values())), + ) + return uarr diff --git a/tests/filters/test_value_transformer.py b/tests/filters/test_value_transformer.py index 03f502fb..87ade2cf 100644 --- a/tests/filters/test_value_transformer.py +++ b/tests/filters/test_value_transformer.py @@ -129,7 +129,7 @@ def test_negate_manipulator_with_nested_map(self): url = append_filter_config(str(db_url), config) with DatabaseMapping(url) as db_map: values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] - expected = Map(["A"], [Map(["1"], [-2.3])]) + expected = Map(["A"], [Map(["1"], [-2.3], index_name="col_2")], index_name="col_1") self.assertEqual(values, [expected]) db_map.engine.dispose() @@ -211,6 +211,6 @@ def test_index_generator_on_time_series(self): url = append_filter_config(str(db_url), config) with DatabaseMapping(url) as db_map: values = [from_database(row.value, row.type) for row in db_map.query(db_map.parameter_value_sq)] - expected = Map([1.0, 2.0], [-5.0, -2.3]) + expected = Map([1.0, 2.0], [-5.0, -2.3], index_name="col_1") self.assertEqual(values, [expected]) db_map.engine.dispose() diff --git a/tests/import_mapping/test_generator.py b/tests/import_mapping/test_generator.py index a906e04c..3068b32e 100644 --- a/tests/import_mapping/test_generator.py +++ b/tests/import_mapping/test_generator.py @@ -11,10 +11,25 @@ ###################################################################################################################### """ Contains unit tests for the generator module. """ +import datetime +import json import unittest -from spinedb_api import Array, DateTime, Duration, Map +from dateutil.relativedelta import relativedelta +import pyarrow +from spinedb_api import ( + Array, + DateTime, + Duration, + Map, + TimePattern, + TimeSeries, + TimeSeriesFixedResolution, + TimeSeriesVariableResolution, + to_database, +) from spinedb_api.import_mapping.generator import get_mapped_data from spinedb_api.import_mapping.type_conversion import value_to_convert_spec +from spinedb_api.incomplete_values import join_value_and_type class TestGetMappedData(unittest.TestCase): @@ -1078,6 +1093,165 @@ def test_missing_entity_alternative_does_not_prevent_importing_of_values(self): }, ) + def test_json_converter_with_legacy_values(self): + header = ["Entity", "Value"] + values = [ + DateTime("2025-09-05T09:11"), + Duration("4 hours"), + Array([2.3]), + TimePattern(["WD1-7"], [2.3]), + TimeSeriesFixedResolution("2025-09-05T09:16", "1h", [2.3], ignore_year=True, repeat=False), + TimeSeriesVariableResolution(["2025-09-05T09:16"], [2.3], ignore_year=False, repeat=True), + Map(["a"], ["b"]), + ] + data = [] + expected_imported_parameter_values = [] + for i, value in enumerate(values): + value_dict = value.to_dict() + value_dict["type"] = value.TYPE + entity = value.TYPE + str(i) + data.append([entity, json.dumps(value_dict)]) + expected_value = value.as_arrow() + if isinstance(value, TimeSeries): + new_fields = [] + for field_i in range(len(expected_value.schema)): + field = expected_value.schema.field(field_i) + if pyarrow.types.is_timestamp(field.type): + field = field.with_type(pyarrow.timestamp("us")) + new_fields.append((field.name, field.type)) + expected_value = expected_value.cast(pyarrow.schema(new_fields)) + expected_imported_parameter_values.append(["Object", entity, "y", expected_value, "Base"]) + data_source = iter(data) + mappings = [ + [ + {"map_type": "EntityClass", "position": "hidden", "value": "Object"}, + {"map_type": "Entity", "position": 0}, + {"map_type": "ParameterDefinition", "position": "hidden", "value": "y"}, + {"map_type": "Alternative", "position": "hidden", "value": "Base"}, + {"map_type": "ParameterValue", "position": 1}, + ] + ] + convert_function_specs = {0: "string", 1: "json"} + convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()} + mapped_data, errors = get_mapped_data(data_source, mappings, header, column_convert_fns=convert_functions) + self.assertEqual(errors, []) + self.assertEqual( + mapped_data, + { + "entity_classes": [ + ("Object",), + ], + "entities": [ + ("Object", "date_time0"), + ("Object", "duration1"), + ("Object", "array2"), + ("Object", "time_pattern3"), + ("Object", "time_series4"), + ("Object", "time_series5"), + ("Object", "map6"), + ], + "parameter_definitions": [("Object", "y")], + "alternatives": {"Base"}, + "parameter_values": expected_imported_parameter_values, + }, + ) + + def test_json_converter(self): + header = ["Entity", "Value"] + values = [ + 2.3, + 23, + "a string", + True, + datetime.datetime(year=2025, month=9, day=5, hour=10, minute=4), + relativedelta(hours=7), + pyarrow.record_batch({"col_1": pyarrow.array(["a"]), "value": pyarrow.array("b")}), + ] + data = [] + expected_imported_parameter_values = [] + for value in values: + blob, value_type = to_database(value) + value_and_type = join_value_and_type(blob, value_type) + entity = value_type + data.append([entity, value_and_type]) + expected_imported_parameter_values.append(["Object", entity, "y", value, "Base"]) + data_source = iter(data) + mappings = [ + [ + {"map_type": "EntityClass", "position": "hidden", "value": "Object"}, + {"map_type": "Entity", "position": 0}, + {"map_type": "ParameterDefinition", "position": "hidden", "value": "y"}, + {"map_type": "Alternative", "position": "hidden", "value": "Base"}, + {"map_type": "ParameterValue", "position": 1}, + ] + ] + convert_function_specs = {0: "string", 1: "json"} + convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()} + mapped_data, errors = get_mapped_data(data_source, mappings, header, column_convert_fns=convert_functions) + self.assertEqual(errors, []) + self.assertEqual( + mapped_data, + { + "entity_classes": [ + ("Object",), + ], + "entities": [ + ("Object", "float"), + ("Object", "int"), + ("Object", "str"), + ("Object", "bool"), + ("Object", "date_time"), + ("Object", "duration"), + ("Object", "table"), + ], + "parameter_definitions": [("Object", "y")], + "alternatives": {"Base"}, + "parameter_values": expected_imported_parameter_values, + }, + ) + + def test_json_converter_with_unrecognized_json_imports_string_as_is(self): + header = ["Entity", "Value"] + values = [ + 2.3, + 23, + "a string", + True, + datetime.datetime(year=2025, month=9, day=5, hour=10, minute=4), + relativedelta(hours=7), + pyarrow.record_batch({"col_1": pyarrow.array(["a"]), "value": pyarrow.array("b")}), + ] + data = [["non-compatible json", json.dumps({"my_data": [11]})], ["total gibberish", "abc"]] + data_source = iter(data) + mappings = [ + [ + {"map_type": "EntityClass", "position": "hidden", "value": "Object"}, + {"map_type": "Entity", "position": 0}, + {"map_type": "ParameterDefinition", "position": "hidden", "value": "y"}, + {"map_type": "Alternative", "position": "hidden", "value": "Base"}, + {"map_type": "ParameterValue", "position": 1}, + ] + ] + convert_function_specs = {0: "string", 1: "json"} + convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()} + mapped_data, errors = get_mapped_data(data_source, mappings, header, column_convert_fns=convert_functions) + self.assertEqual(errors, []) + self.assertEqual( + mapped_data, + { + "entity_classes": [ + ("Object",), + ], + "entities": [("Object", "non-compatible json"), ("Object", "total gibberish")], + "parameter_definitions": [("Object", "y")], + "alternatives": {"Base"}, + "parameter_values": [ + ["Object", "non-compatible json", "y", json.dumps({"my_data": [11]}), "Base"], + ["Object", "total gibberish", "y", "abc", "Base"], + ], + }, + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/import_mapping/test_import_mapping.py b/tests/import_mapping/test_import_mapping.py index 78073d43..0eb34b4f 100644 --- a/tests/import_mapping/test_import_mapping.py +++ b/tests/import_mapping/test_import_mapping.py @@ -11,6 +11,7 @@ ###################################################################################################################### """ Unit tests for import Mappings. """ +import json import unittest from unittest.mock import Mock from spinedb_api.exception import InvalidMapping @@ -39,11 +40,17 @@ parameter_mapping_from_dict, parameter_value_mapping_from_dict, ) -from spinedb_api.import_mapping.type_conversion import BooleanConvertSpec, FloatConvertSpec, StringConvertSpec +from spinedb_api.import_mapping.type_conversion import ( + BooleanConvertSpec, + FloatConvertSpec, + JSONConvertSpec, + StringConvertSpec, +) +from spinedb_api.incomplete_values import join_value_and_type from spinedb_api.mapping import Position from spinedb_api.mapping import to_dict as mapping_to_dict from spinedb_api.mapping import unflatten -from spinedb_api.parameter_value import Array, DateTime, Map, TimeSeriesVariableResolution +from spinedb_api.parameter_value import Array, DateTime, Map, TimeSeriesVariableResolution, to_database class TestConvertFunctions(unittest.TestCase): @@ -81,7 +88,7 @@ def test_convert_functions_str(self): expected = { "entity_classes": [("a",)], "entities": [("a", "obj")], - "parameter_definitions": [("a", "param", "1111.2222")], + "parameter_definitions": [("a", "param", '"1111.2222"')], } self.assertEqual(mapped_data, expected) @@ -104,6 +111,25 @@ def test_convert_functions_bool(self): } self.assertEqual(mapped_data, expected) + def test_json(self): + data = [["a", join_value_and_type(*to_database(2.3))]] + column_convert_fns = {0: str, 1: JSONConvertSpec()} + mapping = import_mapping_from_dict({"map_type": "ObjectClass"}) + mapping.position = 0 + mapping.child.value = "obj" + mapping.flatten()[-1].child = param_def_mapping = parameter_mapping_from_dict( + {"map_type": "ParameterDefinition"} + ) + param_def_mapping.value = "param" + param_def_mapping.flatten()[-1].position = 1 + mapped_data, _ = get_mapped_data(data, [mapping], column_convert_fns=column_convert_fns) + expected = { + "entity_classes": [("a",)], + "entities": [("a", "obj")], + "parameter_definitions": [("a", "param", 2.3)], + } + self.assertEqual(mapped_data, expected) + def test_convert_functions_with_error(self): data = [["a", "not a float"]] column_convert_fns = {0: str, 1: FloatConvertSpec()} diff --git a/tests/spine_io/exporters/test_excel_writer.py b/tests/spine_io/exporters/test_excel_writer.py index a4d56cfb..49341c26 100644 --- a/tests/spine_io/exporters/test_excel_writer.py +++ b/tests/spine_io/exporters/test_excel_writer.py @@ -145,8 +145,8 @@ def test_value_indexes_remain_unsorted(self): try: self.assertEqual(workbook.sheetnames, ["Sheet1"]) expected = [ - ["z", "o1", "Base", "1d_map", "x", "T02", 1.1], - ["z", "o1", "Base", "1d_map", "x", "T01", 1.2], + ["z", "o1", "Base", "1d_map", "col_1", "T02", 1.1], + ["z", "o1", "Base", "1d_map", "col_1", "T01", 1.2], ] self.check_sheet(workbook, "Sheet1", expected) finally: diff --git a/tests/test_DatabaseMapping.py b/tests/test_DatabaseMapping.py index 5a7e032c..3ec052dc 100644 --- a/tests/test_DatabaseMapping.py +++ b/tests/test_DatabaseMapping.py @@ -3337,6 +3337,30 @@ def test_add_relationship_after_purge(self): relationship = db_map.add_entity(entity_byname=("thing", "thing"), entity_class_name="Object__Object") self.assertEqual(relationship["name"], "thing__thing") + def test_update_parameter_value_from_float_to_duration(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class(name="Object") + db_map.add_parameter_definition(entity_class_name="Object", name="Y") + db_map.add_entity(entity_class_name="Object", name="widget") + value_item = db_map.add_parameter_value( + entity_class_name="Object", + entity_byname=("widget",), + parameter_definition_name="Y", + alternative_name="Base", + parsed_value=2.3, + ) + value, value_type = to_database(Duration("3 hours")) + value_item.update(value=value, type=value_type) + self.assertEqual(value_item["parsed_value"], Duration("3h")) + self.assertEqual(value_item["arrow_value"], relativedelta(hours=3)) + + def test_update_list_value_by_parsed_value(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_parameter_value_list(name="my list") + list_value = db_map.add_list_value(parameter_value_list_name="my list", parsed_value="original", index=0) + list_value.update(parsed_value="new") + self.assertEqual(list_value["parsed_value"], "new") + def test_entity_class_with_id_that_replaces_a_removed_id_is_found_by_fetch_all(self): with TemporaryDirectory() as temp_dir: url = "sqlite:///" + os.path.join(temp_dir, "db.sqlite") diff --git a/tests/test_alembic_migration.py b/tests/test_alembic_migration.py index 12d16224..d6dc61c0 100644 --- a/tests/test_alembic_migration.py +++ b/tests/test_alembic_migration.py @@ -20,7 +20,6 @@ Map, SpineDBAPIError, TimePattern, - TimeSeriesFixedResolution, TimeSeriesVariableResolution, ) @@ -133,8 +132,15 @@ def _assert_parameter_definitions(db_map): entity_class_name="Widget", name="time_series_fixed_resolution" ) assert time_series_fixed_resolution_definition["description"] == "Parameter with time series values." - assert time_series_fixed_resolution_definition["parsed_value"] == TimeSeriesFixedResolution( - "2020-04-22 00:00:00", "3h", [1.1, 2.2, 3.3], ignore_year=True, repeat=False + assert time_series_fixed_resolution_definition["parsed_value"] == TimeSeriesVariableResolution( + [ + "2020-04-22 00:00:00", + "2020-04-22 03:00:00", + "2020-04-22 06:00:00", + ], + [1.1, 2.2, 3.3], + ignore_year=True, + repeat=False, ) assert time_series_fixed_resolution_definition["parameter_value_list_name"] is None time_series_variable_resolution_definition = db_map.parameter_definition( @@ -151,7 +157,9 @@ def _assert_parameter_definitions(db_map): map_definition = db_map.parameter_definition(entity_class_name="Widget", name="map") assert map_definition["description"] == "Parameter with map values." assert map_definition["parsed_value"] == Map( - ["A", "B"], [Map(["T00", "T01"], [1.1, 2.2]), Map(["T00", "T01"], [3.3, 4.4])] + ["A", "B"], + [Map(["T00", "T01"], [1.1, 2.2], index_name="col_2"), Map(["T00", "T01"], [3.3, 4.4], index_name="col_2")], + index_name="col_1", ) assert map_definition["parameter_value_list_name"] is None @@ -222,8 +230,14 @@ def _assert_parameter_values(db_map): entity_byname=("clock",), alternative_name="Base", ) - assert time_series_fixed_resolution_value["parsed_value"] == TimeSeriesFixedResolution( - "2025-09-23 00:00:00", "6h", [-1.1, -2.2], ignore_year=False, repeat=False + assert time_series_fixed_resolution_value["parsed_value"] == TimeSeriesVariableResolution( + [ + "2025-09-23T00:00:00", + "2025-09-23T06:00:00", + ], + [-1.1, -2.2], + ignore_year=False, + repeat=False, ) time_series_variable_resolution_value = db_map.parameter_value( entity_class_name="Widget", @@ -244,7 +258,9 @@ def _assert_parameter_values(db_map): alternative_name="Base", ) assert map_value["parsed_value"] == Map( - ["A", "A", "B", "B"], [-1.1, Map(["a"], [-2.2]), -3.3, Map(["b"], [-4.4])] + ["A", "A", "B", "B"], + [-1.1, Map(["a"], [-2.2], index_name="col_2"), -3.3, Map(["b"], [-4.4], index_name="col_2")], + index_name="col_1", ) diff --git a/tests/test_arrow_value.py b/tests/test_arrow_value.py index ee758bfa..ceaa80db 100644 --- a/tests/test_arrow_value.py +++ b/tests/test_arrow_value.py @@ -11,261 +11,96 @@ ###################################################################################################################### """ Unit tests for the `arrow_value` module. """ import datetime -import unittest +from dateutil.relativedelta import relativedelta import pyarrow -from spinedb_api import arrow_value, parameter_value - - -class DatabaseUsingTest(unittest.TestCase): - def _assert_success(self, result): - item, error = result - self.assertIsNone(error) - return item - - -class TestFromDatabaseForArrays(unittest.TestCase): - def test_empty_array(self): - value, value_type = parameter_value.to_database(parameter_value.Array([])) - record_batch = arrow_value.from_database(value, value_type) - self.assertEqual(len(record_batch), 0) - self.assertEqual(record_batch.column_names, ["i", "value"]) - self.assertEqual(record_batch.column("i").type, pyarrow.int64()) - self.assertEqual(record_batch.column("value").type, pyarrow.float64()) - - def test_floats_with_index_name(self): - value, value_type = parameter_value.to_database(parameter_value.Array([2.3], index_name="my index")) - record_batch = arrow_value.from_database(value, value_type) - self.assertEqual(len(record_batch), 1) - self.assertEqual(record_batch.column_names, ["my index", "value"]) - indices = record_batch.column("my index") - self.assertEqual(indices.type, pyarrow.int64()) - self.assertEqual(indices, pyarrow.array([0])) - ys = record_batch.column("value") - self.assertEqual(ys.type, pyarrow.float64()) - self.assertEqual(ys, pyarrow.array([2.3])) - - def test_date_times_with_index_name(self): - value, value_type = parameter_value.to_database( - parameter_value.Array([parameter_value.DateTime("2024-09-02T05:51:00")], index_name="my index") - ) - record_batch = arrow_value.from_database(value, value_type) - self.assertEqual(len(record_batch), 1) - self.assertEqual(record_batch.column_names, ["my index", "value"]) - indices = record_batch.column("my index") - self.assertEqual(indices.type, pyarrow.int64()) - self.assertEqual(indices, pyarrow.array([0])) - ys = record_batch.column("value") - self.assertEqual(ys.type, pyarrow.timestamp("s")) - self.assertEqual(ys.tolist(), [datetime.datetime(2024, 9, 2, 5, 51)]) - - -class TestFromDatabaseForMaps(unittest.TestCase): - def test_empty_map(self): - value, value_type = parameter_value.to_database(parameter_value.Map([], [], str)) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 0) - self.assertEqual(map_.column_names, ["col_1", "value"]) - self.assertEqual(map_.column("col_1").type, pyarrow.string()) - self.assertEqual(map_.column("value").type, pyarrow.null()) - - def test_string_to_string_map_with_index_name(self): - value, value_type = parameter_value.to_database(parameter_value.Map(["key"], ["value"], index_name="Keys")) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 1) - self.assertEqual(map_.column_names, ["Keys", "value"]) - self.assertEqual(map_.column("Keys").type, pyarrow.string()) - self.assertEqual(map_.column("Keys")[0].as_py(), "key") - self.assertEqual(map_.column("value").type, pyarrow.string()) - self.assertEqual(map_.column("value")[0].as_py(), "value") - - def test_date_time_to_different_simple_types_map_with_index_name(self): - value, value_type = parameter_value.to_database( - parameter_value.Map( - [parameter_value.DateTime("2024-02-09T10:00"), parameter_value.DateTime("2024-02-09T11:00")], - ["value", 2.3], - index_name="timestamps", - ) - ) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 2) - self.assertEqual(map_.column_names, ["timestamps", "value"]) - self.assertEqual(map_.column("timestamps").type, pyarrow.timestamp("s")) - self.assertEqual( - map_.column("timestamps").to_pylist(), - [datetime.datetime(2024, 2, 9, 10), datetime.datetime(2024, 2, 9, 11)], - ) - self.assertEqual( - map_.column("value").type, - pyarrow.dense_union([pyarrow.field("str", pyarrow.string()), pyarrow.field("float", pyarrow.float64())]), - ) - self.assertEqual(map_.column("value").to_pylist(), ["value", 2.3]) - - def test_nested_maps(self): - string_map = parameter_value.Map([11.0], ["value"], index_name="nested index") - float_map = parameter_value.Map(["key"], [22.0], index_name="nested index") - value, value_type = parameter_value.to_database( - parameter_value.Map(["strings", "floats"], [string_map, float_map], index_name="main index") - ) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 2) - self.assertEqual(map_.column_names, ["main index", "nested index", "value"]) - self.assertEqual(map_.column("main index").type, pyarrow.string()) - self.assertEqual(map_.column("main index").to_pylist(), ["strings", "floats"]) - self.assertEqual( - map_.column("nested index").type, - pyarrow.dense_union([pyarrow.field("float", pyarrow.float64()), pyarrow.field("str", pyarrow.string())]), - ) - self.assertEqual(map_.column("nested index").to_pylist(), [11.0, "key"]) - self.assertEqual( - map_.column("value").type, - pyarrow.dense_union([pyarrow.field("str", pyarrow.string()), pyarrow.field("float", pyarrow.float64())]), - ) - self.assertEqual(map_.column("value").to_pylist(), ["value", 22.0]) - - def test_unevenly_nested_map_with_fixed_resolution_time_series(self): - string_map = parameter_value.Map([11.0], ["value"], index_name="nested index") - float_map = parameter_value.Map(["key"], [22.0], index_name="nested index") - time_series = parameter_value.TimeSeriesFixedResolution( - "2025-02-26T09:00:00", "1h", [2.3, 23.0], ignore_year=False, repeat=False - ) - time_series_map = parameter_value.Map([parameter_value.DateTime("2024-02-26T16:45:00")], [time_series]) - nested_time_series_map = parameter_value.Map( - ["ts", "no ts"], [time_series_map, "empty"], index_name="nested index" - ) - value, value_type = parameter_value.to_database( - parameter_value.Map( - ["not nested", "strings", "time series", "floats"], - ["none", string_map, nested_time_series_map, float_map], - index_name="main index", - ) - ) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 6) - self.assertEqual(map_.column_names, ["main index", "nested index", "col_3", "t", "value"]) - self.assertEqual(map_.column("main index").type, pyarrow.string()) - self.assertEqual( - map_.column("main index").to_pylist(), - ["not nested", "strings", "time series", "time series", "time series", "floats"], - ) - self.assertEqual(map_.column("nested index").to_pylist(), [None, 11.0, "ts", "ts", "no ts", "key"]) - self.assertEqual( - map_.column("col_3").to_pylist(), - [ - None, - None, - datetime.datetime.fromisoformat("2024-02-26T16:45:00"), - datetime.datetime.fromisoformat("2024-02-26T16:45:00"), - None, - None, - ], - ) - self.assertEqual( - map_.column("t").to_pylist(), - [ - None, - None, - datetime.datetime.fromisoformat("2025-02-26T09:00:00"), - datetime.datetime.fromisoformat("2025-02-26T10:00:00"), - None, - None, - ], - ) - self.assertEqual(map_.column("value").to_pylist(), ["none", "value", 2.3, 23.0, "empty", 22.0]) - - def test_unevenly_nested_map(self): - string_map = parameter_value.Map([11.0], ["value"], index_name="nested index") - float_map = parameter_value.Map(["key"], [22.0], index_name="nested index") - datetime_map = parameter_value.Map(["time of my life"], [parameter_value.DateTime("2024-02-26T16:45:00")]) - another_string_map = parameter_value.Map([parameter_value.DateTime("2024-02-26T17:45:00")], ["future"]) - nested_map = parameter_value.Map( - ["date time", "more date time", "non nested"], - [datetime_map, another_string_map, "empty"], - index_name="nested index", - ) - value, value_type = parameter_value.to_database( - parameter_value.Map( - ["not nested", "strings", "date times", "floats"], - ["none", string_map, nested_map, float_map], - index_name="main index", - ) - ) - map_ = arrow_value.from_database(value, value_type) - self.assertEqual(len(map_), 6) - self.assertEqual(map_.column_names, ["main index", "nested index", "col_3", "value"]) - self.assertEqual(map_.column("main index").type, pyarrow.string()) - self.assertEqual( - map_.column("main index").to_pylist(), - ["not nested", "strings", "date times", "date times", "date times", "floats"], - ) - self.assertEqual( - map_.column("nested index").to_pylist(), [None, 11.0, "date time", "more date time", "non nested", "key"] - ) - self.assertEqual( - map_.column("col_3").to_pylist(), - [None, None, "time of my life", datetime.datetime.fromisoformat("2024-02-26T17:45:00"), None, None], - ) - self.assertEqual( - map_.column("value").to_pylist(), - ["none", "value", datetime.datetime.fromisoformat("2024-02-26T16:45:00"), "future", "empty", 22.0], - ) - - -class TestFromDatabaseForTimeSeries(unittest.TestCase): - def test_fixed_resolution_series(self): - value, value_type = parameter_value.to_database( - parameter_value.TimeSeriesFixedResolution( - "2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=False, repeat=False - ) - ) - fixed_resolution = arrow_value.from_database(value, value_type) - self.assertEqual(fixed_resolution.column_names, ["t", "value"]) - self.assertEqual( - fixed_resolution.column("t").to_pylist(), - [datetime.datetime(2025, 2, 5, 9, 59), datetime.datetime(2025, 2, 5, 10, 14)], - ) - self.assertEqual(fixed_resolution.schema.field("t").metadata, {b"ignore_year": b"false", b"repeat": b"false"}) - self.assertEqual(fixed_resolution.column("value").to_pylist(), [1.1, 1.2]) - - def test_ignore_year(self): - value, value_type = parameter_value.to_database( - parameter_value.TimeSeriesFixedResolution( - "2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=True, repeat=False - ) - ) - fixed_resolution = arrow_value.from_database(value, value_type) - self.assertEqual(fixed_resolution.schema.field("t").metadata, {b"ignore_year": b"true", b"repeat": b"false"}) - - def test_repeat(self): - value, value_type = parameter_value.to_database( - parameter_value.TimeSeriesFixedResolution( - "2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=False, repeat=True - ) - ) - fixed_resolution = arrow_value.from_database(value, value_type) - self.assertEqual(fixed_resolution.schema.field("t").metadata, {b"ignore_year": b"false", b"repeat": b"true"}) - - def test_variable_resolution_series(self): - value, value_type = parameter_value.to_database( - parameter_value.TimeSeriesVariableResolution( - ["2025-02-05T09:59", "2025-02-05T10:14", "2025-02-05T11:31"], - [1.1, 1.2, 1.3], - ignore_year=False, - repeat=False, - ) - ) - fixed_resolution = arrow_value.from_database(value, value_type) - self.assertEqual(fixed_resolution.column_names, ["t", "value"]) - self.assertEqual( - fixed_resolution.column("t").to_pylist(), - [ - datetime.datetime(2025, 2, 5, 9, 59), - datetime.datetime(2025, 2, 5, 10, 14), - datetime.datetime(2025, 2, 5, 11, 31), - ], - ) - self.assertEqual(fixed_resolution.schema.field("t").metadata, {b"ignore_year": b"false", b"repeat": b"false"}) - self.assertEqual(fixed_resolution.column("value").to_pylist(), [1.1, 1.2, 1.3]) - - -if __name__ == "__main__": - unittest.main() +import pytest +from spinedb_api import SpineDBAPIError, to_database +from spinedb_api.arrow_value import ( + from_database, + load_field_metadata, + with_column_as_time_period, + with_column_as_time_stamps, +) + + +class TestFromDatabase: + def test_string(self): + value = "this is a string" + assert from_database(*to_database(value)) == value + + def test_boolean(self): + value = False + assert from_database(*to_database(value)) == value + + def test_float(self): + value = 2.3 + assert from_database(*to_database(value)) == value + + def test_date_time(self): + value = datetime.datetime(year=2025, month=8, day=29, hour=16, minute=40) + assert from_database(*to_database(value)) == value + + def test_relativedelta(self): + value = relativedelta(minutes=23) + assert from_database(*to_database(value)) == value + + def test_record_batch(self): + value = pyarrow.record_batch( + { + "index": pyarrow.array(["a", "b"]), + "value": pyarrow.array([2.3, 3.2]), + } + ) + assert from_database(*to_database(value)) == value + + +class TestWithColumnAsTimePeriod: + def test_column_given_by_name(self): + column = pyarrow.array(["M1-4,M9-12", "M5-8"]) + record_batch = pyarrow.record_batch({"data": column}) + as_time_period = with_column_as_time_period(record_batch, "data") + column_metadata = load_field_metadata(as_time_period.field("data")) + assert column_metadata == {"format": "time_period"} + + def test_column_given_by_index(self): + column = pyarrow.array(["M1-4,M9-12", "M5-8"]) + record_batch = pyarrow.record_batch({"data": column}) + as_time_period = with_column_as_time_period(record_batch, 0) + column_metadata = load_field_metadata(as_time_period.field("data")) + assert column_metadata == {"format": "time_period"} + + def test_raises_when_column_data_is_invalid(self): + column = pyarrow.array(["gibberish"]) + record_batch = pyarrow.record_batch({"data": column}) + with pytest.raises( + SpineDBAPIError, match="^Invalid interval gibberish, it should start with either Y, M, D, WD, h, m, or s.$" + ): + with_column_as_time_period(record_batch, 0) + + +class TestWithColumnAsTimeStamps: + def test_column_given_by_name(self): + column = pyarrow.array([datetime.datetime(year=2025, month=7, day=25, hour=9, minute=48)]) + record_batch = pyarrow.record_batch({"stamps": column}) + as_time_stamps_with_year_ignored = with_column_as_time_stamps(record_batch, "stamps", True, False) + as_time_stamps_with_repeat = with_column_as_time_stamps(record_batch, "stamps", False, True) + assert load_field_metadata(as_time_stamps_with_year_ignored.field("stamps")) == { + "ignore_year": True, + "repeat": False, + } + assert load_field_metadata(as_time_stamps_with_repeat.field("stamps")) == { + "ignore_year": False, + "repeat": True, + } + + def test_column_given_by_index(self): + column = pyarrow.array([datetime.datetime(year=2025, month=7, day=25, hour=9, minute=48)]) + record_batch = pyarrow.record_batch({"stamps": column}) + as_time_stamps = with_column_as_time_stamps(record_batch, 0, True, True) + assert load_field_metadata(as_time_stamps.field("stamps")) == {"ignore_year": True, "repeat": True} + + def test_raises_when_column_type_is_wrong(self): + column = pyarrow.array(["A"]) + record_batch = pyarrow.record_batch({"stamps": column}) + with pytest.raises(SpineDBAPIError, match="^column is not time stamp column$"): + with_column_as_time_stamps(record_batch, 0, False, False) diff --git a/tests/test_check_integrity.py b/tests/test_check_integrity.py index 94378512..73af8ffc 100644 --- a/tests/test_check_integrity.py +++ b/tests/test_check_integrity.py @@ -50,8 +50,8 @@ def setUp(self): "list_value", {"id": 1, **_val_dict(True), "index": 0, "parameter_value_list_id": 1}, {"id": 2, **_val_dict(False), "index": 1, "parameter_value_list_id": 1}, - {"id": 3, **_val_dict(42), "index": 0, "parameter_value_list_id": 2}, - {"id": 4, **_val_dict(-2), "index": 1, "parameter_value_list_id": 2}, + {"id": 3, **_val_dict(42.0), "index": 0, "parameter_value_list_id": 2}, + {"id": 4, **_val_dict(-2.0), "index": 1, "parameter_value_list_id": 2}, {"id": 5, **_val_dict("foo"), "index": 0, "parameter_value_list_id": 3}, {"id": 6, **_val_dict("Bar"), "index": 1, "parameter_value_list_id": 3}, {"id": 7, **_val_dict("BAZ"), "index": 2, "parameter_value_list_id": 3}, @@ -77,21 +77,22 @@ def get_item(id_: int, val: bytes, type_: str, entity_id: int): def test_parameter_values_and_default_values_with_list_references(self): # regression test for spine-tools/Spine-Toolbox#1878 - for type_, fail, pass_ in self.data: - id_ = self.value_type[type_] # setup: parameter definition/value list ids are equal - for k, value in enumerate(fail): - with self.subTest(type=type_, value=value): - item = self.get_item(id_, value, type_, 1) - with self.db_map: + with self.db_map: + for type_, fail, pass_ in self.data: + id_ = self.value_type[type_] # setup: parameter definition/value list ids are equal + for k, value in enumerate(fail): + with self.subTest(type=type_, value=value): + item = self.get_item(id_, value, type_, 1) _, errors = self.db_map.add_items("parameter_value", item) - self.assertEqual(len(errors), 1) - parsed_value = json.loads(value.decode("utf8")) - if isinstance(parsed_value, Number): - parsed_value = float(parsed_value) - self.assertEqual(errors[0], f"value {parsed_value} of par{id_} for ('Tom',) is not in list{id_}") - for k, value in enumerate(pass_): - with self.subTest(type=type_, value=value): - item = self.get_item(id_, value, type_, k + 1) - with self.db_map: + self.assertEqual(len(errors), 1) + parsed_value = json.loads(value.decode("utf8")) + if isinstance(parsed_value, Number): + parsed_value = float(parsed_value) + self.assertEqual( + errors[0], f"value {parsed_value} of par{id_} for ('Tom',) is not in list{id_}" + ) + for k, value in enumerate(pass_): + with self.subTest(type=type_, value=value): + item = self.get_item(id_, value, type_, k + 1) _, errors = self.db_map.add_items("parameter_value", item) - self.assertEqual(errors, []) + self.assertEqual(errors, []) diff --git a/tests/test_dataframes.py b/tests/test_dataframes.py index e48476fe..09585d5d 100644 --- a/tests/test_dataframes.py +++ b/tests/test_dataframes.py @@ -128,7 +128,7 @@ def test_time_series_value(self): "Object": pd.Series(["fork", "fork"], dtype="string"), "parameter_definition_name": pd.Series(["y", "y"], dtype="category"), "alternative_name": pd.Series(["Base", "Base"], dtype="category"), - "t": np.array(["2025-02-05T12:30", "2025-02-05T12:45"], dtype="datetime64[s]"), + "t": np.array(["2025-02-05T12:30", "2025-02-05T12:45"], dtype="datetime64[us]"), "value": [1.1, 1.2], } ) @@ -169,7 +169,7 @@ def test_time_series_value_of_multidimensional_entity(self): "Subject": pd.Series(2 * ["spoon"], dtype="string"), "parameter_definition_name": pd.Series(["y", "y"], dtype="category"), "alternative_name": pd.Series(["Base", "Base"], dtype="category"), - "t": np.array(["2025-02-05T12:30", "2025-02-05T12:45"], dtype="datetime64[s]"), + "t": np.array(["2025-02-05T12:30", "2025-02-05T12:45"], dtype="datetime64[us]"), "value": [1.1, 1.2], } ) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 0af84687..37dcaf24 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -27,6 +27,8 @@ name_from_elements, remove_credentials_from_url, string_to_bool, + time_period_format_specification, + time_series_metadata, vacuum, ) from tests.mock_helpers import AssertSuccessTestCase @@ -82,7 +84,7 @@ def test_password_with_special_characters(self): class TestGetHeadAlembicVersion(unittest.TestCase): def test_returns_latest_version(self): # This test must be updated each time new migration script is added. - self.assertEqual(get_head_alembic_version(), "e9f2c2330cf8") + self.assertEqual(get_head_alembic_version(), "a973ab537da2") class TestStringToBool(unittest.TestCase): @@ -162,5 +164,13 @@ def test_grouping(self): self.assertEqual(list(group_consecutive((1, 2, 6, 3, 7, 10))), [(1, 3), (6, 7), (10, 10)]) -if __name__ == "__main__": - unittest.main() +class TestTimePeriodFormatSpecification: + def test_correctness(self): + specification = time_period_format_specification() + assert specification == {"format": "time_period"} + + +class TestTimeSeriesMetadata: + def test_correctness(self): + assert time_series_metadata(True, False) == {"ignore_year": True, "repeat": False} + assert time_series_metadata(False, True) == {"ignore_year": False, "repeat": True} diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 6864802e..b80c2dae 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -42,6 +42,7 @@ import_scenario_alternatives, import_scenarios, ) +from spinedb_api.incomplete_values import dump_db_value from spinedb_api.parameter_value import ( Array, DateTime, @@ -50,7 +51,6 @@ TimePattern, TimeSeriesFixedResolution, TimeSeriesVariableResolution, - dump_db_value, from_database, to_database, ) @@ -1271,7 +1271,7 @@ def test_import_existing_map_on_conflict_merge(self): alternative_name="Base", ) self.assertTrue(value_item) - merged_value = Map(["T1", "T2", "T3", "T4"], [1.1, 1.2, 1.3, 1.4]) + merged_value = Map(["T1", "T2", "T3", "T4"], [1.1, 1.2, 1.3, 1.4], index_name="col_1") self.assertEqual(value_item["parsed_value"], merged_value) def test_import_duplicate_object_parameter_value(self): @@ -1576,8 +1576,18 @@ def test_unparse_value_imports_fields_correctly(self): self.assertEqual(value.entity_name, "aa") time_series = from_database(value.value, value.type) - expected_result = TimeSeriesFixedResolution( - "2000-01-01 00:00:00", "1h", [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], False, False + expected_result = TimeSeriesVariableResolution( + [ + "2000-01-01 00:00:00", + "2000-01-01 01:00:00", + "2000-01-01 02:00:00", + "2000-01-01 03:00:00", + "2000-01-01 04:00:00", + "2000-01-01 05:00:00", + ], + [0.0, 1.0, 2.0, 4.0, 8.0, 0.0], + False, + False, ) self.assertEqual(time_series, expected_result) @@ -2166,12 +2176,18 @@ def test_all_value_types(self): db_map, [("Object", "widget", "X", to_database(value))], unparse_value=_identity ) ) - assert ( - db_map.parameter_value( - entity_class_name="Object", - entity_byname=("widget",), - parameter_definition_name="X", - alternative_name="Base", - )["parsed_value"] - == value - ) + db_value = db_map.parameter_value( + entity_class_name="Object", + entity_byname=("widget",), + parameter_definition_name="X", + alternative_name="Base", + )["parsed_value"] + if isinstance(value, TimeSeriesFixedResolution): + assert db_value == TimeSeriesVariableResolution( + ["2025-09-02T12:00"], [2.3], ignore_year=True, repeat=False + ) + elif isinstance(value, Map): + value.index_name = "col_1" + assert db_value == value + else: + assert db_value == value diff --git a/tests/test_incomplete_values.py b/tests/test_incomplete_values.py new file mode 100644 index 00000000..7bd9c023 --- /dev/null +++ b/tests/test_incomplete_values.py @@ -0,0 +1,124 @@ +###################################################################################################################### +# Copyright (C) 2017-2022 Spine project consortium +# Copyright Spine Database API contributors +# This file is part of Spine Database API. +# Spine Database API is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser +# General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; +# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General +# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with +# this program. If not, see . +###################################################################################################################### +from datetime import datetime +import json +import pyarrow +from spinedb_api import ( + Array, + DateTime, + Duration, + Map, + TimePattern, + TimeSeriesFixedResolution, + TimeSeriesVariableResolution, + duration_to_relativedelta, + to_database, +) +from spinedb_api.incomplete_values import ( + dump_db_value, + from_database_to_dimension_count, + join_value_and_type, + split_value_and_type, +) +from spinedb_api.parameter_value import to_list + + +class TestDumpDbValue: + def test_legacy_values(self): + values = [ + Duration("23 minutes"), + DateTime("2025-09-04T15:57"), + Array([Duration("5Y")]), + TimePattern(["M1-12"], [2.3]), + TimeSeriesFixedResolution("2025-09-04T15:59", "6h", [2.3], ignore_year=False, repeat=True), + TimeSeriesVariableResolution( + ["2025-09-04T15:59", "2025-09-04T16:00"], [2.3, 3.2], ignore_year=True, repeat=False + ), + Map(["A", "B"], [2.3, 3.2]), + ] + for value in values: + value_dict = value.to_dict() + value_dict["type"] = value.TYPE + assert dump_db_value(value_dict) == to_database(value) + + def test_scalars(self): + values = ["a string", False, 2.3, 5.0, 23, None] + types = ["str", "bool", "float", "float", "int", None] + for value, value_type in zip(values, types): + assert dump_db_value(value) == to_database(value) + + def test_record_batch(self): + index_array = pyarrow.array(["a", "b"]) + value_array = pyarrow.array([2.3, 3.2]) + value = pyarrow.record_batch({"col_1": index_array, "value": value_array}) + assert dump_db_value(to_list(value)) == to_database(value) + + +class TestFromDatabaseToDimensionCount: + def test_zero_dimensional_types(self): + assert from_database_to_dimension_count(*to_database(None)) == 0 + assert from_database_to_dimension_count(*to_database("a string")) == 0 + assert from_database_to_dimension_count(*to_database(5)) == 0 + assert from_database_to_dimension_count(*to_database(2.3)) == 0 + assert from_database_to_dimension_count(*to_database(True)) == 0 + assert ( + from_database_to_dimension_count(*to_database(datetime(year=2025, month=8, day=25, hour=15, minute=15))) + == 0 + ) + assert from_database_to_dimension_count(*to_database(duration_to_relativedelta("5 years"))) == 0 + + def test_one_dimensional_types(self): + assert from_database_to_dimension_count(*to_database(Array([2.3, 3.2]))) == 1 + assert from_database_to_dimension_count(*to_database(TimePattern(["WD1-7"], [23.0]))) == 1 + assert ( + from_database_to_dimension_count( + *to_database( + TimeSeriesFixedResolution("2025-08-25T15:15", "1h", [2.3], ignore_year=False, repeat=False) + ) + ) + == 1 + ) + assert ( + from_database_to_dimension_count( + *to_database( + TimeSeriesVariableResolution( + ["2025-08-25T15:15", "2025-08-25T16:15"], [2.3, 3.2], ignore_year=False, repeat=False + ) + ) + ) + == 1 + ) + + def test_variable_dimensional_types(self): + assert from_database_to_dimension_count(*to_database(Map(["a"], [2.3]))) == 1 + assert from_database_to_dimension_count(*to_database(Map(["a"], [Map(["A"], [2.3])]))) == 2 + indexes_1 = pyarrow.array(["A", "B"]) + values = pyarrow.array([2.3, 3.3]) + record_batch = pyarrow.record_batch({"category": indexes_1, "value": values}) + assert from_database_to_dimension_count(*to_database(record_batch)) == 1 + indexes_2 = pyarrow.array(["a", "a"]) + record_batch = pyarrow.record_batch({"category": indexes_1, "subcategory": indexes_2, "value": values}) + assert from_database_to_dimension_count(*to_database(record_batch)) == 2 + + +class TestJoinValueAndType: + def test_correctness(self): + blob, value_type = to_database(2.3) + assert json.loads(join_value_and_type(*to_database(2.3))) == [blob.decode(), value_type] + + +class TestSplitValueAndType: + def test_with_join_value_and_type(self): + blob, value_type = to_database(2.3) + assert split_value_and_type(join_value_and_type(blob, value_type)) == to_database(2.3) + blob, value_type = to_database(DateTime("2025-09-04T16:20")) + assert split_value_and_type(join_value_and_type(blob, value_type)) == (blob, value_type) diff --git a/tests/test_parameter_value.py b/tests/test_parameter_value.py index 446493a6..9df0d41a 100644 --- a/tests/test_parameter_value.py +++ b/tests/test_parameter_value.py @@ -19,6 +19,12 @@ from dateutil.relativedelta import relativedelta import numpy as np import numpy.testing +import pyarrow +import pytest +from spinedb_api import ParameterValueFormatError, SpineDBAPIError +from spinedb_api.arrow_value import load_field_metadata, with_column_as_time_period, with_column_as_time_stamps +from spinedb_api.compat.converters import parse_duration +from spinedb_api.helpers import time_period_format_specification, time_series_metadata from spinedb_api.parameter_value import ( Array, DateTime, @@ -28,6 +34,7 @@ TimeSeries, TimeSeriesFixedResolution, TimeSeriesVariableResolution, + _month_day_nano_interval_to_iso_duration, convert_containers_to_maps, convert_leaf_maps_to_specialized_containers, convert_map_to_table, @@ -37,8 +44,8 @@ from_database, relativedelta_to_duration, to_database, + type_and_rank_for_value, type_and_rank_to_fancy_type, - type_for_value, ) @@ -145,41 +152,32 @@ def test_relativedelta_to_duration_years(self): def test_from_database_plain_number(self): database_value = b"23.0" value = from_database(database_value, type_="float") - self.assertTrue(isinstance(value, float)) + self.assertIsInstance(value, float) self.assertEqual(value, 23.0) + def test_from_database_int_like_float(self): + database_value = b"23" + value = from_database(database_value, type_="float") + self.assertIsInstance(value, float) + self.assertEqual(value, 23.0) + + def test_from_database_int(self): + database_value = b"23" + value = from_database(database_value, type_="int") + self.assertIsInstance(value, int) + self.assertEqual(value, 23) + def test_from_database_boolean(self): database_value = b"true" value = from_database(database_value, type_="boolean") - self.assertTrue(isinstance(value, bool)) - self.assertEqual(value, True) - - def test_to_database_plain_number(self): - value = 23.0 - database_value, value_type = to_database(value) - value_as_float = json.loads(database_value) - self.assertEqual(value_as_float, value) - self.assertEqual(value_type, "float") - - def test_to_database_DateTime(self): - value = DateTime(datetime(year=2019, month=6, day=26, hour=12, minute=50, second=13)) - database_value, value_type = to_database(value) - value_as_dict = json.loads(database_value) - self.assertEqual(value_as_dict, {"data": "2019-06-26T12:50:13"}) - self.assertEqual(value_type, "date_time") + self.assertIsInstance(value, bool) + self.assertTrue(value) def test_from_database_DateTime(self): database_value = b'{"data": "2019-06-01T22:15:00+01:00"}' value = from_database(database_value, type_="date_time") self.assertEqual(value.value, dateutil.parser.parse("2019-06-01T22:15:00+01:00")) - def test_DateTime_to_database(self): - value = DateTime(datetime(year=2019, month=6, day=26, hour=10, minute=50, second=34)) - database_value, value_type = value.to_database() - value_dict = json.loads(database_value) - self.assertEqual(value_dict, {"data": "2019-06-26T10:50:34"}) - self.assertEqual(value_type, "date_time") - def test_from_database_Duration(self): database_value = b'{"data": "4 seconds"}' value = from_database(database_value, type_="duration") @@ -196,13 +194,6 @@ def test_from_database_Duration_legacy_list_format_converted_to_Array(self): expected = Array([Duration("1h"), Duration("1h"), Duration("1h"), Duration("2h")]) self.assertEqual(value, expected) - def test_Duration_to_database(self): - value = Duration(duration_to_relativedelta("8 years")) - database_value, value_type = value.to_database() - value_as_dict = json.loads(database_value) - self.assertEqual(value_as_dict, {"data": "8Y"}) - self.assertEqual(value_type, "duration") - def test_from_database_TimePattern(self): database_value = b""" { @@ -232,28 +223,6 @@ def test_from_database_TimePattern_with_index_name(self): numpy.testing.assert_equal(value.values, numpy.array([300.0])) self.assertEqual(value.index_name, "index") - def test_TimePattern_to_database(self): - value = TimePattern(["M1-4,M9-12", "M5-8"], numpy.array([300.0, 221.5])) - database_value, value_type = value.to_database() - value_as_dict = json.loads(database_value) - self.assertEqual(value_as_dict, {"data": {"M1-4,M9-12": 300.0, "M5-8": 221.5}}) - self.assertEqual(value_type, "time_pattern") - - def test_TimePattern_to_database_with_integer_values(self): - value = TimePattern(["M1-4,M9-12", "M5-8"], [300, 221]) - database_value, value_type = value.to_database() - value_as_dict = json.loads(database_value) - self.assertEqual(value_as_dict, {"data": {"M1-4,M9-12": 300.0, "M5-8": 221.0}}) - self.assertEqual(value_type, "time_pattern") - - def test_TimePattern_to_database_with_index_name(self): - value = TimePattern(["M1-12"], [300.0]) - value.index_name = "index" - database_value, value_type = value.to_database() - value_as_dict = json.loads(database_value) - self.assertEqual(value_as_dict, {"index_name": "index", "data": {"M1-12": 300.0}}) - self.assertEqual(value_type, "time_pattern") - def test_TimePattern_index_length_is_not_limited(self): value = TimePattern(["M1-4", "M5-12"], [300, 221]) value.indexes[0] = "M1-2,M3-4,M5-6,M7-8,M9-10,M11-12" @@ -277,7 +246,7 @@ def test_from_database_TimeSeriesVariableResolution_as_dictionary(self): ), ) self.assertEqual(len(time_series), 3) - self.assertTrue(isinstance(time_series.values, numpy.ndarray)) + self.assertIsInstance(time_series.values, numpy.ndarray) numpy.testing.assert_equal(time_series.values, numpy.array([4, 5, 6])) self.assertEqual(time_series.index_name, "t") @@ -336,39 +305,6 @@ def test_from_database_TimeSeriesFixedResolution_default_repeat(self): self.assertTrue(time_series.ignore_year) self.assertFalse(time_series.repeat) - def test_TimeSeriesVariableResolution_to_database(self): - dates = numpy.array(["1999-05-19", "2002-05-16", "2005-05-19"], dtype="datetime64[D]") - episodes = numpy.array([1, 2, 3], dtype=float) - value = TimeSeriesVariableResolution(dates, episodes, False, False) - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual(releases, {"data": {"1999-05-19": 1, "2002-05-16": 2, "2005-05-19": 3}}) - self.assertEqual(value_type, "time_series") - - def test_TimeSeriesVariableResolution_to_database_with_index_name(self): - dates = numpy.array(["2002-05-16", "2005-05-19"], dtype="datetime64[D]") - episodes = numpy.array([1, 2], dtype=float) - value = TimeSeriesVariableResolution(dates, episodes, False, False, "index") - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual(releases, {"index_name": "index", "data": {"2002-05-16": 1, "2005-05-19": 2}}) - self.assertEqual(value_type, "time_series") - - def test_TimeSeriesVariableResolution_to_database_with_ignore_year_and_repeat(self): - dates = numpy.array(["1999-05-19", "2002-05-16", "2005-05-19"], dtype="datetime64[D]") - episodes = numpy.array([1, 2, 3], dtype=float) - value = TimeSeriesVariableResolution(dates, episodes, True, True) - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual( - releases, - { - "data": {"1999-05-19": 1, "2002-05-16": 2, "2005-05-19": 3}, - "index": {"ignore_year": True, "repeat": True}, - }, - ) - self.assertEqual(value_type, "time_series") - def test_from_database_TimeSeriesFixedResolution(self): days_of_our_lives = b"""{ "index": { @@ -388,7 +324,7 @@ def test_from_database_TimeSeriesFixedResolution(self): dtype="datetime64[s]", ), ) - self.assertTrue(isinstance(time_series.values, numpy.ndarray)) + self.assertIsInstance(time_series.values, numpy.ndarray) numpy.testing.assert_equal(time_series.values, numpy.array([7.0, 5.0, 8.1])) self.assertEqual(time_series.start, dateutil.parser.parse("2019-03-23")) self.assertEqual(len(time_series.resolution), 1) @@ -527,79 +463,24 @@ def test_from_database_TimeSeriesFixedResolution_default_ignore_year(self): time_series = from_database(database_value, type_="time_series") self.assertTrue(time_series.ignore_year) - def test_TimeSeriesFixedResolution_to_database(self): - values = numpy.array([3, 2, 4], dtype=float) - resolution = [duration_to_relativedelta("1 months")] - start = datetime(year=2007, month=6, day=1) - value = TimeSeriesFixedResolution(start, resolution, values, True, True) - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual( - releases, - { - "index": {"start": "2007-06-01 00:00:00", "resolution": "1M", "ignore_year": True, "repeat": True}, - "data": [3, 2, 4], - }, - ) - self.assertEqual(value_type, "time_series") - - def test_TimeSeriesFixedResolution_to_database_with_index_type(self): - values = numpy.array([3, 2, 4], dtype=float) - resolution = [duration_to_relativedelta("1 months")] - start = datetime(year=2007, month=6, day=1) - value = TimeSeriesFixedResolution(start, resolution, values, True, True, "index") - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual( - releases, - { - "index_name": "index", - "index": {"start": "2007-06-01 00:00:00", "resolution": "1M", "ignore_year": True, "repeat": True}, - "data": [3, 2, 4], - }, - ) - self.assertEqual(value_type, "time_series") - - def test_TimeSeriesFixedResolution_resolution_list_to_database(self): - start = datetime(year=2007, month=1, day=1) - resolutions = ["1 month", "1 year"] - resolutions = [duration_to_relativedelta(r) for r in resolutions] - values = numpy.array([3.0, 2.0, 4.0]) - value = TimeSeriesFixedResolution(start, resolutions, values, True, True) - db_value, value_type = value.to_database() - releases = json.loads(db_value) - self.assertEqual( - releases, - { - "index": { - "start": "2007-01-01 00:00:00", - "resolution": ["1M", "1Y"], - "ignore_year": True, - "repeat": True, - }, - "data": [3.0, 2.0, 4.0], - }, - ) - self.assertEqual(value_type, "time_series") - def test_TimeSeriesFixedResolution_init_conversions(self): series = TimeSeriesFixedResolution("2019-01-03T00:30:33", "1D", [3.0, 2.0, 1.0], False, False) - self.assertTrue(isinstance(series.start, datetime)) - self.assertTrue(isinstance(series.resolution, list)) + self.assertIsInstance(series.start, datetime) + self.assertIsInstance(series.resolution, list) for element in series.resolution: - self.assertTrue(isinstance(element, relativedelta)) - self.assertTrue(isinstance(series.values, numpy.ndarray)) + self.assertIsInstance(element, relativedelta) + self.assertIsInstance(series.values, numpy.ndarray) series = TimeSeriesFixedResolution("2019-01-03T00:30:33", ["2h", "4h"], [3.0, 2.0, 1.0], False, False) - self.assertTrue(isinstance(series.resolution, list)) + self.assertIsInstance(series.resolution, list) for element in series.resolution: - self.assertTrue(isinstance(element, relativedelta)) + self.assertIsInstance(element, relativedelta) def test_TimeSeriesVariableResolution_init_conversion(self): series = TimeSeriesVariableResolution(["2008-07-08T03:00", "2008-08-08T13:30"], [3.3, 4.4], True, True) - self.assertTrue(isinstance(series.indexes, np.ndarray)) + self.assertIsInstance(series.indexes, np.ndarray) for index in series.indexes: - self.assertTrue(isinstance(index, np.datetime64)) - self.assertTrue(isinstance(series.values, np.ndarray)) + self.assertIsInstance(index, np.datetime64) + self.assertIsInstance(series.values, np.ndarray) def test_from_database_Map_with_index_name(self): database_value = b'{"index_type":"str", "index_name": "index", "data":[["a", 1.1]]}' @@ -676,114 +557,6 @@ def test_from_database_Map_with_TimePattern_values(self): self.assertEqual(value.indexes, [2.3]) self.assertEqual(value.values, [TimePattern(["M1-2", "M3-12"], [-9.3, -3.9])]) - def test_Map_to_database(self): - map_value = Map(["a", "b"], [1.1, 2.2]) - db_value, value_type = to_database(map_value) - raw = json.loads(db_value) - self.assertEqual(raw, {"index_type": "str", "rank": 1, "data": [["a", 1.1], ["b", 2.2]]}) - self.assertEqual(value_type, "map") - - def test_Map_to_database_with_index_names(self): - nested_map = Map(["a"], [0.3]) - nested_map.index_name = "nested index" - map_value = Map(["A"], [nested_map]) - map_value.index_name = "index" - db_value, value_type = to_database(map_value) - raw = json.loads(db_value) - self.assertEqual( - raw, - { - "index_type": "str", - "index_name": "index", - "rank": 2, - "data": [ - [ - "A", - { - "type": "map", - "index_type": "str", - "index_name": "nested index", - "rank": 1, - "data": [["a", 0.3]], - }, - ] - ], - }, - ) - self.assertEqual(value_type, "map") - - def test_Map_to_database_with_TimeSeries_values(self): - time_series1 = TimeSeriesVariableResolution(["2020-01-01T12:00", "2020-01-02T12:00"], [2.3, 4.5], False, False) - time_series2 = TimeSeriesVariableResolution( - ["2020-01-01T12:00", "2020-01-02T12:00"], [-4.5, -2.3], False, False - ) - map_value = Map(["a", "b"], [time_series1, time_series2]) - db_value, value_type = to_database(map_value) - raw = json.loads(db_value) - expected = { - "index_type": "str", - "rank": 2, - "data": [ - ["a", {"type": "time_series", "data": {"2020-01-01T12:00:00": 2.3, "2020-01-02T12:00:00": 4.5}}], - ["b", {"type": "time_series", "data": {"2020-01-01T12:00:00": -4.5, "2020-01-02T12:00:00": -2.3}}], - ], - } - self.assertEqual(raw, expected) - self.assertEqual(value_type, "map") - - def test_Map_to_database_nested_maps(self): - nested_map = Map([Duration("2 months")], [Duration("5 days")]) - map_value = Map([DateTime("2020-01-01T13:00")], [nested_map]) - db_value, value_type = to_database(map_value) - raw = json.loads(db_value) - self.assertEqual( - raw, - { - "index_type": "date_time", - "rank": 2, - "data": [ - [ - "2020-01-01T13:00:00", - { - "type": "map", - "index_type": "duration", - "rank": 1, - "data": [["2M", {"type": "duration", "data": "5D"}]], - }, - ] - ], - }, - ) - self.assertEqual(value_type, "map") - - def test_Array_of_floats_to_database(self): - array = Array([-1.1, -2.2, -3.3]) - db_value, value_type = to_database(array) - raw = json.loads(db_value) - self.assertEqual(raw, {"value_type": "float", "data": [-1.1, -2.2, -3.3]}) - self.assertEqual(value_type, "array") - - def test_Array_of_strings_to_database(self): - array = Array(["a", "b"]) - db_value, value_type = to_database(array) - raw = json.loads(db_value) - self.assertEqual(raw, {"value_type": "str", "data": ["a", "b"]}) - self.assertEqual(value_type, "array") - - def test_Array_of_DateTimes_to_database(self): - array = Array([DateTime("2020-01-01T13:00")]) - db_value, value_type = to_database(array) - raw = json.loads(db_value) - self.assertEqual(raw, {"value_type": "date_time", "data": ["2020-01-01T13:00:00"]}) - self.assertEqual(value_type, "array") - - def test_Array_of_Durations_to_database(self): - array = Array([Duration("4 months")]) - db_value, value_type = to_database(array) - raw = json.loads(db_value) - self.assertEqual(raw, {"value_type": "duration", "data": ["4M"]}) - self.assertEqual(value_type, "array") - def test_Array_of_floats_from_database(self): database_value = b"""{ "value_type": "float", @@ -1096,23 +869,169 @@ def test_deep_copy_map(self): self.assertIsNot(x.get_value("T1"), copy_of_x.get_value("T1")) -class TestTimeSeriesVariableResolution: - def test_get_value(self): - time_series = TimeSeriesVariableResolution( - ["2025-07-01T15:45", "2025-07-01T16:45"], [2.3, 3.2], ignore_year=False, repeat=False, index_name="y" +class TestInt: + def test_from_database(self): + value = from_database(*to_database(23)) + assert isinstance(value, int) + assert value == 23 + + +class TestFloat: + def test_from_database(self): + value = from_database(*to_database(2.3)) + assert isinstance(value, float) + assert value == 2.3 + value = from_database(*to_database(5.0)) + assert isinstance(value, float) + assert value == 5.0 + + +class TestDuration: + def test_as_arrow(self): + duration = Duration("15h") + assert duration.as_arrow() == duration_to_relativedelta("15h") + + def test_from_arrow(self): + assert Duration.from_arrow(duration_to_relativedelta("15h")) == Duration("15h") + + +class TestDateTime: + def test_as_arrow(self): + date_time = DateTime("2025-08-06T14:37") + assert date_time.as_arrow() == datetime.fromisoformat("2025-08-06T14:37") + + def test_from_arrow(self): + assert DateTime.from_arrow(datetime.fromisoformat("2025-08-21T11:11")) == DateTime("2025-08-21T11:11") + + +class TestArray: + def test_as_arrow_with_float(self): + array = Array([2.3, 3.2]) + record_batch = array.as_arrow() + expected = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array([2.3, 3.2])}) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_string(self): + array = Array(["a", "b"]) + record_batch = array.as_arrow() + expected = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array(["a", "b"])}) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_bool(self): + array = Array([False, True]) + record_batch = array.as_arrow() + expected = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array([False, True])}) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_duration(self): + array = Array([Duration("3D"), Duration("7m")]) + record_batch = array.as_arrow() + expected = pyarrow.record_batch( + { + "i": pyarrow.array([1, 2]), + "value": pyarrow.array([duration_to_relativedelta("3 days"), duration_to_relativedelta("7 minutes")]), + } ) - assert time_series.get_value(np.datetime64("2025-07-01T15:45")) == 2.3 - assert time_series.get_value(np.datetime64("2025-07-01T16:45")) == 3.2 - assert time_series.get_value(np.datetime64("2025-07-01T16:00")) is None + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_from_database_with_empty_array(self): + value, value_type = to_database(Array([])) + array = from_database(value, value_type) + assert array == Array([]) + + def test_from_database_with_floats_and_index_name(self): + value, value_type = to_database(Array([2.3], index_name="my index")) + array = from_database(value, value_type) + assert array == Array([2.3], index_name="my index") + + def test_from_database_with_duration(self): + value, value_type = to_database(Array([Duration("4 months")])) + array = from_database(value, value_type) + assert array == Array([Duration("4 months")]) + + def test_from_arrow_with_float(self): + record_batch = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array([2.3, 3.2])}) + array = Array.from_arrow(record_batch) + assert array == Array([2.3, 3.2]) + + def test_from_arrow_with_float_and_index_name(self): + record_batch = pyarrow.record_batch({"my index": pyarrow.array([1, 2]), "value": pyarrow.array([2.3, 3.2])}) + array = Array.from_arrow(record_batch) + assert array == Array([2.3, 3.2], index_name="my index") + + def test_from_arrow_with_string(self): + record_batch = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array(["P", "Q"])}) + array = Array.from_arrow(record_batch) + assert array == Array(["P", "Q"]) + + def test_from_arrow_with_bool(self): + record_batch = pyarrow.record_batch({"i": pyarrow.array([1, 2]), "value": pyarrow.array([True, False])}) + array = Array.from_arrow(record_batch) + assert array == Array([True, False]) + + def test_from_arrow_with_duration(self): + record_batch = pyarrow.record_batch( + { + "i": pyarrow.array([1, 2]), + "value": pyarrow.array([duration_to_relativedelta("3 days"), duration_to_relativedelta("7 minutes")]), + } + ) + array = Array.from_arrow(record_batch) + assert array == Array([Duration("3 days"), Duration("7 minutes")]) - def test_set_value(self): - time_series = TimeSeriesVariableResolution( - ["2025-07-01T15:45", "2025-07-01T16:45"], [2.3, 3.2], ignore_year=False, repeat=False, index_name="y" + def test_from_arrow_with_duration_and_index_name(self): + record_batch = pyarrow.record_batch( + { + "my index": pyarrow.array([1, 2]), + "value": pyarrow.array([duration_to_relativedelta("3 days"), duration_to_relativedelta("7 minutes")]), + } ) - time_series.set_value(np.datetime64("2025-07-01T15:45"), -2.3) - assert time_series.get_value(np.datetime64("2025-07-01T15:45")) == -2.3 - time_series.set_value(np.datetime64("2025-07-01T16:45"), -3.2) - assert time_series.get_value(np.datetime64("2025-07-01T16:45")) == -3.2 + array = Array.from_arrow(record_batch) + assert array == Array([Duration("3 days"), Duration("7 minutes")], index_name="my index") + + def test_from_arrow_with_empty_record_batch(self): + record_batch = pyarrow.record_batch({"i": [], "value": pyarrow.array([], type=pyarrow.float64())}) + array = Array.from_arrow(record_batch) + assert array == Array([], value_type=float) + record_batch = pyarrow.record_batch({"my index": [], "value": pyarrow.array([], type=pyarrow.float64())}) + array = Array.from_arrow(record_batch) + assert array == Array([], value_type=float, index_name="my index") + record_batch = pyarrow.record_batch({"i": [], "value": pyarrow.array([], type=pyarrow.bool_())}) + array = Array.from_arrow(record_batch) + assert array == Array([], value_type=bool) + record_batch = pyarrow.record_batch({"i": [], "value": pyarrow.array([], type=pyarrow.string())}) + array = Array.from_arrow(record_batch) + assert array == Array([], value_type=str) + record_batch = pyarrow.record_batch( + {"i": [], "value": pyarrow.array([], type=pyarrow.month_day_nano_interval())} + ) + array = Array.from_arrow(record_batch) + assert array == Array([], value_type=relativedelta) + + +class TestTimePattern: + def test_as_arrow(self): + pattern = TimePattern(["WD1-7"], [2.3]) + record_batch = pattern.as_arrow() + index_array = pyarrow.array(["WD1-7"]) + value_array = pyarrow.array([2.3]) + expected = with_column_as_time_period(pyarrow.record_batch({"p": index_array, "value": value_array}), "p") + assert record_batch == expected + assert load_field_metadata(record_batch.field("p")) == time_period_format_specification() + + def test_from_arrow(self): + index_array = pyarrow.array(["WD1-7"]) + value_array = pyarrow.array([2.3]) + record_batch = with_column_as_time_period( + pyarrow.record_batch({"pattern": index_array, "value": value_array}), "pattern" + ) + time_pattern = TimePattern.from_arrow(record_batch) + assert time_pattern == TimePattern(["WD1-7"], [2.3], index_name="pattern") + assert time_pattern.get_value("WD1-7") == 2.3 class TestTimeSeriesFixedResolution: @@ -1156,6 +1075,858 @@ def test_set_value_with_multiresolution(self): time_series.set_value(np.datetime64("2025-07-01T18:00"), -4.4) assert time_series.get_value(np.datetime64("2025-07-01T18:00")) == -4.4 + def test_as_arrow(self): + time_series = TimeSeriesFixedResolution("2025-08-18T17:11", "4h", [2.3, 3.2], ignore_year=False, repeat=False) + record_batch = time_series.as_arrow() + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=18, hour=17, minute=11), + datetime(year=2025, month=8, day=18, hour=21, minute=11), + ], + type=pyarrow.timestamp("s"), + ) + value_array = pyarrow.array([2.3, 3.2]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"t": index_array, "value": value_array}), "t", ignore_year=False, repeat=False + ) + assert record_batch == expected + assert load_field_metadata(record_batch.field("t")) == time_series_metadata(ignore_year=False, repeat=False) + + def test_as_arrow_with_index_name(self): + time_series = TimeSeriesFixedResolution( + "2025-08-18T17:11", "4h", [2.3, 3.2], ignore_year=True, repeat=True, index_name="stamps" + ) + record_batch = time_series.as_arrow() + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=18, hour=17, minute=11), + datetime(year=2025, month=8, day=18, hour=21, minute=11), + ], + type=pyarrow.timestamp("s"), + ) + value_array = pyarrow.array([2.3, 3.2]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"stamps": index_array, "value": value_array}), "stamps", ignore_year=True, repeat=True + ) + assert record_batch == expected + assert load_field_metadata(record_batch.field("stamps")) == time_series_metadata(ignore_year=True, repeat=True) + + def test_from_database_with_fixed_resolution_series(self): + original = TimeSeriesFixedResolution("2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=False, repeat=False) + value, value_type = to_database(original) + deserialized = from_database(value, value_type) + assert deserialized == TimeSeriesVariableResolution( + ["2025-02-05T09:59", "2025-02-05T10:14"], [1.1, 1.2], ignore_year=False, repeat=False + ) + + def test_from_database_with_ignore_year(self): + original = TimeSeriesFixedResolution("2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=True, repeat=False) + value, value_type = to_database(original) + deserialized = from_database(value, value_type) + assert deserialized == TimeSeriesVariableResolution( + ["2025-02-05T09:59", "2025-02-05T10:14"], [1.1, 1.2], ignore_year=True, repeat=False + ) + + def test_from_database_with_repeat(self): + original = TimeSeriesFixedResolution("2025-02-05T09:59", "15m", [1.1, 1.2], ignore_year=False, repeat=True) + value, value_type = to_database(original) + serialized = from_database(value, value_type) + assert serialized == TimeSeriesVariableResolution( + ["2025-02-05T09:59", "2025-02-05T10:14"], [1.1, 1.2], ignore_year=False, repeat=True + ) + + def test_from_arrow(self): + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=18, hour=17, minute=11), + datetime(year=2025, month=8, day=18, hour=21, minute=11), + ], + ) + value_array = pyarrow.array([2.3, 3.2]) + record_batch = with_column_as_time_stamps( + pyarrow.record_batch({"stamp": index_array, "value": value_array}), "stamp", ignore_year=True, repeat=False + ) + time_series = TimeSeriesFixedResolution.from_arrow(record_batch) + assert time_series == TimeSeriesFixedResolution( + "2025-08-18T17:11", "4h", [2.3, 3.2], ignore_year=True, repeat=False, index_name="stamp" + ) + + def test_from_arrow_with_variable_resolution(self): + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=22, hour=1, minute=0), + datetime(year=2025, month=8, day=22, hour=3, minute=0), + datetime(year=2025, month=8, day=22, hour=6, minute=0), + datetime(year=2025, month=8, day=22, hour=10, minute=0), + datetime(year=2025, month=8, day=22, hour=12, minute=0), + datetime(year=2025, month=8, day=22, hour=15, minute=0), + ], + ) + value_array = pyarrow.array([2.3, 3.3, 4.3, 5.3, 6.3, 7.3]) + record_batch = with_column_as_time_stamps( + pyarrow.record_batch({"t": index_array, "value": value_array}), "t", ignore_year=False, repeat=True + ) + time_series = TimeSeriesFixedResolution.from_arrow(record_batch) + assert time_series == TimeSeriesFixedResolution( + "2025-08-22T01:00", + ["2h", "3h", "4h"], + [2.3, 3.3, 4.3, 5.3, 6.3, 7.3], + ignore_year=False, + repeat=True, + index_name="t", + ) + + +class TestTimeSeriesVariableResolution: + def test_get_value(self): + time_series = TimeSeriesVariableResolution( + ["2025-07-01T15:45", "2025-07-01T16:45"], [2.3, 3.2], ignore_year=False, repeat=False, index_name="y" + ) + assert time_series.get_value(np.datetime64("2025-07-01T15:45")) == 2.3 + assert time_series.get_value(np.datetime64("2025-07-01T16:45")) == 3.2 + assert time_series.get_value(np.datetime64("2025-07-01T16:00")) is None + + def test_set_value(self): + time_series = TimeSeriesVariableResolution( + ["2025-07-01T15:45", "2025-07-01T16:45"], [2.3, 3.2], ignore_year=False, repeat=False, index_name="y" + ) + time_series.set_value(np.datetime64("2025-07-01T15:45"), -2.3) + assert time_series.get_value(np.datetime64("2025-07-01T15:45")) == -2.3 + time_series.set_value(np.datetime64("2025-07-01T16:45"), -3.2) + assert time_series.get_value(np.datetime64("2025-07-01T16:45")) == -3.2 + + def test_as_arrow(self): + time_series = TimeSeriesVariableResolution( + ["2025-08-21T09:10", "2025-08-21T09:14"], [2.3, 3.2], ignore_year=False, repeat=False + ) + record_batch = time_series.as_arrow() + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=21, hour=9, minute=10), + datetime(year=2025, month=8, day=21, hour=9, minute=14), + ], + type=pyarrow.timestamp("s"), + ) + value_array = pyarrow.array([2.3, 3.2]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"t": index_array, "value": value_array}), "t", ignore_year=False, repeat=False + ) + assert record_batch == expected + assert load_field_metadata(record_batch.field("t")) == time_series_metadata(ignore_year=False, repeat=False) + + def test_as_arrow_with_index_name(self): + time_series = TimeSeriesVariableResolution( + ["2025-08-21T09:10", "2025-08-21T09:14"], [2.3, 3.2], ignore_year=True, repeat=True, index_name="stamp" + ) + record_batch = time_series.as_arrow() + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=21, hour=9, minute=10), + datetime(year=2025, month=8, day=21, hour=9, minute=14), + ], + type=pyarrow.timestamp("s"), + ) + value_array = pyarrow.array([2.3, 3.2]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"stamp": index_array, "value": value_array}), "stamp", ignore_year=True, repeat=True + ) + assert record_batch == expected + assert load_field_metadata(record_batch.field("stamp")) == time_series_metadata(ignore_year=True, repeat=True) + + def test_from_database(self): + original = TimeSeriesVariableResolution( + ["2025-02-05T09:59", "2025-02-05T10:14", "2025-02-05T11:31"], + [1.1, 1.2, 1.3], + ignore_year=False, + repeat=False, + ) + value, value_type = to_database(original) + deserialized = from_database(value, value_type) + assert deserialized == original + + def test_from_arrow(self): + index_array = pyarrow.array( + [ + datetime(year=2025, month=8, day=21, hour=9, minute=10), + datetime(year=2025, month=8, day=21, hour=9, minute=14), + ], + ) + value_array = pyarrow.array([2.3, 3.2]) + record_batch = with_column_as_time_stamps( + pyarrow.record_batch({"stamp": index_array, "value": value_array}), "stamp", ignore_year=False, repeat=True + ) + time_series = TimeSeriesVariableResolution.from_arrow(record_batch) + assert time_series == TimeSeriesVariableResolution( + ["2025-08-21T09:10", "2025-08-21T09:14"], [2.3, 3.2], ignore_year=False, repeat=True, index_name="stamp" + ) + + +class TestMap: + def test_as_arrow(self): + map_value = Map(["a", "b"], [2.3, 3.2]) + record_batch = map_value.as_arrow() + expected = pyarrow.record_batch({"col_1": pyarrow.array(["a", "b"]), "value": pyarrow.array([2.3, 3.2])}) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_empty_map(self): + map_value = Map([], [], index_type=str) + record_batch = map_value.as_arrow() + expected = pyarrow.record_batch( + {"col_1": pyarrow.array([], type=pyarrow.string()), "value": pyarrow.array([], type=pyarrow.float64())} + ) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_index_name(self): + map_value = Map( + [DateTime("2025-08-05T15:30"), DateTime("2025-08-05T15:45")], [2.3, 3.2], index_name="my indexes" + ) + record_batch = map_value.as_arrow() + expected = pyarrow.record_batch( + { + "my indexes": pyarrow.array( + [ + datetime(year=2025, month=8, day=5, hour=15, minute=30), + datetime(year=2025, month=8, day=5, hour=15, minute=45), + ], + type=pyarrow.timestamp("s"), + ), + "value": pyarrow.array([2.3, 3.2]), + } + ) + assert record_batch == expected + assert record_batch.schema.metadata == expected.schema.metadata + + def test_as_arrow_with_nested_map(self): + map_value = Map(["a", "b", "c"], ["yes", Map([Duration("1h"), Duration("2h")], [True, False]), "no"]) + record_batch = map_value.as_arrow() + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([1, 3, 4], ["a", "b", "c"]) + indexes_2 = pyarrow.array([None, duration_to_relativedelta("1h"), duration_to_relativedelta("2h"), None]) + str_values = pyarrow.array(["yes", "no"]) + bool_values = pyarrow.array([True, False]) + values = pyarrow.UnionArray.from_dense( + pyarrow.array([0, 1, 1, 0], type=pyarrow.int8()), + pyarrow.array([0, 0, 1, 1], type=pyarrow.int32()), + [str_values, bool_values], + ) + expected = pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "value": values}) + assert record_batch == expected + assert record_batch.schema.metadata is None + + def test_as_arrow_with_time_series(self): + map_value = Map( + [-1.0, -2.0], + [ + TimeSeriesFixedResolution("2025-08-05T17:00", "15m", [2.3, 3.2], ignore_year=True, repeat=True), + TimeSeriesVariableResolution( + ["2025-08-05T17:32", "2025-08-05T17:41"], [-2.3, -3.2], ignore_year=True, repeat=True + ), + ], + ) + record_batch = map_value.as_arrow() + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([2, 4], [-1.0, -2.0]) + indexes_2 = pyarrow.array( + [ + datetime(year=2025, month=8, day=5, hour=17), + datetime(year=2025, month=8, day=5, hour=17, minute=15), + datetime(year=2025, month=8, day=5, hour=17, minute=32), + datetime(year=2025, month=8, day=5, hour=17, minute=41), + ], + type=pyarrow.timestamp("s"), + ) + values = pyarrow.array([2.3, 3.2, -2.3, -3.2]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"col_1": indexes_1, "t": indexes_2, "value": values}), + "t", + ignore_year=True, + repeat=True, + ) + assert record_batch == expected + assert record_batch.schema.metadata is None + assert load_field_metadata(record_batch.field("t")) == time_series_metadata(ignore_year=True, repeat=True) + + def test_as_arrow_with_uneven_time_series(self): + forecast_value = Map( + [DateTime("2025-09-17T11:00")], + [ + TimeSeriesVariableResolution( + ["2025-09-17T11:00", "2025-09-17T12:00"], [-2.3, -3.2], ignore_year=False, repeat=False + ), + ], + ) + map_value = Map( + ["forecast", "realization"], + [ + forecast_value, + TimeSeriesVariableResolution( + ["2025-09-17T11:00", "2025-09-17T12:00"], [-2.2, -3.1], ignore_year=False, repeat=False + ), + ], + ) + record_batch = map_value.as_arrow() + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([2, 4], ["forecast", "realization"]) + indexes_2 = pyarrow.RunEndEncodedArray.from_arrays( + [2, 4], + [datetime(year=2025, month=9, day=17, hour=11), None], + type=pyarrow.run_end_encoded(pyarrow.int64(), pyarrow.timestamp("s")), + ) + indexes_3 = pyarrow.array( + [ + datetime(year=2025, month=9, day=17, hour=11), + datetime(year=2025, month=9, day=17, hour=12), + datetime(year=2025, month=9, day=17, hour=11), + datetime(year=2025, month=9, day=17, hour=12), + ], + type=pyarrow.timestamp("s"), + ) + values = pyarrow.array([-2.3, -3.2, -2.2, -3.1]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "t": indexes_3, "value": values}), + "t", + ignore_year=False, + repeat=False, + ) + assert record_batch == expected + assert record_batch.schema.metadata is None + assert load_field_metadata(record_batch.field("t")) == time_series_metadata(ignore_year=False, repeat=False) + + def test_as_arrow_with_nested_time_series(self): + map_value = Map( + ["realization", "forecast_1", "forecast_tail"], + [ + Map( + [DateTime("2000-01-01T00:00:00")], + [ + TimeSeriesFixedResolution( + "2000-01-01T00:00:00", "1h", [0.73, 0.66], ignore_year=False, repeat=False + ) + ], + ), + Map( + [DateTime("2000-01-01T00:00:00")], + [ + TimeSeriesFixedResolution( + "2000-01-01T00:00:00", "1h", [0.63, 0.61], ignore_year=False, repeat=False + ) + ], + ), + Map( + [DateTime("2000-01-01T00:00:00")], + [ + TimeSeriesFixedResolution( + "2000-01-01T00:00:00", "1h", [0.68, 0.64], ignore_year=False, repeat=False + ) + ], + ), + ], + ) + record_batch = map_value.as_arrow() + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([2, 4, 6], ["realization", "forecast_1", "forecast_tail"]) + indexes_2 = pyarrow.RunEndEncodedArray.from_arrays( + [6], + [ + datetime(year=2000, month=1, day=1, hour=00), + ], + type=pyarrow.run_end_encoded(pyarrow.int64(), pyarrow.timestamp("s")), + ) + indexes_3 = pyarrow.array( + [ + datetime(year=2000, month=1, day=1, hour=00), + datetime(year=2000, month=1, day=1, hour=1), + datetime(year=2000, month=1, day=1, hour=00), + datetime(year=2000, month=1, day=1, hour=1), + datetime(year=2000, month=1, day=1, hour=00), + datetime(year=2000, month=1, day=1, hour=1), + ], + type=pyarrow.timestamp("s"), + ) + values = pyarrow.array([0.73, 0.66, 0.63, 0.61, 0.68, 0.64]) + expected = with_column_as_time_stamps( + pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "t": indexes_3, "value": values}), + "t", + ignore_year=False, + repeat=False, + ) + assert record_batch == expected + assert record_batch.schema.metadata is None + assert load_field_metadata(record_batch.field("t")) == time_series_metadata(ignore_year=False, repeat=False) + + def test_as_arrow_uneven_with_array_leafs(self): + row_lookup = Map(["columns", "values"], [Array(["A"]), Array(["!area-!other_area"])]) + cells = Map(["column", "row_lookup"], ["C", row_lookup]) + path_patterns = Array(["Lines", "Reference.xlsx"]) + map_value = Map( + ["cells", "path_patterns", "factor", "sheet", "type"], + [cells, path_patterns, 1.0, "2025", "single_value_lookup"], + ) + record_batch = map_value.as_arrow() + indexes_1 = pyarrow.compute.run_end_encode( + pyarrow.array(["cells", "cells", "cells", "path_patterns", "path_patterns", "factor", "sheet", "type"]), + run_end_type=pyarrow.int64(), + ) + indexes_2 = pyarrow.compute.run_end_encode( + pyarrow.array(["column", "row_lookup", "row_lookup", None, None, None, None, None]), + run_end_type=pyarrow.int64(), + ) + indexes_3 = pyarrow.compute.run_end_encode( + pyarrow.array([None, "columns", "values", None, None, None, None, None]), run_end_type=pyarrow.int64() + ) + indexes_4 = pyarrow.array( + [None, 0, 0, 0, 1, None, None, None], + ) + str_values = pyarrow.array( + ["C", "A", "!area-!other_area", "Lines", "Reference.xlsx", "2025", "single_value_lookup"] + ) + float_values = pyarrow.array([1.0]) + values = pyarrow.UnionArray.from_dense( + pyarrow.array([0, 0, 0, 0, 0, 1, 0, 0], type=pyarrow.int8()), + pyarrow.array([0, 1, 2, 3, 4, 0, 5, 6], type=pyarrow.int32()), + [str_values, float_values], + ) + expected = pyarrow.record_batch( + {"col_1": indexes_1, "col_2": indexes_2, "col_3": indexes_3, "i": indexes_4, "value": values} + ) + assert record_batch == expected + + def test_from_database(self): + original_map = Map(["A", "B"], [2.3, 3.2]) + blob, value_type = to_database(original_map) + deserialized = from_database(blob, value_type) + assert deserialized == Map(["A", "B"], [2.3, 3.2], index_name="col_1") + + def test_from_database_with_empty_map(self): + original_map = Map([], [], str) + value, value_type = to_database(original_map) + deserialized = from_database(value, value_type) + assert deserialized == Map([], [], str, index_name="col_1") + + def test_from_database_with_string_to_string_map_with_index_name(self): + original_map = Map(["key"], ["value"], index_name="Keys") + value, value_type = to_database(original_map) + deserialized = from_database(value, value_type) + assert deserialized == original_map + + def test_from_database_with_date_time_to_different_simple_types_map_with_index_name(self): + original_map = Map( + [DateTime("2024-02-09T10:00"), DateTime("2024-02-09T11:00")], + ["value", 2.3], + index_name="timestamps", + ) + value, value_type = to_database(original_map) + deserialized = from_database(value, value_type) + assert deserialized == original_map + + def test_from_database_with_nested_maps_of_different_index_types_raises(self): + string_map = Map([11.0], ["value"], index_name="nested index") + float_map = Map(["key"], [22.0], index_name="nested index") + original_map = Map(["strings", "floats"], [string_map, float_map], index_name="main index") + with pytest.raises(SpineDBAPIError, match="^different index types at the same depth are not supported$"): + to_database(original_map) + + def test_from_database_with_unevenly_nested_map_with_fixed_resolution_time_series(self): + float_map = Map(["key"], [22.0], index_name="nested index") + time_series = TimeSeriesVariableResolution( + ["2025-02-26T09:00:00", "2025-02-26T10:00:00"], [2.3, 23.0], ignore_year=False, repeat=False + ) + time_series_map = Map([DateTime("2024-02-26T16:45:00")], [time_series]) + nested_time_series_map = Map(["ts", "no ts"], [time_series_map, "empty"], index_name="nested index") + original_map = Map( + ["not nested", "time series", "floats"], + ["none", nested_time_series_map, float_map], + index_name="main index", + ) + value, value_type = to_database(original_map) + deserialized = from_database(value, value_type) + expected = deep_copy_value(original_map) + expected.get_value("time series").get_value("ts").index_name = "col_3" + assert deserialized == expected + + def test_from_database_with_unevenly_nested_map(self): + string_map = Map(["first"], ["value"], index_name="nested index") + float_map = Map(["key"], [22.0], index_name="nested index") + duration_map = Map([Duration("12h")], [Duration("9M")]) + another_string_map = Map([Duration("11h")], ["future"]) + nested_map = Map( + ["nested durations", "duration to string", "non nested"], + [duration_map, another_string_map, "empty"], + index_name="nested index", + ) + original_map = Map( + ["not nested", "strings", "durations", "floats"], + ["none", string_map, nested_map, float_map], + index_name="main index", + ) + value, value_type = to_database(original_map) + deserialized = from_database(value, value_type) + expected = deep_copy_value(original_map) + expected.get_value("durations").get_value("nested durations").index_name = "col_3" + expected.get_value("durations").get_value("duration to string").index_name = "col_3" + assert deserialized == expected + + def test_from_database_uneven_with_array_leafs(self): + row_lookup = Map(["columns", "values"], [Array(["A"]), Array(["!area-!other_area"])]) + cells = Map(["column", "row_lookup"], ["C", row_lookup]) + path_patterns = Array(["Lines", "Reference.xlsx"]) + original = Map( + ["cells", "path_patterns", "factor", "sheet", "type"], + [cells, path_patterns, 1.0, "2025", "single_value_lookup"], + ) + + blob, value_type = to_database(original) + deserialized = from_database(blob, value_type) + expected = deep_copy_value(original) + expected.index_name = "col_1" + expected.get_value("cells").index_name = "col_2" + expected.get_value("cells").get_value("row_lookup").index_name = "col_3" + assert deserialized == expected + + def test_from_arrow_with_empty_record_batch(self): + record_batch = pyarrow.record_batch( + {"col_1": pyarrow.array([], type=pyarrow.string()), "value": pyarrow.array([], type=pyarrow.float64())} + ) + map_value = Map.from_arrow(record_batch) + assert map_value == Map([], [], index_type=str, index_name="col_1") + + def test_from_arrow_with_string_indices(self): + record_batch = pyarrow.record_batch({"col_1": pyarrow.array(["A", "B"]), "value": pyarrow.array([2.3, 3.2])}) + map_value = Map.from_arrow(record_batch) + assert map_value == Map(["A", "B"], [2.3, 3.2], index_type=str, index_name="col_1") + + def test_from_arrow_with_duration_values(self): + record_batch = pyarrow.record_batch( + { + "col_1": pyarrow.array(["A", "B"]), + "value": pyarrow.array([relativedelta(months=3), relativedelta(months=7)]), + } + ) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["A", "B"], [Duration("3 months"), Duration("7 months")], index_type=str, index_name="col_1" + ) + + def test_from_arrow_with_datetime_indices(self): + record_batch = pyarrow.record_batch( + { + "col_1": pyarrow.array( + [ + datetime(year=2025, month=8, day=22, hour=16, minute=40), + datetime(year=2025, month=8, day=22, hour=18, minute=0), + ] + ), + "value": pyarrow.array([2.3, 3.2]), + } + ) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + [DateTime("2025-08-22T16:40"), DateTime("2025-08-22T18:00")], [2.3, 3.2], index_name="col_1" + ) + + def test_from_arrow_with_nested_structure(self): + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([1, 3, 4], ["a", "b", "c"]) + indexes_2 = pyarrow.array([None, duration_to_relativedelta("1h"), duration_to_relativedelta("2h"), None]) + str_values = pyarrow.array(["yes", "no"]) + bool_values = pyarrow.array([True, False]) + values = pyarrow.UnionArray.from_dense( + pyarrow.array([0, 1, 1, 0], type=pyarrow.int8()), + pyarrow.array([0, 0, 1, 1], type=pyarrow.int32()), + [str_values, bool_values], + ) + record_batch = pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "value": values}) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["a", "b", "c"], + ["yes", Map([Duration("1h"), Duration("2h")], [True, False], index_name="col_2"), "no"], + index_name="col_1", + ) + + def test_from_arrow_with_unevenly_nested_structure(self): + indexes_1 = pyarrow.array(["A", "A"]) + indexes_2 = pyarrow.array([None, "a"]) + values = pyarrow.array(["non-nested", "nested"]) + record_batch = pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "value": values}) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["A", "A"], + ["non-nested", Map(["a"], ["nested"], index_name="col_2")], + index_name="col_1", + ) + indexes_2 = pyarrow.array(["a", None]) + values = pyarrow.array(["nested", "non-nested"]) + record_batch = pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "value": values}) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["A", "A"], + [Map(["a"], ["nested"], index_name="col_2"), "non-nested"], + index_name="col_1", + ) + indexes_1 = pyarrow.array(["A", "A", "A"]) + indexes_2 = pyarrow.array(["a", None, "b"]) + values = pyarrow.array(["nested1", "non-nested", "nested2"]) + record_batch = pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "value": values}) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["A", "A"], + [Map(["a", "b"], ["nested1", "nested2"], index_name="col_2"), "non-nested"], + index_name="col_1", + ) + + def test_from_arrow_with_nested_time_series(self): + indexes_1 = pyarrow.array(["a", "a", "b", "b"]) + indexes_2 = pyarrow.array( + [ + datetime(year=2025, month=8, day=25, hour=8, minute=25), + datetime(year=2025, month=8, day=25, hour=8, minute=35), + datetime(year=2025, month=8, day=25, hour=8, minute=35), + datetime(year=2025, month=8, day=25, hour=8, minute=45), + ] + ) + values = pyarrow.array([2.3, 3.2, -2.3, -3.2]) + metadata = {"stamp": json.dumps(time_series_metadata(ignore_year=True, repeat=False))} + record_batch = with_column_as_time_stamps( + pyarrow.record_batch({"choice": indexes_1, "stamp": indexes_2, "value": values}, metadata=metadata), + "stamp", + ignore_year=True, + repeat=False, + ) + map_value = Map.from_arrow(record_batch) + assert map_value == Map( + ["a", "b"], + [ + TimeSeriesVariableResolution( + [ + "2025-08-25T08:25", + "2025-08-25T08:35", + ], + [2.3, 3.2], + ignore_year=True, + repeat=False, + index_name="stamp", + ), + TimeSeriesVariableResolution( + [ + "2025-08-25T08:35", + "2025-08-25T08:45", + ], + [-2.3, -3.2], + ignore_year=True, + repeat=False, + index_name="stamp", + ), + ], + index_name="choice", + ) + + def test_from_arrow_with_uneven_nested_time_series(self): + indexes_1 = pyarrow.RunEndEncodedArray.from_arrays([2, 4], ["forecast", "realization"]) + indexes_2 = pyarrow.RunEndEncodedArray.from_arrays( + [2, 4], + [datetime(year=2025, month=9, day=17, hour=11), None], + type=pyarrow.run_end_encoded(pyarrow.int64(), pyarrow.timestamp("s")), + ) + indexes_3 = pyarrow.array( + [ + datetime(year=2025, month=9, day=17, hour=11), + datetime(year=2025, month=9, day=17, hour=12), + datetime(year=2025, month=9, day=17, hour=11), + datetime(year=2025, month=9, day=17, hour=12), + ], + type=pyarrow.timestamp("s"), + ) + values = pyarrow.array([-2.3, -3.2, -2.2, -3.1]) + record_batch = with_column_as_time_stamps( + pyarrow.record_batch({"col_1": indexes_1, "col_2": indexes_2, "t": indexes_3, "value": values}), + "t", + ignore_year=False, + repeat=False, + ) + map_value = Map.from_arrow(record_batch) + expected = Map( + ["forecast", "realization"], + [ + Map( + [DateTime("2025-09-17T11:00")], + [ + TimeSeriesVariableResolution( + ["2025-09-17T11:00", "2025-09-17T12:00"], [-2.3, -3.2], ignore_year=False, repeat=False + ) + ], + index_name="col_2", + ), + TimeSeriesVariableResolution( + ["2025-09-17T11:00", "2025-09-17T12:00"], [-2.2, -3.1], ignore_year=False, repeat=False + ), + ], + index_name="col_1", + ) + assert map_value == expected + + +class TestToDatabaseForRecordBatches: + def test_strings_as_run_end_encoded(self): + index_array = pyarrow.RunEndEncodedArray.from_arrays([3, 5], ["A", "B"]) + value_array = pyarrow.RunEndEncodedArray.from_arrays([2, 5], ["a", "b"]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_date_times_as_run_end_encoded_index(self): + index_array = pyarrow.RunEndEncodedArray.from_arrays( + [3, 5], + [ + datetime(year=2025, month=7, day=24, hour=11, minute=41), + datetime(year=2025, month=7, day=24, hour=18, minute=41), + ], + ) + value_array = pyarrow.RunEndEncodedArray.from_arrays([2, 5], ["a", "b"]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_duration_as_run_end_encoded(self): + index_array = pyarrow.RunEndEncodedArray.from_arrays([3, 5], [parse_duration("PT30M"), parse_duration("PT45M")]) + value_array = pyarrow.RunEndEncodedArray.from_arrays([2, 5], [parse_duration("P3Y"), parse_duration("P2Y")]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_dictionary(self): + index_array = pyarrow.DictionaryArray.from_arrays([0, 1, 0], ["A", "B"]) + value_array = pyarrow.DictionaryArray.from_arrays([1, 0, 2], [2.3, 3.2, -2.3]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_date_times_in_dictionary(self): + index_array = pyarrow.DictionaryArray.from_arrays( + [0, 1, 0], + [ + datetime(year=2025, month=7, day=24, hour=13, minute=20), + datetime(year=2025, month=7, day=24, hour=13, minute=20), + ], + ) + value_array = pyarrow.DictionaryArray.from_arrays([1, 0, 1], [True, False]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_durations_in_dictionary(self): + index_array = pyarrow.DictionaryArray.from_arrays([0, 1, 0], [parse_duration("P23D"), parse_duration("P5M")]) + value_array = pyarrow.DictionaryArray.from_arrays([2, 1, 0], ["a", "b", "c"]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_union(self): + index_array = pyarrow.array(["integer", "float_generic", "float_int_like", "string", "boolean", "duration"]) + int_array = pyarrow.array([23]) + float_array = pyarrow.array([2.3, 5.0]) + str_array = pyarrow.array(["A"]) + boolean_array = pyarrow.array([True]) + duration_array = pyarrow.array([parse_duration("PT5H")]) + value_type_array = pyarrow.array([0, 1, 1, 2, 3, 4], type=pyarrow.int8()) + value_index_array = pyarrow.array([0, 0, 1, 0, 0, 0], type=pyarrow.int32()) + value_array = pyarrow.UnionArray.from_dense( + value_type_array, value_index_array, [int_array, float_array, str_array, boolean_array, duration_array] + ) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_union_as_index_raises(self): + int_array = pyarrow.array([23]) + value_type_array = pyarrow.array([0], type=pyarrow.int8()) + value_index_array = pyarrow.array([0], type=pyarrow.int32()) + value_array = pyarrow.UnionArray.from_dense(value_type_array, value_index_array, [int_array]) + record_batch = pyarrow.RecordBatch.from_arrays([value_array, value_array], ["Indexes", "Values"]) + with pytest.raises(SpineDBAPIError, match="union array cannot be index"): + to_database(record_batch) + + def test_float(self): + index_array = pyarrow.array([1.1, 2.2]) + value_array = pyarrow.array([2.3, 3.2]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_str(self): + index_array = pyarrow.array(["T01", "T02"]) + value_array = pyarrow.array(["high", "low"]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_int(self): + index_array = pyarrow.array([23, 55]) + value_array = pyarrow.array([-2, -4]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_date_time(self): + index_array = pyarrow.array([datetime(year=2025, month=7, day=21, hour=15, minute=30)]) + value_array = pyarrow.array([2.3]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_duration(self): + index_array = pyarrow.array([parse_duration("P3D")]) + value_array = pyarrow.array([parse_duration("PT5H")]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_time_pattern(self): + index_array = pyarrow.array(["M1-4,M9-12", "M5-8"]) + value_array = pyarrow.array([3.0, -2.0]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + record_batch = with_column_as_time_period(record_batch, "Indexes") + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_bool(self): + index_array = pyarrow.array(["T001", "T002"]) + value_array = pyarrow.array([False, True]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + blob, value_type = to_database(record_batch) + deserialized = from_database(blob, value_type) + assert deserialized == record_batch + assert deserialized.schema.metadata == record_batch.schema.metadata + + def test_bool_as_index_raises(self): + index_array = pyarrow.array([True, False]) + value_array = pyarrow.array([False, True]) + record_batch = pyarrow.RecordBatch.from_arrays([index_array, value_array], ["Indexes", "Values"]) + with pytest.raises(SpineDBAPIError, match="boolean array cannot be index"): + to_database(record_batch) + class TestPickling(unittest.TestCase): def test_array_is_picklable(self): @@ -1200,20 +1971,36 @@ def test_function_works(self): class TestTypeForValue(unittest.TestCase): def test_function_works(self): - self.assertEqual(type_for_value(None), (None, 0)) - self.assertEqual(type_for_value(2.3), ("float", 0)) - self.assertEqual(type_for_value("debug"), ("str", 0)) - self.assertEqual(type_for_value(False), ("bool", 0)) - self.assertEqual(type_for_value(DateTime("2024-07-25T11:00:00")), ("date_time", 0)) - self.assertEqual(type_for_value(Duration("23D")), ("duration", 0)) - self.assertEqual(type_for_value(Array(["a", "b"])), ("array", 1)) - self.assertEqual(type_for_value(TimePattern(["D1-7"], [8.0])), ("time_pattern", 1)) + self.assertEqual(type_and_rank_for_value(None), (None, 0)) + self.assertEqual(type_and_rank_for_value(2.3), ("float", 0)) + self.assertEqual(type_and_rank_for_value("debug"), ("str", 0)) + self.assertEqual(type_and_rank_for_value(False), ("bool", 0)) + self.assertEqual(type_and_rank_for_value(DateTime("2024-07-25T11:00:00")), ("date_time", 0)) + self.assertEqual(type_and_rank_for_value(Duration("23D")), ("duration", 0)) + self.assertEqual(type_and_rank_for_value(Array(["a", "b"])), ("array", 1)) + self.assertEqual(type_and_rank_for_value(TimePattern(["D1-7"], [8.0])), ("time_pattern", 1)) self.assertEqual( - type_for_value(TimeSeriesVariableResolution(["2024-07-24T11:00:00"], [2.3], False, False)), + type_and_rank_for_value(TimeSeriesVariableResolution(["2024-07-24T11:00:00"], [2.3], False, False)), ("time_series", 1), ) - self.assertEqual(type_for_value(Map(["a", "b"], [Map(["i"], [2.3]), 23.0])), ("map", 2)) - - -if __name__ == "__main__": - unittest.main() + self.assertEqual(type_and_rank_for_value(Map(["a", "b"], [Map(["i"], [2.3]), 23.0])), ("map", 2)) + + +class TestMonthDayNanoIntervalToIsoDuration: + def test_seconds(self): + durations = ["PT0S", "PT23S", "PT120S", "PT145S", "PT7200S", "PT7310S", "PT86400S", "PT86460S"] + intervals = pyarrow.array([parse_duration(d) for d in durations]) + converted = [_month_day_nano_interval_to_iso_duration(dt) for dt in intervals] + assert converted == ["P0D", "PT23S", "PT2M", "PT2M25S", "PT2H", "PT2H1M50S", "P1D", "P1DT1M"] + + def test_days(self): + durations = ["P0D", "P12D", "P1DT4H"] + intervals = pyarrow.array([parse_duration(d) for d in durations]) + converted = [_month_day_nano_interval_to_iso_duration(dt) for dt in intervals] + assert converted == ["P0D", "P12D", "P1DT4H"] + + def test_months(self): + durations = ["P0M", "P5M", "P12M", "P17M"] + intervals = pyarrow.array([parse_duration(d) for d in durations]) + converted = [_month_day_nano_interval_to_iso_duration(dt) for dt in intervals] + assert converted == ["P0D", "P5M", "P1Y", "P1Y5M"]