Skip to content

Commit c612d05

Browse files
committed
Implement Connection.close(), improve Cursor.close()
Fixes #90
1 parent 138ae16 commit c612d05

File tree

2 files changed

+92
-21
lines changed

2 files changed

+92
-21
lines changed

src/lib.rs

Lines changed: 75 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,10 @@ fn _connect_core(
160160
let conn = db.connect().map_err(to_py_err)?;
161161
Ok(Connection {
162162
db,
163-
conn: Arc::new(ConnectionGuard {
163+
conn: RefCell::new(Some(Arc::new(ConnectionGuard {
164164
conn: Some(conn),
165165
handle: rt.clone(),
166-
}),
166+
}))),
167167
isolation_level,
168168
autocommit,
169169
})
@@ -199,7 +199,7 @@ impl Drop for ConnectionGuard {
199199
#[pyclass]
200200
pub struct Connection {
201201
db: libsql_core::Database,
202-
conn: Arc<ConnectionGuard>,
202+
conn: RefCell<Option<Arc<ConnectionGuard>>>,
203203
isolation_level: Option<String>,
204204
autocommit: i32,
205205
}
@@ -209,10 +209,15 @@ unsafe impl Send for Connection {}
209209

210210
#[pymethods]
211211
impl Connection {
212+
fn close(self_: PyRef<'_, Self>, py: Python<'_>) -> PyResult<()> {
213+
self_.conn.replace(None);
214+
Ok(())
215+
}
216+
212217
fn cursor(&self) -> PyResult<Cursor> {
213218
Ok(Cursor {
214219
arraysize: 1,
215-
conn: self.conn.clone(),
220+
conn: RefCell::new(Some(self.conn.borrow().as_ref().unwrap().clone())),
216221
stmt: RefCell::new(None),
217222
rows: RefCell::new(None),
218223
rowcount: RefCell::new(0),
@@ -235,18 +240,34 @@ impl Connection {
235240

236241
fn commit(self_: PyRef<'_, Self>) -> PyResult<()> {
237242
// TODO: Switch to libSQL transaction API
238-
if !self_.conn.is_autocommit() {
239-
rt().block_on(async { self_.conn.execute("COMMIT", ()).await })
240-
.map_err(to_py_err)?;
243+
if !self_.conn.borrow().as_ref().unwrap().is_autocommit() {
244+
rt().block_on(async {
245+
self_
246+
.conn
247+
.borrow()
248+
.as_ref()
249+
.unwrap()
250+
.execute("COMMIT", ())
251+
.await
252+
})
253+
.map_err(to_py_err)?;
241254
}
242255
Ok(())
243256
}
244257

245258
fn rollback(self_: PyRef<'_, Self>) -> PyResult<()> {
246259
// TODO: Switch to libSQL transaction API
247-
if !self_.conn.is_autocommit() {
248-
rt().block_on(async { self_.conn.execute("ROLLBACK", ()).await })
249-
.map_err(to_py_err)?;
260+
if !self_.conn.borrow().as_ref().unwrap().is_autocommit() {
261+
rt().block_on(async {
262+
self_
263+
.conn
264+
.borrow()
265+
.as_ref()
266+
.unwrap()
267+
.execute("ROLLBACK", ())
268+
.await
269+
})
270+
.map_err(to_py_err)?;
250271
}
251272
Ok(())
252273
}
@@ -276,7 +297,15 @@ impl Connection {
276297

277298
fn executescript(self_: PyRef<'_, Self>, script: String) -> PyResult<()> {
278299
let _ = rt()
279-
.block_on(async { self_.conn.execute_batch(&script).await })
300+
.block_on(async {
301+
self_
302+
.conn
303+
.borrow()
304+
.as_ref()
305+
.unwrap()
306+
.execute_batch(&script)
307+
.await
308+
})
280309
.map_err(to_py_err);
281310
Ok(())
282311
}
@@ -290,9 +319,11 @@ impl Connection {
290319
fn in_transaction(self_: PyRef<'_, Self>) -> PyResult<bool> {
291320
#[cfg(Py_3_12)]
292321
{
293-
return Ok(!self_.conn.is_autocommit() || self_.autocommit == 0);
322+
return Ok(
323+
!self_.conn.borrow().as_ref().unwrap().is_autocommit() || self_.autocommit == 0
324+
);
294325
}
295-
Ok(!self_.conn.is_autocommit())
326+
Ok(!self_.conn.borrow().as_ref().unwrap().is_autocommit())
296327
}
297328

298329
#[getter]
@@ -318,7 +349,7 @@ impl Connection {
318349
pub struct Cursor {
319350
#[pyo3(get, set)]
320351
arraysize: usize,
321-
conn: Arc<ConnectionGuard>,
352+
conn: RefCell<Option<Arc<ConnectionGuard>>>,
322353
stmt: RefCell<Option<libsql_core::Statement>>,
323354
rows: RefCell<Option<libsql_core::Rows>>,
324355
rowcount: RefCell<i64>,
@@ -332,6 +363,8 @@ unsafe impl Send for Cursor {}
332363

333364
impl Drop for Cursor {
334365
fn drop(&mut self) {
366+
let _enter = rt().enter();
367+
self.conn.replace(None);
335368
self.stmt.replace(None);
336369
self.rows.replace(None);
337370
}
@@ -342,6 +375,7 @@ impl Cursor {
342375
fn close(self_: PyRef<'_, Self>) -> PyResult<()> {
343376
rt().block_on(async {
344377
let cursor: &Cursor = &self_;
378+
cursor.conn.replace(None);
345379
cursor.stmt.replace(None);
346380
cursor.rows.replace(None);
347381
});
@@ -373,8 +407,16 @@ impl Cursor {
373407
self_: PyRef<'a, Self>,
374408
script: String,
375409
) -> PyResult<pyo3::PyRef<'a, Self>> {
376-
rt().block_on(async { self_.conn.execute_batch(&script).await })
377-
.map_err(to_py_err)?;
410+
rt().block_on(async {
411+
self_
412+
.conn
413+
.borrow()
414+
.as_ref()
415+
.unwrap()
416+
.execute_batch(&script)
417+
.await
418+
})
419+
.map_err(to_py_err)?;
378420
Ok(self_)
379421
}
380422

@@ -481,7 +523,9 @@ impl Cursor {
481523
fn lastrowid(self_: PyRef<'_, Self>) -> PyResult<Option<i64>> {
482524
let stmt = self_.stmt.borrow();
483525
match stmt.as_ref() {
484-
Some(_) => Ok(Some(self_.conn.last_insert_rowid())),
526+
Some(_) => Ok(Some(
527+
self_.conn.borrow().as_ref().unwrap().last_insert_rowid(),
528+
)),
485529
None => Ok(None),
486530
}
487531
}
@@ -498,10 +542,13 @@ async fn begin_transaction(conn: &libsql_core::Connection) -> PyResult<()> {
498542
}
499543

500544
async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) -> PyResult<()> {
545+
if cursor.conn.borrow().as_ref().is_none() {
546+
return Err(PyValueError::new_err("Connection already closed"));
547+
}
501548
let stmt_is_dml = stmt_is_dml(&sql);
502549
let autocommit = determine_autocommit(cursor);
503-
if !autocommit && stmt_is_dml && cursor.conn.is_autocommit() {
504-
begin_transaction(&cursor.conn).await?;
550+
if !autocommit && stmt_is_dml && cursor.conn.borrow().as_ref().unwrap().is_autocommit() {
551+
begin_transaction(&cursor.conn.borrow().as_ref().unwrap()).await?;
505552
}
506553
let params = match parameters {
507554
Some(parameters) => {
@@ -526,11 +573,18 @@ async fn execute(cursor: &Cursor, sql: String, parameters: Option<&PyTuple>) ->
526573
}
527574
None => libsql_core::params::Params::None,
528575
};
529-
let mut stmt = cursor.conn.prepare(&sql).await.map_err(to_py_err)?;
576+
let mut stmt = cursor
577+
.conn
578+
.borrow()
579+
.as_ref()
580+
.unwrap()
581+
.prepare(&sql)
582+
.await
583+
.map_err(to_py_err)?;
530584
let rows = stmt.query(params).await.map_err(to_py_err)?;
531585
if stmt_is_dml {
532586
let mut rowcount = cursor.rowcount.borrow_mut();
533-
*rowcount += cursor.conn.changes() as i64;
587+
*rowcount += cursor.conn.borrow().as_ref().unwrap().changes() as i64;
534588
} else {
535589
cursor.rowcount.replace(-1);
536590
}

tests/test_suite.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
import pytest
77

88

9+
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
10+
def test_connection_close(provider):
11+
conn = connect(provider, ":memory:")
12+
conn.close()
13+
914
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
1015
def test_execute(provider):
1116
conn = connect(provider, ":memory:")
@@ -24,6 +29,18 @@ def test_cursor_execute(provider):
2429
res = cur.execute("SELECT * FROM users")
2530
assert (1, "alice@example.com") == res.fetchone()
2631

32+
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
33+
def test_cursor_close(provider):
34+
conn = connect(provider, ":memory:")
35+
cur = conn.cursor()
36+
cur.execute("CREATE TABLE users (id INTEGER, email TEXT)")
37+
cur.execute("INSERT INTO users VALUES (1, 'alice@example.com')")
38+
cur.execute("INSERT INTO users VALUES (2, 'bob@example.com')")
39+
res = cur.execute("SELECT * FROM users")
40+
assert [(1, "alice@example.com"), (2, "bob@example.com")] == res.fetchall()
41+
cur.close()
42+
with pytest.raises(Exception):
43+
cur.execute("SELECT * FROM users")
2744

2845
@pytest.mark.parametrize("provider", ["libsql", "sqlite"])
2946
def test_executemany(provider):

0 commit comments

Comments
 (0)