diff --git a/src/pool.rs b/src/pool.rs index 36f8e43..7bd44ce 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -204,9 +204,9 @@ impl Pool { /// After this method returns, all calls to `self::conn()` or /// `self::conn_mut()` will return an [`Error::Closed`] error. pub async fn close(&self) -> Result<(), Error> { - for client in self.state.clients.iter() { - client.close().await?; - } + let closes = self.state.clients.iter().map(|client| client.close()); + let res = join_all(closes).await; + res.into_iter().collect::, Error>>()?; Ok(()) } diff --git a/tests/tests.rs b/tests/tests.rs index 0027220..4eaf002 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -84,6 +84,7 @@ async_test!(test_journal_mode); async_test!(test_concurrency); async_test!(test_pool); async_test!(test_pool_conn_for_each); +async_test!(test_pool_close_concurrent); async_test!(test_pool_num_conns_zero_clamps); async fn test_journal_mode() { @@ -229,6 +230,27 @@ async fn test_pool_conn_for_each() { pool.close().await.expect("closing client conn"); } +async fn test_pool_close_concurrent() { + let tmp_dir = tempfile::tempdir().unwrap(); + let pool = PoolBuilder::new() + .path(tmp_dir.path().join("sqlite.db")) + .num_conns(2) + .open() + .await + .expect("pool unable to be opened"); + + let c1 = pool.close(); + let c2 = pool.close(); + futures_util::future::join_all([c1, c2]) + .await + .into_iter() + .collect::, Error>>() + .expect("closing concurrently"); + + let res = pool.conn(|c| c.execute("SELECT 1", ())).await; + assert!(matches!(res, Err(Error::Closed))); +} + async fn test_pool_num_conns_zero_clamps() { let tmp_dir = tempfile::tempdir().unwrap(); let pool = PoolBuilder::new()