Skip to content
Merged
28 changes: 21 additions & 7 deletions core/patina_internal_collections/src/bst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ where
///
pub fn get(&self, key: &D::Key) -> Option<&D> {
match self.get_node(key) {
Some(node) => Some(&node.data),
Some(node) => {
// SAFETY: Nodes in the tree always have initialized data
Some(unsafe { node.data() })
}
None => None,
}
}
Expand All @@ -186,7 +189,7 @@ where
// SAFETY: The pointer comes from as_mut_ptr() on a valid node reference obtained from get_node().
// The caller is responsible for ensuring that the mutable reference doesn't modify key-affecting
// values.
Some(unsafe { &mut (*ptr).data })
Some(unsafe { (*ptr).data_mut() })
}
None => None,
}
Expand All @@ -209,7 +212,10 @@ where
///
pub fn get_with_idx(&self, idx: usize) -> Option<&D> {
match self.storage.get(idx) {
Some(node) => Some(&node.data),
Some(node) => {
// SAFETY: Nodes in storage always have initialized data
Some(unsafe { node.data() })
}
None => None,
}
}
Expand All @@ -236,7 +242,10 @@ where
///
pub unsafe fn get_with_idx_mut(&mut self, idx: usize) -> Option<&mut D> {
match self.storage.get_mut(idx) {
Some(node) => Some(&mut node.data),
Some(node) => {
// SAFETY: Nodes in storage always have initialized data
Some(unsafe { node.data_mut() })
}
None => None,
}
}
Expand Down Expand Up @@ -281,7 +290,8 @@ where
let mut current = self.root();
let mut closest = None;
while let Some(node) = current {
match key.cmp(node.data.key()) {
// SAFETY: Nodes in the tree always have initialized data
match key.cmp(unsafe { node.data() }.key()) {
Ordering::Equal => return Some(self.storage.idx(node.as_mut_ptr())),
Ordering::Less => current = node.left(),
Ordering::Greater => {
Expand Down Expand Up @@ -494,7 +504,8 @@ where
fn get_node(&self, key: &D::Key) -> Option<&Node<D>> {
let mut current_idx = self.root();
while let Some(node) = current_idx {
match key.cmp(node.data.key()) {
// SAFETY: Nodes in the tree always have initialized data
match key.cmp(unsafe { node.data() }.key()) {
Ordering::Equal => return Some(node),
Ordering::Less => current_idx = node.left(),
Ordering::Greater => current_idx = node.right(),
Expand Down Expand Up @@ -646,7 +657,8 @@ where
fn _dfs(node: Option<&Node<D>>, values: &mut alloc::vec::Vec<D>) {
if let Some(node) = node {
Self::_dfs(node.left(), values);
values.push(node.data);
// SAFETY: Nodes in the tree always have initialized data
values.push(unsafe { *node.data() });
Self::_dfs(node.right(), values);
}
}
Expand All @@ -666,6 +678,7 @@ where

#[cfg(test)]
#[coverage(off)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod tests {
use crate::{Bst, node_size};

Expand Down Expand Up @@ -883,6 +896,7 @@ mod tests {
}

#[cfg(test)]
#[allow(clippy::undocumented_unsafe_blocks)]
mod fuzz_tests {
extern crate std;
use crate::{Bst, node_size};
Expand Down
159 changes: 122 additions & 37 deletions core/patina_internal_collections/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//!
//! SPDX-License-Identifier: Apache-2.0
//!
use core::{cell::Cell, mem, ptr::NonNull, slice};
use core::{cell::Cell, mem, mem::MaybeUninit, ptr::NonNull, slice};

use crate::{Error, Result, SliceKey};

Expand Down Expand Up @@ -51,23 +51,32 @@ where

/// Create a new storage container with a slice of memory.
pub fn with_capacity(slice: &'a mut [u8]) -> Storage<'a, D> {
let storage = Storage {
// SAFETY: This is reinterpreting a byte slice as a Node<D> slice.
// 1. The alignment is checked implicitly by the slice bounds.
// 2. The correct number of Node<D> elements that fit in the byte slice is calculated.
// 3. The lifetime ensures the byte slice remains valid for the storage's lifetime
data: unsafe {
slice::from_raw_parts_mut::<'a, Node<D>>(
slice as *mut [u8] as *mut Node<D>,
slice.len() / mem::size_of::<Node<D>>(),
)
},
length: 0,
available: Cell::default(),
// SAFETY: This is reinterpreting a byte slice as a MaybeUninit<Node<D>> slice.
// Using MaybeUninit explicitly represents uninitialized memory.
let uninit_buffer = unsafe {
slice::from_raw_parts_mut::<'a, MaybeUninit<Node<D>>>(
slice as *mut [u8] as *mut MaybeUninit<Node<D>>,
slice.len() / mem::size_of::<Node<D>>(),
)
};

Self::build_linked_list(storage.data);
storage.available.set(storage.data[0].as_mut_ptr());
// Initialize nodes with uninitialized data fields
for elem in uninit_buffer.iter_mut() {
elem.write(Node::new_uninit());
}

// SAFETY: All nodes have been initialized (though their data fields are uninitialized).
// We can now safely convert from MaybeUninit<Node<D>> to Node<D>.
let buffer =
unsafe { slice::from_raw_parts_mut(uninit_buffer.as_mut_ptr() as *mut Node<D>, uninit_buffer.len()) };

let storage = Storage { data: buffer, length: 0, available: Cell::default() };

if !storage.data.is_empty() {
Self::build_linked_list(storage.data);
storage.available.set(storage.data[0].as_mut_ptr());
}

storage
}

Expand Down Expand Up @@ -105,7 +114,11 @@ where
node.set_left(None);
node.set_right(None);
node.set_parent(None);
node.data = data;
// SAFETY: The node is from the available list, so its data field is uninitialized.
// We initialize it here when moving the node to the "in use" state.
unsafe {
node.init_data(data);
}
self.length += 1;
Ok((self.idx(node.as_mut_ptr()), node))
} else {
Expand Down Expand Up @@ -191,18 +204,32 @@ where
///
/// O(n)
pub fn expand(&mut self, slice: &'a mut [u8]) {
// SAFETY: This is reinterpreting a byte slice as a Node<D> slice.
// SAFETY: This is reinterpreting a byte slice as a MaybeUninit<Node<D>> slice.
// Using MaybeUninit explicitly represents uninitialized memory and avoids undefined
// behavior from creating references to uninitialized Node<D>.
// 1. The alignment is handled by slice casting rules
// 2. The correct number of Node<D> elements that fit in the byte slice is calculated
// 3. The lifetime 'a ensures the byte slice remains valid for the storage's lifetime
let buffer = unsafe {
slice::from_raw_parts_mut::<'a, Node<D>>(
slice as *mut [u8] as *mut Node<D>,
// 4. MaybeUninit<T> has the same size and alignment as T
let uninit_buffer = unsafe {
slice::from_raw_parts_mut::<'a, MaybeUninit<Node<D>>>(
slice as *mut [u8] as *mut MaybeUninit<Node<D>>,
slice.len() / mem::size_of::<Node<D>>(),
)
};

assert!(buffer.len() >= self.capacity());
assert!(uninit_buffer.len() >= self.capacity());

// Initialize all new nodes with uninitialized data fields.
// Nodes at indices 0..self.capacity() will be overwritten with copied data below.
for elem in uninit_buffer.iter_mut() {
elem.write(Node::new_uninit());
}

// SAFETY: All nodes have been initialized (though their data fields are uninitialized).
// We can now safely convert from MaybeUninit<Node<D>> to Node<D>.
let buffer =
unsafe { slice::from_raw_parts_mut(uninit_buffer.as_mut_ptr() as *mut Node<D>, uninit_buffer.len()) };

// When current capacity is 0, we just need to copy the data and build the available list
if self.capacity() == 0 {
Expand All @@ -213,10 +240,15 @@ where
}

// Copy the data from the old buffer to the new buffer. Update the pointers to the new buffer
for i in 0..self.capacity() {
for i in 0..self.len() {
let old = &self.data[i];

buffer[i].data = old.data;
// SAFETY: Nodes at indices 0..self.len() are "in use" and have initialized data.
// We copy the initialized data from old to new.
unsafe {
let old_data = old.data();
buffer[i].data = MaybeUninit::new(*old_data);
}
buffer[i].set_color(old.color());

if let Some(left) = old.left() {
Expand Down Expand Up @@ -467,7 +499,7 @@ pub struct Node<D>
where
D: SliceKey,
{
pub data: D,
pub(crate) data: MaybeUninit<D>,
color: Cell<bool>,
parent: Cell<*mut Node<D>>,
left: Cell<*mut Node<D>>,
Expand All @@ -478,8 +510,48 @@ impl<D> Node<D>
where
D: SliceKey,
{
/// Create a new node with uninitialized data.
/// The data field must be initialized separately using `init_data()`.
pub fn new_uninit() -> Self {
Node {
data: MaybeUninit::uninit(),
color: Cell::new(RED),
parent: Cell::default(),
left: Cell::default(),
right: Cell::default(),
}
}

/// Initialize the data field of an uninitialized node.
/// # Safety
/// The caller must ensure the data field has not been previously initialized.
pub unsafe fn init_data(&mut self, data: D) {
self.data.write(data);
}

/// Creates a new Node with initialized data.
/// Used for testing purposes.
#[cfg(test)]
pub fn new(data: D) -> Self {
Node { data, color: Cell::new(RED), parent: Cell::default(), left: Cell::default(), right: Cell::default() }
let mut node = Self::new_uninit();
node.data.write(data);
node
}

/// Get a reference to the data, assuming it is initialized.
/// # Safety
/// The caller must ensure the data field has been initialized.
pub unsafe fn data(&self) -> &D {
// SAFETY: Caller guarantees data is initialized
unsafe { self.data.assume_init_ref() }
}

/// Get a mutable reference to the data, assuming it is initialized.
/// # Safety
/// The caller must ensure the data field has been initialized.
pub unsafe fn data_mut(&mut self) -> &mut D {
// SAFETY: Caller guarantees data is initialized
unsafe { self.data.assume_init_mut() }
}

pub fn height_and_balance(node: Option<&Node<D>>) -> (i32, bool) {
Expand Down Expand Up @@ -587,7 +659,9 @@ where
impl<D: SliceKey> SliceKey for Node<D> {
type Key = D::Key;
fn key(&self) -> &Self::Key {
self.data.key()
// SAFETY: This method is only called on nodes that are in use (initialized).
// Nodes in the available list are never accessed for their key.
unsafe { self.data().key() }
}
}

Expand All @@ -605,7 +679,8 @@ mod tests {
for i in 0..10 {
let (index, node) = storage.add(i).unwrap();
assert_eq!(index, i);
assert_eq!(node.data, i);
// SAFETY: Node was just added with data, so it's initialized
assert_eq!(unsafe { *node.data() }, i);
assert_eq!(storage.len(), i + 1);
}

Expand All @@ -616,16 +691,22 @@ mod tests {
storage.delete(storage.get(5).unwrap().as_mut_ptr());
let (index, node) = storage.add(11).unwrap();
assert_eq!(index, 5);
assert_eq!(node.data, 11);
// SAFETY: Node was just added with data, so it's initialized
assert_eq!(unsafe { *node.data() }, 11);

// Try and get a mutable reference to a node
{
let node = storage.get_mut(5).unwrap();
assert_eq!(node.data, 11);
node.data = 12;
// SAFETY: Node is in use, so data is initialized
assert_eq!(unsafe { *node.data() }, 11);
// SAFETY: Node is in use, we can modify the initialized data
unsafe {
*node.data_mut() = 12;
}
}
let node = storage.get(5).unwrap();
assert_eq!(node.data, 12);
// SAFETY: Node is in use, so data is initialized
assert_eq!(unsafe { *node.data() }, 12);
}

#[test]
Expand All @@ -643,8 +724,10 @@ mod tests {

p4.set_parent(Some(p1));

assert_eq!(Node::sibling(p2).unwrap().data, 3);
assert_eq!(Node::sibling(p3).unwrap().data, 2);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::sibling(p2).unwrap().data() }, 3);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::sibling(p3).unwrap().data() }, 2);
assert!(Node::sibling(p1).is_none());
}

Expand Down Expand Up @@ -683,7 +766,8 @@ mod tests {
p2.set_right(Some(p4));
p4.set_parent(Some(p2));

assert_eq!(Node::predecessor(p1).unwrap().data, 4);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::predecessor(p1).unwrap().data() }, 4);
assert!(Node::predecessor(p4).is_none());
}

Expand All @@ -703,7 +787,8 @@ mod tests {
p2.set_right(Some(p4));
p4.set_parent(Some(p2));

assert_eq!(Node::successor(p1).unwrap().data, 3);
// SAFETY: Test nodes are created with initialized data via Node::new()
assert_eq!(unsafe { *Node::successor(p1).unwrap().data() }, 3);
assert!(Node::successor(p4).is_none());
}

Expand Down Expand Up @@ -785,7 +870,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "assertion failed: buffer.len() >= self.capacity()")]
#[should_panic(expected = "assertion failed: uninit_buffer.len() >= self.capacity()")]
fn test_expand_prevents_capacity_shrink() {
// Verify that expand() prevents shrinking capacity
const INITIAL_SIZE: usize = 10;
Expand Down
Loading
Loading