Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions spinedb_api/import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@
but the syntax is a little more compact.
"""
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterable, Iterator, Sequence
from contextlib import suppress
from typing import Any, Optional
from typing import Any, Optional, TypeAlias
from . import DatabaseMapping, SpineDBAPIError
from .helpers import _parse_metadata
from .parameter_value import Value, fancy_type_to_type_and_rank, get_conflict_fixer, to_database
from .parameter_value import ConflictResolution, Value, fancy_type_to_type_and_rank, get_conflict_fixer, to_database

UnparseCallable: TypeAlias = Callable[[Value], tuple[bytes, Optional[str]]]
ParameterValue: TypeAlias = (
tuple[str, str | tuple[str, ...], str, Any] | tuple[str, str | tuple[str, ...], str, Any, str]
)


def import_data(
db_map: DatabaseMapping,
unparse_value: Callable[[Value], tuple[bytes, Optional[str]]] = to_database,
on_conflict: str = "merge",
unparse_value: UnparseCallable = to_database,
on_conflict: ConflictResolution = "merge",
**kwargs,
) -> tuple[int, list[str]]:
"""Imports data into a Spine database using a standard format.
Expand Down Expand Up @@ -99,9 +104,9 @@ def import_data(


def get_data_for_import(
db_map,
all_errors,
unparse_value=to_database,
db_map: DatabaseMapping,
all_errors: list[str],
unparse_value: UnparseCallable = to_database,
fix_value_conflict=get_conflict_fixer("merge"),
entity_classes=(),
entities=(),
Expand Down Expand Up @@ -369,18 +374,23 @@ def import_parameter_types(db_map, data, unparse_value=to_database):
return import_data(db_map, parameter_types=data, unparse_value=unparse_value)


def import_parameter_values(db_map, data, unparse_value=to_database, on_conflict="merge"):
def import_parameter_values(
db_map: DatabaseMapping,
data: Iterable[ParameterValue],
unparse_value=to_database,
on_conflict: ConflictResolution = "merge",
) -> tuple[int, list[str]]:
"""Imports parameter values into a Spine database using a standard format.

Args:
db_map (DatabaseMapping): database mapping
data (Iterable of Sequence):
db_map: database mapping
data:
tuples of (class name [str], entity name [str] or byname [tuple[str]], parameter definition name [str], value, [alternative_name [str]])
unparse_value (Callable): function to parse parameter values
on_conflict (str): Conflict resolution strategy; options: "keep", "replace", "merge"
unparse_value: function to parse parameter values
on_conflict: Conflict resolution strategy; options: "keep", "replace", "merge"

Returns:
tuple: tuple of (number of items imported, list of errors)
tuple of (number of items imported, list of errors)
"""
return import_data(db_map, parameter_values=data, unparse_value=unparse_value, on_conflict=on_conflict)

Expand Down
19 changes: 14 additions & 5 deletions spinedb_api/parameter_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@
from itertools import takewhile
import json
from json.decoder import JSONDecodeError
import math
import re
from typing import Any, Optional, SupportsFloat, Type, Union
from typing import Any, Literal, Optional, SupportsFloat, Type, TypeAlias, Union
import dateutil.parser
from dateutil.relativedelta import relativedelta
import numpy as np
Expand All @@ -111,6 +110,12 @@
STRING_VALUE_TYPE = "str"


ConflictResolution: TypeAlias = Literal["keep", "replace", "merge"]
ConflictResolutionCallable: TypeAlias = Callable[
[tuple[bytes, Optional[str]], tuple[bytes, Optional[str]]], tuple[bytes, Optional[str]]
]


def from_database(value: bytes, type_: Optional[str]) -> Optional[Value]:
"""
Converts a parameter value from the DB into a Python object.
Expand Down Expand Up @@ -368,12 +373,16 @@ def merge_parsed(parsed_value: Optional[Value], parsed_other: Optional[Value]) -
return parsed_value.merge(parsed_other)


_MERGE_FUNCTIONS = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge}
_MERGE_FUNCTIONS: dict[ConflictResolution:ConflictResolutionCallable] = {
"keep": lambda new, old: old,
"replace": lambda new, old: new,
"merge": merge,
}


def get_conflict_fixer(
on_conflict: str,
) -> Callable[[tuple[bytes, Optional[str]], tuple[bytes, Optional[str]]], tuple[bytes, Optional[str]]]:
on_conflict: ConflictResolution,
) -> ConflictResolutionCallable:
"""
:meta private:
Returns parameter value conflict resolution function.
Expand Down
115 changes: 114 additions & 1 deletion tests/test_db_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,19 @@
from tempfile import TemporaryDirectory
import threading
import unittest
from spinedb_api import Array, DateTime, Duration, Map, TimePattern, TimeSeriesVariableResolution, to_database
from spinedb_api.db_mapping import DatabaseMapping
from spinedb_api.spine_db_client import SpineDBClient
from spinedb_api.spine_db_server import DBHandler, closing_spine_db_server, db_server_manager


class TestDBServer(unittest.TestCase):
def _assert_import(self, result):
self.assertIn("result", result)
count, errors = result["result"]
self.assertEqual(errors, [])
return count

def test_use_id_from_server_response(self):
with TemporaryDirectory() as temp_dir:
db_url = "sqlite:///" + os.path.join(temp_dir, "database.sqlite")
Expand All @@ -44,7 +51,7 @@ def test_ordering(self):
def _import_entity_class(server_url, class_name):
client = SpineDBClient.from_server_url(server_url)
client.db_checkin()
_answer = client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}")
self._assert_import(client.import_data({"entity_classes": [(class_name, ())]}, f"Import {class_name}"))
client.db_checkout()

with TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -129,3 +136,109 @@ def test_query_with_data(self):
}
},
)

def test_export_parameter_values(self):
with closing_spine_db_server("sqlite://") as server_url:
client = SpineDBClient.from_server_url(server_url)
self._assert_import(
client.import_data(
{"entity_classes": [("Object",)], "parameter_definitions": [("Object", "X")]},
"Import basic structure.",
)
)
self._assert_import(
client.import_data(
{
"entities": [
("Object", "float"),
("Object", "string"),
("Object", "boolean"),
("Object", "none"),
("Object", "date time"),
("Object", "duration"),
("Object", "array"),
("Object", "time pattern"),
("Object", "time series"),
("Object", "map"),
]
},
"Import entities.",
)
)
self._assert_import(
client.import_data(
{
"parameter_values": [
("Object", ("float",), "X", to_database(2.3)),
("Object", ("string",), "X", to_database("oh my")),
("Object", ("boolean",), "X", to_database(False)),
("Object", ("none",), "X", to_database(None)),
("Object", ("date time",), "X", to_database(DateTime("2025-09-02T13:45"))),
("Object", ("duration",), "X", to_database(Duration("33m"))),
("Object", ("array",), "X", to_database(Array([2.3]))),
("Object", ("time pattern",), "X", to_database(TimePattern(["M1-12"], [2.3]))),
(
"Object",
("time series",),
"X",
to_database(
TimeSeriesVariableResolution(
["2025-09-02T13:50"], [2.3], ignore_year=False, repeat=True
)
),
),
("Object", ("map",), "X", to_database(Map([DateTime("2025-09-02T13:50")], [2.3]))),
]
},
"Import values.",
)
)
result = client.export_data()
self.assertEqual(len(result), 1)
result_data = result["result"]
self.assertEqual(len(result_data), 5)
self.assertEqual(result_data["alternatives"], [["Base", "Base alternative"]])
self.assertEqual(result_data["entity_classes"], [["Object", [], None, None, True]])
self.assertEqual(result_data["parameter_definitions"], [["Object", "X", [None, None], None, None]])
self.assertCountEqual(
result_data["entities"],
[
["Object", "float", None],
["Object", "string", None],
["Object", "boolean", None],
["Object", "none", None],
["Object", "date time", None],
["Object", "duration", None],
["Object", "array", None],
["Object", "time pattern", None],
["Object", "time series", None],
["Object", "map", None],
],
)
self.assertCountEqual(
result_data["parameter_values"],
[
["Object", "float", "X", list(to_database(2.3)), "Base"],
["Object", "string", "X", list(to_database("oh my")), "Base"],
["Object", "boolean", "X", list(to_database(False)), "Base"],
["Object", "none", "X", list(to_database(None)), "Base"],
["Object", "date time", "X", list(to_database(DateTime("2025-09-02T13:45"))), "Base"],
["Object", "duration", "X", list(to_database(Duration("33m"))), "Base"],
["Object", "array", "X", list(to_database(Array([2.3]))), "Base"],
["Object", "time pattern", "X", list(to_database(TimePattern(["M1-12"], [2.3]))), "Base"],
[
"Object",
"time series",
"X",
list(
to_database(
TimeSeriesVariableResolution(
["2025-09-02T13:50"], [2.3], ignore_year=False, repeat=True
)
)
),
"Base",
],
["Object", "map", "X", list(to_database(Map([DateTime("2025-09-02T13:50")], [2.3]))), "Base"],
],
)
43 changes: 41 additions & 2 deletions tests/test_import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@
import_scenarios,
)
from spinedb_api.parameter_value import (
Array,
DateTime,
Duration,
Map,
TimePattern,
TimeSeriesFixedResolution,
TimeSeriesVariableResolution,
dump_db_value,
from_database,
to_database,
Expand Down Expand Up @@ -2134,5 +2138,40 @@ def test_import_single_entity_class_display_mode(self):
)


if __name__ == "__main__":
unittest.main()
def _identity(x):
return x


class TestImportWithDatabaseValue:
def test_all_value_types(self):
values = [
2.3,
"a string",
True,
DateTime("2025-09-02T11:15"),
Duration("2 years"),
Array([DateTime("2025-09-02T12:00")]),
TimePattern(["D1-7"], [2.3]),
TimeSeriesFixedResolution("2025-09-02T12:00", "7D", [2.3], ignore_year=True, repeat=False),
TimeSeriesVariableResolution(["2025-09-02T12:00"], [2.3], ignore_year=False, repeat=True),
Map([Duration("5 hours")], [2.3]),
]
with DatabaseMapping("sqlite://", create=True) as db_map:
db_map.add_entity_class(name="Object")
db_map.add_parameter_definition(entity_class_name="Object", name="X")
db_map.add_entity(entity_class_name="Object", name="widget")
for value in values:
assert_imports(
import_parameter_values(
db_map, [("Object", "widget", "X", to_database(value))], unparse_value=_identity
)
)
assert (
db_map.parameter_value(
entity_class_name="Object",
entity_byname=("widget",),
parameter_definition_name="X",
alternative_name="Base",
)["parsed_value"]
== value
)
Loading