From 2e91c4639b383f1ffe6fd3e4e5b45fbf6f00885d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 19:16:36 +0100 Subject: [PATCH 01/10] feat: add `D::accumulate` and `D::IS_MAX_BASED` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - deprecate `rd_update` with `D::accumulate` for consistent sum-based and max-based metrics - conditional logic for SIMD (L1/L2) and general L∞ - differentiate distance accumulation behaviour --- src/fixed/distance.rs | 14 ++++++++++++++ src/float/distance.rs | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 5fb3693..0ce66b8 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -99,6 +99,13 @@ impl DistanceMetric for Chebyshev { delta } } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } + + const IS_MAX_BASED: bool = false; } /// Returns the squared euclidean distance between two points. @@ -457,4 +464,11 @@ mod integration_tests { ) { run_test_helper::(dim, scenario, n); } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } + + const IS_MAX_BASED: bool = false; } diff --git a/src/float/distance.rs b/src/float/distance.rs index 035b1f0..fb587cd 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -40,6 +40,13 @@ impl DistanceMetric for Manhattan { fn dist1(a: A, b: A) -> A { (a - b).abs() } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } + + const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev / L-infinity distance between two points. @@ -115,6 +122,13 @@ impl DistanceMetric for SquaredEuclidean { fn dist1(a: A, b: A) -> A { (a - b) * (a - b) } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } + + const IS_MAX_BASED: bool = false; } /// Returns the Minkowski distance (power distance) between two points. From 773bd14a7509316789240f1b81297e7bc947d994 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 21:38:42 +0100 Subject: [PATCH 02/10] feat: add fixed Chebyshev distance metric - integration `nearest_n` tests (Chebyshev, Manhattan, SquaredEuclidean). --- src/fixed/distance.rs | 285 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 285 insertions(+) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 0ce66b8..ff42bfe 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -108,6 +108,67 @@ impl DistanceMetric for Chebyshev { const IS_MAX_BASED: bool = false; } +/// Returns the Chebyshev distance (L-infinity norm) between two points. +/// +/// This is the maximum of the absolute differences between coordinates of points. +/// +/// # Examples +/// +/// ```rust +/// use fixed::types::extra::U0; +/// use fixed::FixedU16; +/// use kiddo::traits::DistanceMetric; +/// use kiddo::fixed::distance::Chebyshev; +/// type Fxd = FixedU16; +/// +/// let ZERO = Fxd::from_num(0); +/// let ONE = Fxd::from_num(1); +/// let TWO = Fxd::from_num(2); +/// +/// assert_eq!(ZERO, Chebyshev::dist(&[ZERO, ZERO], &[ZERO, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ONE])); +/// assert_eq!(TWO, Chebyshev::dist(&[ZERO, ZERO], &[TWO, ONE])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| { + if a_val > b_val { + a_val - b_val + } else { + b_val - a_val + } + }) + .reduce(|a, b| if a > b { a } else { b }) + .unwrap_or(A::ZERO) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + if a > b { + a - b + } else { + b - a + } + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + if rd > delta { + rd + } else { + delta + } + } + + const IS_MAX_BASED: bool = true; +} + /// Returns the squared euclidean distance between two points. /// /// Faster than Euclidean distance due to not needing a square root, but still @@ -472,3 +533,227 @@ mod integration_tests { const IS_MAX_BASED: bool = false; } + +#[cfg(test)] +mod tests { + use super::*; + use fixed::types::extra::U0; + use rstest::rstest; + + type FxdU16 = fixed::FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], ONE)] + #[case([ZERO, ZERO], [TWO, ONE], TWO)] + #[case([ZERO, ZERO], [ONE, TWO], TWO)] + fn test_chebyshev_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, TWO, THREE], THREE)] + #[case([FIVE, FIVE, FIVE], [ONE, TWO, THREE], FOUR)] + fn test_chebyshev_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] + #[case([TWO, THREE], [ONE, ONE], THREE)] + fn test_manhattan_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::fixed::kdtree::KdTree; + use fixed::types::extra::U0; + use fixed::FixedU16; + use rstest::rstest; + + type FxdU16 = FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + enum DataScenario { + NoTies, + Ties, + } + + impl DataScenario { + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => { + vec![vec![ONE], vec![TWO], vec![THREE], vec![FOUR], vec![FIVE]] + } + (DataScenario::NoTies, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![TWO, ZERO], + vec![THREE, ZERO], + vec![FOUR, ZERO], + vec![FIVE, ZERO], + ], + (DataScenario::NoTies, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![TWO, ZERO, ZERO], + vec![THREE, ZERO, ZERO], + vec![FOUR, ZERO, ZERO], + vec![FIVE, ZERO, ZERO], + ], + (DataScenario::Ties, 1) => vec![ + vec![ZERO], + vec![ONE], + vec![ONE], + vec![TWO], + vec![THREE], + vec![THREE], + ], + (DataScenario::Ties, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![ZERO, ONE], + vec![TWO, ZERO], + vec![ZERO, TWO], + vec![TWO, TWO], + ], + (DataScenario::Ties, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![ZERO, ONE, ZERO], + vec![ZERO, ZERO, ONE], + vec![TWO, ZERO, ZERO], + vec![ZERO, TWO, ZERO], + ], + _ => panic!("Unsupported dimension"), + } + } + } + + fn run_test_helper>(dim: usize, scenario: DataScenario, n: usize) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[FxdU16; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [ZERO; 6]; + for (i, &val) in row.iter().enumerate() { + if i < 6 { + p[i] = val; + } + } + points.push(p); + } + + let mut query_arr = [ZERO; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + let expected: Vec<(usize, FxdU16)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + let mut tree: KdTree = KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u32); + } + + let results = tree.nearest_n::(&query_arr, n); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, ZERO, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u32, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + #[rstest] + fn test_nearest_n_chebyshev( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } +} From 85423ae4e87c282ff60cc70e4ca003213ee09a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sun, 8 Feb 2026 00:26:21 +0100 Subject: [PATCH 03/10] refactor: remove `D::IS_MAX_BASED`, unify heap logic, doc & test - improve `DistanceMetric` doc - add Gaussian scenario to tests --- src/fixed/distance.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index ff42bfe..5b71d63 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -104,8 +104,6 @@ impl DistanceMetric for Chebyshev { fn accumulate(rd: A, delta: A) -> A { rd.saturating_add(delta) } - - const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -165,8 +163,6 @@ impl DistanceMetric for Chebyshev { delta } } - - const IS_MAX_BASED: bool = true; } /// Returns the squared euclidean distance between two points. @@ -530,8 +526,6 @@ mod integration_tests { fn accumulate(rd: A, delta: A) -> A { rd.saturating_add(delta) } - - const IS_MAX_BASED: bool = false; } #[cfg(test)] From 16086960d3c7dab13a84110021fbbeb25491c919 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 12 Feb 2026 23:54:25 +0100 Subject: [PATCH 04/10] chore: add default implementation of `accumulate` to `DistanceMetric` trait - improve test coverage --- src/fixed/distance.rs | 95 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 85 insertions(+), 10 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 5b71d63..46acb4b 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -99,11 +99,6 @@ impl DistanceMetric for Chebyshev { delta } } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -521,11 +516,6 @@ mod integration_tests { ) { run_test_helper::(dim, scenario, n); } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } } #[cfg(test)] @@ -583,6 +573,91 @@ mod tests { ) { assert_eq!(Manhattan::dist(&a, &b), expected); } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], FxdU16::lit("2"))] + #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] + #[case([ONE, TWO], [TWO, ONE], FxdU16::lit("2"))] + fn test_squared_euclidean_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, ZERO, ZERO], ONE)] + #[case([ONE, ONE, ONE], [TWO, TWO, TWO], THREE)] + fn test_squared_euclidean_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::diff(THREE, ONE, TWO)] + fn test_manhattan_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_chebyshev_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_squared_euclidean_dist1( + #[case] a: FxdU16, + #[case] b: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero_one(ZERO, ONE, ONE)] + #[case::one_zero(ONE, ZERO, ONE)] + #[case::first_larger(ONE, TWO, TWO)] + #[case::second_larger(TWO, ONE, TWO)] + fn test_chebyshev_accumulate( + #[case] rd: FxdU16, + #[case] delta: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } } #[cfg(test)] From 1fc989da3d8cac56fc579e8a62914c9b99774b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Fri, 13 Feb 2026 01:04:21 +0100 Subject: [PATCH 05/10] chore: saturating add for fixed metrics --- src/fixed/distance.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 46acb4b..31dd3dd 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -99,6 +99,11 @@ impl DistanceMetric for Chebyshev { delta } } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -516,6 +521,11 @@ mod integration_tests { ) { run_test_helper::(dim, scenario, n); } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } } #[cfg(test)] From 5ac694679b9b0b2afdb1af2a9f736dc06b8513ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Fri, 27 Feb 2026 09:54:24 +0100 Subject: [PATCH 06/10] fix: use saturating arithmetic for fixed-point distances - `A::dist` - `saturating_mul` --- src/fixed/distance.rs | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 31dd3dd..aa4cfb9 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -135,24 +135,14 @@ impl DistanceMetric for Chebyshev { fn dist(a: &[A; K], b: &[A; K]) -> A { a.iter() .zip(b.iter()) - .map(|(&a_val, &b_val)| { - if a_val > b_val { - a_val - b_val - } else { - b_val - a_val - } - }) + .map(|(&a_val, &b_val)| a_val.dist(b_val)) .reduce(|a, b| if a > b { a } else { b }) .unwrap_or(A::ZERO) } #[inline] fn dist1(a: A, b: A) -> A { - if a > b { - a - b - } else { - b - a - } + a.dist(b) } #[inline] From 52a65e2fb9ca61117219bc25a0c47e3e2781cb37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 26 Feb 2026 11:45:50 +0100 Subject: [PATCH 07/10] feat: add `inclusive` flag to distance-based queries - add flag into query logic - add test `test_within_squared_euclidean` for all metrics --- .../generate_nearest_n_within_unsorted.rs | 25 +- src/common/generate_within.rs | 11 +- src/common/generate_within_unsorted.rs | 17 +- src/common/generate_within_unsorted_iter.rs | 22 +- src/fixed/query/within.rs | 20 ++ src/float/distance.rs | 234 ++++++++++++++---- src/float_leaf_slice/fallback.rs | 29 ++- src/float_leaf_slice/leaf_slice.rs | 55 ++-- src/hybrid/query/within.rs | 53 +++- src/hybrid/query/within_unsorted.rs | 53 +++- .../generate_immutable_best_n_within.rs | 35 ++- .../generate_immutable_nearest_n_within.rs | 35 ++- .../common/generate_immutable_within.rs | 12 +- .../generate_immutable_within_unsorted.rs | 12 +- 14 files changed, 512 insertions(+), 101 deletions(-) diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index 09d56c4..cef7fb6 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -5,22 +5,31 @@ macro_rules! generate_nearest_n_within_unsorted { #[doc = concat!$comments] #[inline] pub fn nearest_n_within(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool) -> Vec> + where + D: DistanceMetric, + { + self.nearest_n_within_with_condition::(query, dist, max_items, sorted, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn nearest_n_within_with_condition(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool, inclusive: bool) -> Vec> where D: DistanceMetric, { if sorted || max_items < std::num::NonZero::new(usize::MAX).unwrap() { if max_items <= std::num::NonZero::new(MAX_VEC_RESULT_SIZE).unwrap() { - self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted) + self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted, inclusive) } else { - self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted) + self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted, inclusive) } } else { - self.nearest_n_within_stub::>>(query, dist, 0, sorted) + self.nearest_n_within_stub::>>(query, dist, 0, sorted, inclusive) } } fn nearest_n_within_stub, H: ResultCollection>( - &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool + &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool, inclusive: bool ) -> Vec> { let mut matching_items = H::new_with_capacity(res_capacity); let mut off = [A::zero(); K]; @@ -35,6 +44,7 @@ macro_rules! generate_nearest_n_within_unsorted { &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -55,6 +65,7 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items: &mut R, off: &mut [A; K], rd: A, + inclusive: bool, ) where D: DistanceMetric, { @@ -84,11 +95,12 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.nearest_n_within_unsorted_recurse::( query, @@ -98,6 +110,7 @@ macro_rules! generate_nearest_n_within_unsorted { matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -116,7 +129,7 @@ macro_rules! generate_nearest_n_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within.rs b/src/common/generate_within.rs index 7952636..642d02f 100644 --- a/src/common/generate_within.rs +++ b/src/common/generate_within.rs @@ -8,7 +8,16 @@ macro_rules! generate_within { where D: DistanceMetric, { - let mut matching_items = self.within_unsorted::(query, dist); + self.within_with_condition::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + D: DistanceMetric, + { + let mut matching_items = self.within_unsorted_with_condition::(query, dist, inclusive); matching_items.sort(); matching_items } diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index 656ab03..10e8803 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -5,6 +5,15 @@ macro_rules! generate_within_unsorted { #[doc = concat!$comments] #[inline] pub fn within_unsorted(&self, query: &[A; K], dist: A) -> Vec> + where + D: DistanceMetric, + { + self.within_unsorted_with_condition::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> where D: DistanceMetric, { @@ -21,6 +30,7 @@ macro_rules! generate_within_unsorted { &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -37,6 +47,7 @@ macro_rules! generate_within_unsorted { matching_items: &mut Vec>, off: &mut [A; K], rd: A, + inclusive: bool, ) where D: DistanceMetric, { @@ -66,11 +77,12 @@ macro_rules! generate_within_unsorted { matching_items, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_unsorted_recurse::( query, @@ -80,6 +92,7 @@ macro_rules! generate_within_unsorted { matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -98,7 +111,7 @@ macro_rules! generate_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index a44ad86..8fc691d 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -9,6 +9,20 @@ macro_rules! generate_within_unsorted_iter { query: &'query [A; K], dist: A, ) -> WithinUnsortedIter<'a, A, T> + where + D: DistanceMetric, + { + self.within_unsorted_iter_with_condition::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_iter_with_condition( + &'a self, + query: &'query [A; K], + dist: A, + inclusive: bool, + ) -> WithinUnsortedIter<'a, A, T> where D: DistanceMetric, { @@ -27,6 +41,7 @@ macro_rules! generate_within_unsorted_iter { gen_scope, &mut off, A::zero(), + inclusive, ); } @@ -46,6 +61,7 @@ macro_rules! generate_within_unsorted_iter { mut gen_scope: Scope<'scope, 'a, (), NearestNeighbour>, off: &mut [A; K], rd: A, + inclusive: bool, ) -> Scope<'scope, 'a, (), NearestNeighbour> where D: DistanceMetric, @@ -76,11 +92,12 @@ macro_rules! generate_within_unsorted_iter { gen_scope, off, rd, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; gen_scope = self.within_unsorted_iter_recurse::( query, @@ -90,6 +107,7 @@ macro_rules! generate_within_unsorted_iter { gen_scope, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -108,7 +126,7 @@ macro_rules! generate_within_unsorted_iter { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance <= radius { + if if inclusive { distance <= radius } else { distance < radius } { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index 094211a..c00b13a 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -50,6 +50,7 @@ mod tests { use fixed::types::extra::U14; use fixed::FixedU16; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type Fxd = FixedU16; @@ -154,6 +155,25 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[n(1.0), n(0.0)], 1); + kdtree.add(&[n(2.0), n(0.0)], 2); + + let query = [n(0.0), n(0.0)]; + let radius = n(1.0); + + let results = kdtree.within_with_condition::(&query, radius, inclusive); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, n(1.0)); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/float/distance.rs b/src/float/distance.rs index fb587cd..f44245a 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -873,7 +873,7 @@ mod tests { (DataScenario::Gaussian, d) => { let mut rng = StdRng::seed_from_u64(8757); let normal = Normal::new(1.0, 10.0).unwrap(); - let n_samples = 200; + let n_samples = 2000; let mut data = vec![vec![0.0; d]; n_samples]; for i in 0..n_samples { for j in 0..d { @@ -906,7 +906,7 @@ mod tests { /// - Point 0 is always the query point (distance 0, index 0 expected first result) /// - NoTies scenario: checks distances and item IDs for points with unique distances /// - Ties scenario: checks distances (order among ties is non-deterministic) - fn run_test_helper>( + fn run_nearest_n_test_helper>( dim: usize, tree_type: TreeType, scenario: DataScenario, @@ -953,7 +953,7 @@ mod tests { // Query based on tree type let results = match tree_type { TreeType::Mutable => { - let mut tree: crate::float::kdtree::KdTree = + let mut tree: crate::float::kdtree::KdTree = crate::float::kdtree::KdTree::new(); for (i, point) in points.iter().enumerate() { tree.add(point, i as u64); @@ -961,7 +961,7 @@ mod tests { tree.nearest_n::(&query_arr, n) } TreeType::Immutable => { - let tree: crate::immutable::float::kdtree::ImmutableKdTree = + let tree: crate::immutable::float::kdtree::ImmutableKdTree = crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) } @@ -1007,7 +1007,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1018,7 +1018,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1029,7 +1029,7 @@ mod tests { #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { - run_test_helper::(dim, tree_type, scenario, n); + run_nearest_n_test_helper::(dim, tree_type, scenario, n); } #[rstest] @@ -1042,8 +1042,8 @@ mod tests { #[values(3, 4)] p: u32, ) { match p { - 3 => run_test_helper::>(dim, tree_type, scenario, n), - 4 => run_test_helper::>(dim, tree_type, scenario, n), + 3 => run_nearest_n_test_helper::>(dim, tree_type, scenario, n), + 4 => run_nearest_n_test_helper::>(dim, tree_type, scenario, n), _ => unreachable!(), } } @@ -1058,9 +1058,13 @@ mod tests { #[values(0.5, 1.5)] p: f64, ) { if (p - 0.5).abs() < f64::EPSILON { - run_test_helper::>(dim, tree_type, scenario, n); + run_nearest_n_test_helper::>( + dim, tree_type, scenario, n, + ); } else if (p - 1.5).abs() < f64::EPSILON { - run_test_helper::>(dim, tree_type, scenario, n); + run_nearest_n_test_helper::>( + dim, tree_type, scenario, n, + ); } else { unreachable!() } @@ -1356,53 +1360,177 @@ mod tests { assert!(nearby_items.contains(&5)); } - #[test] - fn test_within_chebyshev_distance() { - let mut kdtree: KdTree = KdTree::new(); + /// Helper function to test within queries for `D: DistanceMetric` + fn run_within_test_helper>( + dim: usize, + tree_type: TreeType, + scenario: DataScenario, + radius: f64, + inclusive: bool, + ) { + let data = scenario.get(dim); + let query_point = &data[0]; - // Add points with varying Chebyshev distances - let points = [ - ([0.0f32, 0.0f32], 0), // distance 0 - ([0.5f32, 0.5f32], 1), // Chebyshev: 0.5 - ([1.0f32, 0.0f32], 2), // Chebyshev: 1.0 - ([0.8f32, 0.9f32], 3), // Chebyshev: 0.9 - ([2.0f32, 0.0f32], 4), // Chebyshev: 2.0 - ([0.0f32, 2.0f32], 5), // Chebyshev: 2.0 - ([1.5f32, 1.5f32], 6), // Chebyshev: 1.5 - ]; + let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [0.0; 6]; + for (i, &val) in row.iter().enumerate() { + p[i] = val; + } + points.push(p); + } - for (point, index) in points { - kdtree.add(&point, index); + let mut query_arr = [0.0; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } } - let query_point = [0.0f32, 0.0f32]; - let radius = 1.0; // radius 1 (not squared for Chebyshev) - let mut results = kdtree.within::(&query_point, radius); - - // Sort by distance for easier verification - results.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) - // For `<=5.2.4` found: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) - let found_indices: Vec = results.iter().map(|r| r.item).collect(); - - assert!(found_indices.contains(&0)); - assert!(found_indices.contains(&1)); - assert!(found_indices.contains(&2)); - assert!(found_indices.contains(&3)); - // Should NOT include points with Chebyshev distance > 1 - assert!(!found_indices.contains(&4)); - assert!(!found_indices.contains(&5)); - assert!(!found_indices.contains(&6)); - - // Verify distances - for result in results { - assert!(result.distance <= 1.0 || (result.distance - 1.0).abs() < 0.001); + // Ground truth with brute-force + let mut expected: Vec<(usize, f64)> = points + .iter() + .enumerate() + .filter_map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + if if inclusive { + dist <= radius + } else { + dist < radius + } { + Some((i, dist)) + } else { + None + } + }) + .collect(); + + expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + println!( + "Within Query: TreeType: {:?}, Scenario: {:?}, dim={}, radius={}, inclusive={}", + tree_type, scenario, dim, radius, inclusive + ); + + // Query based on tree type + let mut results = match tree_type { + TreeType::Mutable => { + let mut tree: crate::float::kdtree::KdTree = + crate::float::kdtree::KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u64); + } + tree.within_with_condition::(&query_arr, radius, inclusive) + } + TreeType::Immutable => { + let tree: crate::immutable::float::kdtree::ImmutableKdTree = + crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); + tree.within_with_condition::(&query_arr, radius, inclusive) + } + }; + + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap()); + + println!( + "Results (len: {}), Expected (len: {})", + results.len(), + expected.len() + ); + + assert_eq!( + results.len(), + expected.len(), + "Result count mismatch. Expected {}, got {}", + expected.len(), + results.len() + ); + + for (i, result) in results.iter().enumerate() { + assert!( + (result.distance - expected[i].1).abs() < 1e-10, + "Distance at index {} should be {}, but was {}", + i, + expected[i].1, + result.distance + ); } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u64, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + #[rstest] + fn test_within_chebyshev( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + fn test_within_squared_euclidean( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + fn test_within_manhattan( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(0.1, 0.5, 1.0, 2.0)] radius: f64, + #[values(1, 2, 3, 4)] dim: usize, + #[values(true, false)] inclusive: bool, + ) { + run_within_test_helper::(dim, tree_type, scenario, radius, inclusive); + } + + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness( + #[case] inclusive: bool, + #[case] expected_len: usize, + ) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = + kdtree.within_with_condition::(&query, radius, inclusive); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + + // Test nearest_n_within_with_condition + let max_qty = std::num::NonZero::new(10).unwrap(); + let results = kdtree.nearest_n_within_with_condition::( + &query, radius, max_qty, true, inclusive, + ); + assert_eq!(results.len(), expected_len); } #[test] diff --git a/src/float_leaf_slice/fallback.rs b/src/float_leaf_slice/fallback.rs index d8fadfd..25750d0 100644 --- a/src/float_leaf_slice/fallback.rs +++ b/src/float_leaf_slice/fallback.rs @@ -34,6 +34,7 @@ pub(crate) fn update_nearest_dists_within_autovec( items: &[T], radius: A, results: &mut R, + inclusive: bool, ) where usize: Cast, R: ResultCollection, @@ -42,7 +43,13 @@ pub(crate) fn update_nearest_dists_within_autovec( dists .iter() .zip(items.iter()) - .filter(|(&distance, _)| distance <= radius) + .filter(|(&distance, _)| { + if inclusive { + distance <= radius + } else { + distance < radius + } + }) .for_each(|(&distance, &item)| { results.add(NearestNeighbour { distance, item }); }); @@ -55,6 +62,7 @@ pub(crate) fn update_best_dists_within_autovec( radius: A, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where usize: Cast, { @@ -62,7 +70,13 @@ pub(crate) fn update_best_dists_within_autovec( dists .iter() .zip(items.iter()) - .filter(|(&distance, _)| distance <= radius) + .filter(|(&distance, _)| { + if inclusive { + distance <= radius + } else { + distance < radius + } + }) .for_each(|(&distance, &item)| { if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -125,7 +139,7 @@ mod tests { item: 100u32, }]; - update_nearest_dists_within_autovec(&dists[..], &items[..], radius, &mut results); + update_nearest_dists_within_autovec(&dists[..], &items[..], radius, &mut results, true); assert_eq!( results, @@ -157,7 +171,14 @@ mod tests { item: 100u32, }); - update_best_dists_within_autovec(&dists[..], &items[..], radius, max_qty, &mut results); + update_best_dists_within_autovec( + &dists[..], + &items[..], + radius, + max_qty, + &mut results, + true, + ); let results = results.into_vec(); diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 8f2d022..4030932 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -140,6 +140,7 @@ where items: &[T; C], radius: Self, results: &mut R, + inclusive: bool, ) where R: ResultCollection, usize: Cast, @@ -151,6 +152,7 @@ where radius: Self, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where Self: Axis + Sized; } @@ -220,8 +222,13 @@ where } #[inline] - pub(crate) fn nearest_n_within(&self, query: &[A; K], radius: A, results: &mut R) - where + pub(crate) fn nearest_n_within( + &self, + query: &[A; K], + radius: A, + results: &mut R, + inclusive: bool, + ) where D: DistanceMetric, R: ResultCollection, { @@ -230,7 +237,7 @@ where for chunk in chunk_iter { let dists = A::dists_for_chunk::(chunk.0, query); - A::update_nearest_dists_within(dists, chunk.1, radius, results); + A::update_nearest_dists_within(dists, chunk.1, radius, results, inclusive); } #[allow(clippy::needless_range_loop)] @@ -240,7 +247,11 @@ where distance = D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance <= radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { results.add(NearestNeighbour { distance, item: remainder_items[idx], @@ -256,6 +267,7 @@ where radius: A, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) where D: DistanceMetric, { @@ -264,7 +276,7 @@ where for chunk in chunk_iter { let dists = A::dists_for_chunk::(chunk.0, query); - A::update_best_dists_within(dists, chunk.1, radius, max_qty, results); + A::update_best_dists_within(dists, chunk.1, radius, max_qty, results, inclusive); } #[allow(clippy::needless_range_loop)] @@ -275,7 +287,11 @@ where D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance <= radius { + if if inclusive { + distance <= radius + } else { + distance < radius + } { let item = remainder_items[idx]; if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -336,10 +352,11 @@ where items: &[T; C], radius: f64, results: &mut R, + inclusive: bool, ) where R: ResultCollection, { - update_nearest_dists_within_autovec(&acc, items, radius, results) + update_nearest_dists_within_autovec(&acc, items, radius, results, inclusive) } #[inline] @@ -349,8 +366,9 @@ where radius: f64, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) { - update_best_dists_within_autovec(&acc, items, radius, max_qty, results) + update_best_dists_within_autovec(&acc, items, radius, max_qty, results, inclusive) } } @@ -420,10 +438,11 @@ where items: &[T; C], radius: f32, results: &mut R, + inclusive: bool, ) where R: ResultCollection, { - update_nearest_dists_within_autovec(&acc, items, radius, results) + update_nearest_dists_within_autovec(&acc, items, radius, results, inclusive) } #[inline] @@ -433,8 +452,9 @@ where radius: f32, max_qty: usize, results: &mut BinaryHeap>, + inclusive: bool, ) { - update_best_dists_within_autovec(&acc, items, radius, max_qty, results) + update_best_dists_within_autovec(&acc, items, radius, max_qty, results, inclusive) } } @@ -503,7 +523,7 @@ mod test { item: 100u32, }); - f64::update_nearest_dists_within(dists, &items, radius, &mut results); + f64::update_nearest_dists_within(dists, &items, radius, &mut results, true); let results = results.into_vec(); @@ -537,7 +557,7 @@ mod test { item: 100u32, }); - f64::update_best_dists_within(dists, &items, radius, max_qty, &mut results); + f64::update_best_dists_within(dists, &items, radius, max_qty, &mut results, true); let results = results.into_vec(); @@ -569,7 +589,7 @@ mod test { item: 100u32, }); - f32::update_nearest_dists_within(dists, &items, radius, &mut results); + f32::update_nearest_dists_within(dists, &items, radius, &mut results, true); let results = results.into_vec(); @@ -603,7 +623,7 @@ mod test { item: 100u32, }); - f32::update_best_dists_within(dists, &items, radius, max_qty, &mut results); + f32::update_best_dists_within(dists, &items, radius, max_qty, &mut results, true); let results = results.into_vec(); @@ -644,7 +664,12 @@ mod test { }; let mut results: BinaryHeap> = BinaryHeap::with_capacity(10); - slice.nearest_n_within::(&[32.0f64, 0.0f64], 4.0f64, &mut results); + slice.nearest_n_within::( + &[32.0f64, 0.0f64], + 4.0f64, + &mut results, + true, + ); let items_found: Vec<_> = results.iter().map(|n| n.item).collect(); assert!( diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index 7650504..449944d 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -36,6 +36,20 @@ where /// ``` #[inline] pub fn within(&self, query: &[A; K], dist: A, distance_fn: &F) -> Vec> + where + F: Fn(&[A; K], &[A; K]) -> A, + { + self.within_with_condition(query, dist, distance_fn, true) + } + + #[inline] + pub fn within_with_condition( + &self, + query: &[A; K], + dist: A, + distance_fn: &F, + inclusive: bool, + ) -> Vec> where F: Fn(&[A; K], &[A; K]) -> A, { @@ -52,6 +66,7 @@ where &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -68,6 +83,7 @@ where matching_items: &mut BinaryHeap>, off: &mut [A; K], rd: A, + inclusive: bool, ) where F: Fn(&[A; K], &[A; K]) -> A, { @@ -95,13 +111,14 @@ where matching_items, off, rd, + inclusive, ); // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_recurse( query, @@ -112,6 +129,7 @@ where matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -128,7 +146,7 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if distance < radius { + if if inclusive { distance <= radius } else { distance < radius } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -143,6 +161,7 @@ where mod tests { use crate::float::distance::manhattan; use crate::float::kdtree::{Axis, KdTree}; + use rstest::rstest; use rand::Rng; use std::cmp::Ordering; @@ -245,6 +264,36 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = kdtree.within_with_condition( + &query, + radius, + &|a, b| { + let mut dist = 0.0; + for i in 0..2 { + dist += (a[i] - b[i]) * (a[i] - b[i]); + } + dist + }, + inclusive, + ); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index fb93473..e95baf6 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -38,6 +38,20 @@ where dist: A, distance_fn: &F, ) -> Vec> + where + F: Fn(&[A; K], &[A; K]) -> A, + { + self.within_unsorted_with_condition(query, dist, distance_fn, true) + } + + #[inline] + pub fn within_unsorted_with_condition( + &self, + query: &[A; K], + dist: A, + distance_fn: &F, + inclusive: bool, + ) -> Vec> where F: Fn(&[A; K], &[A; K]) -> A, { @@ -54,6 +68,7 @@ where &mut matching_items, &mut off, A::zero(), + inclusive, ); } @@ -70,6 +85,7 @@ where matching_items: &mut Vec>, off: &mut [A; K], rd: A, + inclusive: bool, ) where F: Fn(&[A; K], &[A; K]) -> A, { @@ -97,13 +113,14 @@ where matching_items, off, rd, + inclusive, ); // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.within_unsorted_recurse( query, @@ -114,6 +131,7 @@ where matching_items, off, rd, + inclusive, ); off[split_dim] = old_off; } @@ -130,7 +148,7 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if distance < radius { + if if inclusive { distance <= radius } else { distance < radius } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -145,6 +163,7 @@ where mod tests { use crate::float::distance::squared_euclidean; use crate::float::kdtree::{Axis, KdTree}; + use rstest::rstest; use rand::Rng; use std::cmp::Ordering; @@ -247,6 +266,36 @@ mod tests { } } + #[rstest] + #[case(true, 1)] + #[case(false, 0)] + fn test_within_unsorted_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + let mut kdtree: KdTree = KdTree::new(); + kdtree.add(&[1.0, 0.0], 1); + kdtree.add(&[2.0, 0.0], 2); + + let query = [0.0, 0.0]; + let radius = 1.0; + + let results = kdtree.within_unsorted_with_condition( + &query, + radius, + &|a, b| { + let mut dist = 0.0; + for i in 0..2 { + dist += (a[i] - b[i]) * (a[i] - b[i]); + } + dist + }, + inclusive, + ); + assert_eq!(results.len(), expected_len); + if expected_len > 0 { + assert_eq!(results[0].item, 1); + assert_eq!(results[0].distance, 1.0); + } + } + fn linear_search( content: &[([A; K], u32)], query_point: &[A; K], diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index e5b5837..81ec296 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -10,6 +10,23 @@ macro_rules! generate_immutable_best_n_within { dist: A, max_qty: NonZero, ) -> impl Iterator> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + usize: Cast, + D: DistanceMetric, + { + self.best_n_within_with_condition::(query, dist, max_qty, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn best_n_within_with_condition( + &self, + query: &[A; K], + dist: A, + max_qty: NonZero, + inclusive: bool, + ) -> impl Iterator> where A: LeafSliceFloat + LeafSliceFloatChunk, usize: Cast, @@ -35,6 +52,7 @@ macro_rules! generate_immutable_best_n_within { A::zero(), 0, 0, + inclusive, ); #[cfg(feature = "modified_van_emde_boas")] @@ -50,6 +68,7 @@ macro_rules! generate_immutable_best_n_within { 0, 0, 0, + inclusive, ); best_items.into_iter() @@ -69,13 +88,14 @@ macro_rules! generate_immutable_best_n_within { rd: A, mut level: usize, mut leaf_idx: usize, + inclusive: bool, ) where A: LeafSliceFloat + LeafSliceFloatChunk, usize: Cast, D: DistanceMetric, { if level as isize > i32::from(self.max_stem_level) as isize { - self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize); + self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize, inclusive); return; } @@ -107,11 +127,12 @@ macro_rules! generate_immutable_best_n_within { rd, level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.best_n_within_recurse::( query, @@ -124,6 +145,7 @@ macro_rules! generate_immutable_best_n_within { rd, level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -144,6 +166,7 @@ macro_rules! generate_immutable_best_n_within { mut level: i32, mut minor_level: u32, mut leaf_idx: usize, + inclusive: bool, ) where A: LeafSliceFloat + LeafSliceFloatChunk, usize: Cast, @@ -153,7 +176,7 @@ macro_rules! generate_immutable_best_n_within { use $crate::modified_van_emde_boas::modified_van_emde_boas_get_child_idx_v2_branchless; if level > i32::from(self.max_stem_level) { - self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize); + self.search_leaf_for_best_n_within::(query, radius, max_qty, best_items, leaf_idx as usize, inclusive); return; } @@ -188,11 +211,12 @@ macro_rules! generate_immutable_best_n_within { level, minor_level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius { + if if inclusive { rd <= radius } else { rd < radius } { off[split_dim] = new_off; self.best_n_within_recurse::( query, @@ -206,6 +230,7 @@ macro_rules! generate_immutable_best_n_within { level, minor_level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -219,6 +244,7 @@ macro_rules! generate_immutable_best_n_within { max_qty: usize, results: &mut BinaryHeap>, leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, { @@ -229,6 +255,7 @@ macro_rules! generate_immutable_best_n_within { radius, max_qty, results, + inclusive, ); } }; diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 3d3d843..877e5d1 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -5,6 +5,15 @@ macro_rules! generate_immutable_nearest_n_within { #[doc = concat!$comments] #[inline] pub fn nearest_n_within(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool) -> Vec> + where + D: DistanceMetric, + { + self.nearest_n_within_with_condition::(query, dist, max_items, sorted, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn nearest_n_within_with_condition(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool, inclusive: bool) -> Vec> where D: DistanceMetric, { @@ -12,17 +21,17 @@ macro_rules! generate_immutable_nearest_n_within { if sorted && max_items < usize::MAX { if max_items <= MAX_VEC_RESULT_SIZE { - self.nearest_n_within_stub::>>(query, dist, max_items, sorted) + self.nearest_n_within_stub::>>(query, dist, max_items, sorted, inclusive) } else { - self.nearest_n_within_stub::>>(query, dist, max_items, sorted) + self.nearest_n_within_stub::>>(query, dist, max_items, sorted, inclusive) } } else { - self.nearest_n_within_stub::>>(query, dist, 0, sorted) + self.nearest_n_within_stub::>>(query, dist, 0, sorted, inclusive) } } fn nearest_n_within_stub, H: ResultCollection>( - &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool + &self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool, inclusive: bool ) -> Vec> { let mut matching_items = H::new_with_capacity(res_capacity); let mut off = [A::zero(); K]; @@ -38,6 +47,7 @@ macro_rules! generate_immutable_nearest_n_within { A::zero(), 0, 0, + inclusive, ); #[cfg(feature = "modified_van_emde_boas")] @@ -52,6 +62,7 @@ macro_rules! generate_immutable_nearest_n_within { 0, 0, 0, + inclusive, ); if sorted { @@ -74,12 +85,13 @@ macro_rules! generate_immutable_nearest_n_within { rd: A, mut level: usize, mut leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, { if level > i32::from(self.max_stem_level) as usize || self.stems.is_empty() { - self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize); + self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize, inclusive); return; } @@ -110,11 +122,12 @@ macro_rules! generate_immutable_nearest_n_within { rd, level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius && rd < matching_items.max_dist() { + if if inclusive { rd <= radius } else { rd < radius } && rd < matching_items.max_dist() { off[split_dim] = new_off; self.nearest_n_within_recurse::( query, @@ -126,6 +139,7 @@ macro_rules! generate_immutable_nearest_n_within { rd, level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -145,6 +159,7 @@ macro_rules! generate_immutable_nearest_n_within { mut level: i32, mut minor_level: u32, mut leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, @@ -153,7 +168,7 @@ macro_rules! generate_immutable_nearest_n_within { use $crate::modified_van_emde_boas::modified_van_emde_boas_get_child_idx_v2_branchless; if level > i32::from(self.max_stem_level) || self.stems.is_empty() { - self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize); + self.search_leaf_for_nearest_n_within::(query, radius, matching_items, leaf_idx as usize, inclusive); return; } @@ -187,11 +202,12 @@ macro_rules! generate_immutable_nearest_n_within { level, minor_level, closer_leaf_idx, + inclusive, ); rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if rd <= radius && rd < matching_items.max_dist() { + if if inclusive { rd <= radius } else { rd < radius } && rd < matching_items.max_dist() { off[split_dim] = new_off; self.nearest_n_within_recurse::( query, @@ -204,6 +220,7 @@ macro_rules! generate_immutable_nearest_n_within { level, minor_level, further_leaf_idx, + inclusive, ); off[split_dim] = old_off; } @@ -216,6 +233,7 @@ macro_rules! generate_immutable_nearest_n_within { radius: A, results: &mut R, leaf_idx: usize, + inclusive: bool, ) where D: DistanceMetric, R: ResultCollection, @@ -226,6 +244,7 @@ macro_rules! generate_immutable_nearest_n_within { query, radius, results, + inclusive, ); } }; diff --git a/src/immutable/common/generate_immutable_within.rs b/src/immutable/common/generate_immutable_within.rs index 83cc472..ee5de48 100644 --- a/src/immutable/common/generate_immutable_within.rs +++ b/src/immutable/common/generate_immutable_within.rs @@ -9,7 +9,17 @@ macro_rules! generate_immutable_within { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true) + self.within_with_condition::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + D: DistanceMetric, + usize: Cast, { + self.nearest_n_within_with_condition::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true, inclusive) } }; } diff --git a/src/immutable/common/generate_immutable_within_unsorted.rs b/src/immutable/common/generate_immutable_within_unsorted.rs index c891047..e8d7af2 100644 --- a/src/immutable/common/generate_immutable_within_unsorted.rs +++ b/src/immutable/common/generate_immutable_within_unsorted.rs @@ -9,7 +9,17 @@ macro_rules! generate_immutable_within_unsorted { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false) + self.within_unsorted_with_condition::(query, dist, true) + } + + #[doc = concat!$comments] + #[inline] + pub fn within_unsorted_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + where + A: LeafSliceFloat + LeafSliceFloatChunk, + D: DistanceMetric, + usize: Cast, { + self.nearest_n_within_with_condition::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false, inclusive) } }; } From ae548a908e2da1e69536c062470dd4bc94a91713 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 14 Mar 2026 01:02:28 +0100 Subject: [PATCH 08/10] chore: fix rebase mistakes --- src/fixed/distance.rs | 368 ------------------------------------------ src/float/distance.rs | 4 - 2 files changed, 372 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index aa4cfb9..5fb3693 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -76,60 +76,6 @@ impl DistanceMetric for Manhattan { /// ``` pub struct Chebyshev {} -impl DistanceMetric for Chebyshev { - #[inline] - fn dist(a: &[A; K], b: &[A; K]) -> A { - a.iter() - .zip(b.iter()) - .map(|(&a_val, &b_val)| a_val.dist(b_val)) - .reduce(|a, b| if a > b { a } else { b }) - .unwrap_or(A::ZERO) - } - - #[inline] - fn dist1(a: A, b: A) -> A { - a.dist(b) - } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - if rd > delta { - rd - } else { - delta - } - } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } -} - -/// Returns the Chebyshev distance (L-infinity norm) between two points. -/// -/// This is the maximum of the absolute differences between coordinates of points. -/// -/// # Examples -/// -/// ```rust -/// use fixed::types::extra::U0; -/// use fixed::FixedU16; -/// use kiddo::traits::DistanceMetric; -/// use kiddo::fixed::distance::Chebyshev; -/// type Fxd = FixedU16; -/// -/// let ZERO = Fxd::from_num(0); -/// let ONE = Fxd::from_num(1); -/// let TWO = Fxd::from_num(2); -/// -/// assert_eq!(ZERO, Chebyshev::dist(&[ZERO, ZERO], &[ZERO, ZERO])); -/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ZERO])); -/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ONE])); -/// assert_eq!(TWO, Chebyshev::dist(&[ZERO, ZERO], &[TWO, ONE])); -/// ``` -pub struct Chebyshev {} - impl DistanceMetric for Chebyshev { #[inline] fn dist(a: &[A; K], b: &[A; K]) -> A { @@ -511,318 +457,4 @@ mod integration_tests { ) { run_test_helper::(dim, scenario, n); } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use fixed::types::extra::U0; - use rstest::rstest; - - type FxdU16 = fixed::FixedU16; - - const ZERO: FxdU16 = FxdU16::ZERO; - const ONE: FxdU16 = FxdU16::lit("1"); - const TWO: FxdU16 = FxdU16::lit("2"); - const THREE: FxdU16 = FxdU16::lit("3"); - const FOUR: FxdU16 = FxdU16::lit("4"); - const FIVE: FxdU16 = FxdU16::lit("5"); - - #[rstest] - #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] - #[case([ZERO, ZERO], [ONE, ZERO], ONE)] - #[case([ZERO, ZERO], [ZERO, ONE], ONE)] - #[case([ZERO, ZERO], [ONE, ONE], ONE)] - #[case([ZERO, ZERO], [TWO, ONE], TWO)] - #[case([ZERO, ZERO], [ONE, TWO], TWO)] - fn test_chebyshev_distance_2d( - #[case] a: [FxdU16; 2], - #[case] b: [FxdU16; 2], - #[case] expected: FxdU16, - ) { - assert_eq!(Chebyshev::dist(&a, &b), expected); - } - - #[rstest] - #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] - #[case([ZERO, ZERO, ZERO], [ONE, TWO, THREE], THREE)] - #[case([FIVE, FIVE, FIVE], [ONE, TWO, THREE], FOUR)] - fn test_chebyshev_distance_3d( - #[case] a: [FxdU16; 3], - #[case] b: [FxdU16; 3], - #[case] expected: FxdU16, - ) { - assert_eq!(Chebyshev::dist(&a, &b), expected); - } - - #[rstest] - #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] - #[case([ZERO, ZERO], [ONE, ZERO], ONE)] - #[case([ZERO, ZERO], [ZERO, ONE], ONE)] - #[case([ZERO, ZERO], [ONE, ONE], TWO)] - #[case([TWO, THREE], [ONE, ONE], THREE)] - fn test_manhattan_distance_2d( - #[case] a: [FxdU16; 2], - #[case] b: [FxdU16; 2], - #[case] expected: FxdU16, - ) { - assert_eq!(Manhattan::dist(&a, &b), expected); - } - - #[rstest] - #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] - #[case([ZERO, ZERO], [ONE, ZERO], ONE)] - #[case([ZERO, ZERO], [ZERO, ONE], ONE)] - #[case([ZERO, ZERO], [ONE, ONE], FxdU16::lit("2"))] - #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] - #[case([ONE, TWO], [TWO, ONE], FxdU16::lit("2"))] - fn test_squared_euclidean_distance_2d( - #[case] a: [FxdU16; 2], - #[case] b: [FxdU16; 2], - #[case] expected: FxdU16, - ) { - assert_eq!(SquaredEuclidean::dist(&a, &b), expected); - } - - #[rstest] - #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] - #[case([ZERO, ZERO, ZERO], [ONE, ZERO, ZERO], ONE)] - #[case([ONE, ONE, ONE], [TWO, TWO, TWO], THREE)] - fn test_squared_euclidean_distance_3d( - #[case] a: [FxdU16; 3], - #[case] b: [FxdU16; 3], - #[case] expected: FxdU16, - ) { - assert_eq!(SquaredEuclidean::dist(&a, &b), expected); - } - - #[rstest] - #[case::zero(ZERO, ZERO, ZERO)] - #[case::pos(ONE, ZERO, ONE)] - #[case::neg(ZERO, ONE, ONE)] - #[case::diff(THREE, ONE, TWO)] - fn test_manhattan_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { - assert_eq!( - >::dist1(a, b), - expected - ); - } - - #[rstest] - #[case::zero(ZERO, ZERO, ZERO)] - #[case::pos(ONE, ZERO, ONE)] - #[case::neg(ZERO, ONE, ONE)] - #[case::a_larger(TWO, ONE, ONE)] - #[case::b_larger(ONE, TWO, ONE)] - fn test_chebyshev_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { - assert_eq!( - >::dist1(a, b), - expected - ); - } - - #[rstest] - #[case::zero(ZERO, ZERO, ZERO)] - #[case::pos(ONE, ZERO, ONE)] - #[case::neg(ZERO, ONE, ONE)] - #[case::a_larger(TWO, ONE, ONE)] - #[case::b_larger(ONE, TWO, ONE)] - fn test_squared_euclidean_dist1( - #[case] a: FxdU16, - #[case] b: FxdU16, - #[case] expected: FxdU16, - ) { - assert_eq!( - >::dist1(a, b), - expected - ); - } - - #[rstest] - #[case::zero_one(ZERO, ONE, ONE)] - #[case::one_zero(ONE, ZERO, ONE)] - #[case::first_larger(ONE, TWO, TWO)] - #[case::second_larger(TWO, ONE, TWO)] - fn test_chebyshev_accumulate( - #[case] rd: FxdU16, - #[case] delta: FxdU16, - #[case] expected: FxdU16, - ) { - assert_eq!( - >::accumulate(rd, delta), - expected - ); - } -} - -#[cfg(test)] -mod integration_tests { - use super::*; - use crate::fixed::kdtree::KdTree; - use fixed::types::extra::U0; - use fixed::FixedU16; - use rstest::rstest; - - type FxdU16 = FixedU16; - - const ZERO: FxdU16 = FxdU16::ZERO; - const ONE: FxdU16 = FxdU16::lit("1"); - const TWO: FxdU16 = FxdU16::lit("2"); - const THREE: FxdU16 = FxdU16::lit("3"); - const FOUR: FxdU16 = FxdU16::lit("4"); - const FIVE: FxdU16 = FxdU16::lit("5"); - - enum DataScenario { - NoTies, - Ties, - } - - impl DataScenario { - fn get(&self, dim: usize) -> Vec> { - match (self, dim) { - (DataScenario::NoTies, 1) => { - vec![vec![ONE], vec![TWO], vec![THREE], vec![FOUR], vec![FIVE]] - } - (DataScenario::NoTies, 2) => vec![ - vec![ZERO, ZERO], - vec![ONE, ZERO], - vec![TWO, ZERO], - vec![THREE, ZERO], - vec![FOUR, ZERO], - vec![FIVE, ZERO], - ], - (DataScenario::NoTies, 3) => vec![ - vec![ZERO, ZERO, ZERO], - vec![ONE, ZERO, ZERO], - vec![TWO, ZERO, ZERO], - vec![THREE, ZERO, ZERO], - vec![FOUR, ZERO, ZERO], - vec![FIVE, ZERO, ZERO], - ], - (DataScenario::Ties, 1) => vec![ - vec![ZERO], - vec![ONE], - vec![ONE], - vec![TWO], - vec![THREE], - vec![THREE], - ], - (DataScenario::Ties, 2) => vec![ - vec![ZERO, ZERO], - vec![ONE, ZERO], - vec![ZERO, ONE], - vec![TWO, ZERO], - vec![ZERO, TWO], - vec![TWO, TWO], - ], - (DataScenario::Ties, 3) => vec![ - vec![ZERO, ZERO, ZERO], - vec![ONE, ZERO, ZERO], - vec![ZERO, ONE, ZERO], - vec![ZERO, ZERO, ONE], - vec![TWO, ZERO, ZERO], - vec![ZERO, TWO, ZERO], - ], - _ => panic!("Unsupported dimension"), - } - } - } - - fn run_test_helper>(dim: usize, scenario: DataScenario, n: usize) { - let data = scenario.get(dim); - let query_point = &data[0]; - - let mut points: Vec<[FxdU16; 6]> = Vec::with_capacity(data.len()); - for row in &data { - let mut p = [ZERO; 6]; - for (i, &val) in row.iter().enumerate() { - if i < 6 { - p[i] = val; - } - } - points.push(p); - } - - let mut query_arr = [ZERO; 6]; - for (i, &val) in query_point.iter().enumerate() { - if i < 6 { - query_arr[i] = val; - } - } - - let expected: Vec<(usize, FxdU16)> = points - .iter() - .enumerate() - .map(|(i, &point)| { - let dist = D::dist(&query_arr, &point); - (i, dist) - }) - .collect(); - - let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); - - let mut tree: KdTree = KdTree::new(); - for (i, point) in points.iter().enumerate() { - tree.add(point, i as u32); - } - - let results = tree.nearest_n::(&query_arr, n); - - assert_eq!(results[0].item, 0, "First result should be the query point"); - assert_eq!( - results[0].distance, ZERO, - "First result distance should be 0.0" - ); - - for (i, result) in results.iter().enumerate() { - assert_eq!( - result.distance, expected_distances[i], - "Distance at index {} should be {}, but was {}", - i, expected_distances[i], result.distance - ); - } - - if matches!(scenario, DataScenario::NoTies) { - for (i, result) in results.iter().enumerate() { - let expected_id = expected[i].0; - assert_eq!( - result.item, expected_id as u32, - "Result {}: item ID mismatch. Expected {}, got {}", - i, expected_id, result.item - ); - } - } - } - - #[rstest] - fn test_nearest_n_chebyshev( - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, - #[values(1, 2, 3, 4, 5, 6)] n: usize, - #[values(1, 2, 3)] dim: usize, - ) { - run_test_helper::(dim, scenario, n); - } - - #[rstest] - fn test_nearest_n_squared_euclidean( - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, - #[values(1, 2, 3, 4, 5, 6)] n: usize, - #[values(1, 2, 3)] dim: usize, - ) { - run_test_helper::(dim, scenario, n); - } - - #[rstest] - fn test_nearest_n_manhattan( - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, - #[values(1, 2, 3, 4, 5, 6)] n: usize, - #[values(1, 2, 3)] dim: usize, - ) { - run_test_helper::(dim, scenario, n); - } } diff --git a/src/float/distance.rs b/src/float/distance.rs index f44245a..65b3193 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -45,8 +45,6 @@ impl DistanceMetric for Manhattan { fn accumulate(rd: A, delta: A) -> A { rd + delta } - - const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev / L-infinity distance between two points. @@ -127,8 +125,6 @@ impl DistanceMetric for SquaredEuclidean { fn accumulate(rd: A, delta: A) -> A { rd + delta } - - const IS_MAX_BASED: bool = false; } /// Returns the Minkowski distance (power distance) between two points. From c8c9fe968214d87f5160685da115391e707c7db7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Mon, 16 Mar 2026 07:26:19 +0100 Subject: [PATCH 09/10] refactor: rename methods - rename `*_with_condition` to `*_exclusive` across query methods - update method calls and tests --- .../generate_nearest_n_within_unsorted.rs | 4 ++-- src/common/generate_within.rs | 6 +++--- src/common/generate_within_unsorted.rs | 4 ++-- src/common/generate_within_unsorted_iter.rs | 4 ++-- src/fixed/query/within.rs | 2 +- src/float/distance.rs | 11 +++++------ src/hybrid/query/within.rs | 14 +++++++++----- src/hybrid/query/within_unsorted.rs | 19 +++++++++++++------ .../generate_immutable_best_n_within.rs | 4 ++-- .../generate_immutable_nearest_n_within.rs | 4 ++-- .../common/generate_immutable_within.rs | 6 +++--- .../generate_immutable_within_unsorted.rs | 6 +++--- 12 files changed, 47 insertions(+), 37 deletions(-) diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index cef7fb6..35955f8 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -8,12 +8,12 @@ macro_rules! generate_nearest_n_within_unsorted { where D: DistanceMetric, { - self.nearest_n_within_with_condition::(query, dist, max_items, sorted, true) + self.nearest_n_within_exclusive::(query, dist, max_items, sorted, true) } #[doc = concat!$comments] #[inline] - pub fn nearest_n_within_with_condition(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool, inclusive: bool) -> Vec> + pub fn nearest_n_within_exclusive(&self, query: &[A; K], dist: A, max_items: std::num::NonZero, sorted: bool, inclusive: bool) -> Vec> where D: DistanceMetric, { diff --git a/src/common/generate_within.rs b/src/common/generate_within.rs index 642d02f..60a6654 100644 --- a/src/common/generate_within.rs +++ b/src/common/generate_within.rs @@ -8,16 +8,16 @@ macro_rules! generate_within { where D: DistanceMetric, { - self.within_with_condition::(query, dist, true) + self.within_exclusive::(query, dist, true) } #[doc = concat!$comments] #[inline] - pub fn within_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + pub fn within_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> where D: DistanceMetric, { - let mut matching_items = self.within_unsorted_with_condition::(query, dist, inclusive); + let mut matching_items = self.within_unsorted_exclusive::(query, dist, inclusive); matching_items.sort(); matching_items } diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index 10e8803..93f3b74 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -8,12 +8,12 @@ macro_rules! generate_within_unsorted { where D: DistanceMetric, { - self.within_unsorted_with_condition::(query, dist, true) + self.within_unsorted_exclusive::(query, dist, true) } #[doc = concat!$comments] #[inline] - pub fn within_unsorted_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + pub fn within_unsorted_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> where D: DistanceMetric, { diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index 8fc691d..c069a20 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -12,12 +12,12 @@ macro_rules! generate_within_unsorted_iter { where D: DistanceMetric, { - self.within_unsorted_iter_with_condition::(query, dist, true) + self.within_unsorted_iter_exclusive::(query, dist, true) } #[doc = concat!$comments] #[inline] - pub fn within_unsorted_iter_with_condition( + pub fn within_unsorted_iter_exclusive( &'a self, query: &'query [A; K], dist: A, diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index c00b13a..785ed03 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -166,7 +166,7 @@ mod tests { let query = [n(0.0), n(0.0)]; let radius = n(1.0); - let results = kdtree.within_with_condition::(&query, radius, inclusive); + let results = kdtree.within_exclusive::(&query, radius, inclusive); assert_eq!(results.len(), expected_len); if expected_len > 0 { assert_eq!(results[0].item, 1); diff --git a/src/float/distance.rs b/src/float/distance.rs index 65b3193..4bedd98 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -1416,12 +1416,12 @@ mod tests { for (i, point) in points.iter().enumerate() { tree.add(point, i as u64); } - tree.within_with_condition::(&query_arr, radius, inclusive) + tree.within_exclusive::(&query_arr, radius, inclusive) } TreeType::Immutable => { let tree: crate::immutable::float::kdtree::ImmutableKdTree = crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); - tree.within_with_condition::(&query_arr, radius, inclusive) + tree.within_exclusive::(&query_arr, radius, inclusive) } }; @@ -1513,17 +1513,16 @@ mod tests { let query = [0.0, 0.0]; let radius = 1.0; - let results = - kdtree.within_with_condition::(&query, radius, inclusive); + let results = kdtree.within_exclusive::(&query, radius, inclusive); assert_eq!(results.len(), expected_len); if expected_len > 0 { assert_eq!(results[0].item, 1); assert_eq!(results[0].distance, 1.0); } - // Test nearest_n_within_with_condition + // Test nearest_n_within_exclusive let max_qty = std::num::NonZero::new(10).unwrap(); - let results = kdtree.nearest_n_within_with_condition::( + let results = kdtree.nearest_n_within_exclusive::( &query, radius, max_qty, true, inclusive, ); assert_eq!(results.len(), expected_len); diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index 449944d..4015cdc 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -39,11 +39,11 @@ where where F: Fn(&[A; K], &[A; K]) -> A, { - self.within_with_condition(query, dist, distance_fn, true) + self.within_exclusive(query, dist, distance_fn, true) } #[inline] - pub fn within_with_condition( + pub fn within_exclusive( &self, query: &[A; K], dist: A, @@ -146,7 +146,11 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if if inclusive { distance <= radius } else { distance < radius } { + if if inclusive { + distance <= radius + } else { + distance < radius + } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -161,8 +165,8 @@ where mod tests { use crate::float::distance::manhattan; use crate::float::kdtree::{Axis, KdTree}; - use rstest::rstest; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type AX = f32; @@ -275,7 +279,7 @@ mod tests { let query = [0.0, 0.0]; let radius = 1.0; - let results = kdtree.within_with_condition( + let results = kdtree.within_exclusive( &query, radius, &|a, b| { diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index e95baf6..36a321b 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -41,11 +41,11 @@ where where F: Fn(&[A; K], &[A; K]) -> A, { - self.within_unsorted_with_condition(query, dist, distance_fn, true) + self.within_unsorted_exclusive(query, dist, distance_fn, true) } #[inline] - pub fn within_unsorted_with_condition( + pub fn within_unsorted_exclusive( &self, query: &[A; K], dist: A, @@ -148,7 +148,11 @@ where .for_each(|(idx, entry)| { let distance = distance_fn(query, entry); - if if inclusive { distance <= radius } else { distance < radius } { + if if inclusive { + distance <= radius + } else { + distance < radius + } { matching_items.push(Neighbour { distance, item: *leaf_node.content_items.get_unchecked(idx.az::()), @@ -163,8 +167,8 @@ where mod tests { use crate::float::distance::squared_euclidean; use crate::float::kdtree::{Axis, KdTree}; - use rstest::rstest; use rand::Rng; + use rstest::rstest; use std::cmp::Ordering; type AX = f32; @@ -269,7 +273,10 @@ mod tests { #[rstest] #[case(true, 1)] #[case(false, 0)] - fn test_within_unsorted_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) { + fn test_within_unsorted_boundary_inclusiveness( + #[case] inclusive: bool, + #[case] expected_len: usize, + ) { let mut kdtree: KdTree = KdTree::new(); kdtree.add(&[1.0, 0.0], 1); kdtree.add(&[2.0, 0.0], 2); @@ -277,7 +284,7 @@ mod tests { let query = [0.0, 0.0]; let radius = 1.0; - let results = kdtree.within_unsorted_with_condition( + let results = kdtree.within_unsorted_exclusive( &query, radius, &|a, b| { diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index 81ec296..c227e0a 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -15,12 +15,12 @@ macro_rules! generate_immutable_best_n_within { usize: Cast, D: DistanceMetric, { - self.best_n_within_with_condition::(query, dist, max_qty, true) + self.best_n_within_exclusive::(query, dist, max_qty, true) } #[doc = concat!$comments] #[inline] - pub fn best_n_within_with_condition( + pub fn best_n_within_exclusive( &self, query: &[A; K], dist: A, diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 877e5d1..70d829f 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -8,12 +8,12 @@ macro_rules! generate_immutable_nearest_n_within { where D: DistanceMetric, { - self.nearest_n_within_with_condition::(query, dist, max_items, sorted, true) + self.nearest_n_within_exclusive::(query, dist, max_items, sorted, true) } #[doc = concat!$comments] #[inline] - pub fn nearest_n_within_with_condition(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool, inclusive: bool) -> Vec> + pub fn nearest_n_within_exclusive(&self, query: &[A; K], dist: A, max_items: NonZero, sorted: bool, inclusive: bool) -> Vec> where D: DistanceMetric, { diff --git a/src/immutable/common/generate_immutable_within.rs b/src/immutable/common/generate_immutable_within.rs index ee5de48..82e3551 100644 --- a/src/immutable/common/generate_immutable_within.rs +++ b/src/immutable/common/generate_immutable_within.rs @@ -9,17 +9,17 @@ macro_rules! generate_immutable_within { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.within_with_condition::(query, dist, true) + self.within_exclusive::(query, dist, true) } #[doc = concat!$comments] #[inline] - pub fn within_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + pub fn within_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> where A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within_with_condition::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true, inclusive) + self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true, inclusive) } }; } diff --git a/src/immutable/common/generate_immutable_within_unsorted.rs b/src/immutable/common/generate_immutable_within_unsorted.rs index e8d7af2..aee1320 100644 --- a/src/immutable/common/generate_immutable_within_unsorted.rs +++ b/src/immutable/common/generate_immutable_within_unsorted.rs @@ -9,17 +9,17 @@ macro_rules! generate_immutable_within_unsorted { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.within_unsorted_with_condition::(query, dist, true) + self.within_unsorted_exclusive::(query, dist, true) } #[doc = concat!$comments] #[inline] - pub fn within_unsorted_with_condition(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> + pub fn within_unsorted_exclusive(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec> where A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { - self.nearest_n_within_with_condition::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false, inclusive) + self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false, inclusive) } }; } From 3c19d7c55c872920eeaca32028ca0c0475a8eb58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Mon, 16 Mar 2026 08:27:28 +0100 Subject: [PATCH 10/10] docs: add doc comments to *_exclusive methods explaining inclusive param --- src/common/generate_nearest_n_within_unsorted.rs | 4 ++++ src/common/generate_within.rs | 4 ++++ src/common/generate_within_unsorted.rs | 4 ++++ src/common/generate_within_unsorted_iter.rs | 4 ++++ src/hybrid/query/within.rs | 4 ++++ src/hybrid/query/within_unsorted.rs | 4 ++++ src/immutable/common/generate_immutable_best_n_within.rs | 4 ++++ src/immutable/common/generate_immutable_nearest_n_within.rs | 4 ++++ src/immutable/common/generate_immutable_within.rs | 4 ++++ src/immutable/common/generate_immutable_within_unsorted.rs | 4 ++++ 10 files changed, 40 insertions(+) diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index 35955f8..e5779e7 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -17,6 +17,10 @@ macro_rules! generate_nearest_n_within_unsorted { where D: DistanceMetric, { + // Like [`nearest_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. if sorted || max_items < std::num::NonZero::new(usize::MAX).unwrap() { if max_items <= std::num::NonZero::new(MAX_VEC_RESULT_SIZE).unwrap() { self.nearest_n_within_stub::>>(query, dist, max_items.get(), sorted, inclusive) diff --git a/src/common/generate_within.rs b/src/common/generate_within.rs index 60a6654..a014a2d 100644 --- a/src/common/generate_within.rs +++ b/src/common/generate_within.rs @@ -17,6 +17,10 @@ macro_rules! generate_within { where D: DistanceMetric, { + // Like [`within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut matching_items = self.within_unsorted_exclusive::(query, dist, inclusive); matching_items.sort(); matching_items diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index 93f3b74..4912027 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -17,6 +17,10 @@ macro_rules! generate_within_unsorted { where D: DistanceMetric, { + // Like [`within_unsorted`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let mut matching_items = Vec::new(); let root_index: IDX = *transform(&self.root_index); diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index c069a20..1ae2245 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -26,6 +26,10 @@ macro_rules! generate_within_unsorted_iter { where D: DistanceMetric, { + // Like [`within_unsorted_iter`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let root_index: IDX = *transform(&self.root_index); diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index 4015cdc..9bcba71 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -42,6 +42,10 @@ where self.within_exclusive(query, dist, distance_fn, true) } + /// Like [`within`] but allows controlling boundary inclusiveness. + /// + /// When `inclusive` is true, points at exactly the maximum distance are included. + /// When false, only points strictly less than the maximum distance are included. #[inline] pub fn within_exclusive( &self, diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index 36a321b..dfaab2b 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -44,6 +44,10 @@ where self.within_unsorted_exclusive(query, dist, distance_fn, true) } + /// Like [`within_unsorted`] but allows controlling boundary inclusiveness. + /// + /// When `inclusive` is true, points at exactly the maximum distance are included. + /// When false, only points strictly less than the maximum distance are included. #[inline] pub fn within_unsorted_exclusive( &self, diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index c227e0a..e881d2b 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -32,6 +32,10 @@ macro_rules! generate_immutable_best_n_within { usize: Cast, D: DistanceMetric, { + // Like [`best_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let mut off = [A::zero(); K]; let mut best_items: BinaryHeap> = BinaryHeap::with_capacity(max_qty.into()); diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 70d829f..8c989e3 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -17,6 +17,10 @@ macro_rules! generate_immutable_nearest_n_within { where D: DistanceMetric, { + // Like [`nearest_n_within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. let max_items = max_items.into(); if sorted && max_items < usize::MAX { diff --git a/src/immutable/common/generate_immutable_within.rs b/src/immutable/common/generate_immutable_within.rs index 82e3551..6ed5f96 100644 --- a/src/immutable/common/generate_immutable_within.rs +++ b/src/immutable/common/generate_immutable_within.rs @@ -19,6 +19,10 @@ macro_rules! generate_immutable_within { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { + // Like [`within`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), true, inclusive) } }; diff --git a/src/immutable/common/generate_immutable_within_unsorted.rs b/src/immutable/common/generate_immutable_within_unsorted.rs index aee1320..529b0de 100644 --- a/src/immutable/common/generate_immutable_within_unsorted.rs +++ b/src/immutable/common/generate_immutable_within_unsorted.rs @@ -19,6 +19,10 @@ macro_rules! generate_immutable_within_unsorted { A: LeafSliceFloat + LeafSliceFloatChunk, D: DistanceMetric, usize: Cast, { + // Like [`within_unsorted`] but allows controlling boundary inclusiveness. + // + // When `inclusive` is true, points at exactly the maximum distance are included. + // When false, only points strictly less than the maximum distance are included. self.nearest_n_within_exclusive::(query, dist, std::num::NonZero::new(usize::MAX).unwrap(), false, inclusive) } };