From 6137d4190ed03a3ca56228a9f5e26fddeac72f4e Mon Sep 17 00:00:00 2001 From: jessekrubin Date: Fri, 7 Feb 2025 12:19:05 -0800 Subject: [PATCH 1/2] feat: conn for each --- src/pool.rs | 33 ++++++++++++++++++++++++++ tests/tests.rs | 63 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/src/pool.rs b/src/pool.rs index e80c48d..e8c133e 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -238,4 +238,37 @@ impl Pool { let n = self.state.counter.fetch_add(1, Relaxed); &self.state.clients[n as usize % self.state.clients.len()] } + + /// Runs a function on all connections in the pool asynchronously. + /// + /// The function is executed on each connection concurrently. + pub async fn conn_for_each(&self, func: F) -> Vec> + where + F: Fn(&Connection) -> Result + Send + Sync + 'static, + T: Send + 'static, + { + let func = Arc::new(func); + let futures = self.state.clients.iter().map(|client| { + let func = func.clone(); + async move { client.conn(move |conn| func(conn)).await } + }); + join_all(futures).await + } + + /// Runs a function on all connections in the pool, blocking the current thread. + pub fn conn_for_each_blocking(&self, func: F) -> Vec> + where + F: Fn(&Connection) -> Result + Send + Sync + 'static, + T: Send + 'static, + { + let func = Arc::new(func); + self.state + .clients + .iter() + .map(|client| { + let func = func.clone(); + client.conn_blocking(move |conn| func(conn)) + }) + .collect() + } } diff --git a/tests/tests.rs b/tests/tests.rs index 74c0095..c0fd31f 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,3 +1,5 @@ +use std::env::temp_dir; + use async_sqlite::{ClientBuilder, Error, JournalMode, PoolBuilder}; #[test] @@ -83,6 +85,7 @@ macro_rules! async_test { async_test!(test_journal_mode); async_test!(test_concurrency); async_test!(test_pool); +async_test!(test_pool_conn_for_each); async fn test_journal_mode() { let tmp_dir = tempfile::tempdir().unwrap(); @@ -166,3 +169,63 @@ async fn test_pool() { .collect::>() .expect("collecting query results"); } + +async fn test_pool_conn_for_each() { + // make dummy db + let tmp_dir = tempfile::tempdir().unwrap(); + { + let client = ClientBuilder::new() + .journal_mode(JournalMode::Wal) + .path(tmp_dir.path().join("sqlite.db")) + .open_blocking() + .expect("client unable to be opened"); + + client + .conn_blocking(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", ["value1"]) + }) + .expect("writing schema and seed data"); + } + + let pool = PoolBuilder::new() + .path(tmp_dir.path().join("another-sqlite.db")) + .num_conns(2) + .open() + .await + .expect("pool unable to be opened"); + + let dummy_db_path = tmp_dir.path().join("sqlite.db"); + let attach_fn = move |conn: &rusqlite::Connection| { + conn.execute( + "ATTACH DATABASE ? AS dummy", + [dummy_db_path.to_str().unwrap()], + ) + }; + // attach to the dummy db via conn_for_each + pool.conn_for_each(attach_fn).await; + + // check that the dummy db is attached + fn check_fn(conn: &rusqlite::Connection) -> Result, rusqlite::Error> { + let mut stmt = conn + .prepare_cached("SELECT name FROM dummy.sqlite_master WHERE type='table'") + .unwrap(); + let names = stmt + .query_map([], |row| row.get(0)) + .unwrap() + .map(|r| r.unwrap()) + .collect::>(); + + Ok(names) + } + let res = pool.conn_for_each(check_fn).await; + for r in res { + assert_eq!(r.unwrap(), vec!["testing"]); + } + + // cleanup + pool.close().await.expect("closing client conn"); +} From 66f88255790ac93299846511a05e4c2154acc18e Mon Sep 17 00:00:00 2001 From: jessekrubin Date: Fri, 7 Feb 2025 12:19:38 -0800 Subject: [PATCH 2/2] conn for each --- tests/tests.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests.rs b/tests/tests.rs index c0fd31f..f96195f 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,5 +1,3 @@ -use std::env::temp_dir; - use async_sqlite::{ClientBuilder, Error, JournalMode, PoolBuilder}; #[test]