diff --git a/Cargo.lock b/Cargo.lock index a15482e..c381d55 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -449,7 +449,7 @@ dependencies = [ [[package]] name = "kovan" -version = "0.1.10" +version = "0.1.11" dependencies = [ "criterion", "crossbeam-epoch", @@ -461,7 +461,7 @@ dependencies = [ [[package]] name = "kovan-channel" -version = "0.1.10" +version = "0.1.11" dependencies = [ "crossbeam-utils", "futures", @@ -470,7 +470,7 @@ dependencies = [ [[package]] name = "kovan-map" -version = "0.1.10" +version = "0.1.11" dependencies = [ "criterion", "foldhash 0.2.0", @@ -479,17 +479,18 @@ dependencies = [ [[package]] name = "kovan-mvcc" -version = "0.1.10" +version = "0.1.11" dependencies = [ "kovan", "kovan-map", + "parking_lot", "rand", "uuid", ] [[package]] name = "kovan-queue" -version = "0.1.10" +version = "0.1.11" dependencies = [ "crossbeam-utils", "kovan", @@ -497,7 +498,7 @@ dependencies = [ [[package]] name = "kovan-stm" -version = "0.1.10" +version = "0.1.11" dependencies = [ "kovan", "kovan-map", @@ -521,6 +522,15 @@ version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + [[package]] name = "log" version = "0.4.29" @@ -585,6 +595,29 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -717,6 +750,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + [[package]] name = "regex" version = "1.12.3" @@ -767,6 +809,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "seize" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 1666149..8f2e43c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ members = [ ] [workspace.package] -version = "0.1.10" +version = "0.1.11" edition = "2024" rust-version = "1.90" authors = ["Theo M. Bulut "] @@ -20,12 +20,12 @@ readme = "README.md" exclude = ["model_chk"] [workspace.dependencies] -kovan = { version = "0.1.10", path = "kovan" } -kovan-channel = { version = "0.1.10", path = "kovan-channel" } -kovan-map = { version = "0.1.10", path = "kovan-map" } -kovan-mvcc = { version = "0.1.10", path = "kovan-mvcc" } -kovan-stm = { version = "0.1.10", path = "kovan-stm" } -kovan-queue = { version = "0.1.10", path = "kovan-queue" } +kovan = { version = "0.1.11", path = "kovan" } +kovan-channel = { version = "0.1.11", path = "kovan-channel" } +kovan-map = { version = "0.1.11", path = "kovan-map" } +kovan-mvcc = { version = "0.1.11", path = "kovan-mvcc" } +kovan-stm = { version = "0.1.11", path = "kovan-stm" } +kovan-queue = { version = "0.1.11", path = "kovan-queue" } [workspace.metadata.docs.rs] rustdoc-args = ["-C", "target-feature=+cmpxchg16b"] diff --git a/kovan-channel/src/flavors/bounded.rs b/kovan-channel/src/flavors/bounded.rs index 904099d..e95dbea 100644 --- a/kovan-channel/src/flavors/bounded.rs +++ b/kovan-channel/src/flavors/bounded.rs @@ -4,6 +4,7 @@ use crossbeam_utils::Backoff; use std::collections::LinkedList; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; +use std::thread; struct Channel { sender: unbounded::Sender, @@ -12,6 +13,10 @@ struct Channel { len: AtomicUsize, senders: Mutex>>, receivers: Mutex>>, + /// Number of live bounded Sender handles. + sender_count: AtomicUsize, + /// Set when all bounded senders are dropped. + disconnected: std::sync::atomic::AtomicBool, } /// The sending half of a bounded channel. @@ -21,12 +26,28 @@ pub struct Sender { impl Clone for Sender { fn clone(&self) -> Self { + self.inner.sender_count.fetch_add(1, Ordering::Relaxed); Self { inner: self.inner.clone(), } } } +impl Drop for Sender { + fn drop(&mut self) { + if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + self.inner.disconnected.store(true, Ordering::Release); + // Wake all blocked receivers + { + let mut receivers = self.inner.receivers.lock().unwrap(); + while let Some(signal) = receivers.pop_front() { + signal.notify(); + } + } + } + } +} + unsafe impl Send for Sender {} unsafe impl Sync for Sender {} @@ -56,6 +77,8 @@ impl Channel { len: AtomicUsize::new(0), senders: Mutex::new(LinkedList::new()), receivers: Mutex::new(LinkedList::new()), + sender_count: AtomicUsize::new(1), + disconnected: std::sync::atomic::AtomicBool::new(false), } } } @@ -242,30 +265,59 @@ impl Receiver { } } + /// Returns `true` if all senders have been dropped. + pub fn is_disconnected(&self) -> bool { + self.inner.disconnected.load(Ordering::Acquire) + } + /// Receives a message from the channel, blocking if empty. + /// + /// Returns `None` when the channel is empty **and** all senders have been dropped. pub fn recv(&self) -> Option { if let Some(msg) = self.try_recv() { return Some(msg); } + if self.is_disconnected() { + return self.try_recv(); + } + loop { let signal = Arc::new(Signal::new()); - // Register signal { let mut receivers = self.inner.receivers.lock().unwrap(); receivers.push_back(signal.clone()); } - // Re-check if let Some(msg) = self.try_recv() { return Some(msg); } - signal.wait(); + if self.is_disconnected() { + return self.try_recv(); + } + + // Wait, checking the queue on every wakeup. + loop { + if signal.is_notified() { + break; + } + thread::park(); + if let Some(msg) = self.try_recv() { + return Some(msg); + } + if self.is_disconnected() { + return self.try_recv(); + } + } if let Some(msg) = self.try_recv() { return Some(msg); } + + if self.is_disconnected() { + return None; + } } } diff --git a/kovan-channel/src/flavors/unbounded.rs b/kovan-channel/src/flavors/unbounded.rs index 9ce7460..1cfb478 100644 --- a/kovan-channel/src/flavors/unbounded.rs +++ b/kovan-channel/src/flavors/unbounded.rs @@ -1,7 +1,7 @@ use kovan::{Atomic, RetiredNode, Shared, pin, retire}; use std::ptr; use std::sync::Arc; -use std::sync::atomic::Ordering; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use crate::signal::{AsyncSignal, Notifier, Signal}; use std::collections::LinkedList; @@ -28,6 +28,10 @@ pub(crate) struct Channel { head: Atomic>, tail: Atomic>, receivers: Mutex>>, + /// Number of live Sender handles. When this reaches 0, the channel is disconnected. + sender_count: AtomicUsize, + /// Set to true when all senders have been dropped. + disconnected: AtomicBool, } impl Channel { @@ -37,6 +41,16 @@ impl Channel { head: Atomic::new(sentinel), tail: Atomic::new(sentinel), receivers: Mutex::new(LinkedList::new()), + sender_count: AtomicUsize::new(1), // Starts at 1 for the initial Sender + disconnected: AtomicBool::new(false), + } + } + + /// Wake all blocked receivers (used when senders disconnect). + fn wake_all_receivers(&self) { + let mut receivers = self.receivers.lock().unwrap(); + while let Some(signal) = receivers.pop_front() { + signal.notify(); } } } @@ -63,11 +77,14 @@ impl Drop for Channel { while !curr.is_null() { let next = unsafe { curr.deref().next.load(Ordering::Relaxed, &guard) }; - // We can't just drop `curr` because of kovan. - // We should `retire` it. unsafe { retire(curr.as_raw()) }; curr = next; } + + // Force-flush retired nodes on this thread to prevent use-after-free + // during process teardown when kovan's global state is being destroyed. + drop(guard); + kovan::flush(); } } @@ -78,12 +95,23 @@ pub struct Sender { impl Clone for Sender { fn clone(&self) -> Self { + self.inner.sender_count.fetch_add(1, Ordering::Relaxed); Self { inner: self.inner.clone(), } } } +impl Drop for Sender { + fn drop(&mut self) { + if self.inner.sender_count.fetch_sub(1, Ordering::AcqRel) == 1 { + // Last sender dropped — mark channel as disconnected and wake all receivers. + self.inner.disconnected.store(true, Ordering::Release); + self.inner.wake_all_receivers(); + } + } +} + unsafe impl Send for Sender {} unsafe impl Sync for Sender {} @@ -172,6 +200,11 @@ impl Sender { } impl Receiver { + /// Returns `true` if all senders have been dropped. + pub fn is_disconnected(&self) -> bool { + self.inner.disconnected.load(Ordering::Acquire) + } + /// Attempts to receive a message from the channel without blocking. pub fn try_recv(&self) -> Option { let guard = pin(); @@ -218,7 +251,14 @@ impl Receiver { // We take the data from `next`. // `next` is now the sentinel. // Its data is logically gone. - let data = unsafe { ptr::read(data_ptr) }; + // + // CRITICAL: use `ptr::replace` (not `ptr::read`) to clear + // the source field. `ptr::read` leaves the bit pattern + // intact, so when this node is later retired and freed by + // kovan, `Node::drop` would drop `data: Option` again + // — a double-free. Writing `None` ensures the destructor + // sees an empty Option and skips the inner drop. + let data = unsafe { ptr::replace(data_ptr, None) }; return data; } Err(_) => continue, @@ -228,12 +268,18 @@ impl Receiver { } /// Receives a message from the channel, blocking if empty. - /// Receives a message from the channel, blocking if empty. + /// + /// Returns `None` when the channel is empty **and** all senders have been dropped. pub fn recv(&self) -> Option { if let Some(msg) = self.try_recv() { return Some(msg); } + // Fast path: already disconnected and empty + if self.is_disconnected() { + return self.try_recv(); // Drain remaining + } + loop { let signal = Arc::new(Signal::new()); // Register signal @@ -244,17 +290,25 @@ impl Receiver { // Re-check to avoid race if let Some(msg) = self.try_recv() { - // We got a message, remove signal if still there? - // It might have been popped by sender, but that's fine, notify is harmless. return Some(msg); } + // Check disconnection after registering signal but before parking + if self.is_disconnected() { + return self.try_recv(); // Drain remaining + } + signal.wait(); - // Woken up, try to receive + // Woken up — either a message arrived or senders disconnected if let Some(msg) = self.try_recv() { return Some(msg); } + + // Woken by disconnect with empty queue + if self.is_disconnected() { + return None; + } } } @@ -267,9 +321,6 @@ impl Receiver { head == tail && next.is_null() } - /// Registers a signal for notification when a message arrives. - /// - /// This is used for `select!` implementation. /// Registers a signal for notification when a message arrives. /// /// This is used for `select!` implementation. @@ -279,7 +330,9 @@ impl Receiver { } /// Receives a message from the channel asynchronously. - pub async fn recv_async(&self) -> T { + /// + /// Returns `None` when the channel is empty and all senders have been dropped. + pub async fn recv_async(&self) -> Option { use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -292,25 +345,26 @@ impl Receiver { impl<'a, T: 'static> Unpin for RecvFuture<'a, T> {} impl<'a, T: 'static> Future for RecvFuture<'a, T> { - type Output = T; + type Output = Option; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); if let Some(msg) = this.receiver.try_recv() { - return Poll::Ready(msg); + return Poll::Ready(Some(msg)); + } + + // Disconnected and empty — done + if this.receiver.is_disconnected() { + return Poll::Ready(this.receiver.try_recv()); } if this.signal.is_notified() { - // We were notified but failed to get message (stolen). - // We need a new signal. this.signal = Arc::new(AsyncSignal::new()); } this.signal.register(cx.waker()); // Register signal - // Note: This might register duplicates if polled spuriously. - // But AsyncSignal handles multiple notifies. { let mut receivers = this.receiver.inner.receivers.lock().unwrap(); receivers.push_back(this.signal.clone()); @@ -318,7 +372,11 @@ impl Receiver { // Re-check if let Some(msg) = this.receiver.try_recv() { - return Poll::Ready(msg); + return Poll::Ready(Some(msg)); + } + + if this.receiver.is_disconnected() { + return Poll::Ready(None); } Poll::Pending diff --git a/kovan-channel/tests/async_test.rs b/kovan-channel/tests/async_test.rs index 48bf45b..71f52a5 100644 --- a/kovan-channel/tests/async_test.rs +++ b/kovan-channel/tests/async_test.rs @@ -12,8 +12,8 @@ fn test_unbounded_async() { s.send(1); s.send(2); - assert_eq!(r.recv_async().await, 1); - assert_eq!(r.recv_async().await, 2); + assert_eq!(r.recv_async().await, Some(1)); + assert_eq!(r.recv_async().await, Some(2)); let r_clone = r.clone(); thread::spawn(move || { @@ -21,7 +21,7 @@ fn test_unbounded_async() { s.send(3); }); - assert_eq!(r_clone.recv_async().await, 3); + assert_eq!(r_clone.recv_async().await, Some(3)); }); } @@ -52,7 +52,7 @@ fn test_mixed_async_blocking() { let (s, r) = unbounded(); s.send(1); - assert_eq!(r.recv_async().await, 1); + assert_eq!(r.recv_async().await, Some(1)); let r_clone = r.clone(); thread::spawn(move || { @@ -60,6 +60,6 @@ fn test_mixed_async_blocking() { s.send(2); }); - assert_eq!(r_clone.recv_async().await, 2); + assert_eq!(r_clone.recv_async().await, Some(2)); }); } diff --git a/kovan-channel/tests/unbounded_test.rs b/kovan-channel/tests/unbounded_test.rs index d42ea90..1cc0bad 100644 --- a/kovan-channel/tests/unbounded_test.rs +++ b/kovan-channel/tests/unbounded_test.rs @@ -79,6 +79,60 @@ fn test_receiver_clone() { assert_eq!(r2.recv(), Some(2)); } +/// Regression test for double-free in `try_recv`. +/// +/// `try_recv` uses `ptr::read` (now `ptr::replace`) to extract data from the +/// successor node. Without replacing the source with `None`, the node's +/// `Option` discriminant still says `Some` when kovan's reclamation later +/// frees the node, causing `Node::drop` to double-free the inner value. +/// +/// This test uses `String` (a heap-allocated, Drop type) to detect double-free. +/// With Copy types like `i32`, the double-free silently corrupts the heap but +/// rarely crashes; with String it reliably triggers "double free or corruption". +#[test] +#[cfg_attr(miri, ignore)] +fn test_no_double_free_on_recv() { + // Run on a dedicated thread so Handle cleanup (which calls try_retire and + // free_batch_list) happens deterministically at thread exit. + let handle = thread::spawn(|| { + for _ in 0..200 { + let (tx, rx) = unbounded::(); + tx.send("hello".to_string()); + tx.send("world".to_string()); + let a = rx.try_recv().unwrap(); + let b = rx.try_recv().unwrap(); + assert_eq!(a, "hello"); + assert_eq!(b, "world"); + // Channel drop retires remaining nodes (sentinel). + // Consumed nodes are retired on next try_recv or drop. + // If ptr::read left stale Some(value), kovan's destructor + // would double-free the strings here. + } + }); + handle.join().unwrap(); +} + +/// Test that dropping a channel with unconsumed heap-allocated messages +/// doesn't double-free. Channel::drop retires all remaining nodes; consumed +/// nodes must have their data cleared to None. +#[test] +#[cfg_attr(miri, ignore)] +fn test_no_double_free_on_drop_with_pending() { + let handle = thread::spawn(|| { + for _ in 0..200 { + let (tx, rx) = unbounded::(); + for i in 0..5 { + tx.send(format!("msg-{i}")); + } + // Consume only some messages — the rest are pending at drop time. + let _ = rx.try_recv(); + let _ = rx.try_recv(); + // Drop channel with 3 pending messages + 2 consumed sentinel nodes. + } + }); + handle.join().unwrap(); +} + /// Concurrent send/recv must not violate aliasing rules. /// /// Before the fix, `try_recv()` created `&mut (*next.as_raw()).data` — a mutable diff --git a/kovan-map/src/hashmap.rs b/kovan-map/src/hashmap.rs index d1b8acf..811c423 100644 --- a/kovan-map/src/hashmap.rs +++ b/kovan-map/src/hashmap.rs @@ -598,6 +598,10 @@ impl Drop for HashMap { } } } + + // Flush nodes previously retired by concurrent operations + drop(guard); + kovan::flush(); } } diff --git a/kovan-map/src/hopscotch.rs b/kovan-map/src/hopscotch.rs index 9847dec..92ea726 100644 --- a/kovan-map/src/hopscotch.rs +++ b/kovan-map/src/hopscotch.rs @@ -225,6 +225,10 @@ where /// Helper for get_or_insert logic. fn insert_impl(&self, key: K, value: V, only_if_absent: bool) -> Option { let hash = self.hasher.hash_one(&key); + // Track whether we already incremented count for a new insert across + // retry iterations. Prevents both under-count (which causes cascading + // resizes on Windows) and double-count. + let mut counted = false; loop { self.wait_for_resize(); @@ -249,18 +253,24 @@ where &guard, ) { InsertResult::Success(old_val) => { - // CRITICAL FIX: If a resize started while we were inserting, our update + // Count new inserts immediately so that concurrent removes + // cannot decrement count below the true entry count. + if old_val.is_none() && !counted { + self.count.fetch_add(1, Ordering::Relaxed); + counted = true; + } + + // If a resize started while we were inserting, our update // might have been missed by the migration. // We must retry to ensure we write to the new table. - // We check BOTH the resizing flag (active resize) AND the table pointer (completed resize). if self.resizing.load(Ordering::SeqCst) || self.table.load(Ordering::SeqCst, &guard) != table_ptr { continue; } - if old_val.is_none() { - let new_count = self.count.fetch_add(1, Ordering::Relaxed) + 1; + if counted { + let new_count = self.count.load(Ordering::Relaxed); let current_capacity = table.capacity; let load_factor = new_count as f64 / current_capacity as f64; @@ -289,21 +299,27 @@ where /// Returns the value corresponding to the key, or inserts the given value if the key is not present. /// - /// When multiple threads call this concurrently for the same key, all callers - /// are guaranteed to receive the same value (the one visible in the map). + /// When multiple threads call this concurrently for the same key (without + /// concurrent removes), all callers receive the same value. pub fn get_or_insert(&self, key: K, value: V) -> V { // Fast path: key already exists — no clone, no insert. if let Some(v) = self.get(&key) { return v; } - // Slow path: try to insert, then read back for concurrent consistency. - // The hop_info update is non-atomic with the slot CAS, so two threads can - // both "win" the insert. Reading back ensures all callers agree on one value. - // - // Note: `expect` has zero overhead vs `unwrap` on the success path — the - // static string literal is only materialized in the panic (unreachable) path. - let _ = self.insert_impl(key.clone(), value, true); - self.get(&key).expect("key was just inserted") + // Slow path: insert_if_absent and use the return value directly. + // We must NOT do insert-then-get because a concurrent remove between + // the two operations would cause get to return None. + let key2 = key.clone(); + match self.insert_impl(key, value.clone(), true) { + None => { + // We inserted, but concurrent inserts may have also placed + // the same key at a different offset (the CAS-then-hop-bit + // window allows duplicates). Re-get returns the canonical + // (lowest-offset) entry so every caller agrees on one value. + self.get(&key2).unwrap_or(value) + } + Some(existing) => existing, // Key already existed + } } /// Insert a key-value pair only if the key does not exist. @@ -363,15 +379,25 @@ where unsafe { retire(entry_ptr.as_raw()) }; - let new_count = self.count.fetch_sub(1, Ordering::Relaxed) - 1; - let current_capacity = table.capacity; - let load_factor = new_count as f64 / current_capacity as f64; - - if load_factor < SHRINK_THRESHOLD - && current_capacity > MIN_CAPACITY - { - drop(guard); - self.try_resize(current_capacity / 2); + // Saturating decrement: prevent count from wrapping + // to usize::MAX which would trigger catastrophic + // cascading resizes. + if let Ok(prev) = self.count.fetch_update( + Ordering::Relaxed, + Ordering::Relaxed, + |c| c.checked_sub(1), + ) { + let new_count = prev - 1; + let current_capacity = table.capacity; + let load_factor = + new_count as f64 / current_capacity as f64; + + if load_factor < SHRINK_THRESHOLD + && current_capacity > MIN_CAPACITY + { + drop(guard); + self.try_resize(current_capacity / 2); + } } return Some(old_value); @@ -957,6 +983,11 @@ impl Drop for HopscotchMap { unsafe { drop(Box::from_raw(table_ptr.as_raw())); } + + // Flush nodes previously retired by concurrent operations (insert/remove/resize) + // to prevent use-after-free during process teardown. + drop(guard); + kovan::flush(); } } @@ -1077,4 +1108,42 @@ mod tests { } } } + + /// Regression: get_or_insert must not panic when a concurrent remove + /// deletes the key between the internal insert and the return. + #[test] + fn test_hopscotch_get_or_insert_concurrent_remove() { + use alloc::sync::Arc; + extern crate std; + use std::sync::Barrier; + use std::thread; + + let map = Arc::new(HopscotchMap::::with_capacity(64)); + let barrier = Arc::new(Barrier::new(8)); + + let handles: Vec<_> = (0..8u64) + .map(|tid| { + let map = map.clone(); + let barrier = barrier.clone(); + thread::spawn(move || { + barrier.wait(); + for i in 0..5000u64 { + let key = i % 32; // Small key space forces heavy contention + if tid % 2 == 0 { + // Half the threads do get_or_insert + let _ = map.get_or_insert(key, tid * 1000 + i); + } else { + // Other half remove + let _ = map.remove(&key); + } + } + }) + }) + .collect(); + + for h in handles { + h.join() + .expect("Thread panicked during get_or_insert/remove race"); + } + } } diff --git a/kovan-mvcc/Cargo.toml b/kovan-mvcc/Cargo.toml index a210d52..6860a86 100644 --- a/kovan-mvcc/Cargo.toml +++ b/kovan-mvcc/Cargo.toml @@ -19,6 +19,7 @@ path = "src/lib.rs" uuid = { version = "1.18", features = ["v4"] } kovan = { workspace = true } kovan-map = { workspace = true } +parking_lot = "0.12" [dev-dependencies] rand = "0.9" diff --git a/kovan-mvcc/src/backoff.rs b/kovan-mvcc/src/backoff.rs new file mode 100644 index 0000000..dbdcf44 --- /dev/null +++ b/kovan-mvcc/src/backoff.rs @@ -0,0 +1,32 @@ +/// Action to take after a backoff attempt +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BackoffAction { + /// Retry the operation + Retry, + /// Yield the current thread and retry + Yield, + /// Abort the operation + Abort, +} + +/// Pluggable backoff strategy for lock conflicts during reads +pub trait BackoffStrategy: Send + Sync { + /// Determine the action to take for a given attempt number (0-indexed) + fn backoff(&self, attempt: u32) -> BackoffAction; +} + +/// Default backoff: retry with yield for first 3 attempts, then short sleep, then abort at 8. +pub struct DefaultBackoff; + +impl BackoffStrategy for DefaultBackoff { + fn backoff(&self, attempt: u32) -> BackoffAction { + match attempt { + 0..3 => BackoffAction::Yield, + 3..8 => { + std::thread::sleep(std::time::Duration::from_micros(100 << (attempt - 3))); + BackoffAction::Retry + } + _ => BackoffAction::Abort, + } + } +} diff --git a/kovan-mvcc/src/error.rs b/kovan-mvcc/src/error.rs new file mode 100644 index 0000000..cff12d0 --- /dev/null +++ b/kovan-mvcc/src/error.rs @@ -0,0 +1,68 @@ +use std::fmt; + +/// Typed errors for MVCC operations +#[derive(Debug, Clone)] +pub enum MvccError { + /// Another transaction holds a lock on the key + LockConflict { key: String, holder_txn: u128 }, + /// A write conflict was detected (another txn committed after our start_ts) + WriteConflict { key: String, conflicting_ts: u64 }, + /// A rollback record exists for the key at or after our start_ts + RollbackRecord { key: String }, + /// Primary lock is missing during commit + PrimaryLockMissing { key: String }, + /// Primary lock belongs to a different transaction + PrimaryLockMismatch, + /// Serialization failure (SSI): a concurrent transaction modified a key we read + SerializationFailure { key: String, conflicting_ts: u64 }, + /// Storage layer error + StorageError(String), +} + +impl fmt::Display for MvccError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + MvccError::LockConflict { key, holder_txn } => { + write!( + f, + "Lock conflict on key '{}' held by txn {}", + key, holder_txn + ) + } + MvccError::WriteConflict { + key, + conflicting_ts, + } => { + write!( + f, + "Write conflict on key '{}' at ts {}", + key, conflicting_ts + ) + } + MvccError::RollbackRecord { key } => { + write!(f, "Rollback record exists for key '{}'", key) + } + MvccError::PrimaryLockMissing { key } => { + write!(f, "Primary lock missing for key '{}'", key) + } + MvccError::PrimaryLockMismatch => { + write!(f, "Primary lock mismatch") + } + MvccError::SerializationFailure { + key, + conflicting_ts, + } => { + write!( + f, + "Serialization failure: key '{}' was modified at ts {} by concurrent transaction", + key, conflicting_ts + ) + } + MvccError::StorageError(msg) => { + write!(f, "Storage error: {}", msg) + } + } + } +} + +impl std::error::Error for MvccError {} diff --git a/kovan-mvcc/src/lib.rs b/kovan-mvcc/src/lib.rs index 3ab1655..ee8f394 100644 --- a/kovan-mvcc/src/lib.rs +++ b/kovan-mvcc/src/lib.rs @@ -33,12 +33,17 @@ //! assert_eq!(val, b"value1"); //! ``` +pub mod backoff; +pub mod error; mod lock_table; pub mod percolator; pub mod storage; mod timestamp_oracle; // Export KovanMVCC and Txn directly from percolator +pub use crate::backoff::{BackoffAction, BackoffStrategy, DefaultBackoff}; +pub use crate::error::MvccError; pub use crate::lock_table::{LockInfo, LockTable, LockType}; -pub use crate::percolator::{KovanMVCC, Txn}; +pub use crate::percolator::{ActiveTxnRegistry, IsolationLevel, KovanMVCC, Txn}; +pub use crate::storage::{InMemoryStorage, Storage, Value, WriteInfo, WriteKind}; pub use crate::timestamp_oracle::{LocalTimestampOracle, MockTimestampOracle, TimestampOracle}; diff --git a/kovan-mvcc/src/lock_table.rs b/kovan-mvcc/src/lock_table.rs index 5971275..060457f 100644 --- a/kovan-mvcc/src/lock_table.rs +++ b/kovan-mvcc/src/lock_table.rs @@ -15,6 +15,7 @@ pub enum LockType { Delete, } +use crate::error::MvccError; use kovan_map::HashMap; /// Lock Table @@ -40,7 +41,7 @@ impl Default for LockTable { impl LockTable { /// Try to acquire a lock on a key /// Returns Ok(()) if successful, Err if key is already locked - pub fn try_lock(&self, key: &str, lock_info: LockInfo) -> Result<(), String> { + pub fn try_lock(&self, key: &str, lock_info: LockInfo) -> Result<(), MvccError> { match self .locks .insert_if_absent(key.to_string(), lock_info.clone()) @@ -52,7 +53,10 @@ impl LockTable { self.locks.insert(key.to_string(), lock_info); Ok(()) } else { - Err(format!("Key {} is locked", key)) + Err(MvccError::LockConflict { + key: key.to_string(), + holder_txn: existing.txn_id, + }) } } } diff --git a/kovan-mvcc/src/percolator.rs b/kovan-mvcc/src/percolator.rs index 04da7bb..dd5f07b 100644 --- a/kovan-mvcc/src/percolator.rs +++ b/kovan-mvcc/src/percolator.rs @@ -1,13 +1,111 @@ +use crate::backoff::{BackoffAction, BackoffStrategy, DefaultBackoff}; +use crate::error::MvccError; use crate::lock_table::{LockInfo, LockType}; use crate::storage::{InMemoryStorage, Storage, Value, WriteInfo, WriteKind}; use crate::timestamp_oracle::{LocalTimestampOracle, TimestampOracle}; use kovan_map::HopscotchMap; use std::sync::Arc; +/// Registry of active transactions, used for GC watermark computation. +pub struct ActiveTxnRegistry { + txns: kovan_map::HashMap, // txn_id -> start_ts +} + +impl ActiveTxnRegistry { + pub fn new() -> Self { + Self { + txns: kovan_map::HashMap::new(), + } + } + + /// Register a new active transaction. + pub fn register(&self, txn_id: u128, start_ts: u64) { + self.txns.insert(txn_id, start_ts); + } + + /// Unregister a transaction (on commit, rollback, or drop). + pub fn unregister(&self, txn_id: u128) { + self.txns.remove(&txn_id); + } + + /// Returns the minimum start_ts across all active transactions. + /// Returns None if no transactions are active. + pub fn min_active_ts(&self) -> Option { + let mut min = None; + for (_, ts) in self.txns.iter() { + match min { + None => min = Some(ts), + Some(current) if ts < current => min = Some(ts), + _ => {} + } + } + min + } + + /// Returns the minimum start_ts, ignoring transactions older than max_age. + /// This prevents long-running OLAP queries from blocking GC (anti-vicious-cycle). + pub fn min_active_ts_bounded(&self, current_ts: u64, max_age: u64) -> Option { + let cutoff = current_ts.saturating_sub(max_age); + let mut min = None; + for (_, ts) in self.txns.iter() { + if ts >= cutoff { + match min { + None => min = Some(ts), + Some(current) if ts < current => min = Some(ts), + _ => {} + } + } + } + min + } + + /// Returns the number of active transactions. + pub fn len(&self) -> usize { + self.txns.len() + } + + /// Returns true if there are no active transactions. + pub fn is_empty(&self) -> bool { + self.txns.is_empty() + } +} + +impl Default for ActiveTxnRegistry { + fn default() -> Self { + Self::new() + } +} + +/// Transaction isolation level. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum IsolationLevel { + /// Each statement sees a fresh snapshot (per-read freshness). PostgreSQL default. + #[default] + ReadCommitted, + /// Entire transaction uses one snapshot (standard SI behavior). + RepeatableRead, + /// SI + detection of serialization anomalies (write skew prevention). + Serializable, +} + /// KovanMVCC (Percolator-style) pub struct KovanMVCC { + // NOTE: Debug manually implemented below storage: Arc, ts_oracle: Arc, + backoff: Arc, + active_txns: Arc, + /// Serializes SSI validation + commit for Serializable transactions. + /// Ensures that when one Serializable txn commits, the next one sees it. + ssi_commit_lock: Arc>, +} + +impl std::fmt::Debug for KovanMVCC { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("KovanMVCC") + .field("active_txns", &self.active_txns.len()) + .finish() + } } impl KovanMVCC { @@ -27,6 +125,9 @@ impl KovanMVCC { Self { storage: Arc::new(InMemoryStorage::new()), ts_oracle, + backoff: Arc::new(DefaultBackoff), + active_txns: Arc::new(ActiveTxnRegistry::new()), + ssi_commit_lock: Arc::new(parking_lot::Mutex::new(())), } } @@ -34,20 +135,79 @@ impl KovanMVCC { Self { storage, ts_oracle: Arc::new(LocalTimestampOracle::new()), + backoff: Arc::new(DefaultBackoff), + active_txns: Arc::new(ActiveTxnRegistry::new()), + ssi_commit_lock: Arc::new(parking_lot::Mutex::new(())), } } + pub fn with_storage_and_oracle( + storage: Arc, + ts_oracle: Arc, + ) -> Self { + Self { + storage, + ts_oracle, + backoff: Arc::new(DefaultBackoff), + active_txns: Arc::new(ActiveTxnRegistry::new()), + ssi_commit_lock: Arc::new(parking_lot::Mutex::new(())), + } + } + + /// Set a custom backoff strategy. + pub fn set_backoff(&mut self, backoff: Arc) { + self.backoff = backoff; + } + + /// Get a reference to the active transaction registry. + pub fn active_txns(&self) -> &Arc { + &self.active_txns + } + + /// Get a reference to the timestamp oracle. + pub fn ts_oracle(&self) -> &Arc { + &self.ts_oracle + } + + /// Get a reference to the storage backend. + pub fn storage(&self) -> &Arc { + &self.storage + } + + /// Begin a transaction with the default isolation level (ReadCommitted). + /// + /// This matches the PostgreSQL default. Use `begin_with_isolation()` for + /// other levels (e.g., RepeatableRead for snapshot isolation). pub fn begin(&self) -> Txn { + self.begin_with_isolation(IsolationLevel::ReadCommitted) + } + + pub fn begin_with_isolation(&self, isolation_level: IsolationLevel) -> Txn { let start_ts = self.ts_oracle.get_timestamp(); + let txn_id = uuid::Uuid::new_v4().as_u128(); + + // Register in active transaction registry + self.active_txns.register(txn_id, start_ts); + + let read_set = if isolation_level == IsolationLevel::Serializable { + Some(HopscotchMap::new()) + } else { + None + }; Txn { - txn_id: uuid::Uuid::new_v4().as_u128(), + txn_id, start_ts, storage: self.storage.clone(), ts_oracle: self.ts_oracle.clone(), + backoff: self.backoff.clone(), + active_txns: self.active_txns.clone(), writes: HopscotchMap::new(), primary_key: None, committed: false, + isolation_level, + read_set, + ssi_commit_lock: self.ssi_commit_lock.clone(), } } } @@ -57,28 +217,72 @@ pub struct Txn { start_ts: u64, storage: Arc, ts_oracle: Arc, + backoff: Arc, + active_txns: Arc, /// Buffered writes: key -> (lock_type, value) writes: HopscotchMap)>, /// Primary key for 2PC primary_key: Option, /// Whether this transaction has been committed (prevents Drop from rolling back) committed: bool, + /// Isolation level for this transaction + isolation_level: IsolationLevel, + /// Read-set for Serializable: keys read during the transaction + read_set: Option>, + /// SSI commit lock (shared with KovanMVCC), serializes Serializable commits + ssi_commit_lock: Arc>, } impl Txn { + /// Get the transaction ID. + pub fn txn_id(&self) -> u128 { + self.txn_id + } + + /// Get the start timestamp. + pub fn start_ts(&self) -> u64 { + self.start_ts + } + + /// Get the isolation level. + pub fn isolation_level(&self) -> IsolationLevel { + self.isolation_level + } + /// Get operation (Snapshot Read) + /// + /// Behavior varies by isolation level: + /// - ReadCommitted: uses a fresh timestamp per read call + /// - RepeatableRead: uses start_ts (standard SI) + /// - Serializable: uses start_ts + tracks key in read_set pub fn read(&self, key: &str) -> Option> { - let mut attempts = 0; - loop { - attempts += 1; + // 0. Check local write buffer first (read-your-own-writes, even before prewrite) + if let Some((lock_type, value_opt)) = self.writes.get(key) { + // Track in read_set for Serializable even on local hits + if let Some(ref read_set) = self.read_set { + read_set.insert(key.to_string(), ()); + } + return match lock_type { + LockType::Put => value_opt.map(|arc| (*arc).clone()), + LockType::Delete => None, + }; + } + + // Determine read timestamp based on isolation level + let read_ts = match self.isolation_level { + IsolationLevel::ReadCommitted => self.ts_oracle.get_timestamp(), + IsolationLevel::RepeatableRead | IsolationLevel::Serializable => self.start_ts, + }; - // 1. Check for locks with start_ts <= self.start_ts + let mut attempts = 0u32; + loop { + // 1. Check for locks with start_ts <= read_ts if let Some(lock) = self.storage.get_lock(key) - && lock.start_ts <= self.start_ts + && lock.start_ts <= read_ts { // Key is locked by an active transaction that started before us. if lock.txn_id == self.txn_id { - // Read-your-own-writes + // Read-your-own-writes (already prewritten) return self .writes .get(key) @@ -86,23 +290,37 @@ impl Txn { } // Locked by another transaction. - // Backoff and retry. - if attempts < 5 { - std::thread::sleep(std::time::Duration::from_millis(1)); - continue; + // Use pluggable backoff strategy. + match self.backoff.backoff(attempts) { + BackoffAction::Retry => { + attempts += 1; + continue; + } + BackoffAction::Yield => { + std::thread::yield_now(); + attempts += 1; + continue; + } + BackoffAction::Abort => { + eprintln!( + "[READ_CONFLICT] key={} locked by txn={} at ts={}", + key, lock.txn_id, lock.start_ts + ); + // Track in read_set for Serializable even on misses + if let Some(ref read_set) = self.read_set { + read_set.insert(key.to_string(), ()); + } + return None; + } } - - eprintln!( - "[READ_CONFLICT] key={} locked by txn={} at ts={}", - key, lock.txn_id, lock.start_ts - ); - return None; // Or Err("Locked") } - // 2. Find latest non-rollback write in CF_WRITE with commit_ts <= self.start_ts - if let Some((_commit_ts, write_info)) = - self.storage.get_latest_commit(key, self.start_ts) - { + // 2. Find latest non-rollback write in CF_WRITE with commit_ts <= read_ts + if let Some((_commit_ts, write_info)) = self.storage.get_latest_commit(key, read_ts) { + // Track in read_set for Serializable + if let Some(ref read_set) = self.read_set { + read_set.insert(key.to_string(), ()); + } match write_info.kind { WriteKind::Put => { // 3. Retrieve data from CF_DATA using start_ts from WriteInfo @@ -119,11 +337,15 @@ impl Txn { } } + // Track in read_set for Serializable even on misses (phantom prevention) + if let Some(ref read_set) = self.read_set { + read_set.insert(key.to_string(), ()); + } return None; } } - pub fn write(&mut self, key: &str, value: Vec) -> Result<(), String> { + pub fn write(&mut self, key: &str, value: Vec) -> Result<(), MvccError> { self.writes .insert(key.to_string(), (LockType::Put, Some(Arc::new(value)))); if self.primary_key.is_none() { @@ -132,7 +354,7 @@ impl Txn { Ok(()) } - pub fn delete(&mut self, key: &str) -> Result<(), String> { + pub fn delete(&mut self, key: &str) -> Result<(), MvccError> { self.writes .insert(key.to_string(), (LockType::Delete, None)); if self.primary_key.is_none() { @@ -141,9 +363,12 @@ impl Txn { Ok(()) } - pub fn commit(mut self) -> Result { + pub fn commit(mut self) -> Result { self.committed = true; + // Unregister from active transactions + self.active_txns.unregister(self.txn_id); + if self.writes.is_empty() { return Ok(self.start_ts); } @@ -151,15 +376,41 @@ impl Txn { let primary_key = self .primary_key .as_ref() - .ok_or_else(|| "No primary key".to_string())? + .ok_or_else(|| MvccError::StorageError("No primary key".to_string()))? .clone(); - // Phase 1: Prewrite + // Phase 1: Prewrite (lock acquisition + write-write conflict check) if let Err(e) = self.prewrite(&primary_key) { self.rollback(); return Err(e); } + // For Serializable transactions: hold SSI commit lock during validation + commit. + // This serializes SSI commits so that when T1 commits, T2 sees it in its validation. + let _ssi_guard = if self.isolation_level == IsolationLevel::Serializable { + Some(self.ssi_commit_lock.lock()) + } else { + None + }; + + // SSI validation: check read-set for concurrent commits + if self.isolation_level == IsolationLevel::Serializable + && let Some(ref read_set) = self.read_set + { + for (key, _) in read_set.iter() { + if let Some((commit_ts, write_info)) = self.storage.get_latest_write(&key, u64::MAX) + && write_info.kind != WriteKind::Rollback + && commit_ts > self.start_ts + { + self.rollback(); + return Err(MvccError::SerializationFailure { + key: key.clone(), + conflicting_ts: commit_ts, + }); + } + } + } + // Get commit timestamp let commit_ts = self.ts_oracle.get_timestamp(); @@ -170,10 +421,11 @@ impl Txn { } self.commit_secondaries(&primary_key, commit_ts); + // _ssi_guard dropped here, releasing the lock Ok(commit_ts) } - fn prewrite(&mut self, primary_key: &str) -> Result<(), String> { + fn prewrite(&mut self, primary_key: &str) -> Result<(), MvccError> { // Sort keys to prevent deadlocks/livelocks let mut keys: Vec<_> = self.writes.keys().collect(); keys.sort(); @@ -218,17 +470,19 @@ impl Txn { for acquired_key in &acquired_locks { self.storage.delete_lock(acquired_key); } - return Err(format!( - "Prewrite rejected: rollback record exists for key {}", - key - )); + return Err(MvccError::RollbackRecord { + key: key.to_string(), + }); } if write_info.kind != WriteKind::Rollback && commit_ts >= self.start_ts { // Write conflict for acquired_key in &acquired_locks { self.storage.delete_lock(acquired_key); } - return Err(format!("Write conflict on key {}", key)); + return Err(MvccError::WriteConflict { + key: key.to_string(), + conflicting_ts: commit_ts, + }); } } } @@ -243,14 +497,16 @@ impl Txn { Ok(()) } - fn commit_primary(&self, primary_key: &str, commit_ts: u64) -> Result<(), String> { - let lock = self - .storage - .get_lock(primary_key) - .ok_or_else(|| format!("Primary lock missing for {}", primary_key))?; + fn commit_primary(&self, primary_key: &str, commit_ts: u64) -> Result<(), MvccError> { + let lock = + self.storage + .get_lock(primary_key) + .ok_or_else(|| MvccError::PrimaryLockMissing { + key: primary_key.to_string(), + })?; if lock.txn_id != self.txn_id { - return Err("Primary lock mismatch".to_string()); + return Err(MvccError::PrimaryLockMismatch); } // Write to CF_WRITE @@ -303,21 +559,22 @@ impl Txn { fn rollback(&self) { for (key, _) in &self.writes { - // Only remove our own locks + // Only clean up keys where we actually hold the lock. + // A transaction that failed lock acquisition has no data to rollback + // and must not write rollback records that could poison concurrent prewrites. if let Some(lock) = self.storage.get_lock(&key) && lock.txn_id == self.txn_id { self.storage.delete_lock(&key); - // Also delete data we wrote self.storage.delete_data(&key, self.start_ts); - } - // Write a Rollback record to CF_WRITE to prevent future prewrites at this start_ts - let rollback_info = WriteInfo { - start_ts: self.start_ts, - kind: WriteKind::Rollback, - }; - self.storage.put_write(&key, self.start_ts, rollback_info); + // Write a Rollback record to prevent resurrection of this start_ts + let rollback_info = WriteInfo { + start_ts: self.start_ts, + kind: WriteKind::Rollback, + }; + self.storage.put_write(&key, self.start_ts, rollback_info); + } } } } @@ -325,6 +582,8 @@ impl Txn { impl Drop for Txn { fn drop(&mut self) { if !self.committed { + // Unregister from active transactions + self.active_txns.unregister(self.txn_id); self.rollback(); } } diff --git a/kovan-mvcc/src/storage.rs b/kovan-mvcc/src/storage.rs index 903b91e..5cd137a 100644 --- a/kovan-mvcc/src/storage.rs +++ b/kovan-mvcc/src/storage.rs @@ -1,3 +1,4 @@ +use crate::error::MvccError; use crate::lock_table::LockInfo; use std::collections::BTreeMap; @@ -22,7 +23,7 @@ pub struct WriteInfo { pub type Value = Arc>; use kovan_map::HashMap; -use std::sync::Mutex; +use parking_lot::Mutex; /// Storage Trait /// Defines the interface for the underlying storage engine. @@ -30,7 +31,7 @@ pub trait Storage: Send + Sync { /// Get a lock for a key fn get_lock(&self, key: &str) -> Option; /// Acquire a lock for a key - fn put_lock(&self, key: &str, lock: LockInfo) -> Result<(), String>; + fn put_lock(&self, key: &str, lock: LockInfo) -> Result<(), MvccError>; /// Release a lock for a key fn delete_lock(&self, key: &str); @@ -47,6 +48,23 @@ pub trait Storage: Send + Sync { fn put_data(&self, key: &str, start_ts: u64, value: Value); /// Delete data for a key at a specific start_ts fn delete_data(&self, key: &str, start_ts: u64); + + /// GC: Remove write records with commit_ts < watermark, keeping the latest visible version. + /// Returns the number of records removed. + fn gc_writes(&self, _key: &str, _watermark: u64) -> usize { + 0 + } + + /// GC: Remove data versions with start_ts < watermark that are no longer referenced. + /// Returns the number of versions removed. + fn gc_data(&self, _key: &str, _watermark: u64) -> usize { + 0 + } + + /// GC: Scan all keys that have write records (for incremental GC cursor). + fn scan_write_keys(&self) -> Vec { + vec![] + } } /// In-Memory Storage Implementation @@ -83,7 +101,7 @@ impl Storage for InMemoryStorage { self.locks.get(key) } - fn put_lock(&self, key: &str, lock: LockInfo) -> Result<(), String> { + fn put_lock(&self, key: &str, lock: LockInfo) -> Result<(), MvccError> { match self.locks.insert_if_absent(key.to_string(), lock.clone()) { None => Ok(()), // Acquired Some(existing) => { @@ -92,7 +110,10 @@ impl Storage for InMemoryStorage { self.locks.insert(key.to_string(), lock); Ok(()) } else { - Err(format!("Key {} is already locked", key)) + Err(MvccError::LockConflict { + key: key.to_string(), + holder_txn: existing.txn_id, + }) } } } @@ -116,14 +137,14 @@ impl Storage for InMemoryStorage { } }; - let mut map = map_mutex.lock().unwrap(); + let mut map = map_mutex.lock(); map.insert(commit_ts, info); } /// Find the latest write with commit_ts <= ts fn get_latest_write(&self, key: &str, ts: u64) -> Option<(u64, WriteInfo)> { if let Some(map_mutex) = self.writes.get(key) { - let map = map_mutex.lock().unwrap(); + let map = map_mutex.lock(); // range(..=ts) gives all entries with key <= ts // next_back() gives the largest key <= ts map.range(..=ts).next_back().map(|(k, v)| (*k, v.clone())) @@ -135,7 +156,7 @@ impl Storage for InMemoryStorage { /// Find the latest Put or Delete write with commit_ts <= ts, skipping Rollback records fn get_latest_commit(&self, key: &str, ts: u64) -> Option<(u64, WriteInfo)> { if let Some(map_mutex) = self.writes.get(key) { - let map = map_mutex.lock().unwrap(); + let map = map_mutex.lock(); for (k, v) in map.range(..=ts).rev() { if v.kind != WriteKind::Rollback { return Some((*k, v.clone())); @@ -158,13 +179,13 @@ impl Storage for InMemoryStorage { } }; - let mut map = map_mutex.lock().unwrap(); + let mut map = map_mutex.lock(); map.insert(start_ts, value); } fn get_data(&self, key: &str, start_ts: u64) -> Option { if let Some(map_mutex) = self.data.get(key) { - let map = map_mutex.lock().unwrap(); + let map = map_mutex.lock(); map.get(&start_ts).cloned() } else { None @@ -173,8 +194,54 @@ impl Storage for InMemoryStorage { fn delete_data(&self, key: &str, start_ts: u64) { if let Some(map_mutex) = self.data.get(key) { - let mut map = map_mutex.lock().unwrap(); + let mut map = map_mutex.lock(); map.remove(&start_ts); } } + + fn gc_writes(&self, key: &str, watermark: u64) -> usize { + if let Some(map_mutex) = self.writes.get(key) { + let mut map = map_mutex.lock(); + // Find the latest version at or below watermark + let latest_visible = map.range(..=watermark).next_back().map(|(k, _)| *k); + if let Some(keep_ts) = latest_visible { + // Remove all entries strictly below keep_ts + let to_remove: Vec = map.range(..keep_ts).map(|(k, _)| *k).collect(); + let count = to_remove.len(); + for ts in to_remove { + map.remove(&ts); + } + count + } else { + 0 + } + } else { + 0 + } + } + + fn gc_data(&self, key: &str, watermark: u64) -> usize { + if let Some(map_mutex) = self.data.get(key) { + let mut map = map_mutex.lock(); + // Find the latest version at or below watermark + let latest_visible = map.range(..=watermark).next_back().map(|(k, _)| *k); + if let Some(keep_ts) = latest_visible { + // Remove all entries strictly below keep_ts + let to_remove: Vec = map.range(..keep_ts).map(|(k, _)| *k).collect(); + let count = to_remove.len(); + for ts in to_remove { + map.remove(&ts); + } + count + } else { + 0 + } + } else { + 0 + } + } + + fn scan_write_keys(&self) -> Vec { + self.writes.keys().collect() + } } diff --git a/kovan-mvcc/tests/isolation.rs b/kovan-mvcc/tests/isolation.rs index cf62e18..31705bd 100644 --- a/kovan-mvcc/tests/isolation.rs +++ b/kovan-mvcc/tests/isolation.rs @@ -1,4 +1,4 @@ -use kovan_mvcc::KovanMVCC; +use kovan_mvcc::{IsolationLevel, KovanMVCC}; use std::sync::Arc; use std::thread; @@ -11,8 +11,8 @@ fn test_snapshot_isolation_readers_ignore_new_commits() { t0.write("x", b"initial".to_vec()).unwrap(); t0.commit().unwrap(); - // 2. Start Long-Running Reader (Snapshot T1) - let reader_txn = db.begin(); + // 2. Start Long-Running Reader with RepeatableRead (Snapshot T1) + let reader_txn = db.begin_with_isolation(IsolationLevel::RepeatableRead); let val_start = reader_txn.read("x").unwrap(); assert_eq!(val_start, b"initial"); diff --git a/kovan-mvcc/tests/isolation_levels.rs b/kovan-mvcc/tests/isolation_levels.rs new file mode 100644 index 0000000..e01c87e --- /dev/null +++ b/kovan-mvcc/tests/isolation_levels.rs @@ -0,0 +1,635 @@ +//! Tests for true PostgreSQL isolation levels in kovan-mvcc. +//! +//! - ReadCommitted: per-read fresh snapshot +//! - RepeatableRead: transaction-wide snapshot (standard SI) +//! - Serializable: SI + read-set validation (SSI) + +use kovan_mvcc::{IsolationLevel, KovanMVCC, MvccError}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Barrier}; + +// =========================================================================== +// IsolationLevel enum tests +// =========================================================================== + +#[test] +fn test_isolation_level_default() { + // PostgreSQL default is Read Committed + assert_eq!(IsolationLevel::default(), IsolationLevel::ReadCommitted); +} + +#[test] +fn test_isolation_level_variants() { + let rc = IsolationLevel::ReadCommitted; + let rr = IsolationLevel::RepeatableRead; + let sr = IsolationLevel::Serializable; + assert_ne!(rc, rr); + assert_ne!(rr, sr); + assert_ne!(rc, sr); +} + +// =========================================================================== +// READ COMMITTED tests +// =========================================================================== + +#[test] +fn test_rc_sees_latest_committed_per_read() { + // T1 writes "v1" commits, T2(RC) reads (sees "v1"), + // T3 writes "v2" commits, T2 reads again (sees "v2") + let db = KovanMVCC::new(); + + let mut t1 = db.begin(); + t1.write("key", b"v1".to_vec()).unwrap(); + t1.commit().unwrap(); + + let t2 = db.begin_with_isolation(IsolationLevel::ReadCommitted); + assert_eq!(t2.read("key").unwrap(), b"v1"); + + let mut t3 = db.begin(); + t3.write("key", b"v2".to_vec()).unwrap(); + t3.commit().unwrap(); + + // RC: second read sees the new committed value + assert_eq!(t2.read("key").unwrap(), b"v2"); +} + +#[test] +fn test_rc_no_dirty_reads() { + // T1 writes but doesn't commit, T2(RC) doesn't see T1's uncommitted write + let db = KovanMVCC::new(); + + let barrier = Arc::new(Barrier::new(2)); + + let db1 = Arc::new(db); + let db2 = db1.clone(); + let b1 = barrier.clone(); + let b2 = barrier.clone(); + + let writer = std::thread::spawn(move || { + let mut txn = db1.begin(); + txn.write("key", b"uncommitted".to_vec()).unwrap(); + b1.wait(); // Writer has written but NOT committed + b1.wait(); // Wait for reader to finish + drop(txn); // Rollback + }); + + let reader = std::thread::spawn(move || { + b2.wait(); // Wait for writer to write + let txn = db2.begin_with_isolation(IsolationLevel::ReadCommitted); + let result = txn.read("key"); + assert!(result.is_none(), "RC must not see uncommitted data"); + b2.wait(); + }); + + writer.join().unwrap(); + reader.join().unwrap(); +} + +#[test] +fn test_rc_non_repeatable_read_allowed() { + // Demonstrates that RC allows non-repeatable reads + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"v1".to_vec()).unwrap(); + setup.commit().unwrap(); + + let t_rc = db.begin_with_isolation(IsolationLevel::ReadCommitted); + let read1 = t_rc.read("key").unwrap(); + assert_eq!(read1, b"v1"); + + // Concurrent commit changes the value + let mut writer = db.begin(); + writer.write("key", b"v2".to_vec()).unwrap(); + writer.commit().unwrap(); + + // RC: second read may return different value (non-repeatable read) + let read2 = t_rc.read("key").unwrap(); + assert_eq!(read2, b"v2", "RC allows non-repeatable reads"); + assert_ne!(read1, read2); +} + +#[test] +fn test_rc_phantom_read_allowed() { + // New keys appearing between reads are visible to RC + let db = KovanMVCC::new(); + + let t_rc = db.begin_with_isolation(IsolationLevel::ReadCommitted); + assert!(t_rc.read("phantom").is_none()); + + // Another txn inserts the key + let mut writer = db.begin(); + writer.write("phantom", b"appeared".to_vec()).unwrap(); + writer.commit().unwrap(); + + // RC: sees the phantom + assert_eq!(t_rc.read("phantom").unwrap(), b"appeared"); +} + +#[test] +fn test_rc_write_conflict_still_detected() { + // RC doesn't bypass write-write conflict detection + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"init".to_vec()).unwrap(); + setup.commit().unwrap(); + + let barrier = Arc::new(Barrier::new(2)); + let t1_ok = Arc::new(AtomicBool::new(false)); + let t2_ok = Arc::new(AtomicBool::new(false)); + + let db = Arc::new(db); + let db1 = db.clone(); + let db2 = db.clone(); + let b1 = barrier.clone(); + let b2 = barrier.clone(); + let t1c = t1_ok.clone(); + let t2c = t2_ok.clone(); + + let h1 = std::thread::spawn(move || { + let mut txn = db1.begin_with_isolation(IsolationLevel::ReadCommitted); + txn.write("key", b"v1".to_vec()).unwrap(); + b1.wait(); + if txn.commit().is_ok() { + t1c.store(true, Ordering::SeqCst); + } + }); + + let h2 = std::thread::spawn(move || { + let mut txn = db2.begin_with_isolation(IsolationLevel::ReadCommitted); + txn.write("key", b"v2".to_vec()).unwrap(); + b2.wait(); + std::thread::yield_now(); + if txn.commit().is_ok() { + t2c.store(true, Ordering::SeqCst); + } + }); + + h1.join().unwrap(); + h2.join().unwrap(); + + let t1 = t1_ok.load(Ordering::SeqCst); + let t2 = t2_ok.load(Ordering::SeqCst); + assert!( + !(t1 && t2), + "Write-write conflict must be detected even under RC" + ); + assert!(t1 || t2, "At least one should commit"); +} + +// =========================================================================== +// REPEATABLE READ tests +// =========================================================================== + +#[test] +fn test_rr_repeatable_reads() { + // Same key returns same value throughout transaction + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"stable".to_vec()).unwrap(); + setup.commit().unwrap(); + + let t_rr = db.begin_with_isolation(IsolationLevel::RepeatableRead); + let read1 = t_rr.read("key").unwrap(); + + let mut writer = db.begin(); + writer.write("key", b"changed".to_vec()).unwrap(); + writer.commit().unwrap(); + + let read2 = t_rr.read("key").unwrap(); + assert_eq!(read1, read2, "RR must return same value for same key"); +} + +#[test] +fn test_rr_no_dirty_reads() { + let db = KovanMVCC::new(); + + let mut writer = db.begin(); + writer.write("key", b"dirty".to_vec()).unwrap(); + + let reader = db.begin_with_isolation(IsolationLevel::RepeatableRead); + assert!(reader.read("key").is_none()); + + drop(writer); +} + +#[test] +fn test_rr_phantom_prevented() { + // New keys committed after start_ts are invisible + let db = KovanMVCC::new(); + + let reader = db.begin_with_isolation(IsolationLevel::RepeatableRead); + + let mut writer = db.begin(); + writer.write("phantom", b"new".to_vec()).unwrap(); + writer.commit().unwrap(); + + assert!( + reader.read("phantom").is_none(), + "RR prevents phantom reads" + ); +} + +#[test] +fn test_rr_write_skew_allowed() { + // Classic write skew: T1 reads X writes Y, T2 reads Y writes X + // Both should commit under RR (disjoint write sets) + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("alice_oncall", b"true".to_vec()).unwrap(); + setup.write("bob_oncall", b"true".to_vec()).unwrap(); + setup.commit().unwrap(); + + let barrier = Arc::new(Barrier::new(2)); + let t1_ok = Arc::new(AtomicBool::new(false)); + let t2_ok = Arc::new(AtomicBool::new(false)); + + let db = Arc::new(db); + let db1 = db.clone(); + let db2 = db.clone(); + let b1 = barrier.clone(); + let b2 = barrier.clone(); + let t1c = t1_ok.clone(); + let t2c = t2_ok.clone(); + + let h1 = std::thread::spawn(move || { + let mut txn = db1.begin_with_isolation(IsolationLevel::RepeatableRead); + let alice = txn.read("alice_oncall").unwrap(); + let bob = txn.read("bob_oncall").unwrap(); + b1.wait(); + if alice == b"true" && bob == b"true" { + txn.write("alice_oncall", b"false".to_vec()).unwrap(); + } + b1.wait(); + if txn.commit().is_ok() { + t1c.store(true, Ordering::SeqCst); + } + }); + + let h2 = std::thread::spawn(move || { + let mut txn = db2.begin_with_isolation(IsolationLevel::RepeatableRead); + let alice = txn.read("alice_oncall").unwrap(); + let bob = txn.read("bob_oncall").unwrap(); + b2.wait(); + if alice == b"true" && bob == b"true" { + txn.write("bob_oncall", b"false".to_vec()).unwrap(); + } + b2.wait(); + if txn.commit().is_ok() { + t2c.store(true, Ordering::SeqCst); + } + }); + + h1.join().unwrap(); + h2.join().unwrap(); + + assert!( + t1_ok.load(Ordering::SeqCst) && t2_ok.load(Ordering::SeqCst), + "RR allows write skew: both should commit" + ); +} + +#[test] +fn test_begin_defaults_to_read_committed() { + // begin() produces ReadCommitted (the PostgreSQL default) + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"v1".to_vec()).unwrap(); + setup.commit().unwrap(); + + let t_begin = db.begin(); + let t_explicit = db.begin_with_isolation(IsolationLevel::ReadCommitted); + + assert_eq!(t_begin.isolation_level(), IsolationLevel::ReadCommitted); + assert_eq!(t_begin.isolation_level(), t_explicit.isolation_level()); + + // Both see committed value + assert_eq!(t_begin.read("key").unwrap(), b"v1"); + assert_eq!(t_explicit.read("key").unwrap(), b"v1"); + + // After concurrent commit, RC sees the new value (non-repeatable read) + let mut writer = db.begin(); + writer.write("key", b"v2".to_vec()).unwrap(); + writer.commit().unwrap(); + + assert_eq!(t_begin.read("key").unwrap(), b"v2"); + assert_eq!(t_explicit.read("key").unwrap(), b"v2"); +} + +// =========================================================================== +// SERIALIZABLE tests +// =========================================================================== + +#[test] +fn test_serializable_write_skew_prevented() { + // Classic write skew: T1 reads X,Y writes X. T2 reads X,Y writes Y. + // Under Serializable, one must abort with SerializationFailure. + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("alice_oncall", b"true".to_vec()).unwrap(); + setup.write("bob_oncall", b"true".to_vec()).unwrap(); + setup.commit().unwrap(); + + let barrier = Arc::new(Barrier::new(2)); + let t1_ok = Arc::new(AtomicBool::new(false)); + let t2_ok = Arc::new(AtomicBool::new(false)); + let serialization_failure = Arc::new(AtomicBool::new(false)); + + let db = Arc::new(db); + let db1 = db.clone(); + let db2 = db.clone(); + let b1 = barrier.clone(); + let b2 = barrier.clone(); + let t1c = t1_ok.clone(); + let t2c = t2_ok.clone(); + let sf1 = serialization_failure.clone(); + let sf2 = serialization_failure.clone(); + + let h1 = std::thread::spawn(move || { + let mut txn = db1.begin_with_isolation(IsolationLevel::Serializable); + let alice = txn.read("alice_oncall").unwrap(); + let bob = txn.read("bob_oncall").unwrap(); + b1.wait(); // Both have read + if alice == b"true" && bob == b"true" { + txn.write("alice_oncall", b"false".to_vec()).unwrap(); + } + b1.wait(); // Both have written + match txn.commit() { + Ok(_) => t1c.store(true, Ordering::SeqCst), + Err(MvccError::SerializationFailure { .. }) => { + sf1.store(true, Ordering::SeqCst); + } + Err(_) => {} + } + }); + + let h2 = std::thread::spawn(move || { + let mut txn = db2.begin_with_isolation(IsolationLevel::Serializable); + let alice = txn.read("alice_oncall").unwrap(); + let bob = txn.read("bob_oncall").unwrap(); + b2.wait(); // Both have read + if alice == b"true" && bob == b"true" { + txn.write("bob_oncall", b"false".to_vec()).unwrap(); + } + b2.wait(); // Both have written + match txn.commit() { + Ok(_) => t2c.store(true, Ordering::SeqCst), + Err(MvccError::SerializationFailure { .. }) => { + sf2.store(true, Ordering::SeqCst); + } + Err(_) => {} + } + }); + + h1.join().unwrap(); + h2.join().unwrap(); + + let t1 = t1_ok.load(Ordering::SeqCst); + let t2 = t2_ok.load(Ordering::SeqCst); + let sf = serialization_failure.load(Ordering::SeqCst); + + // Under Serializable: at most one commits, at least one gets SerializationFailure + assert!( + !(t1 && t2), + "Serializable must prevent write skew: both committed!" + ); + assert!( + sf, + "At least one transaction should get SerializationFailure" + ); +} + +#[test] +fn test_serializable_no_false_positive_disjoint() { + // Two transactions with completely disjoint read/write sets both commit + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("a", b"1".to_vec()).unwrap(); + setup.write("b", b"2".to_vec()).unwrap(); + setup.commit().unwrap(); + + // T1 reads "a", writes "a" + let mut t1 = db.begin_with_isolation(IsolationLevel::Serializable); + t1.read("a"); + t1.write("a", b"10".to_vec()).unwrap(); + + // T2 reads "b", writes "b" + let mut t2 = db.begin_with_isolation(IsolationLevel::Serializable); + t2.read("b"); + t2.write("b", b"20".to_vec()).unwrap(); + + // Both should commit since their read/write sets are disjoint + assert!(t1.commit().is_ok(), "T1 should commit (disjoint sets)"); + assert!(t2.commit().is_ok(), "T2 should commit (disjoint sets)"); +} + +#[test] +fn test_serializable_read_only_txn_no_conflict() { + // Read-only serializable transaction never causes SerializationFailure + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"value".to_vec()).unwrap(); + setup.commit().unwrap(); + + let reader = db.begin_with_isolation(IsolationLevel::Serializable); + assert_eq!(reader.read("key").unwrap(), b"value"); + + // Concurrent write + let mut writer = db.begin(); + writer.write("key", b"new_value".to_vec()).unwrap(); + writer.commit().unwrap(); + + // Read again — still sees old value (snapshot) + assert_eq!(reader.read("key").unwrap(), b"value"); + + // Drop reader — no commit, no prewrite, no SSI check → no error + drop(reader); +} + +#[test] +fn test_serializable_phantom_prevention() { + // T1 reads non-existent key K, T2 inserts K and commits, + // T1 writes something and tries to commit → SerializationFailure + let db = KovanMVCC::new(); + + let mut t1 = db.begin_with_isolation(IsolationLevel::Serializable); + // Read non-existent key — added to read_set + assert!(t1.read("phantom_key").is_none()); + + // T2 inserts the key and commits + let mut t2 = db.begin(); + t2.write("phantom_key", b"inserted".to_vec()).unwrap(); + t2.commit().unwrap(); + + // T1 writes something (to trigger prewrite & SSI check) + t1.write("other_key", b"val".to_vec()).unwrap(); + let result = t1.commit(); + assert!( + matches!(result, Err(MvccError::SerializationFailure { .. })), + "Should detect phantom: read-set includes 'phantom_key' which was inserted concurrently. Got: {:?}", + result + ); +} + +#[test] +fn test_serializable_concurrent_write_same_key() { + // Two serializable txns writing the same key → write-write conflict (not serialization failure) + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"init".to_vec()).unwrap(); + setup.commit().unwrap(); + + let mut t1 = db.begin_with_isolation(IsolationLevel::Serializable); + let mut t2 = db.begin_with_isolation(IsolationLevel::Serializable); + + t1.write("key", b"v1".to_vec()).unwrap(); + t2.write("key", b"v2".to_vec()).unwrap(); + + // First committer wins + assert!(t1.commit().is_ok()); + + // Second gets WriteConflict (not SerializationFailure, since it didn't read the key) + let result = t2.commit(); + assert!( + matches!(result, Err(MvccError::WriteConflict { .. })), + "Should be WriteConflict, not SerializationFailure. Got: {:?}", + result + ); +} + +#[test] +fn test_serializable_read_set_includes_misses() { + // Reading a key that doesn't exist still adds it to read-set + let db = KovanMVCC::new(); + + let mut t1 = db.begin_with_isolation(IsolationLevel::Serializable); + // Read a non-existent key + assert!(t1.read("nonexistent").is_none()); + + // Another txn creates it + let mut t2 = db.begin(); + t2.write("nonexistent", b"now_exists".to_vec()).unwrap(); + t2.commit().unwrap(); + + // T1 writes something to trigger commit + t1.write("unrelated", b"x".to_vec()).unwrap(); + + let result = t1.commit(); + assert!( + matches!(result, Err(MvccError::SerializationFailure { .. })), + "Read miss should be in read-set. Got: {:?}", + result + ); +} + +// =========================================================================== +// Cross-level tests +// =========================================================================== + +#[test] +fn test_mixed_isolation_levels_same_db() { + // RC, RR, and Serializable transactions coexist on same KovanMVCC instance + let db = KovanMVCC::new(); + + let mut setup = db.begin(); + setup.write("key", b"v1".to_vec()).unwrap(); + setup.commit().unwrap(); + + let t_rc = db.begin_with_isolation(IsolationLevel::ReadCommitted); + let t_rr = db.begin_with_isolation(IsolationLevel::RepeatableRead); + let t_sr = db.begin_with_isolation(IsolationLevel::Serializable); + + // All see the committed value + assert_eq!(t_rc.read("key").unwrap(), b"v1"); + assert_eq!(t_rr.read("key").unwrap(), b"v1"); + assert_eq!(t_sr.read("key").unwrap(), b"v1"); + + // Concurrent commit + let mut writer = db.begin(); + writer.write("key", b"v2".to_vec()).unwrap(); + writer.commit().unwrap(); + + // RC sees new value, RR and Serializable see old value + assert_eq!(t_rc.read("key").unwrap(), b"v2"); + assert_eq!(t_rr.read("key").unwrap(), b"v1"); + assert_eq!(t_sr.read("key").unwrap(), b"v1"); +} + +#[test] +fn test_serializable_vs_rr_write_skew() { + // Same write-skew scenario: RR allows it, Serializable prevents it + let db = KovanMVCC::new(); + + // --- RR: write skew succeeds --- + let mut setup = db.begin(); + setup.write("X", 1u64.to_le_bytes().to_vec()).unwrap(); + setup.write("Y", 1u64.to_le_bytes().to_vec()).unwrap(); + setup.commit().unwrap(); + + let mut t1_rr = db.begin_with_isolation(IsolationLevel::RepeatableRead); + let mut t2_rr = db.begin_with_isolation(IsolationLevel::RepeatableRead); + + let x1 = u64::from_le_bytes(t1_rr.read("X").unwrap().try_into().unwrap()); + let y1 = u64::from_le_bytes(t1_rr.read("Y").unwrap().try_into().unwrap()); + let x2 = u64::from_le_bytes(t2_rr.read("X").unwrap().try_into().unwrap()); + let y2 = u64::from_le_bytes(t2_rr.read("Y").unwrap().try_into().unwrap()); + + if x1 + y1 >= 1 { + t1_rr.write("X", (x1 - 1).to_le_bytes().to_vec()).unwrap(); + } + if x2 + y2 >= 1 { + t2_rr.write("Y", (y2 - 1).to_le_bytes().to_vec()).unwrap(); + } + + assert!(t1_rr.commit().is_ok(), "RR T1 should commit"); + assert!(t2_rr.commit().is_ok(), "RR T2 should commit (write skew)"); + + // --- Serializable: write skew prevented --- + // Reset values + let mut reset = db.begin(); + reset.write("X", 1u64.to_le_bytes().to_vec()).unwrap(); + reset.write("Y", 1u64.to_le_bytes().to_vec()).unwrap(); + reset.commit().unwrap(); + + let mut t1_sr = db.begin_with_isolation(IsolationLevel::Serializable); + let mut t2_sr = db.begin_with_isolation(IsolationLevel::Serializable); + + let x1 = u64::from_le_bytes(t1_sr.read("X").unwrap().try_into().unwrap()); + let y1 = u64::from_le_bytes(t1_sr.read("Y").unwrap().try_into().unwrap()); + let x2 = u64::from_le_bytes(t2_sr.read("X").unwrap().try_into().unwrap()); + let y2 = u64::from_le_bytes(t2_sr.read("Y").unwrap().try_into().unwrap()); + + if x1 + y1 >= 1 { + t1_sr.write("X", (x1 - 1).to_le_bytes().to_vec()).unwrap(); + } + if x2 + y2 >= 1 { + t2_sr.write("Y", (y2 - 1).to_le_bytes().to_vec()).unwrap(); + } + + let r1 = t1_sr.commit(); + let r2 = t2_sr.commit(); + + // One must fail + let both_ok = r1.is_ok() && r2.is_ok(); + assert!( + !both_ok, + "Serializable must prevent write skew: both committed!" + ); + + // At least one should get SerializationFailure + let has_sf = matches!(r1, Err(MvccError::SerializationFailure { .. })) + || matches!(r2, Err(MvccError::SerializationFailure { .. })); + assert!( + has_sf, + "At least one should get SerializationFailure. r1={:?}, r2={:?}", + r1, r2 + ); +} diff --git a/kovan-mvcc/tests/simple_conflict.rs b/kovan-mvcc/tests/simple_conflict.rs index 5906cdc..c896ff6 100644 --- a/kovan-mvcc/tests/simple_conflict.rs +++ b/kovan-mvcc/tests/simple_conflict.rs @@ -1,4 +1,4 @@ -use kovan_mvcc::KovanMVCC; +use kovan_mvcc::{KovanMVCC, MvccError}; use std::sync::Arc; use std::thread; @@ -43,10 +43,10 @@ fn test_simple_conflict() { eprintln!("[T1] Commit result: {:?}", commit_result); commit_result } else { - Err("write1 failed".to_string()) + Err(MvccError::StorageError("write1 failed".to_string())) } } else { - Err("write0 failed".to_string()) + Err(MvccError::StorageError("write0 failed".to_string())) } }); @@ -70,10 +70,10 @@ fn test_simple_conflict() { eprintln!("[T2] Commit result: {:?}", commit_result); commit_result } else { - Err("write1 failed".to_string()) + Err(MvccError::StorageError("write1 failed".to_string())) } } else { - Err("write0 failed".to_string()) + Err(MvccError::StorageError("write0 failed".to_string())) } }); diff --git a/kovan-queue/src/seg_queue.rs b/kovan-queue/src/seg_queue.rs index 45c7010..1bd6122 100644 --- a/kovan-queue/src/seg_queue.rs +++ b/kovan-queue/src/seg_queue.rs @@ -49,6 +49,7 @@ impl Segment { pub struct SegQueue { head: CacheAligned>>, tail: CacheAligned>>, + len: AtomicUsize, } unsafe impl Send for SegQueue {} @@ -70,6 +71,7 @@ impl SegQueue { SegQueue { head: CacheAligned::new(head), tail: CacheAligned::new(tail), + len: AtomicUsize::new(0), } } @@ -118,6 +120,7 @@ impl SegQueue { slot.value.get().write(MaybeUninit::new(value)); } slot.state.store(SLOT_WRITTEN, Ordering::Release); + self.len.fetch_add(1, Ordering::Relaxed); return; } } else if state == SLOT_WRITING { @@ -185,6 +188,7 @@ impl SegQueue { .is_ok() { let value = unsafe { slot.value.get().read().assume_init() }; + self.len.fetch_sub(1, Ordering::Relaxed); return Some(value); } } else if state == SLOT_EMPTY { @@ -225,6 +229,22 @@ impl SegQueue { backoff.snooze(); } } + + /// Returns the number of elements in the queue. + /// + /// This is an approximation in concurrent scenarios: elements may be pushed + /// or popped by other threads between the moment the length is read and the + /// moment the caller acts on the returned value. + #[inline] + pub fn len(&self) -> usize { + self.len.load(Ordering::Relaxed) + } + + /// Returns `true` if the queue is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } impl Drop for SegQueue { diff --git a/kovan/Cargo.toml b/kovan/Cargo.toml index 74114b4..9d03aca 100644 --- a/kovan/Cargo.toml +++ b/kovan/Cargo.toml @@ -25,11 +25,6 @@ default = ["std"] std = [] # Enable std-dependent features. nightly = [] # Enable nightly-only optimizations. -# Maximum concurrent thread count. Default is 128. -max-threads-256 = [] -max-threads-512 = [] -max-threads-1024 = [] - [dependencies] once_cell = { version = "1.21", default-features = false, features = ["alloc"] } portable-atomic = { version = "1.11.1", default-features = false, features = [ diff --git a/kovan/src/atom.rs b/kovan/src/atom.rs index e00b1b3..d8ace28 100644 --- a/kovan/src/atom.rs +++ b/kovan/src/atom.rs @@ -637,6 +637,8 @@ impl Drop for Atom { drop(Box::from_raw(ptr)); } } + // Flush nodes previously retired by store()/rcu() operations. + crate::flush(); } } @@ -811,6 +813,8 @@ impl Drop for AtomOption { drop(Box::from_raw(ptr)); } } + // Flush nodes previously retired by store() operations. + crate::flush(); } } diff --git a/kovan/src/guard.rs b/kovan/src/guard.rs index fb883fa..a778723 100644 --- a/kovan/src/guard.rs +++ b/kovan/src/guard.rs @@ -39,15 +39,18 @@ impl Drop for Guard { #[cfg(feature = "nightly")] { let count = HANDLE.pin_count.get(); - debug_assert!(count > 0, "Guard dropped with pin_count == 0"); - HANDLE.pin_count.set(count - 1); + HANDLE.pin_count.set(count.saturating_sub(1)); } #[cfg(not(feature = "nightly"))] { - HANDLE.with(|handle| { + // Use try_with to handle process teardown gracefully. + // During static destructor execution, TLS may already be destroyed. + // Panicking in a destructor during cleanup causes SIGABRT. + let _ = HANDLE.try_with(|handle| { let count = handle.pin_count.get(); - debug_assert!(count > 0, "Guard dropped with pin_count == 0"); - handle.pin_count.set(count - 1); + // Saturating: a dummy Guard (created when TLS was unavailable in + // pin()) was never pinned. Decrementing past 0 would be UB. + handle.pin_count.set(count.saturating_sub(1)); }); } } @@ -138,6 +141,15 @@ impl Handle { /// /// Fast path cost: 1 Acquire load of global_era (L1 cache hit) + 1 Cell read /// + 1 predictable branch. No loop, no traverse — O(1) and wait-free. + /// + /// # Wait-free bound: O(1) + /// + /// No loop. The paper's `protect()` has a convergence loop because `update_era()` + /// triggers traversal during which the epoch may advance. Kovan avoids this by + /// deferring traversal to the next `pin()` call — only an epoch store + pointer + /// re-read on the rare (epoch-change) path. The epoch store ensures `try_retire()` + /// sees the thread's current epoch; the re-read ensures the returned pointer's + /// birth_epoch ≤ the stored epoch. #[inline] fn protect_load(&self, data: &AtomicUsize, order: Ordering) -> usize { // Step 1: load the pointer @@ -173,6 +185,18 @@ impl Handle { /// do_update: transition epoch for a slot. /// Dereferences previous nodes in the slot's list, then stores new epoch. /// Returns the current epoch after update. + /// + /// # Wait-free bound: O(T) where T = number of active threads + /// + /// The exchange is a single atomic instruction (on native platforms). + /// `traverse_cache` walks the slot's list, which contains at most one + /// node per `try_retire()` insertion since the last `do_update()`. Since + /// at most T threads can insert into a slot between updates, + /// the list length (and thus traversal) is bounded by T. + /// + /// Divergence from C++ reference: uses `exchange_lo(0)` in a single step + /// instead of `exchange(INVPTR)` + `store(nullptr)`. Avoids a race where + /// a concurrent `try_retire` insertion between the two steps gets lost. #[cold] fn do_update(&self, curr_epoch: u64, index: usize, tid: usize) -> u64 { let global = self.global(); @@ -220,6 +244,14 @@ impl Handle { /// (epoch check + do_update). Inner calls just increment the pin count /// and return a Guard. Guard::drop decrements the count, so the epoch /// slot is only eligible for transition once all Guards are dropped. + /// + /// # Wait-free bound: O(T) where T = number of active threads + /// + /// Fast path: at most 16 iterations (fixed). Slow path: O(T) — + /// bounded by Crystalline-W helping. Every `advance_epoch()` call is preceded + /// by `help_read()`, so at most T concurrent epoch advances can + /// occur during the slow path loop. The helpee can also self-complete when + /// the epoch stabilizes, independent of helper progress. fn pin(&self) -> Guard { let count = self.pin_count.get(); self.pin_count.set(count + 1); @@ -267,6 +299,18 @@ impl Handle { /// Slow path for pin when epoch keeps changing. /// Sets up helping state so other threads can assist. + /// + /// # Wait-free bound: O(T) where T = number of active threads + /// + /// The main loop (lines ~294-344) exits when either: + /// 1. Epoch stabilizes and self-completion CAS succeeds, or + /// 2. A helper sets result ≠ INVPTR via help_thread. + /// + /// Both conditions are bounded by O(T) epoch advances. For + /// condition (1): after at most T concurrent advance_epoch calls + /// (each preceded by help_read), no more threads are in the advance phase + /// and the epoch stabilizes. For condition (2): at least one of those + /// T helpers will see a stable epoch and set the result. #[cold] fn slow_path(&self, index: usize, tid: usize) { let global = self.global(); @@ -437,6 +481,11 @@ impl Handle { } /// Help other threads in the slow path (matches help_read). + /// + /// # Wait-free bound: O((T ^ 2) * HR_NUM) where T = number of active threads + /// + /// Scans T * HR_NUM slots, calling help_thread for each + /// stalled thread. Each help_thread call is O(T). #[cold] fn help_read(&self, mytid: usize) { let global = self.global(); @@ -448,8 +497,9 @@ impl Handle { let hr_num = global.hr_num(); for i in 0..max_threads { + let slots = global.thread_slots(i); for j in 0..hr_num { - let result_ptr = global.thread_slots(i).state[j].result.load_lo(); + let result_ptr = slots.state[j].result.load_lo(); if result_ptr == INVPTR as u64 { self.help_thread(i, j, mytid); } @@ -458,6 +508,22 @@ impl Handle { } /// Help a specific thread complete its slow-path operation. + /// + /// # Wait-free bound: O(T) where T = number of active threads + /// + /// The main loop exits when either: + /// 1. Epoch stabilizes (curr_epoch == prev_epoch) and result CAS succeeds, or + /// 2. Another helper already set the result (result ≠ INVPTR). + /// + /// For the epoch to advance during this loop, some thread must call + /// `advance_epoch()`, which is preceded by `help_read()`. After at most + /// T concurrent epoch advances, no more threads are in the + /// advance phase and the epoch stabilizes. The bound is independent of + /// the helpee's progress — the helpee is passive. + /// + /// The seqno cleanup loops (DCAS on first/epoch with `while old_hi == seqno`) + /// are bounded by O(T) contention: once any thread advances seqno, + /// all others see old_hi ≠ seqno and exit. #[cold] fn help_thread(&self, helpee_tid: usize, index: usize, mytid: usize) { let global = self.global(); @@ -654,6 +720,13 @@ impl Handle { /// Internal: enqueue a pre-configured node into the thread-local batch. /// The node's destructor and birth_epoch must already be set by the caller. /// + /// # Wait-free bound: O(1) amortized + /// + /// Each call does O(1) work (link node into batch, update counters). + /// Every RETIRE_FREQ (64) calls: finalize batch + try_retire (O(T)). + /// Every EPOCH_FREQ (128) calls: help_read (O(T ^ 2) worst case, + /// O(1) when no stalled threads) + advance_epoch (single fetch_add). + /// /// # Safety /// /// - `node_ptr` must point to a valid, heap-allocated `RetiredNode` at offset 0. @@ -770,6 +843,17 @@ impl Handle { /// Try to retire the current batch by scanning all thread slots. /// Matches ASMR try_retire. + /// + /// # Wait-free bound: O(T + RETIRE_FREQ) where T = number of active threads + /// + /// Two phases, both bounded: + /// - **Scan phase**: O(T * SLOTS_PER_THREAD) — iterates all active thread + /// slots once, assigning batch nodes to active slots with epoch ≥ min_epoch. + /// No loops within — each slot is visited exactly once. + /// - **Insert phase**: O(RETIRE_FREQ) — iterates batch nodes, exchanging each + /// into its assigned slot. The exchange is a single atomic instruction (on + /// native platforms). Contention handling (INVPTR rollback, list tainting) + /// is O(1) per node. fn try_retire(&self) { let global = self.global(); let max_threads = global.max_threads(); @@ -785,26 +869,27 @@ impl Handle { fence(Ordering::SeqCst); let mut last = curr; for i in 0..max_threads { + let slots = global.thread_slots(i); let mut j = 0; // Regular reservation slots (0..hr_num) while j < hr_num { - let first_lo = global.thread_slots(i).first[j].load_lo(); + let first_lo = slots.first[j].load_lo(); if first_lo == INVPTR as u64 { j += 1; continue; } // Check seqno odd (in slow-path transition) - if global.thread_slots(i).first[j].load_hi() & 1 != 0 { + if slots.first[j].load_hi() & 1 != 0 { j += 1; continue; } - let epoch = global.thread_slots(i).epoch[j].load_lo(); + let epoch = slots.epoch[j].load_lo(); if epoch < min_epoch { j += 1; continue; } // Check epoch seqno odd - if global.thread_slots(i).epoch[j].load_hi() & 1 != 0 { + if slots.epoch[j].load_hi() & 1 != 0 { j += 1; continue; } @@ -820,12 +905,12 @@ impl Handle { } // Helper slots (hr_num..hr_num+2) while j < hr_num + 2 { - let first_lo = global.thread_slots(i).first[j].load_lo(); + let first_lo = slots.first[j].load_lo(); if first_lo == INVPTR as u64 { j += 1; continue; } - let epoch = global.thread_slots(i).epoch[j].load_lo(); + let epoch = slots.epoch[j].load_lo(); if epoch < min_epoch { j += 1; continue; @@ -939,6 +1024,75 @@ impl Handle { self.list_count.set(0); } } + + /// Flush: force all retired nodes on this thread to be reclaimed. + /// + /// Three phases: + /// 1. Finalize any partial batch (<64 nodes) and submit via try_retire + /// 2. Advance global epoch to make all submitted batches eligible + /// 3. Traverse own slots to process and free eligible nodes + /// + /// This does NOT guarantee that nodes retired by OTHER threads are freed — + /// those threads must flush themselves or exit. But it does guarantee that + /// all nodes retired by THIS thread are submitted and that this thread's + /// slots are drained. + fn flush(&self) { + if self.tid.get().is_none() { + return; // Thread never participated in reclamation + } + let tid = self.tid(); + let global = self.global(); + + // Phase 1: Finalize partial batch if any + let count = self.batch_count.get(); + if count > 0 { + let last = self.batch_last.get(); + let first = self.batch_first.get(); + unsafe { + (*last) + .batch_link + .store(rnode_mark(first), Ordering::SeqCst); + } + self.try_retire(); + + // Reset batch + self.batch_first.set(core::ptr::null_mut()); + self.batch_last.set(core::ptr::null_mut()); + self.batch_count.set(0); + } + + // Phase 2: Advance epoch so submitted batches become eligible. + // Each advance makes batches with min_epoch <= new_epoch eligible. + // We need enough advances so all threads' slot epochs are surpassed. + // max_threads + 2 iterations is conservative and bounded. + let max = global.max_threads() + 2; + for _ in 0..max { + global.advance_epoch(); + } + + // Phase 3: Traverse own reservation slots to drain pending lists. + // This is the same logic as do_update but without storing a new epoch. + let hr_num = global.hr_num(); + for i in 0..hr_num { + let first = global.thread_slots(tid).first[i].exchange_lo(0, Ordering::AcqRel); + if first != 0 && first != INVPTR as u64 { + let mut free_list = self.free_list.get(); + let mut list_count = self.list_count.get(); + unsafe { + crate::reclaim::traverse_cache( + &mut free_list, + &mut list_count, + first as *mut RetiredNode, + ); + } + self.free_list.set(free_list); + self.list_count.set(list_count); + } + } + + // Drain the accumulated free list + self.drain_free_list(); + } } impl Handle { @@ -1052,7 +1206,10 @@ pub(crate) fn protect_load(data: &AtomicUsize, order: Ordering) -> usize { } #[cfg(not(feature = "nightly"))] { - HANDLE.with(|handle| handle.protect_load(data, order)) + // During process teardown TLS may be destroyed. Fall back to raw load. + HANDLE + .try_with(|handle| handle.protect_load(data, order)) + .unwrap_or_else(|_| data.load(order)) } } @@ -1069,7 +1226,12 @@ pub fn pin() -> Guard { } #[cfg(not(feature = "nightly"))] { - HANDLE.with(|handle| handle.pin()) + // During process teardown TLS may be destroyed. Return a dummy guard + // whose drop is also a no-op (try_with in Guard::drop handles this). + HANDLE.try_with(|handle| handle.pin()).unwrap_or(Guard { + _private: (), + marker, + }) } } @@ -1112,7 +1274,32 @@ pub unsafe fn retire(ptr: *mut T) { #[cfg(not(feature = "nightly"))] { // SAFETY: Caller upholds the safety contract. - HANDLE.with(|handle| unsafe { handle.retire(ptr) }) + // During process teardown TLS may be destroyed. Leak the pointer — + // process memory is reclaimed by the OS on exit. + let _ = HANDLE.try_with(|handle| unsafe { handle.retire(ptr) }); + } +} + +/// Flush all retired nodes on the calling thread. +/// +/// Forces any partial batch to be finalized and submitted, advances the global +/// epoch to make submitted batches eligible, then traverses the thread's slots +/// to reclaim nodes. +/// +/// Call this before dropping data structures that use kovan (e.g. at the end of +/// a test or before process exit) to ensure retired nodes are freed promptly. +/// +/// **Note:** This only flushes the calling thread's state. To flush all threads, +/// each thread must call `flush()` independently or exit (which triggers cleanup). +pub fn flush() { + #[cfg(feature = "nightly")] + { + HANDLE.flush() + } + #[cfg(not(feature = "nightly"))] + { + // During process teardown TLS may be destroyed. No-op in that case. + let _ = HANDLE.try_with(|handle| handle.flush()); } } @@ -1135,6 +1322,7 @@ pub(crate) unsafe fn retire_raw(node_ptr: *mut RetiredNode) { } #[cfg(not(feature = "nightly"))] { - HANDLE.with(|handle| unsafe { handle.retire_raw(node_ptr) }) + // During process teardown TLS may be destroyed. Leak the node. + let _ = HANDLE.try_with(|handle| unsafe { handle.retire_raw(node_ptr) }); } } diff --git a/kovan/src/lib.rs b/kovan/src/lib.rs index 8c13cf0..bc73760 100644 --- a/kovan/src/lib.rs +++ b/kovan/src/lib.rs @@ -57,7 +57,7 @@ mod ttas; pub use atom::{Atom, AtomGuard, AtomMap, AtomMapGuard, AtomOption, Removed}; pub use atomic::{Atomic, Shared}; -pub use guard::{Guard, pin}; +pub use guard::{Guard, flush, pin}; pub use reclaim::Reclaimable; pub use retired::RetiredNode; pub use robust::{BirthEra, current_era}; diff --git a/kovan/src/reclaim.rs b/kovan/src/reclaim.rs index 914866a..b6780dd 100644 --- a/kovan/src/reclaim.rs +++ b/kovan/src/reclaim.rs @@ -56,6 +56,13 @@ pub(crate) unsafe fn get_refs_node(node: *mut RetiredNode) -> *mut RetiredNode { /// - Otherwise: follow batch_link to refs-node, fetch_sub(1) /// - If refs reaches 0: add refs-node to free list /// +/// # Wait-free bound: O(T) where T = number of active threads +/// +/// Each slot receives at most one node per `try_retire()` call. Between two +/// `do_update()` calls on the same slot, at most T `try_retire()` +/// calls can insert nodes (one per thread). The list length — and thus the +/// number of loop iterations — is bounded by T. +/// /// # Safety /// /// `next` must be a valid RetiredNode pointer (not null, not INVPTR) diff --git a/kovan/src/slot.rs b/kovan/src/slot.rs index 5772a1f..6bd4a6e 100644 --- a/kovan/src/slot.rs +++ b/kovan/src/slot.rs @@ -7,7 +7,7 @@ use crate::retired::INVPTR; use crate::ttas::TTas; use alloc::boxed::Box; -use core::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use core::sync::atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering}; // --------------------------------------------------------------------------- // WordPair: split-field (native) vs single-AtomicU128 (fallback) @@ -309,24 +309,20 @@ pub(crate) const HR_NUM: usize = 1; /// Total slots per thread: hr_num reservations + 2 helper slots pub(crate) const SLOTS_PER_THREAD: usize = HR_NUM + 2; -// Maximum concurrent threads. Configurable via cargo features: -// kovan = { features = ["max-threads-512"] } -// Default: 128. -#[cfg(feature = "max-threads-1024")] -pub(crate) const MAX_THREADS: usize = 1024; -#[cfg(all(feature = "max-threads-512", not(feature = "max-threads-1024")))] -pub(crate) const MAX_THREADS: usize = 512; -#[cfg(all( - feature = "max-threads-256", - not(any(feature = "max-threads-512", feature = "max-threads-1024")) -))] -pub(crate) const MAX_THREADS: usize = 256; -#[cfg(not(any( - feature = "max-threads-256", - feature = "max-threads-512", - feature = "max-threads-1024" -)))] -pub(crate) const MAX_THREADS: usize = 128; +/// Number of thread slots per page (power of 2 for shift/mask indexing). +const SLOTS_PER_PAGE: usize = 128; + +/// log2(SLOTS_PER_PAGE) — used for tid >> PAGE_SHIFT. +const PAGE_SHIFT: usize = 7; + +/// SLOTS_PER_PAGE - 1 — used for tid & PAGE_MASK. +const PAGE_MASK: usize = 127; + +/// Maximum number of pages in the page table. Supports up to 65,536 threads. +const MAX_PAGES: usize = 512; + +/// A page of thread slots, allocated on demand. +struct SlotPage([ThreadSlots; SLOTS_PER_PAGE]); /// Batch retirement frequency (try_retire every `freq` retires) pub(crate) const RETIRE_FREQ: usize = 64; @@ -380,8 +376,9 @@ impl ThreadSlots { /// Global ASMR state pub(crate) struct ASMRState { - /// Per-thread slot arrays - slots: &'static [ThreadSlots], + /// Two-level page table of per-thread slot arrays. Pages are allocated on + /// demand when new thread IDs are assigned. + pages: [AtomicPtr; MAX_PAGES], /// Global epoch counter (starts at 1) epoch: AtomicU64, /// Count of threads currently in the slow path @@ -392,14 +389,14 @@ pub(crate) struct ASMRState { free_tids: TTas>, } +/// Null-initialized page table constant for use in array initialization. +#[allow(clippy::declare_interior_mutable_const)] +const NULL_PAGE: AtomicPtr = AtomicPtr::new(core::ptr::null_mut()); + impl ASMRState { fn new() -> Self { - let mut slots_vec = alloc::vec::Vec::with_capacity(MAX_THREADS); - for _ in 0..MAX_THREADS { - slots_vec.push(ThreadSlots::new()); - } Self { - slots: Box::leak(slots_vec.into_boxed_slice()), + pages: [NULL_PAGE; MAX_PAGES], epoch: AtomicU64::new(1), slow_counter: AtomicU64::new(0), next_tid: AtomicUsize::new(0), @@ -407,10 +404,40 @@ impl ASMRState { } } - /// Get the thread slots for a given thread ID + /// Ensure the page at `page_idx` is allocated. Uses CAS to handle races. + fn ensure_page(&self, page_idx: usize) { + if self.pages[page_idx].load(Ordering::Acquire).is_null() { + let page = Box::into_raw(Box::new(SlotPage(core::array::from_fn(|_| { + ThreadSlots::new() + })))); + if self.pages[page_idx] + .compare_exchange( + core::ptr::null_mut(), + page, + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_err() + { + // Another thread allocated this page first — free ours. + unsafe { + drop(Box::from_raw(page)); + } + } + } + } + + /// Get the thread slots for a given thread ID. #[inline] pub(crate) fn thread_slots(&self, tid: usize) -> &ThreadSlots { - &self.slots[tid] + let page_idx = tid >> PAGE_SHIFT; + let slot_idx = tid & PAGE_MASK; + let page = self.pages[page_idx].load(Ordering::Acquire); + debug_assert!( + !page.is_null(), + "kovan: page {page_idx} not allocated for tid {tid}" + ); + unsafe { &(*page).0[slot_idx] } } /// Get the current global epoch @@ -443,15 +470,15 @@ impl ASMRState { self.slow_counter.fetch_sub(1, Ordering::AcqRel); } - /// Total number of allocated thread slots + /// Number of active thread slots (upper bound for scanning). #[inline] pub(crate) fn max_threads(&self) -> usize { - MAX_THREADS + self.next_tid.load(Ordering::Acquire) } /// Allocate a thread ID pub(crate) fn alloc_tid(&self) -> usize { - // Try recycled IDs first + // Try recycled IDs first (page already exists for recycled tids) { let mut free = self.free_tids.lock(); if let Some(tid) = free.pop() { @@ -462,14 +489,20 @@ impl ASMRState { // if the assert panics and is caught by catch_unwind. loop { let current = self.next_tid.load(Ordering::Relaxed); + let page_idx = current >> PAGE_SHIFT; assert!( - current < MAX_THREADS, - "kovan: exceeded maximum thread count ({MAX_THREADS})" + page_idx < MAX_PAGES, + "kovan: exceeded maximum thread count ({})", + MAX_PAGES * SLOTS_PER_PAGE, ); + // Ensure the page is allocated BEFORE publishing the tid via next_tid. + // This guarantees concurrent scanners (via max_threads()) never see a + // tid whose page doesn't exist yet. + self.ensure_page(page_idx); match self.next_tid.compare_exchange_weak( current, current + 1, - Ordering::Relaxed, + Ordering::Release, Ordering::Relaxed, ) { Ok(_) => return current, @@ -480,10 +513,11 @@ impl ASMRState { /// Release a thread ID for recycling pub(crate) fn free_tid(&self, tid: usize) { - // Mark all slots inactive + // Mark all slots inactive (page is guaranteed to exist) + let slots = self.thread_slots(tid); for j in 0..SLOTS_PER_THREAD { - self.slots[tid].first[j].store(INVPTR as u64, 0, Ordering::Release); - self.slots[tid].epoch[j].store(0, 0, Ordering::Release); + slots.first[j].store(INVPTR as u64, 0, Ordering::Release); + slots.epoch[j].store(0, 0, Ordering::Release); } let mut free = self.free_tids.lock(); free.push(tid); diff --git a/kovan/tests/stress.rs b/kovan/tests/stress.rs index f85cc1a..7abba70 100644 --- a/kovan/tests/stress.rs +++ b/kovan/tests/stress.rs @@ -406,3 +406,40 @@ fn test_burst_workload() { } } } + +/// Verify that >128 concurrent threads work (previously panicked with fixed MAX_THREADS=128). +#[test] +#[cfg_attr(miri, ignore)] +fn test_many_threads_beyond_old_limit() { + let num_threads = 256; + let shared = Arc::new(Atomic::new(StressNode::new(0))); + let barrier = Arc::new(std::sync::Barrier::new(num_threads)); + + let handles: Vec<_> = (0..num_threads) + .map(|i| { + let shared = shared.clone(); + let barrier = barrier.clone(); + thread::spawn(move || { + barrier.wait(); + let guard = pin(); + let _ = shared.load(Ordering::Acquire, &guard); + // Each thread does a store to exercise retirement + let node = StressNode::new(i); + let old = shared.swap( + unsafe { kovan::Shared::from_raw(node) }, + Ordering::AcqRel, + &guard, + ); + if !old.is_null() { + unsafe { + retire(old.as_raw()); + } + } + }) + }) + .collect(); + + for h in handles { + h.join().unwrap(); + } +}