diff --git a/src/ldlite/_database/__init__.py b/src/ldlite/_database/__init__.py index ffd72f6..9decad5 100644 --- a/src/ldlite/_database/__init__.py +++ b/src/ldlite/_database/__init__.py @@ -98,10 +98,11 @@ def record_history(self, history: LoadHistory) -> None: ... class TypedDatabase(Database, Generic[DB]): def __init__(self, conn_factory: Callable[[], DB]): self._conn_factory = conn_factory - with closing(self._conn_factory()) as conn, conn.cursor() as cur: + with closing(self._conn_factory()) as conn: try: - cur.execute('CREATE SCHEMA IF NOT EXISTS "ldlite_system";') - cur.execute(""" + with conn.cursor() as cur: + cur.execute('CREATE SCHEMA IF NOT EXISTS "ldlite_system";') + cur.execute(""" CREATE TABLE IF NOT EXISTS "ldlite_system"."load_history" ( "table_name" TEXT UNIQUE ,"path" TEXT @@ -112,6 +113,8 @@ def __init__(self, conn_factory: Callable[[], DB]): ,"index_complete_utc" TIMESTAMP ,"row_count" INTEGER );""") + + self._setup_jfuncs(conn) except psycopg.errors.UniqueViolation: # postgres throws this when multiple threads try to create # the same resource even if CREATE IF NOT EXISTS was used @@ -122,6 +125,10 @@ def __init__(self, conn_factory: Callable[[], DB]): else: conn.commit() + @staticmethod + @abstractmethod + def _setup_jfuncs(conn: DB) -> None: ... + @property @abstractmethod def _default_schema(self) -> str: ... diff --git a/src/ldlite/_database/duckdb.py b/src/ldlite/_database/duckdb.py index 0676875..c9b994a 100644 --- a/src/ldlite/_database/duckdb.py +++ b/src/ldlite/_database/duckdb.py @@ -8,6 +8,83 @@ class DuckDbDatabase(TypedDatabase[duckdb.DuckDBPyConnection]): + @staticmethod + def _setup_jfuncs(conn: duckdb.DuckDBPyConnection) -> None: + with conn.cursor() as cur: + cur.execute("SELECT string_split(ltrim(version(),'v'), '.') AS has_lambda;") + if ver := cur.fetchone(): + (ma, mi, _) = ver[0] + if int(ma) > 1 and int(mi) >= 3: + cur.execute("SET lambda_syntax = 'ENABLE_SINGLE_ARROW';") + + with conn.cursor() as cur: + cur.execute( + r""" +CREATE OR REPLACE FUNCTION ldlite_system.jtype_of(j) AS + CASE coalesce(main.json_type(j), 'NULL') + WHEN 'VARCHAR' THEN 'string' + WHEN 'BIGINT' THEN 'number' + WHEN 'DOUBLE' THEN 'number' + WHEN 'UBIGINT' THEN 'number' + WHEN 'OBJECT' THEN 'object' + WHEN 'BOOLEAN' THEN 'boolean' + WHEN 'ARRAY' THEN 'array' + WHEN 'NULL' THEN 'null' + ELSE main.json_type(j) + END +; + +CREATE OR REPLACE FUNCTION ldlite_system.jextract(j, p) AS + CASE ldlite_system.jtype_of(main.json_extract(j, p)) + WHEN 'string' THEN + CASE + WHEN lower(main.json_extract_string(j, p)) = 'null' THEN 'null'::JSON + WHEN length(main.json_extract_string(j, p)) = 0 THEN 'null'::JSON + ELSE main.json_extract(j, p) + END + WHEN 'object' THEN + CASE + WHEN main.json_extract_string(j, p) = '{}' THEN 'null'::JSON + ELSE main.json_extract(j, p) + END + WHEN 'array' THEN + CASE + WHEN length(list_filter((main.json_extract(j, p))::JSON[], x -> x != 'null'::JSON)) = 0 THEN 'null'::JSON + ELSE list_filter((main.json_extract(j, p))::JSON[], x -> x != 'null'::JSON) + END + ELSE coalesce(main.json_extract(j, p), 'null'::JSON) + END +; + +CREATE OR REPLACE FUNCTION ldlite_system.jextract_string(j, p) AS + main.json_extract_string(ldlite_system.jextract(j, p), '$') +; + +CREATE OR REPLACE FUNCTION ldlite_system.jobject_keys(j) AS + unnest(main.json_keys(j)) +; + +CREATE OR REPLACE FUNCTION ldlite_system.jis_uuid(j) AS + CASE ldlite_system.jtype_of(j) + WHEN 'string' THEN regexp_full_match(main.json_extract_string(j, '$'), '^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[1-5][a-fA-F0-9]{3}-[89abAB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$') + ELSE FALSE + END +; + +CREATE OR REPLACE FUNCTION ldlite_system.jis_datetime(j) AS + CASE ldlite_system.jtype_of(j) + WHEN 'string' THEN regexp_full_match(main.json_extract_string(j, '$'), '^\d{4}-[01]\d-[0123]\dT[012]\d:[012345]\d:[012345]\d\.\d{3}(\+\d{2}:\d{2})?$') + ELSE FALSE + END +; + +CREATE OR REPLACE FUNCTION ldlite_system.jis_float(j) AS + coalesce(main.json_type(j), 'NULL')='DOUBLE' +; + +""", # noqa: E501 + ) + @property def _default_schema(self) -> str: return "main" diff --git a/src/ldlite/_database/postgres.py b/src/ldlite/_database/postgres.py index 479018d..7a78477 100644 --- a/src/ldlite/_database/postgres.py +++ b/src/ldlite/_database/postgres.py @@ -13,6 +13,83 @@ def __init__(self, dsn: str): # same sql between duckdb and postgres super().__init__(lambda: psycopg.connect(dsn, cursor_factory=psycopg.RawCursor)) + @staticmethod + def _setup_jfuncs(conn: psycopg.Connection) -> None: + with conn.cursor() as cur: + cur.execute( + r""" +CREATE OR REPLACE FUNCTION ldlite_system.jtype_of(j JSONB) RETURNS TEXT AS $$ +BEGIN + RETURN jsonb_typeof(j); +END +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION ldlite_system.jextract(j JSONB, p TEXT) RETURNS JSONB AS $$ +BEGIN + RETURN CASE + WHEN ldlite_system.jtype_of(j->p) = 'string' THEN + CASE + WHEN lower(j->>p) = 'null' THEN 'null'::JSONB + WHEN length(j->>p) = 0 THEN 'null'::JSONB + ELSE j->p + END + WHEN ldlite_system.jtype_of(j->p) = 'array' THEN + CASE + WHEN jsonb_array_length(jsonb_path_query_array(j->p, '$[*] ? (@ != null)')) = 0 THEN 'null'::JSONB + ELSE jsonb_path_query_array(j->p, '$[*] ? (@ != null)') + END + WHEN ldlite_system.jtype_of(j->p) = 'object' THEN + CASE + WHEN j->>p = '{}' THEN 'null'::JSONB + ELSE j->p + END + ELSE j->p + END; +END +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION ldlite_system.jextract_string(j JSONB, p TEXT) RETURNS TEXT AS $$ +BEGIN + RETURN ldlite_system.jextract(j, p) ->> 0; +END +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION ldlite_system.jobject_keys(j JSONB) RETURNS SETOF TEXT AS $$ +BEGIN + RETURN QUERY SELECT jsonb_object_keys(j); +END +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION ldlite_system.jis_uuid(j JSONB) RETURNS BOOLEAN AS $$ +BEGIN + RETURN CASE + WHEN ldlite_system.jtype_of(j) = 'string' THEN j->>0 ~ '^[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[1-5][a-fA-F0-9]{3}-[89abAB][a-fA-F0-9]{3}-[a-fA-F0-9]{12}$' + ELSE FALSE + END; +END +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION ldlite_system.jis_datetime(j JSONB) RETURNS BOOLEAN AS $$ +BEGIN + RETURN CASE + WHEN ldlite_system.jtype_of(j) = 'string' THEN j->>0 ~ '^\d{4}-[01]\d-[0123]\dT[012]\d:[012345]\d:[012345]\d\.\d{3}(\+\d{2}:\d{2})?$' + ELSE FALSE + END; +END +$$ LANGUAGE plpgsql; + + +CREATE OR REPLACE FUNCTION ldlite_system.jis_float(j JSONB) RETURNS BOOLEAN AS $$ +BEGIN + RETURN CASE + WHEN ldlite_system.jtype_of(j) = 'number' THEN j->>0 LIKE '%.%' + ELSE FALSE + END; +END +$$ LANGUAGE plpgsql; +""", # noqa: E501 + ) + @property def _default_schema(self) -> str: return "public" diff --git a/tests/test_json_operators.py b/tests/test_json_operators.py new file mode 100644 index 0000000..5eec23e --- /dev/null +++ b/tests/test_json_operators.py @@ -0,0 +1,282 @@ +from collections.abc import Callable, Iterator +from contextlib import closing +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast +from uuid import uuid4 + +import duckdb +import psycopg +import pytest +from pytest_cases import parametrize, parametrize_with_cases + +if TYPE_CHECKING: + from _typeshed import dbapi + + +def _db() -> str: + db = "db" + str(uuid4()).split("-")[0] + print(db) # noqa: T201 + return db + + +@dataclass +class JsonTC: + query: str + query_params: tuple[Any, ...] + assertion: str + assertion_params: tuple[Any, ...] + + +@parametrize( + p=[ + ("str", '"str_val"'), + ("str_empty", "null"), + ("num", "12"), + ("float", "16.3"), + ("bool", "true"), + ("obj", '{"k1":"v1","k2":"v2"}'), + ("obj_some", '{"k1":"v1","k2":null}'), + ("obj_empty", "null"), + ("arr_zero", "null"), + ("arr_str", '["s1","s2","s3"]'), + ("arr_str_some", '["s1","s2"]'), + ("arr_obj_some", '[{"k1":"v1"}]'), + ("na", "null"), + ("na_str1", "null"), + ("na_str2", "null"), + ], +) +def case_jextract(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """SELECT ldlite_system.jextract(jc, $1){assertion} FROM j;""", + p[:1], + """= $2::{jtype}""", + p[1:], + ) + + +@parametrize( + p=[ + ("str", "str_val"), + ("num", "12"), + ("float", "16.3"), + ("bool", "true"), + ("na",), + ("na_str1",), + ("na_str2",), + ], +) +def case_jextract_string(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """SELECT ldlite_system.jextract_string(jc, $1){assertion} FROM j;""", + p[:1], + """ = $2""" if len(p) == 2 else """ IS NULL""", + p[1:], + ) + + +def case_jobject_keys() -> JsonTC: + return JsonTC( + """ +{assertion} +(SELECT e.jkey, a.jkey +FROM (SELECT 'k1' jkey UNION SELECT 'k2' jkey) as e +FULL OUTER JOIN (SELECT ldlite_system.jobject_keys(jc->'obj') jkey FROM j) as a + USING (jkey) +WHERE e.jkey IS NULL or a.jkey IS NULL) as q;""", + (), + "SELECT COUNT(1) = 0 FROM ", + (), + ) + + +@parametrize( + p=[ + ("str", "string"), + ("num", "number"), + ("float", "number"), + ("bool", "boolean"), + ("obj", "object"), + ("arr_str", "array"), + ("arr_obj", "array"), + ("na", "null"), + ], +) +def case_jtypeof(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """ +SELECT ldlite_system.jtype_of(jc->$1){assertion} +FROM j;""", + p[:1], + """ = $2""", + p[1:], + ) + + +@parametrize( + p=[ + ("str", False), + ("str_empty", False), + ("num", False), + ("na", False), + ("na_str1", False), + ("na_str2", False), + ("uuid_nof", False), + ("uuid", True), + ], +) +def case_jis_uuid(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """ +SELECT {assertion}ldlite_system.jis_uuid(jc->$1) +FROM j;""", + p[:1], + "" if (p[1]) else """ NOT """, + (), + ) + + +@parametrize( + p=[ + ("str", False), + ("str_empty", False), + ("num", False), + ("na", False), + ("na_str1", False), + ("na_str2", False), + ("uuid_nof", False), + ("uuid", False), + ("dt", True), + ], +) +def case_jis_datetime(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """ +SELECT {assertion}ldlite_system.jis_datetime(jc->$1) +FROM j;""", + p[:1], + "" if (p[1]) else """ NOT """, + (), + ) + + +@parametrize( + p=[ + ("str", False), + ("str_empty", False), + ("num", False), + ("na", False), + ("na_str1", False), + ("na_str2", False), + ("uuid_nof", False), + ("uuid", False), + ("dt", False), + ("num", False), + ("float", True), + ], +) +def case_jis_float(p: tuple[Any, ...]) -> JsonTC: + return JsonTC( + """ +SELECT {assertion}ldlite_system.jis_float(jc->$1) +FROM j;""", + p[:1], + "" if (p[1]) else """ NOT """, + (), + ) + + +def _assert(conn: "dbapi.DBAPIConnection", jtype: str, tc: JsonTC) -> None: + with closing(conn.cursor()) as cur: + query = tc.query.format(assertion="", jtype=jtype) + assertion = tc.query.format( + assertion=tc.assertion.format(jtype=jtype), + jtype=jtype, + ) + + cur.execute(assertion, (*tc.query_params, *tc.assertion_params)) + actual = cur.fetchone() + assert actual is not None + + if actual[0] is None or not actual[0]: + cur.execute(query, tc.query_params) + diff = "" + for r in cur.fetchall(): + diff += f"{r}\n" + pytest.fail(diff) + + assert actual[0] is not None + assert actual[0] + + +def _arrange(conn: "dbapi.DBAPIConnection") -> None: + with closing(conn.cursor()) as cur: + cur.execute( + "INSERT INTO j VALUES " + "('{" + """ "str": "str_val",""" + """ "str_empty": "",""" + """ "num": 12,""" + """ "float": 16.3,""" + """ "bool": true,""" + """ "uuid": "5b285d03-5490-1111-8888-52b2003b475c",""" + """ "uuid_nof": "5b285d03-5490-FFFF-0000-52b2003b475c",""" + """ "obj": {"k1": "v1", "k2": "v2"},""" + """ "obj_some": {"k1": "v1", "k2": null},""" + """ "obj_empty": {},""" + """ "arr_zero": [],""" + """ "arr_str": ["s1", "s2", "s3"],""" + """ "arr_str_some": ["s1", "s2", null],""" + """ "arr_obj": [{"k1": "v1"}, {"k2": "v2"}],""" + """ "arr_obj_some": [{"k1": "v1"}, null],""" + """ "dt": "2022-04-21T18:47:33.581+00:00",""" + """ "na": null,""" + """ "na_str1": "null", """ + """ "na_str2": "NULL" """ + " }')", + ) + + +@pytest.fixture(scope="session") +def duckdb_jop_dsn() -> Iterator[str]: + dsn = f":memory:{_db()}" + + with duckdb.connect(dsn) as conn: + conn.execute("CREATE TABLE j (jc JSON)") + _arrange(cast("dbapi.DBAPIConnection", conn)) + + yield dsn + + +@parametrize_with_cases("tc", cases=".") +def test_duckdb(duckdb_jop_dsn: str, tc: JsonTC) -> None: + from ldlite import LDLite + + ld = LDLite() + ld.connect_db(duckdb_jop_dsn) + + with duckdb.connect(duckdb_jop_dsn) as conn: + _assert(cast("dbapi.DBAPIConnection", conn), "JSON", tc) + + +@pytest.fixture(scope="session") +def pg_jop_dsn(pg_dsn: None | Callable[[str], str]) -> str: + if pg_dsn is None: + pytest.skip("Specify the pg host using --pg-host to run") + + dsn = pg_dsn(_db()) + with psycopg.connect(dsn) as conn, conn.cursor() as cur: + cur.execute("CREATE TABLE j (jc JSONB)") + _arrange(cast("dbapi.DBAPIConnection", conn)) + return dsn + + +@parametrize_with_cases("tc", cases=".") +def test_postgres(pg_jop_dsn: str, tc: JsonTC) -> None: + from ldlite import LDLite + + ld = LDLite() + ld.connect_db_postgresql(pg_jop_dsn) + + with psycopg.connect(pg_jop_dsn, cursor_factory=psycopg.RawCursor) as conn: + _assert(cast("dbapi.DBAPIConnection", conn), "JSONB", tc)