Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flow/record/adapter/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def __iter__(self) -> Iterator[Record]:
if match_record_with_context(record, selector, ctx):
yield record

def close(self) -> None:
if self.con:
self.con.close()
self.con = None


class SqliteWriter(AbstractWriter):
"""SQLite writer."""
Expand Down
21 changes: 13 additions & 8 deletions tests/adapter/test_sqlite_duckdb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sqlite3
from contextlib import closing
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, NamedTuple

Expand Down Expand Up @@ -136,7 +137,7 @@ def test_write_to_sqlite(tmp_path: Path, count: int, db: Database) -> None:
writer.write(record)

record_count = 0
with db.connector.connect(str(db_path)) as con:
with closing(db.connector.connect(str(db_path))) as con:
cursor = con.execute("SELECT COUNT(*) FROM 'test/record'")
record_count = cursor.fetchone()[0]

Expand All @@ -157,7 +158,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
"""Tests basic reading from a SQLite database."""
# Generate a SQLite database
db_path = tmp_path / "records.db"
with db.connector.connect(str(db_path)) as con:
with closing(db.connector.connect(str(db_path))) as con:
con.execute(
"""
CREATE TABLE 'test/record' (
Expand All @@ -176,6 +177,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None:
""",
(f"record{i}", f"foobar{i}".encode(), dt_isoformat, 3.14 + i),
)
con.commit()

# Read the SQLite database using flow.record
with RecordReader(f"{db.scheme}://{db_path}") as reader:
Expand Down Expand Up @@ -251,7 +253,7 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
assert writer

# test if it's a valid database
with db.connector.connect(str(db_path)) as con:
with closing(db.connector.connect(str(db_path))) as con:
assert con.execute("SELECT * FROM sqlite_master").fetchall() == []


Expand All @@ -272,9 +274,10 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None:
def test_non_strict_sqlite_fields(tmp_path: Path, sqlite_coltype: str, sqlite_value: Any, expected_value: Any) -> None:
"""SQLite by default is non strict, meaning that the value could be of different type than the column type."""
db = tmp_path / "records.db"
with sqlite3.connect(db) as con:
with closing(sqlite3.connect(db)) as con:
con.execute(f"CREATE TABLE 'strict-test' (field {sqlite_coltype})")
con.execute("INSERT INTO 'strict-test' VALUES(?)", (sqlite_value,))
con.commit()

with RecordReader(f"sqlite://{db}") as reader:
record = next(iter(reader))
Expand All @@ -294,10 +297,11 @@ def test_invalid_table_names_quoting(tmp_path: Path, invalid_table_name: str) ->

# Creating the tables with these invalid_table_names in SQLite is no problem
db = tmp_path / "records.db"
with sqlite3.connect(db) as con:
with closing(sqlite3.connect(db)) as con:
con.execute(f"CREATE TABLE [{invalid_table_name}] (field TEXT, field2 TEXT)")
con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("hello", "world"))
con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("goodbye", "planet"))
con.commit()

# However, these invalid_table_names should raise an exception when reading
with (
Expand All @@ -320,10 +324,11 @@ def test_invalid_field_names_quoting(tmp_path: Path, invalid_field_name: str) ->

# Creating the table with invalid field name in SQLite is no problem
db = tmp_path / "records.db"
with sqlite3.connect(db) as con:
with closing(sqlite3.connect(db)) as con:
con.execute(f"CREATE TABLE [test] (field TEXT, [{invalid_field_name}] TEXT)")
con.execute("INSERT INTO [test] VALUES(?, ?)", ("hello", "world"))
con.execute("INSERT INTO [test] VALUES(?, ?)", ("goodbye", "planet"))
con.commit()

# However, these field names are invalid in flow.record and should raise an exception
with (
Expand Down Expand Up @@ -365,7 +370,7 @@ def test_batch_size(
writer.write(next(records))

# test count of records in table (no flush yet if batch_size > 1)
with db.connector.connect(str(db_path)) as con:
with closing(db.connector.connect(str(db_path))) as con:
x = con.execute('SELECT COUNT(*) FROM "test/record"')
assert x.fetchone()[0] is expected_first

Expand All @@ -374,7 +379,7 @@ def test_batch_size(
writer.write(next(records))

# test count of records in table after flush
with db.connector.connect(str(db_path)) as con:
with closing(db.connector.connect(str(db_path))) as con:
x = con.execute('SELECT COUNT(*) FROM "test/record"')
assert x.fetchone()[0] == expected_second

Expand Down
Loading