diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index 09d56c4e..e5779e75 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -8,19 +8,32 @@ macro_rules! generate_nearest_n_within_unsorted { where D: DistanceMetric, { + self.nearest_n_within_exclusive::(query, dist, max_items, sorted, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn nearest_n_within_exclusive(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool, inclusive: bool) -> Vec> + where + D: DistanceMetric, + { + // Like [`nearest_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. if sorted || max_items < std::num::NonZero::new(usize::MAX).unwrap() { if max_items <= std::num::NonZero::new(MAX_VEC_RESULT_SIZE).unwrap() { - self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted) + self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted, inclusive) } else { - self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted) + self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted, inclusive) } } else { - self.nearest_n_within_stub::>>(query, dist, 0, sorted) + self.nearest_n_within_stub::>>(query, dist, 0, sorted, inclusive) } } fn nearest_n_within_stub, H: ResultCollection>( - &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool + &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool, inclusive: bool ) -> Vec> { let mut matching_items = H::new_with_capacity(res_capacity); let mut off = [A::zero(); K]; @@ -35,6 +48,7 @@ macro_rules! generate_nearest_n_within_unsorted { &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -55,6 +69,7 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items: &mut R, off: &mut [A; K], rd: A, + inclusive: bool, ) where D: DistanceMetric, { @@ -84,11 +99,12 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.nearest_n_within_unsorted_recurse::( query, @@ -98,6 +114,7 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -116,7 +133,7 @@ macro_rules! generate_nearest_n_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within.rs b/src/common/generate_within.rs index 7952636f..a014a2d4 100644 --- a/src/common/generate_within.rs +++ b/src/common/generate_within.rs @@ -8,7 +8,20 @@ macro_rules! generate_within { where D: DistanceMetric, { - let mut matching_items = self.within_unsorted::(query, dist); + self.within_exclusive::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + D: DistanceMetric, + { + // Like [`within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. + let mut matching_items = self.within_unsorted_exclusive::(query, dist, inclusive); matching_items.sort(); matching_items } diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index 656ab033..49120271 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -8,6 +8,19 @@ macro_rules! generate_within_unsorted { where D: DistanceMetric, { + self.within_unsorted_exclusive::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + D: DistanceMetric, + { + // Like [`within_unsorted`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let mut matching_items = Vec::new(); let root_index: IDX = *transform(&self.root_index); @@ -21,6 +34,7 @@ macro_rules! generate_within_unsorted { &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -37,6 +51,7 @@ macro_rules! generate_within_unsorted { matching_items: &mut Vec>, off: &mut [A; K], rd: A, + inclusive: bool, ) where D: DistanceMetric, { @@ -66,11 +81,12 @@ macro_rules! generate_within_unsorted { matching_items, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_unsorted_recurse::( query, @@ -80,6 +96,7 @@ macro_rules! generate_within_unsorted { matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -98,7 +115,7 @@ macro_rules! generate_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index a44ad861..1ae22454 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -12,6 +12,24 @@ macro_rules! generate_within_unsorted_iter { where D: DistanceMetric, { + self.within_unsorted_iter_exclusive::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_iter_exclusive( + &'a self, + query: &'query [A; K], + dist: A, + inclusive: bool, + ) -> WithinUnsortedIter<'a, A, T> + where + D: DistanceMetric, + { + // Like [`within_unsorted_iter`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let root_index: IDX = *transform(&self.root_index); @@ -27,6 +45,7 @@ macro_rules! generate_within_unsorted_iter { gen_scope, &mut off, A::zero(), + inclusive, ); } @@ -46,6 +65,7 @@ macro_rules! generate_within_unsorted_iter { mut gen_scope: Scope<'scope, 'a, (), NearestNeighbour>, off: &mut [A; K], rd: A, + inclusive: bool, ) -> Scope<'scope, 'a, (), NearestNeighbour> where D: DistanceMetric, @@ -76,11 +96,12 @@ macro_rules! generate_within_unsorted_iter { gen_scope, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; gen_scope = self.within_unsorted_iter_recurse::( query, @@ -90,6 +111,7 @@ macro_rules! generate_within_unsorted_iter { gen_scope, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -108,7 +130,7 @@ macro_rules! generate_within_unsorted_iter { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index 094211af..785ed034 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -50,6 +50,7 @@ mod tests { use fixed::types::extra::U14; use fixed::FixedU16; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type Fxd = FixedU16; @@ -154,6 +155,25 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[n(1.0), n(0.0)], 1); + kdtree.add(&[n(2.0), n(0.0)], 2); + + let query = [n(0.0), n(0.0)]; + let radius = n(1.0); + + let results = kdtree.within_exclusive::(&query, radius, inclusive); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, n(1.0)); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/float/distance.rs b/src/float/distance.rs index 035b1f0f..4bedd982 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -40,6 +40,11 @@ impl DistanceMetric for Manhattan { fn dist1(a: A, b: A) -> A { (a - b).abs() } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } } /// Returns the Chebyshev / L-infinity distance between two points. @@ -115,6 +120,11 @@ impl DistanceMetric for SquaredEuclidean { fn dist1(a: A, b: A) -> A { (a - b) * (a - b) } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } } /// Returns the Minkowski distance (power distance) between two points. @@ -859,7 +869,7 @@ mod tests { (DataScenario::Gaussian, d) => { let mut rng = StdRng::seed_from_u64(8757); let normal = Normal::new(1.0, 10.0).unwrap(); - let n_samples = 200; + let n_samples = 2000; let mut data = vec![vec![0.0; d]; n_samples]; for i in 0..n_samples { for j in 0..d { @@ -892,7 +902,7 @@ mod tests { /// - Point 0 is always the query point (distance 0, index 0 expected first result) /// - NoTies scenario: checks distances and item IDs for points with unique distances /// - Ties scenario: checks distances (order among ties is non-deterministic) - fn run_test_helper>( + fn run_nearest_n_test_helper>( dim: usize, tree_type: TreeType, scenario: DataScenario, @@ -939,7 +949,7 @@ mod tests { // Query based on tree type let results = match tree_type { TreeType::Mutable => { - let mut tree: crate::float::kdtree::KdTree = + let mut tree: crate::float::kdtree::KdTree = crate::float::kdtree::KdTree::new(); for (i, point) in points.iter().enumerate() { tree.add(point, i as u64); @@ -947,7 +957,7 @@ mod tests { tree.nearest_n::(&query_arr, n) } TreeType::Immutable => { - let tree: crate::immutable::float::kdtree::ImmutableKdTree = + let tree: crate::immutable::float::kdtree::ImmutableKdTree = crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) } @@ -993,7 +1003,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1004,7 +1014,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1015,7 +1025,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1028,8 +1038,8 @@ mod tests { #[values(3, 4)] p: u32, ) { match p { - 3 => run_test_helper::>(dim, tree_type, scenario, n), - 4 => run_test_helper::>(dim, tree_type, scenario, n), + 3 => run_nearest_n_test_helper::>(dim, tree_type, scenario, n), + 4 => run_nearest_n_test_helper::>(dim, tree_type, scenario, n), _ => unreachable!(), } } @@ -1044,9 +1054,13 @@ mod tests { #[values(0.5, 1.5)] p: f64, ) { if (p - 0.5).abs() < f64::EPSILON { - run_test_helper::>(dim, tree_type, scenario, n); + run_nearest_n_test_helper::>( + dim, tree_type, scenario, n, + ); } else if (p - 1.5).abs() < f64::EPSILON { - run_test_helper::>(dim, tree_type, scenario, n); + run_nearest_n_test_helper::>( + dim, tree_type, scenario, n, + ); } else { unreachable!() } @@ -1342,55 +1356,178 @@ mod tests { assert!(nearby_items.contains(&5)); } - #[test] - fn test_within_chebyshev_distance() { - let mut kdtree: KdTree = KdTree::new(); + /// Helper function to test within queries for `D: DistanceMetric` + fn run_within_test_helper>( + dim: usize, + tree_type: TreeType, + scenario: DataScenario, + radius: f64, + inclusive: bool, + ) { + let data = scenario.get(dim); + let query_point = &data[0]; - // Add points with varying Chebyshev distances - let points = [ - ([0.0f32, 0.0f32], 0), // distance 0 - ([0.5f32, 0.5f32], 1), // Chebyshev: 0.5 - ([1.0f32, 0.0f32], 2), // Chebyshev: 1.0 - ([0.8f32, 0.9f32], 3), // Chebyshev: 0.9 - ([2.0f32, 0.0f32], 4), // Chebyshev: 2.0 - ([0.0f32, 2.0f32], 5), // Chebyshev: 2.0 - ([1.5f32, 1.5f32], 6), // Chebyshev: 1.5 - ]; + let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [0.0; 6]; + for (i, &val) in row.iter().enumerate() { + p[i] = val; + } + points.push(p); + } - for (point, index) in points { - kdtree.add(&point, index); + let mut query_arr = [0.0; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } } - let query_point = [0.0f32, 0.0f32]; - let radius = 1.0; // radius 1 (not squared for Chebyshev) - let mut results = kdtree.within::(&query_point, radius); - - // Sort by distance for easier verification - results.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) - // For `<=5.2.4` found: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) - let found_indices: Vec = results.iter().map(|r| r.item).collect(); - - assert!(found_indices.contains(&0)); - assert!(found_indices.contains(&1)); - assert!(found_indices.contains(&2)); - assert!(found_indices.contains(&3)); - // Should NOT include points with Chebyshev distance > 1 - assert!(!found_indices.contains(&4)); - assert!(!found_indices.contains(&5)); - assert!(!found_indices.contains(&6)); - - // Verify distances - for result in results { - assert!(result.distance <= 1.0 || (result.distance - 1.0).abs() < 0.001); + // Ground truth with brute-force + let mut expected: Vec<(usize, f64)> = points + .iter() + .enumerate() + .filter_map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + if if inclusive { + dist <= radius + } else { + dist < radius + } { + Some((i, dist)) + } else { + None + } + }) + .collect(); + + expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + println!( + "Within Query: TreeType: {:?}, Scenario: {:?}, dim={}, radius={}, inclusive={}", + tree_type, scenario, dim, radius, inclusive + ); + + // Query based on tree type + let mut results = match tree_type { + TreeType::Mutable => { + let mut tree: crate::float::kdtree::KdTree = + crate::float::kdtree::KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u64); + } + tree.within_exclusive::(&query_arr, radius, inclusive) + } + TreeType::Immutable => { + let tree: crate::immutable::float::kdtree::ImmutableKdTree = + crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); + tree.within_exclusive::(&query_arr, radius, inclusive) + } + }; + + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + println!( + "Results (len: {}), Expected (len: {})", + results.len(), + expected.len() + ); + + assert_eq!( + results.len(), + expected.len(), + "Result count mismatch. Expected {}, got {}", + expected.len(), + results.len() + ); + + for (i, result) in results.iter().enumerate() { + assert!( + (result.distance - expected[i].1).abs() < 1e-10, + "Distance at index {} should be {}, but was {}", + i, + expected[i].1, + result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u64, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } } } + #[rstest] + fn test_within_chebyshev( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + fn test_within_squared_euclidean( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + fn test_within_manhattan( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness( + #[case] inclusive: bool, + #[case] expected_len: usize, + ) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = kdtree.within_exclusive::(&query, radius, inclusive); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + + // Test nearest_n_within_exclusive + let max_qty = std::num::NonZero::new(10).unwrap(); + let results = kdtree.nearest_n_within_exclusive::( + &query, radius, max_qty, true, inclusive, + ); + assert_eq!(results.len(), expected_len); + } + #[test] fn test_chebyshev_vs_manhattan_ordering() { let mut kdtree: KdTree = KdTree::new(); diff --git a/src/float_leaf_slice/fallback.rs b/src/float_leaf_slice/fallback.rs index d8fadfd7..25750d06 100644 --- a/src/float_leaf_slice/fallback.rs +++ b/src/float_leaf_slice/fallback.rs @@ -34,6 +34,7 @@ pub(crate) fn update_nearest_dists_within_autovec( items: &[T], radius: A, results: &mut R, + inclusive: bool, ) where usize: Cast, R: ResultCollection, @@ -42,7 +43,13 @@ pub(crate) fn update_nearest_dists_within_autovec( dists .iter() .zip(items.iter()) - .filter(|(&distance, _)| distance <= radius) + .filter(|(&distance, _)| { + if inclusive { + distance <= radius + } else { + distance < radius + } + }) .for_each(|(&distance, &item)| { results.add(NearestNeighbour { distance, item }); }); @@ -55,6 +62,7 @@ pub(crate) fn update_best_dists_within_autovec( radius: A, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where usize: Cast, { @@ -62,7 +70,13 @@ pub(crate) fn update_best_dists_within_autovec( dists .iter() .zip(items.iter()) - .filter(|(&distance, _)| distance <= radius) + .filter(|(&distance, _)| { + if inclusive { + distance <= radius + } else { + distance < radius + } + }) .for_each(|(&distance, &item)| { if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -125,7 +139,7 @@ mod tests { item: 100u32, }]; - update_nearest_dists_within_autovec(&dists[..], &items[..], radius, &mut results); + update_nearest_dists_within_autovec(&dists[..], &items[..], radius, &mut results, true); assert_eq!( results, @@ -157,7 +171,14 @@ mod tests { item: 100u32, }); - update_best_dists_within_autovec(&dists[..], &items[..], radius, max_qty, &mut results); + update_best_dists_within_autovec( + &dists[..], + &items[..], + radius, + max_qty, + &mut results, + true, + ); let results = results.into_vec(); diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 8f2d0222..4030932b 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -140,6 +140,7 @@ where items: &[T; C], radius: Self, results: &mut R, + inclusive: bool, ) where R: ResultCollection, usize: Cast, @@ -151,6 +152,7 @@ where radius: Self, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where Self: Axis + Sized; } @@ -220,8 +222,13 @@ where } #[inline] - pub(crate) fn nearest_n_within(&self, query: &[A; K], radius: A, results: &mut R) - where + pub(crate) fn nearest_n_within( + &self, + query: &[A; K], + radius: A, + results: &mut R, + inclusive: bool, + ) where D: DistanceMetric, R: ResultCollection, { @@ -230,7 +237,7 @@ where for chunk in chunk_iter { let dists = A::dists_for_chunk::(chunk.0, query); - A::update_nearest_dists_within(dists, chunk.1, radius, results); + A::update_nearest_dists_within(dists, chunk.1, radius, results, inclusive); } #[allow(clippy::needless_range_loop)] @@ -240,7 +247,11 @@ where distance = D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance <= radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { results.add(NearestNeighbour { distance, item: remainder_items[idx], @@ -256,6 +267,7 @@ where radius: A, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where D: DistanceMetric, { @@ -264,7 +276,7 @@ where for chunk in chunk_iter { let dists = A::dists_for_chunk::(chunk.0, query); - A::update_best_dists_within(dists, chunk.1, radius, max_qty, results); + A::update_best_dists_within(dists, chunk.1, radius, max_qty, results, inclusive); } #[allow(clippy::needless_range_loop)] @@ -275,7 +287,11 @@ where D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance <= radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { let item = remainder_items[idx]; if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -336,10 +352,11 @@ where items: &[T; C], radius: f64, results: &mut R, + inclusive: bool, ) where R: ResultCollection, { - update_nearest_dists_within_autovec(&acc, items, radius, results) + update_nearest_dists_within_autovec(&acc, items, radius, results, inclusive) } #[inline] @@ -349,8 +366,9 @@ where radius: f64, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) { - update_best_dists_within_autovec(&acc, items, radius, max_qty, results) + update_best_dists_within_autovec(&acc, items, radius, max_qty, results, inclusive) } } @@ -420,10 +438,11 @@ where items: &[T; C], radius: f32, results: &mut R, + inclusive: bool, ) where R: ResultCollection, { - update_nearest_dists_within_autovec(&acc, items, radius, results) + update_nearest_dists_within_autovec(&acc, items, radius, results, inclusive) } #[inline] @@ -433,8 +452,9 @@ where radius: f32, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) { - update_best_dists_within_autovec(&acc, items, radius, max_qty, results) + update_best_dists_within_autovec(&acc, items, radius, max_qty, results, inclusive) } } @@ -503,7 +523,7 @@ mod test { item: 100u32, }); - f64::update_nearest_dists_within(dists, &items, radius, &mut results); + f64::update_nearest_dists_within(dists, &items, radius, &mut results, true); let results = results.into_vec(); @@ -537,7 +557,7 @@ mod test { item: 100u32, }); - f64::update_best_dists_within(dists, &items, radius, max_qty, &mut results); + f64::update_best_dists_within(dists, &items, radius, max_qty, &mut results, true); let results = results.into_vec(); @@ -569,7 +589,7 @@ mod test { item: 100u32, }); - f32::update_nearest_dists_within(dists, &items, radius, &mut results); + f32::update_nearest_dists_within(dists, &items, radius, &mut results, true); let results = results.into_vec(); @@ -603,7 +623,7 @@ mod test { item: 100u32, }); - f32::update_best_dists_within(dists, &items, radius, max_qty, &mut results); + f32::update_best_dists_within(dists, &items, radius, max_qty, &mut results, true); let results = results.into_vec(); @@ -644,7 +664,12 @@ mod test { }; let mut results: BinaryHeap> = BinaryHeap::with_capacity(10); - slice.nearest_n_within::(&[32.0f64, 0.0f64], 4.0f64, &mut results); + slice.nearest_n_within::( + &[32.0f64, 0.0f64], + 4.0f64, + &mut results, + true, + ); let items_found: Vec<_> = results.iter().map(|n| n.item).collect(); assert!( diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index 76505046..9bcba711 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -36,6 +36,24 @@ where /// ``` #[inline] pub fn within(&self, query: &[A; K], dist: A, distance_fn: &F) -> Vec> + where + F: Fn(&[A; K], &[A; K]) -> A, + { + self.within_exclusive(query, dist, distance_fn, true) + } + + /// Like [`within`] but allows controlling boundary inclusiveness. + /// + /// When `inclusive` is true, points at exactly the maximum distance are included. + /// When false, only points strictly less than the maximum distance are included. + #[inline] + pub fn within_exclusive( + &self, + query: &[A; K], + dist: A, + distance_fn: &F, + inclusive: bool, + ) -> Vec> where F: Fn(&[A; K], &[A; K]) -> A, { @@ -52,6 +70,7 @@ where &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -68,6 +87,7 @@ where matching_items: &mut BinaryHeap>, off: &mut [A; K], rd: A, + inclusive: bool, ) where F: Fn(&[A; K], &[A; K]) -> A, { @@ -95,13 +115,14 @@ where matching_items, off, rd, + inclusive, ); // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_recurse( query, @@ -112,6 +133,7 @@ where matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -128,7 +150,11 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if distance < radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -144,6 +170,7 @@ mod tests { use crate::float::distance::manhattan; use crate::float::kdtree::{Axis, KdTree}; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type AX = f32; @@ -245,6 +272,36 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = kdtree.within_exclusive( + &query, + radius, + &|a, b| { + let mut dist = 0.0; + for i in 0..2 { + dist += (a[i] - b[i]) * (a[i] - b[i]); + } + dist + }, + inclusive, + ); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index fb93473f..dfaab2b0 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -38,6 +38,24 @@ where dist: A, distance_fn: &F, ) -> Vec> + where + F: Fn(&[A; K], &[A; K]) -> A, + { + self.within_unsorted_exclusive(query, dist, distance_fn, true) + } + + /// Like [`within_unsorted`] but allows controlling boundary inclusiveness. + /// + /// When `inclusive` is true, points at exactly the maximum distance are included. + /// When false, only points strictly less than the maximum distance are included. + #[inline] + pub fn within_unsorted_exclusive( + &self, + query: &[A; K], + dist: A, + distance_fn: &F, + inclusive: bool, + ) -> Vec> where F: Fn(&[A; K], &[A; K]) -> A, { @@ -54,6 +72,7 @@ where &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -70,6 +89,7 @@ where matching_items: &mut Vec>, off: &mut [A; K], rd: A, + inclusive: bool, ) where F: Fn(&[A; K], &[A; K]) -> A, { @@ -97,13 +117,14 @@ where matching_items, off, rd, + inclusive, ); // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_unsorted_recurse( query, @@ -114,6 +135,7 @@ where matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -130,7 +152,11 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if distance < radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -146,6 +172,7 @@ mod tests { use crate::float::distance::squared_euclidean; use crate::float::kdtree::{Axis, KdTree}; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type AX = f32; @@ -247,6 +274,39 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_unsorted_boundary_inclusiveness( + #[case] inclusive: bool, + #[case] expected_len: usize, + ) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = kdtree.within_unsorted_exclusive( + &query, + radius, + &|a, b| { + let mut dist = 0.0; + for i in 0..2 { + dist += (a[i] - b[i]) * (a[i] - b[i]); + } + dist + }, + inclusive, + ); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index e5b58371..e881d2b7 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -15,6 +15,27 @@ macro_rules! generate_immutable_best_n_within { usize: Cast, D: DistanceMetric, { + self.best_n_within_exclusive::(query, dist, max_qty, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn best_n_within_exclusive( + &self, + query: &[A; K], + dist: A, + max_qty: NonZero, + inclusive: bool, + ) -> impl Iterator> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + usize: Cast, + D: DistanceMetric, + { + // Like [`best_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let mut best_items: BinaryHeap> = BinaryHeap::with_capacity(max_qty.into()); @@ -35,6 +56,7 @@ macro_rules! generate_immutable_best_n_within { A::zero(), 0, 0, + inclusive, ); #[cfg(feature = "modified_van_emde_boas")] @@ -50,6 +72,7 @@ macro_rules! generate_immutable_best_n_within { 0, 0, 0, + inclusive, ); best_items.into_iter() @@ -69,13 +92,14 @@ macro_rules! generate_immutable_best_n_within { rd: A, mut level: usize, mut leaf_idx: usize, + inclusive: bool, ) where A: LeafSliceFloat + LeafSliceFloatChunk, usize: Cast, D: DistanceMetric, { if level as isize > i32::from(self.max_stem_level) as isize { - self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize); + self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize, inclusive); return; } @@ -107,11 +131,12 @@ macro_rules! generate_immutable_best_n_within { rd, level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.best_n_within_recurse::( query, @@ -124,6 +149,7 @@ macro_rules! generate_immutable_best_n_within { rd, level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -144,6 +170,7 @@ macro_rules! generate_immutable_best_n_within { mut level: i32, mut minor_level: u32, mut leaf_idx: usize, + inclusive: bool, ) where A: LeafSliceFloat + LeafSliceFloatChunk, usize: Cast, @@ -153,7 +180,7 @@ macro_rules! generate_immutable_best_n_within { use $crate::modified_van_emde_boas::modified_van_emde_boas_get_child_idx_v2_branchless; if level > i32::from(self.max_stem_level) { - self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize); + self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize, inclusive); return; } @@ -188,11 +215,12 @@ macro_rules! generate_immutable_best_n_within { level, minor_level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.best_n_within_recurse::( query, @@ -206,6 +234,7 @@ macro_rules! generate_immutable_best_n_within { level, minor_level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -219,6 +248,7 @@ macro_rules! generate_immutable_best_n_within { max_qty: usize, results: &mut BinaryHeap>, leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, { @@ -229,6 +259,7 @@ macro_rules! generate_immutable_best_n_within { radius, max_qty, results, + inclusive, ); } }; diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 3d3d8435..8c989e37 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -8,21 +8,34 @@ macro_rules! generate_immutable_nearest_n_within { where D: DistanceMetric, { + self.nearest_n_within_exclusive::(query, dist, max_items, sorted, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn nearest_n_within_exclusive(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool, inclusive: bool) -> Vec> + where + D: DistanceMetric, + { + // Like [`nearest_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let max_items = max_items.into(); if sorted && max_items < usize::MAX { if max_items <= MAX_VEC_RESULT_SIZE { - self.nearest_n_within_stub::>>(query, dist, max_items, sorted) + self.nearest_n_within_stub::>>(query, dist, max_items, sorted, inclusive) } else { - self.nearest_n_within_stub::>>(query, dist, max_items, sorted) + self.nearest_n_within_stub::>>(query, dist, max_items, sorted, inclusive) } } else { - self.nearest_n_within_stub::>>(query, dist, 0, sorted) + self.nearest_n_within_stub::>>(query, dist, 0, sorted, inclusive) } } fn nearest_n_within_stub, H: ResultCollection>( - &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool + &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool, inclusive: bool ) -> Vec> { let mut matching_items = H::new_with_capacity(res_capacity); let mut off = [A::zero(); K]; @@ -38,6 +51,7 @@ macro_rules! generate_immutable_nearest_n_within { A::zero(), 0, 0, + inclusive, ); #[cfg(feature = "modified_van_emde_boas")] @@ -52,6 +66,7 @@ macro_rules! generate_immutable_nearest_n_within { 0, 0, 0, + inclusive, ); if sorted { @@ -74,12 +89,13 @@ macro_rules! generate_immutable_nearest_n_within { rd: A, mut level: usize, mut leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, { if level > i32::from(self.max_stem_level) as usize || self.stems.is_empty() { - self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize); + self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize, inclusive); return; } @@ -110,11 +126,12 @@ macro_rules! generate_immutable_nearest_n_within { rd, level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius && rd < matching_items.max_dist() { + if if inclusive { rd <= radius } else { rd < radius } && rd < matching_items.max_dist() { off[split_dim] = new_off; self.nearest_n_within_recurse::( query, @@ -126,6 +143,7 @@ macro_rules! generate_immutable_nearest_n_within { rd, level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -145,6 +163,7 @@ macro_rules! generate_immutable_nearest_n_within { mut level: i32, mut minor_level: u32, mut leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, @@ -153,7 +172,7 @@ macro_rules! generate_immutable_nearest_n_within { use $crate::modified_van_emde_boas::modified_van_emde_boas_get_child_idx_v2_branchless; if level > i32::from(self.max_stem_level) || self.stems.is_empty() { - self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize); + self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize, inclusive); return; } @@ -187,11 +206,12 @@ macro_rules! generate_immutable_nearest_n_within { level, minor_level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius && rd < matching_items.max_dist() { + if if inclusive { rd <= radius } else { rd < radius } && rd < matching_items.max_dist() { off[split_dim] = new_off; self.nearest_n_within_recurse::( query, @@ -204,6 +224,7 @@ macro_rules! generate_immutable_nearest_n_within { level, minor_level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -216,6 +237,7 @@ macro_rules! generate_immutable_nearest_n_within { radius: A, results: &mut R, leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, @@ -226,6 +248,7 @@ macro_rules! generate_immutable_nearest_n_within { query, radius, results, + inclusive, ); } }; diff --git a/src/immutable/common/generate_immutable_within.rs b/src/immutable/common/generate_immutable_within.rs index 83cc4725..6ed5f967 100644 --- a/src/immutable/common/generate_immutable_within.rs +++ b/src/immutable/common/generate_immutable_within.rs @@ -9,7 +9,21 @@ macro_rules! generate_immutable_within { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true) + self.within_exclusive::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + D: DistanceMetric, + usize: Cast, { + // Like [`within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. + self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true, inclusive) } }; } diff --git a/src/immutable/common/generate_immutable_within_unsorted.rs b/src/immutable/common/generate_immutable_within_unsorted.rs index c8910474..529b0de7 100644 --- a/src/immutable/common/generate_immutable_within_unsorted.rs +++ b/src/immutable/common/generate_immutable_within_unsorted.rs @@ -9,7 +9,21 @@ macro_rules! generate_immutable_within_unsorted { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false) + self.within_unsorted_exclusive::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + D: DistanceMetric, + usize: Cast, { + // Like [`within_unsorted`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. + self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false, inclusive) } }; }