diff --git a/flow/record/adapter/sqlite.py b/flow/record/adapter/sqlite.py index a48c83de..e446234c 100644 --- a/flow/record/adapter/sqlite.py +++ b/flow/record/adapter/sqlite.py @@ -204,6 +204,11 @@ def __iter__(self) -> Iterator[Record]: if match_record_with_context(record, selector, ctx): yield record + def close(self) -> None: + if self.con: + self.con.close() + self.con = None + class SqliteWriter(AbstractWriter): """SQLite writer.""" diff --git a/tests/adapter/test_sqlite_duckdb.py b/tests/adapter/test_sqlite_duckdb.py index 0538c407..148d807a 100644 --- a/tests/adapter/test_sqlite_duckdb.py +++ b/tests/adapter/test_sqlite_duckdb.py @@ -1,6 +1,7 @@ from __future__ import annotations import sqlite3 +from contextlib import closing from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, NamedTuple @@ -136,7 +137,7 @@ def test_write_to_sqlite(tmp_path: Path, count: int, db: Database) -> None: writer.write(record) record_count = 0 - with db.connector.connect(str(db_path)) as con: + with closing(db.connector.connect(str(db_path))) as con: cursor = con.execute("SELECT COUNT(*) FROM 'test/record'") record_count = cursor.fetchone()[0] @@ -157,7 +158,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None: """Tests basic reading from a SQLite database.""" # Generate a SQLite database db_path = tmp_path / "records.db" - with db.connector.connect(str(db_path)) as con: + with closing(db.connector.connect(str(db_path))) as con: con.execute( """ CREATE TABLE 'test/record' ( @@ -176,6 +177,7 @@ def test_read_from_sqlite(tmp_path: Path, db: Database) -> None: """, (f"record{i}", f"foobar{i}".encode(), dt_isoformat, 3.14 + i), ) + con.commit() # Read the SQLite database using flow.record with RecordReader(f"{db.scheme}://{db_path}") as reader: @@ -251,7 +253,7 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None: assert writer # test if it's a valid database - with db.connector.connect(str(db_path)) as con: + with closing(db.connector.connect(str(db_path))) as con: assert con.execute("SELECT * FROM sqlite_master").fetchall() == [] @@ -272,9 +274,10 @@ def test_write_zero_records(tmp_path: Path, db: Database) -> None: def test_non_strict_sqlite_fields(tmp_path: Path, sqlite_coltype: str, sqlite_value: Any, expected_value: Any) -> None: """SQLite by default is non strict, meaning that the value could be of different type than the column type.""" db = tmp_path / "records.db" - with sqlite3.connect(db) as con: + with closing(sqlite3.connect(db)) as con: con.execute(f"CREATE TABLE 'strict-test' (field {sqlite_coltype})") con.execute("INSERT INTO 'strict-test' VALUES(?)", (sqlite_value,)) + con.commit() with RecordReader(f"sqlite://{db}") as reader: record = next(iter(reader)) @@ -294,10 +297,11 @@ def test_invalid_table_names_quoting(tmp_path: Path, invalid_table_name: str) -> # Creating the tables with these invalid_table_names in SQLite is no problem db = tmp_path / "records.db" - with sqlite3.connect(db) as con: + with closing(sqlite3.connect(db)) as con: con.execute(f"CREATE TABLE [{invalid_table_name}] (field TEXT, field2 TEXT)") con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("hello", "world")) con.execute(f"INSERT INTO [{invalid_table_name}] VALUES(?, ?)", ("goodbye", "planet")) + con.commit() # However, these invalid_table_names should raise an exception when reading with ( @@ -320,10 +324,11 @@ def test_invalid_field_names_quoting(tmp_path: Path, invalid_field_name: str) -> # Creating the table with invalid field name in SQLite is no problem db = tmp_path / "records.db" - with sqlite3.connect(db) as con: + with closing(sqlite3.connect(db)) as con: con.execute(f"CREATE TABLE [test] (field TEXT, [{invalid_field_name}] TEXT)") con.execute("INSERT INTO [test] VALUES(?, ?)", ("hello", "world")) con.execute("INSERT INTO [test] VALUES(?, ?)", ("goodbye", "planet")) + con.commit() # However, these field names are invalid in flow.record and should raise an exception with ( @@ -365,7 +370,7 @@ def test_batch_size( writer.write(next(records)) # test count of records in table (no flush yet if batch_size > 1) - with db.connector.connect(str(db_path)) as con: + with closing(db.connector.connect(str(db_path))) as con: x = con.execute('SELECT COUNT(*) FROM "test/record"') assert x.fetchone()[0] is expected_first @@ -374,7 +379,7 @@ def test_batch_size( writer.write(next(records)) # test count of records in table after flush - with db.connector.connect(str(db_path)) as con: + with closing(db.connector.connect(str(db_path))) as con: x = con.execute('SELECT COUNT(*) FROM "test/record"') assert x.fetchone()[0] == expected_second