diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index fc76e7f6..e268fef2 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -16,18 +16,23 @@ but the syntax is a little more compact. """ from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Sequence from contextlib import suppress -from typing import Any, Optional +from typing import Any, Optional, TypeAlias from . import DatabaseMapping, SpineDBAPIError from .helpers import _parse_metadata -from .parameter_value import Value, fancy_type_to_type_and_rank, get_conflict_fixer, to_database +from .parameter_value import ConflictResolution, Value, fancy_type_to_type_and_rank, get_conflict_fixer, to_database + +UnparseCallable: TypeAlias = Callable[[Value], tuple[bytes, Optional[str]]] +ParameterValue: TypeAlias = ( + tuple[str, str | tuple[str, ...], str, Any] | tuple[str, str | tuple[str, ...], str, Any, str] +) def import_data( db_map: DatabaseMapping, - unparse_value: Callable[[Value], tuple[bytes, Optional[str]]] = to_database, - on_conflict: str = "merge", + unparse_value: UnparseCallable = to_database, + on_conflict: ConflictResolution = "merge", **kwargs, ) -> tuple[int, list[str]]: """Imports data into a Spine database using a standard format. @@ -99,9 +104,9 @@ def import_data( def get_data_for_import( - db_map, - all_errors, - unparse_value=to_database, + db_map: DatabaseMapping, + all_errors: list[str], + unparse_value: UnparseCallable = to_database, fix_value_conflict=get_conflict_fixer("merge"), entity_classes=(), entities=(), @@ -369,18 +374,23 @@ def import_parameter_types(db_map, data, unparse_value=to_database): return import_data(db_map, parameter_types=data, unparse_value=unparse_value) -def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"): +def import_parameter_values( + db_map: DatabaseMapping, + data: Iterable[ParameterValue], + unparse_value=to_database, + on_conflict: ConflictResolution = "merge", +) -> tuple[int, list[str]]: """Imports parameter values into a Spine database using a standard format. Args: - db_map (DatabaseMapping): database mapping - data (Iterable of Sequence): + db_map: database mapping + data: tuples of (class name [str], entity name [str] or byname [tuple[str]], parameter definition name [str], value, [alternative_name [str]]) - unparse_value (Callable): function to parse parameter values - on_conflict (str): Conflict resolution strategy; options: "keep", "replace", "merge" + unparse_value: function to parse parameter values + on_conflict: Conflict resolution strategy; options: "keep", "replace", "merge" Returns: - tuple: tuple of (number of items imported, list of errors) + tuple of (number of items imported, list of errors) """ return import_data(db_map, parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict) diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index 3a0f77c8..1bfccaa6 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -88,9 +88,8 @@ from itertools import takewhile import json from json.decoder import JSONDecodeError -import math import re -from typing import Any, Optional, SupportsFloat, Type, Union +from typing import Any, Literal, Optional, SupportsFloat, Type, TypeAlias, Union import dateutil.parser from dateutil.relativedelta import relativedelta import numpy as np @@ -111,6 +110,12 @@ STRING_VALUE_TYPE = "str" +ConflictResolution: TypeAlias = Literal["keep", "replace", "merge"] +ConflictResolutionCallable: TypeAlias = Callable[ + [tuple[bytes, Optional[str]], tuple[bytes, Optional[str]]], tuple[bytes, Optional[str]] +] + + def from_database(value: bytes, type_: Optional[str]) -> Optional[Value]: """ Converts a parameter value from the DB into a Python object. @@ -368,12 +373,16 @@ def merge_parsed(parsed_value: Optional[Value], parsed_other: Optional[Value]) - return parsed_value.merge(parsed_other) -_MERGE_FUNCTIONS = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge} +_MERGE_FUNCTIONS: dict[ConflictResolution:ConflictResolutionCallable] = { + "keep": lambda new, old: old, + "replace": lambda new, old: new, + "merge": merge, +} def get_conflict_fixer( - on_conflict: str, -) -> Callable[[tuple[bytes, Optional[str]], tuple[bytes, Optional[str]]], tuple[bytes, Optional[str]]]: + on_conflict: ConflictResolution, +) -> ConflictResolutionCallable: """ :meta private: Returns parameter value conflict resolution function. diff --git a/tests/test_db_server.py b/tests/test_db_server.py index 29c80837..c88cc6db 100644 --- a/tests/test_db_server.py +++ b/tests/test_db_server.py @@ -14,12 +14,19 @@ from tempfile import TemporaryDirectory import threading import unittest +from spinedb_api import Array, DateTime, Duration, Map, TimePattern, TimeSeriesVariableResolution, to_database from spinedb_api.db_mapping import DatabaseMapping from spinedb_api.spine_db_client import SpineDBClient from spinedb_api.spine_db_server import DBHandler, closing_spine_db_server, db_server_manager class TestDBServer(unittest.TestCase): + def _assert_import(self, result): + self.assertIn("result", result) + count, errors = result["result"] + self.assertEqual(errors, []) + return count + def test_use_id_from_server_response(self): with TemporaryDirectory() as temp_dir: db_url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite") @@ -44,7 +51,7 @@ def test_ordering(self): def _import_entity_class(server_url, class_name): client = SpineDBClient.from_server_url(server_url) client.db_checkin() - _answer = client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}") + self._assert_import(client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}")) client.db_checkout() with TemporaryDirectory() as temp_dir: @@ -129,3 +136,109 @@ def test_query_with_data(self): } }, ) + + def test_export_parameter_values(self): + with closing_spine_db_server("sqlite://") as server_url: + client = SpineDBClient.from_server_url(server_url) + self._assert_import( + client.import_data( + {"entity_classes": [("Object",)], "parameter_definitions": [("Object", "X")]}, + "Import basic structure.", + ) + ) + self._assert_import( + client.import_data( + { + "entities": [ + ("Object", "float"), + ("Object", "string"), + ("Object", "boolean"), + ("Object", "none"), + ("Object", "date time"), + ("Object", "duration"), + ("Object", "array"), + ("Object", "time pattern"), + ("Object", "time series"), + ("Object", "map"), + ] + }, + "Import entities.", + ) + ) + self._assert_import( + client.import_data( + { + "parameter_values": [ + ("Object", ("float",), "X", to_database(2.3)), + ("Object", ("string",), "X", to_database("oh my")), + ("Object", ("boolean",), "X", to_database(False)), + ("Object", ("none",), "X", to_database(None)), + ("Object", ("date time",), "X", to_database(DateTime("2025-09-02T13:45"))), + ("Object", ("duration",), "X", to_database(Duration("33m"))), + ("Object", ("array",), "X", to_database(Array([2.3]))), + ("Object", ("time pattern",), "X", to_database(TimePattern(["M1-12"], [2.3]))), + ( + "Object", + ("time series",), + "X", + to_database( + TimeSeriesVariableResolution( + ["2025-09-02T13:50"], [2.3], ignore_year=False, repeat=True + ) + ), + ), + ("Object", ("map",), "X", to_database(Map([DateTime("2025-09-02T13:50")], [2.3]))), + ] + }, + "Import values.", + ) + ) + result = client.export_data() + self.assertEqual(len(result), 1) + result_data = result["result"] + self.assertEqual(len(result_data), 5) + self.assertEqual(result_data["alternatives"], [["Base", "Base alternative"]]) + self.assertEqual(result_data["entity_classes"], [["Object", [], None, None, True]]) + self.assertEqual(result_data["parameter_definitions"], [["Object", "X", [None, None], None, None]]) + self.assertCountEqual( + result_data["entities"], + [ + ["Object", "float", None], + ["Object", "string", None], + ["Object", "boolean", None], + ["Object", "none", None], + ["Object", "date time", None], + ["Object", "duration", None], + ["Object", "array", None], + ["Object", "time pattern", None], + ["Object", "time series", None], + ["Object", "map", None], + ], + ) + self.assertCountEqual( + result_data["parameter_values"], + [ + ["Object", "float", "X", list(to_database(2.3)), "Base"], + ["Object", "string", "X", list(to_database("oh my")), "Base"], + ["Object", "boolean", "X", list(to_database(False)), "Base"], + ["Object", "none", "X", list(to_database(None)), "Base"], + ["Object", "date time", "X", list(to_database(DateTime("2025-09-02T13:45"))), "Base"], + ["Object", "duration", "X", list(to_database(Duration("33m"))), "Base"], + ["Object", "array", "X", list(to_database(Array([2.3]))), "Base"], + ["Object", "time pattern", "X", list(to_database(TimePattern(["M1-12"], [2.3]))), "Base"], + [ + "Object", + "time series", + "X", + list( + to_database( + TimeSeriesVariableResolution( + ["2025-09-02T13:50"], [2.3], ignore_year=False, repeat=True + ) + ) + ), + "Base", + ], + ["Object", "map", "X", list(to_database(Map([DateTime("2025-09-02T13:50")], [2.3]))), "Base"], + ], + ) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index 0f900be9..6864802e 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -43,9 +43,13 @@ import_scenarios, ) from spinedb_api.parameter_value import ( + Array, + DateTime, + Duration, Map, TimePattern, TimeSeriesFixedResolution, + TimeSeriesVariableResolution, dump_db_value, from_database, to_database, @@ -2134,5 +2138,40 @@ def test_import_single_entity_class_display_mode(self): ) -if __name__ == "__main__": - unittest.main() +def _identity(x): + return x + + +class TestImportWithDatabaseValue: + def test_all_value_types(self): + values = [ + 2.3, + "a string", + True, + DateTime("2025-09-02T11:15"), + Duration("2 years"), + Array([DateTime("2025-09-02T12:00")]), + TimePattern(["D1-7"], [2.3]), + TimeSeriesFixedResolution("2025-09-02T12:00", "7D", [2.3], ignore_year=True, repeat=False), + TimeSeriesVariableResolution(["2025-09-02T12:00"], [2.3], ignore_year=False, repeat=True), + Map([Duration("5 hours")], [2.3]), + ] + with DatabaseMapping("sqlite://", create=True) as db_map: + db_map.add_entity_class(name="Object") + db_map.add_parameter_definition(entity_class_name="Object", name="X") + db_map.add_entity(entity_class_name="Object", name="widget") + for value in values: + assert_imports( + import_parameter_values( + db_map, [("Object", "widget", "X", to_database(value))], unparse_value=_identity + ) + ) + assert ( + db_map.parameter_value( + entity_class_name="Object", + entity_byname=("widget",), + parameter_definition_name="X", + alternative_name="Base", + )["parsed_value"] + == value + )