From 55143d2960e22af8028becc58bf1654e37458dc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 5 Feb 2026 15:25:02 +0100 Subject: [PATCH 01/17] test: add coverage for Manhattan and Squared Euclidean distance metrics --- src/float/distance.rs | 251 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) diff --git a/src/float/distance.rs b/src/float/distance.rs index 3784f946..dfd1e34e 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -73,3 +73,254 @@ impl DistanceMetric for SquaredEuclidean { (a - b) * (a - b) } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + mod manhattan_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 4.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 4.0f32)] // fractional values + fn test_manhattan_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 6.0f64)] // 3D diagonal + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 9.0f64)] // 3D offset + fn test_manhattan_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_manhattan_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 16.0f32; // |5-1| + |6-2| + |7-3| + |8-4| = 4+4+4+4 = 16 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 25.0f64; // |5-0| + |6-1| + |7-2| + |8-3| + |9-4| = 5+5+5+5+5 = 25 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 6.0f32 + ); // fractional negative to positive + assert_eq!( + >::dist1(1000.0f32, -1000.0f32), + 2000.0f32 + ); // large values + } + + #[test] + fn test_manhattan_symmetry() { + let a = [1.0f64, 2.0f64, 3.0f64]; + let b = [4.0f64, 5.0f64, 6.0f64]; + + assert_eq!(Manhattan::dist(&a, &b), Manhattan::dist(&b, &a)); + } + + #[test] + fn test_manhattan_identity() { + let a = [1.0f32, 2.0f32, 3.0f32]; + assert_eq!(Manhattan::dist(&a, &a), 0.0f32); + } + + #[test] + fn test_manhattan_non_negativity() { + let a = [1.0f32, 2.0f32]; + let b = [3.0f32, 4.0f32]; + let distance = Manhattan::dist(&a, &b); + assert!(distance >= 0.0f32); + } + } + + mod squared_euclidean_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal (1^2 + 1^2) + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 8.0f32)] // negative to positive (2^2 + 2^2) + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 8.0f32)] // fractional values (2^2 + 2^2) + #[case([0.0f32, 0.0f32], [3.0f32, 4.0f32], 25.0f32)] // 3-4-5 triangle + fn test_squared_euclidean_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 2.0f64], 9.0f64)] // 3D (1^2 + 2^2 + 2^2) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 27.0f64)] // 3D offset (3^2 + 3^2 + 3^2) + fn test_squared_euclidean_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 25.0f32)] // 1D positive (5^2) + #[case([5.0f32], [0.0f32], 25.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 100.0f32)] // 1D negative to positive (10^2) + fn test_squared_euclidean_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[test] + fn test_squared_euclidean_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 36.0f32 + ); // fractional negative to positive (6^2) + assert_eq!( + >::dist1(10.0f32, -10.0f32), + 400.0f32 + ); // large values (20^2) + } + + #[test] + fn test_squared_euclidean_symmetry() { + let a = [1.0f64, 2.0f64, 3.0f64]; + let b = [4.0f64, 5.0f64, 6.0f64]; + + assert_eq!( + SquaredEuclidean::dist(&a, &b), + SquaredEuclidean::dist(&b, &a) + ); + } + + #[test] + fn test_squared_euclidean_identity() { + let a = [1.0f32, 2.0f32, 3.0f32]; + assert_eq!(SquaredEuclidean::dist(&a, &a), 0.0f32); + } + + #[test] + fn test_squared_euclidean_non_negativity() { + let a = [1.0f32, 2.0f32]; + let b = [3.0f32, 4.0f32]; + let distance = SquaredEuclidean::dist(&a, &b); + assert!(distance >= 0.0f32); + } + + #[test] + fn test_squared_euclidean_triangle_inequality_property() { + // Test that squared Euclidean distance preserves ordering + let a = [0.0f32, 0.0f32]; + let b = [1.0f32, 0.0f32]; + let c = [1.0f32, 1.0f32]; + + let dist_ab = SquaredEuclidean::dist(&a, &b); + let dist_ac = SquaredEuclidean::dist(&a, &c); + let dist_bc = SquaredEuclidean::dist(&b, &c); + + // For these points: dist(a,b) = 1, dist(b,c) = 1, dist(a,c) = 2 + assert_eq!(dist_ab, 1.0f32); + assert_eq!(dist_bc, 1.0f32); + assert_eq!(dist_ac, 2.0f32); + } + } + + #[cfg(feature = "f16")] + mod f16_tests { + use super::*; + use half::f16; + + #[test] + fn test_manhattan_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = Manhattan::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + + #[test] + fn test_squared_euclidean_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = SquaredEuclidean::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + } +} From a7471bf5d8a4fe2ec58f0d662f33e06dc82e61db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 5 Feb 2026 22:51:30 +0100 Subject: [PATCH 02/17] feat: add Chebyshev distance metric and test coverage --- src/float/distance.rs | 539 ++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 540 insertions(+) diff --git a/src/float/distance.rs b/src/float/distance.rs index dfd1e34e..0d3369f7 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -40,6 +40,40 @@ impl DistanceMetric for Manhattan { } } +/// Returns the Chebyshev / L-infinity distance between two points. +/// +/// Chebyshev distance is the maximum absolute difference along any axis. +/// Also known as chessboard distance or L-infinity norm. +/// +/// re-exported as `kiddo::Chebyshev` for convenience +/// +/// # Examples +/// +/// ```rust +/// use kiddo::traits::DistanceMetric; +/// use kiddo::Chebyshev; +/// +/// assert_eq!(0f32, Chebyshev::dist(&[0f32, 0f32], &[0f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 1f32])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| (a_val - b_val).abs()) + .fold(A::zero(), |acc, val| acc.max(val)) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + (a - b).abs() + } +} + /// Returns the squared euclidean distance between two points. /// /// Faster than Euclidean distance due to not needing a square root, but still @@ -296,6 +330,115 @@ mod tests { } } + mod chebyshev_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 1.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 2.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 2.0f32)] // fractional values + #[case([0.0f32, 0.0f32], [2.0f32, 1.0f32], 2.0f32)] // max on first axis + #[case([0.0f32, 0.0f32], [1.0f32, 2.0f32], 2.0f32)] // max on second axis + fn test_chebyshev_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 3.0f64)] // 3D diagonal (max is 3) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 3.0f64)] // 3D offset (max is 3) + fn test_chebyshev_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_chebyshev_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 4.0f32; // max(|5-1|, |6-2|, |7-3|, |8-4|) = max(4, 4, 4, 4) = 4 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 5.0f64; // max(|5-0|, |6-1|, |7-2|, |8-3|, |9-4|) = max(5, 5, 5, 5, 5) = 5 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case(0.0f32, 0.0f32, 0.0f32)] // zero difference + #[case(1.0f32, 0.0f32, 1.0f32)] // positive difference + #[case(0.0f32, 1.0f32, 1.0f32)] // negative difference (reversed) + #[case(-2.5f32, 3.5f32, 6.0f32)] // fractional negative to positive + #[case(1000.0f32, -1000.0f32, 2000.0f32)] // large values + fn test_chebyshev_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) { + assert_eq!(>::dist1(a, b), expected); + } + + #[test] + fn test_chebyshev_symmetry() { + let a = [1.0f64, 2.0f64, 3.0f64]; + let b = [4.0f64, 5.0f64, 6.0f64]; + assert_eq!(Chebyshev::dist(&a, &b), Chebyshev::dist(&b, &a)); + } + + #[test] + fn test_chebyshev_identity() { + let a = [1.0f32, 2.0f32, 3.0f32]; + assert_eq!(Chebyshev::dist(&a, &a), 0.0f32); + } + + #[test] + fn test_chebyshev_non_negativity() { + let a = [1.0f32, 2.0f32]; + let b = [3.0f32, 4.0f32]; + let distance = Chebyshev::dist(&a, &b); + assert!(distance >= 0.0f32); + } + + #[test] + fn test_chebyshev_max_property() { + // Test that Chebyshev correctly finds the maximum difference + let a = [0.0, 0.0]; + let b = [3.0, 1.0]; + + let result = Chebyshev::dist(&a, &b); + + // max(|0-3|, |0-1|) = max(3, 1) = 3 + assert_eq!(result, 3.0); + + // Verify it's not Manhattan (which would be 4) or Euclidean (sqrt(10)) + assert_ne!(result, 4.0); + assert_ne!(result, (10.0_f64).sqrt()); + } + } + #[cfg(feature = "f16")] mod f16_tests { use super::*; @@ -323,4 +466,400 @@ mod tests { assert_eq!(result, expected); } } + + mod integration_tests { + use super::*; + use crate::KdTree; + + #[test] + fn test_nearest_n_manhattan_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a simple pattern + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // distance 1 from query point + ([0.0f32, 1.0f32], 2), // distance 1 from query point + ([2.0f32, 0.0f32], 3), // distance 2 from query point + ([0.0f32, 2.0f32], 4), // distance 2 from query point + ([3.0f32, 3.0f32], 5), // distance 6 from query point + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 4); + + // Expected order: [0], [1], [2], [3], [4] + // Distances: 0, 1, 1, 2, 2 + // But we only ask for 4 nearest + assert_eq!(results.len(), 4); + + // First result should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Manhattan distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Fourth should be one of the points at distance 2 + assert!(results[3].item == 3 || results[3].item == 4); + assert_eq!(results[3].distance, 2.0); + } + + #[test] + fn test_nearest_n_squared_euclidean_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a pattern where Euclidean and Manhattan differ + let points = [ + ([0.0, 0.0], 0), // distance 0 from query point + ([1.0, 0.0], 1), // Euclidean: 1, Manhattan: 1 + ([0.0, 1.0], 2), // Euclidean: 1, Manhattan: 1 + ([1.0, 1.0], 3), // Euclidean: 2, Manhattan: 2 + ([2.0, 0.0], 4), // Euclidean: 4, Manhattan: 2 + ([0.0, 2.0], 5), // Euclidean: 4, Manhattan: 2 + ([3.0, 4.0], 6), // Euclidean: 25, Manhattan: 7 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + let results = kdtree.nearest_n::(&query_point, 5); + + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Euclidean distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Next two should be the points at Euclidean distance 2 + assert_eq!(results[3].item, 3); + assert_eq!(results[3].distance, 2.0); + assert_eq!(results[4].item, 4); + assert_eq!(results[4].distance, 4.0); + + // Verify that points at squared Euclidean distance 4 are indeed farther + // than points at squared Euclidean distance 2 + assert!(results[4].distance > results[3].distance); + } + + #[test] + fn test_nearest_n_different_metrics_produce_different_orderings() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points where Manhattan and Euclidean give different orderings + let points = [ + ([0.0, 0.0], 0), // origin + ([2.0, 1.0], 1), // Manhattan: 3, Euclidean^2: 5 + ([1.0, 2.0], 2), // Manhattan: 3, Euclidean^2: 5 + ([3.0, 0.0], 3), // Manhattan: 3, Euclidean^2: 9 + ([0.0, 3.0], 4), // Manhattan: 3, Euclidean^2: 9 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + + let manhattan_results = kdtree.nearest_n::(&query_point, 3); + let euclidean_results = kdtree.nearest_n::(&query_point, 3); + + // Both should include the origin as first result + assert_eq!(manhattan_results[0].item, 0); + assert_eq!(euclidean_results[0].item, 0); + + // For Manhattan: points 1, 2, 3, 4 are all at distance 3 + // The ordering among ties depends on tree structure, but they should all have same distance + assert_eq!(manhattan_results[1].distance, 3.0); + assert_eq!(manhattan_results[2].distance, 3.0); + + // For Euclidean: points 1 and 2 are at distance sqrt(5) ≈ 2.236 (squared: 5) + // Points 3 and 4 are at distance 3 (squared: 9) + assert_eq!(euclidean_results[1].distance, 5.0); + assert_eq!(euclidean_results[2].distance, 5.0); + + // Verify that Euclidean ordering puts points 1 and 2 before 3 and 4 + let euclidean_items: Vec = euclidean_results + .iter() + .skip(1) // skip origin + .take(2) // take next 2 + .map(|nn| nn.item) + .collect(); + + assert!(euclidean_items.contains(&1) || euclidean_items.contains(&2)); + + // Calculate actual distances to verify our understanding + let p1 = [2.0, 1.0]; + let p2 = [1.0, 2.0]; + let p3 = [3.0, 0.0]; + + let manhattan_p1 = Manhattan::dist(&query_point, &p1); + let manhattan_p2 = Manhattan::dist(&query_point, &p2); + let manhattan_p3 = Manhattan::dist(&query_point, &p3); + + let euclidean_p1 = SquaredEuclidean::dist(&query_point, &p1); + let euclidean_p2 = SquaredEuclidean::dist(&query_point, &p2); + let euclidean_p3 = SquaredEuclidean::dist(&query_point, &p3); + + assert_eq!(manhattan_p1, 3.0); + assert_eq!(manhattan_p2, 3.0); + assert_eq!(manhattan_p3, 3.0); + + assert_eq!(euclidean_p1, 5.0); + assert_eq!(euclidean_p2, 5.0); + assert_eq!(euclidean_p3, 9.0); + } + + #[test] + fn test_nearest_n_3d_different_metrics() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in 3D space + let points = [ + ([1.0, 1.0, 1.0], 0), // origin + ([2.0, 1.0, 1.0], 1), // 1 unit away on x-axis + ([1.0, 2.0, 1.0], 2), // 1 unit away on y-axis + ([1.0, 1.0, 2.0], 3), // 1 unit away on z-axis + ([3.0, 1.0, 1.0], 4), // 2 units away on x-axis + ([0.0, 0.0, 0.0], 5), // sqrt(3) ≈ 1.732 from origin + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [1.0, 1.0, 1.0]; + let results = kdtree.nearest_n::(&query_point, 4); + + assert_eq!(results.len(), 4); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next three should be the points at Manhattan distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(3) // take next 3 + .map(|nn| nn.item) + .collect(); + + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&3)); + + // All nearby points should have distance 1 + for result in results.iter().skip(1).take(3) { + assert_eq!(result.distance, 1.0); + } + + // Point 4 should be farther (distance 2) and not in top 4 + let all_items: Vec = results.iter().map(|nn| nn.item).collect(); + assert!(!all_items.contains(&4)); + + // Point 5 has Manhattan distance 3, so definitely not in top 4 + assert!(!all_items.contains(&5)); + } + + #[test] + fn test_nearest_n_large_scale() { + let mut kdtree: KdTree = KdTree::new(); + + // Create a grid of points + let mut index = 0; + for x in 0i32..10 { + for y in 0i32..10 { + let point = [x as f32, y as f32]; + kdtree.add(&point, index); + index += 1; + } + } + + // Query from center of grid + let query_point = [5.0f32, 5.0f32]; + let results = kdtree.nearest_n::(&query_point, 10); + + assert_eq!(results.len(), 10); + + // First result should be the center point itself (index 55) + assert_eq!(results[0].item, 55); + assert_eq!(results[0].distance, 0.0); + + // Results should be ordered by increasing distance + for i in 1..10 { + assert!(results[i].distance >= results[i - 1].distance); + } + + // Verify distances make sense for a grid + // The nearest points should be at squared distances: 0, 1, 1, 1, 1, 2, 2, 4, 4, 5... + let expected_distances = [0.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32]; + + for (i, &expected_dist) in expected_distances.iter().enumerate() { + if i < results.len() { + assert_eq!(results[i].distance, expected_dist); + } + } + } + + #[test] + fn test_nearest_n_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points that show Chebyshev behavior + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([0.0f32, 1.0f32], 2), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([2.0f32, 0.0f32], 3), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([0.0f32, 2.0f32], 4), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([1.0f32, 1.0f32], 5), // Chebyshev: 1, Manhattan: 2, Euclidean^2: 2 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 5); + + // With Chebyshev, points at (1,0), (0,1), and (1,1) all have distance 1 + // Points at (2,0) and (0,2) have distance 2 + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next should all be at Chebyshev distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(4) // take next 4 + .filter(|r| (r.distance - 1.0).abs() < 0.001) // check for distance 1 (with some float tolerance) + .map(|nn| nn.item) + .collect(); + + // All of these should be in the results: 1, 2, 5 + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&5)); + } + + #[test] + fn test_within_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points with varying Chebyshev distances + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 + ([0.5f32, 0.5f32], 1), // Chebyshev: 0.5 + ([1.0f32, 0.0f32], 2), // Chebyshev: 1.0 + ([0.8f32, 0.9f32], 3), // Chebyshev: 0.9 + ([2.0f32, 0.0f32], 4), // Chebyshev: 2.0 + ([0.0f32, 2.0f32], 5), // Chebyshev: 2.0 + ([1.5f32, 1.5f32], 6), // Chebyshev: 1.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let radius = 1.0; // radius 1 (not squared for Chebyshev) + let mut results = kdtree.within::(&query_point, radius); + + // Sort by distance for easier verification + results.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // NOTE: This test demonstrates a known limitation with Chebyshev distance: + // The k-d tree query logic uses dist1 for pruning, which is incorrect for Chebyshev. + // Point at [1.0, 0.0] (index 2) has Chebyshev distance exactly 1.0 but is NOT found. + + // Should include points with Chebyshev distance <= 1 + // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) + // But ACTUALLY FINDS: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) + let found_indices: Vec = results.iter().map(|r| r.item).collect(); + + assert!(found_indices.contains(&0)); + assert!(found_indices.contains(&1)); + // This assert FAILS - demonstrates the bug + assert!(found_indices.contains(&2)); // currently not included, but should! + assert!(found_indices.contains(&3)); + + // Should NOT include points with Chebyshev distance > 1 + assert!(!found_indices.contains(&4)); + assert!(!found_indices.contains(&5)); + assert!(!found_indices.contains(&6)); + + // Verify distances + for result in results { + assert!(result.distance <= 1.0 || (result.distance - 1.0).abs() < 0.001); + } + } + + #[test] + fn test_chebyshev_vs_manhattan_ordering() { + let mut kdtree: KdTree = KdTree::new(); + + // Points where Chebyshev and Manhattan differ significantly + let points = [ + ([0.0f32, 0.0f32], 0), // origin + ([3.0f32, 1.0f32], 1), // Chebyshev: 3, Manhattan: 4 + ([1.0f32, 3.0f32], 2), // Chebyshev: 3, Manhattan: 4 + ([2.0f32, 2.0f32], 3), // Chebyshev: 2, Manhattan: 4 + ([4.0f32, 0.5f32], 4), // Chebyshev: 4, Manhattan: 4.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + + let chebyshev_results = kdtree.nearest_n::(&query_point, 4); + let manhattan_results = kdtree.nearest_n::(&query_point, 4); + + // Both should include the origin first + assert_eq!(chebyshev_results[0].item, 0); + assert_eq!(manhattan_results[0].item, 0); + + // With Chebyshev, nearest should be point 3 (distance 2) + // With Manhattan, nearest should be points 1 and 2 (distance 4) + assert_eq!(chebyshev_results[1].item, 3); + assert_eq!(chebyshev_results[1].distance, 2.0); + + // With Manhattan, points 1 and 2 should come before point 3 (which is distance 4) + let manhattan_items: Vec = manhattan_results + .iter() + .skip(1) + .take(3) + .map(|r| r.item) + .collect(); + assert!(manhattan_items.contains(&1) || manhattan_items.contains(&2)); + + // Verify the distance calculations are correct + assert_eq!(chebyshev_results[1].distance, 2.0); // Chebyshev: max(|2-0|, |2-0|) = 2 + assert_eq!(manhattan_results[1].distance, 4.0); // Manhattan: |3-0| + |1-0| = 4 + } + } } diff --git a/src/lib.rs b/src/lib.rs index 557c9a2b..102bd044 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,6 +135,7 @@ pub type ImmutableKdTree = immutable::float::kdtree::ImmutableKdTree; pub use best_neighbour::BestNeighbour; +pub use float::distance::Chebyshev; pub use float::distance::Manhattan; pub use float::distance::SquaredEuclidean; pub use nearest_neighbour::NearestNeighbour; From a5cf629e7e725afa31ff41312a6f86069f0b0258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 16:36:11 +0100 Subject: [PATCH 03/17] test: add integration tests for Chebyshev, Manhattan, and Squared Euclidean distance metrics --- src/float/distance.rs | 253 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 251 insertions(+), 2 deletions(-) diff --git a/src/float/distance.rs b/src/float/distance.rs index 0d3369f7..7cd335f2 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -469,7 +469,256 @@ mod tests { mod integration_tests { use super::*; - use crate::KdTree; + use crate::{ImmutableKdTree, KdTree}; + use rstest::rstest; + + #[derive(Debug, Clone, Copy)] + enum DataScenario { + NoTies, + Ties, + } + + #[derive(Debug, Clone, Copy)] + enum TreeType { + Mutable, + Immutable, + } + + impl DataScenario { + /// Get data scenario + /// + /// Data is ordered to appear in increasing distance to the 0-th point. + /// Predefined data has input dimension (`dim`) and either + /// with `DataScenario::NoTies` or `DataScenario::Ties`. + /// + /// # Parameters + /// - `dim`: The dimensionality of the data to retrieve. + /// Must be a value between 1 and 4 (inclusive). + /// + /// # Returns + /// - `Vec>`: A 2D vector where each inner vector represents a data point. + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => vec![ + vec![1.0], + vec![2.0], + vec![4.0], + vec![7.0], + vec![-9.0], + vec![16.0], + ], + (DataScenario::NoTies, 2) => vec![ + vec![0.0, 0.0], + vec![1.1, 0.1], + vec![2.3, 0.4], + vec![3.6, 0.9], + vec![5.0, 1.6], + vec![6.5, 2.5], + ], + (DataScenario::NoTies, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.1, 0.1, 0.01], + vec![2.3, 0.4, 0.08], + vec![-3.6, -0.9, -0.27], + vec![5.0, 1.6, 0.64], + vec![6.5, 2.5, 1.25], + ], + (DataScenario::NoTies, 4) => vec![ + vec![0.0, 0.0, 0.0, 1000.0], + vec![1.1, 0.1, 0.01, 1000.001], + vec![2.3, 0.4, 0.08, 1000.008], + vec![3.6, 0.9, 0.27, 1000.027], + vec![5.0, 1.6, 0.64, 1000.256], + vec![6.5, 2.5, 1.25, 1000.625], + ], + (DataScenario::Ties, 1) => vec![ + vec![0.0], + vec![1.0], + vec![1.0], + vec![2.0], + vec![2.0], + vec![3.0], + ], + (DataScenario::Ties, 2) => vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.0, 1.0], + vec![-1.0, 0.0], + vec![0.0, -1.0], + vec![1.0, 1.0], + ], + (DataScenario::Ties, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0], + vec![0.0, -1.0, 0.0], + ], + (DataScenario::Ties, 4) => vec![ + vec![0.0, 0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0, 0.0], + vec![0.0, 0.0, 1.0, 0.0], + vec![0.0, 0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0, 0.0], + ], + _ => panic!("Unsupported dimension {} for scenario {:?}", dim, self), + } + } + } + + /// Helper function to test nearest_n queries for `D: DistanceMetric` + /// + /// Tests KD-tree Chebyshev distance queries across different tree types and + /// data scenarios. This simplifies testing across different combinations. + /// + /// # What this function does + /// 1. Get test data points based on a scenario (NoTies/Ties) and dimensionality + /// 2. Builds either MutableKdTree (incremental) or ImmutableKdTree (bulk construction) + /// 3. Performs nearest_n query with Chebyshev distance from point 0 + /// 4. Compares results against Brute-force distances, + /// calculated from `>::dist`. + /// + /// # Choices + /// - Fixed-size array `[f64; 6]`. For `dim<6` a subspace/padding is used for practicality + /// + /// # Assertions + /// - Point 0 is always the query point (distance 0, index 0 expected first result) + /// - NoTies scenario: checks distances and item IDs for points with unique distances + /// - Ties scenario: checks distances (order among ties is non-deterministic) + fn run_test_helper>( + dim: usize, + tree_type: TreeType, + scenario: DataScenario, + n: usize, + ) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [0.0; 6]; + for (i, &val) in row.iter().enumerate() { + p[i] = val; + } + points.push(p); + } + + let mut query_arr = [0.0; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + // Calculate ground truth with brute-force approach + let expected: Vec<(usize, f64)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + println!( + "Query: {:?}, TreeType: {:?}, Scenario: {:?}, dim={}, n={}", + query_point, tree_type, scenario, dim, n + ); + + // Query based on tree type + let results = match tree_type { + TreeType::Mutable => { + let mut tree: KdTree = KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u64); + } + tree.nearest_n::(&query_arr, n) + } + TreeType::Immutable => { + let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&points); + tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) + } + }; + + println!("Results (len: {}):", results.len()); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, 0.0, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u64, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + /// Chebyshev distance nearest-neighbor query tests. + /// + /// Test matrix covering all combinations of mutable/immutable trees, + /// data scenarios (with/out ties), dimensions, and neighbor query counts. + /// + /// Currently passing tests: + /// - All MutableKdTree tests pass + /// - ImmutableKdTree with NoTies: + /// - Pass for when just querying the root n=1 or dim=1 + /// - ImmutableKdTree with Ties: Several pass (one edge case failure for n=6, dim=2) + /// + /// Currently failing tests (16 of 96): + /// - ImmutableKdTree + NoTies: fails for dim>=2 AND n>=2 (15 failures) + /// - ImmutableKdTree + Ties: 1 failure (n=6, dim=2) + /// + /// TODO: Hypothesis: Problem might be `rd_update` in `src/float/kdtree.rs` + /// using `+` aggregation (sensible for sum-based metrics like L1/L2). + /// L_inf would need `max` aggregation. + #[rstest] + fn test_nearest_n_chebyshev( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } #[test] fn test_nearest_n_manhattan_distance() { @@ -803,7 +1052,7 @@ mod tests { assert!(found_indices.contains(&0)); assert!(found_indices.contains(&1)); // This assert FAILS - demonstrates the bug - assert!(found_indices.contains(&2)); // currently not included, but should! + assert!(found_indices.contains(&2)); // currently not included, but should! assert!(found_indices.contains(&3)); // Should NOT include points with Chebyshev distance > 1 From 898da96733cc6657e9c80cc207e957fbb80330ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 17:36:08 +0100 Subject: [PATCH 04/17] refactor: inclusive radius matching and leaf note remainder loops --- .../generate_nearest_n_within_unsorted.rs | 2 +- src/common/generate_within_unsorted.rs | 2 +- src/common/generate_within_unsorted_iter.rs | 2 +- src/fixed/query/within.rs | 2 +- src/fixed/query/within_unsorted.rs | 2 +- src/fixed/query/within_unsorted_iter.rs | 2 +- src/float/query/nearest_n_within.rs | 2 +- src/float/query/within_unsorted.rs | 2 +- src/float_leaf_slice/leaf_slice.rs | 72 +++++++++++-------- src/hybrid/query/within.rs | 2 +- src/hybrid/query/within_unsorted.rs | 2 +- src/immutable/float/query/nearest_n_within.rs | 2 +- src/immutable/float/query/within.rs | 2 +- src/immutable/float/query/within_unsorted.rs | 2 +- .../float/query/within_unsorted_iter.rs | 2 +- 15 files changed, 58 insertions(+), 42 deletions(-) diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index c3654cd8..ec7daca2 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -116,7 +116,7 @@ macro_rules! generate_nearest_n_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index bc9442cc..33026412 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -98,7 +98,7 @@ macro_rules! generate_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index f5c94ea9..a6d0123b 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -108,7 +108,7 @@ macro_rules! generate_within_unsorted_iter { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index 01848e88..094211af 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -163,7 +163,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted.rs b/src/fixed/query/within_unsorted.rs index 3075cc15..b5a1dcb8 100644 --- a/src/fixed/query/within_unsorted.rs +++ b/src/fixed/query/within_unsorted.rs @@ -167,7 +167,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted_iter.rs b/src/fixed/query/within_unsorted_iter.rs index 72cf5d4b..de22be63 100644 --- a/src/fixed/query/within_unsorted_iter.rs +++ b/src/fixed/query/within_unsorted_iter.rs @@ -178,7 +178,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/query/nearest_n_within.rs b/src/float/query/nearest_n_within.rs index 6a2c70c9..6fc78994 100644 --- a/src/float/query/nearest_n_within.rs +++ b/src/float/query/nearest_n_within.rs @@ -353,7 +353,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/query/within_unsorted.rs b/src/float/query/within_unsorted.rs index 5d0bab97..226270b9 100644 --- a/src/float/query/within_unsorted.rs +++ b/src/float/query/within_unsorted.rs @@ -209,7 +209,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 7c104c98..9bc8d638 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -235,10 +235,14 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let mut distance = A::zero(); - (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); - }); + let distance = D::dist( + &(0..K) + .map(|dim| remainder_points[dim][idx]) + .collect::>() + .try_into() + .unwrap(), + query, + ); if distance < radius { results.add(NearestNeighbour { @@ -269,10 +273,14 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let mut distance = A::zero(); - (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); - }); + let distance = D::dist( + &(0..K) + .map(|dim| remainder_points[dim][idx]) + .collect::>() + .try_into() + .unwrap(), + query, + ); if distance < radius { let item = remainder_items[idx]; @@ -365,16 +373,20 @@ where Self: Sized, { // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - let mut acc = [0f64; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); - }); - }); - - acc + // TODO: For each point in chunk, compute full distance using D::dist + // This is prob slower than SIMD, but works for all metrics + (0..C) + .map(|idx| { + let point: [Self; K] = (0..K) + .map(|dim| chunk[dim][idx]) + .collect::>() + .try_into() + .unwrap(); + D::dist(&point, query) + }) + .collect::>() + .try_into() + .unwrap() } } @@ -451,16 +463,20 @@ where Self: Sized, { // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - let mut acc = [0f32; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); - }); - }); - - acc + // TODO: For each point in chunk, compute full distance using D::dist + // Same as above, optimisation to be recovered again + (0..C) + .map(|idx| { + let point: [Self; K] = (0..K) + .map(|dim| chunk[dim][idx]) + .collect::>() + .try_into() + .unwrap(); + D::dist(&point, query) + }) + .collect::>() + .try_into() + .unwrap() } } diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index be5b956b..76505046 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -254,7 +254,7 @@ mod tests { for &(p, item) in content { let dist = manhattan(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index 683c58fe..fb93473f 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -256,7 +256,7 @@ mod tests { for &(p, item) in content { let dist = squared_euclidean(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/immutable/float/query/nearest_n_within.rs b/src/immutable/float/query/nearest_n_within.rs index 7b3cfece..d8b1603a 100644 --- a/src/immutable/float/query/nearest_n_within.rs +++ b/src/immutable/float/query/nearest_n_within.rs @@ -231,7 +231,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within.rs b/src/immutable/float/query/within.rs index a04c9f7a..535d56ee 100644 --- a/src/immutable/float/query/within.rs +++ b/src/immutable/float/query/within.rs @@ -206,7 +206,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = Manhattan::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted.rs b/src/immutable/float/query/within_unsorted.rs index 28310c5a..67b4d6c2 100644 --- a/src/immutable/float/query/within_unsorted.rs +++ b/src/immutable/float/query/within_unsorted.rs @@ -204,7 +204,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted_iter.rs b/src/immutable/float/query/within_unsorted_iter.rs index 3f3cceba..dc8aa5bc 100644 --- a/src/immutable/float/query/within_unsorted_iter.rs +++ b/src/immutable/float/query/within_unsorted_iter.rs @@ -201,7 +201,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } From 728f3eef75f596ec31326156aac1fc2b780de85a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 17:38:32 +0100 Subject: [PATCH 05/17] fix: fix over-pruning for L_inf --- src/float/distance.rs | 4 ---- src/float/kdtree.rs | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/float/distance.rs b/src/float/distance.rs index 7cd335f2..6e9b5ff3 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -686,10 +686,6 @@ mod tests { /// Currently failing tests (16 of 96): /// - ImmutableKdTree + NoTies: fails for dim>=2 AND n>=2 (15 failures) /// - ImmutableKdTree + Ties: 1 failure (n=6, dim=2) - /// - /// TODO: Hypothesis: Problem might be `rd_update` in `src/float/kdtree.rs` - /// using `+` aggregation (sensible for sum-based metrics like L1/L2). - /// L_inf would need `max` aggregation. #[rstest] fn test_nearest_n_chebyshev( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, diff --git a/src/float/kdtree.rs b/src/float/kdtree.rs index 4757cb3c..35f5d682 100644 --- a/src/float/kdtree.rs +++ b/src/float/kdtree.rs @@ -73,7 +73,7 @@ impl #[inline] fn rd_update(rd: Self, delta: Self) -> Self { - rd + delta + rd.max(delta) } } From 78cf3738f0d084815ace229d8ae9b7e6750704d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 19:16:36 +0100 Subject: [PATCH 06/17] feat: add `D::accumulate` and `D::IS_MAX_BASED` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - deprecate `rd_update` with `D::accumulate` for consistent sum-based and max-based metrics - conditional logic for SIMD (L1/L2) and general L∞ - differentiate distance accumulation behaviour --- src/common/generate_best_n_within.rs | 2 +- src/common/generate_nearest_n.rs | 2 +- .../generate_nearest_n_within_unsorted.rs | 2 +- src/common/generate_nearest_one.rs | 2 +- src/common/generate_within_unsorted.rs | 2 +- src/common/generate_within_unsorted_iter.rs | 2 +- src/fixed/distance.rs | 14 ++ src/float/distance.rs | 21 +++ src/float/kdtree.rs | 4 +- src/float_leaf_slice/leaf_slice.rs | 143 ++++++++++++------ src/traits.rs | 23 +++ 11 files changed, 160 insertions(+), 57 deletions(-) diff --git a/src/common/generate_best_n_within.rs b/src/common/generate_best_n_within.rs index 8a116688..21c2e576 100644 --- a/src/common/generate_best_n_within.rs +++ b/src/common/generate_best_n_within.rs @@ -76,7 +76,7 @@ macro_rules! generate_best_n_within { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/common/generate_nearest_n.rs b/src/common/generate_nearest_n.rs index 828aec15..adccc0bf 100644 --- a/src/common/generate_nearest_n.rs +++ b/src/common/generate_nearest_n.rs @@ -65,7 +65,7 @@ macro_rules! generate_nearest_n { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if Self::dist_belongs_in_heap(rd, results) { off[split_dim] = new_off; diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index ec7daca2..09d56c4e 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -86,7 +86,7 @@ macro_rules! generate_nearest_n_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/common/generate_nearest_one.rs b/src/common/generate_nearest_one.rs index e4d5c2f8..c5838b20 100644 --- a/src/common/generate_nearest_one.rs +++ b/src/common/generate_nearest_one.rs @@ -67,7 +67,7 @@ macro_rules! generate_nearest_one { nearest = nearest_neighbour; } - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim] = new_off; diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index 33026412..656ab033 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -68,7 +68,7 @@ macro_rules! generate_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index a6d0123b..a44ad861 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -78,7 +78,7 @@ macro_rules! generate_within_unsorted_iter { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 5e901432..6dc132db 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -55,6 +55,13 @@ impl DistanceMetric for Manhattan { b - a } } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } + + const IS_MAX_BASED: bool = false; } /// Returns the squared euclidean distance between two points. @@ -99,4 +106,11 @@ impl DistanceMetric for SquaredEuclidean { let diff: A = a.dist(b); diff * diff } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } + + const IS_MAX_BASED: bool = false; } diff --git a/src/float/distance.rs b/src/float/distance.rs index 6e9b5ff3..8053c812 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -38,6 +38,13 @@ impl DistanceMetric for Manhattan { fn dist1(a: A, b: A) -> A { (a - b).abs() } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } + + const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev / L-infinity distance between two points. @@ -72,6 +79,13 @@ impl DistanceMetric for Chebyshev { fn dist1(a: A, b: A) -> A { (a - b).abs() } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.max(delta) + } + + const IS_MAX_BASED: bool = true; } /// Returns the squared euclidean distance between two points. @@ -106,6 +120,13 @@ impl DistanceMetric for SquaredEuclidean { fn dist1(a: A, b: A) -> A { (a - b) * (a - b) } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd + delta + } + + const IS_MAX_BASED: bool = false; } #[cfg(test)] diff --git a/src/float/kdtree.rs b/src/float/kdtree.rs index 35f5d682..0ea74d6e 100644 --- a/src/float/kdtree.rs +++ b/src/float/kdtree.rs @@ -63,6 +63,7 @@ pub trait Axis: FloatCore + Default + Debug + Copy + Sync + Send + std::ops::Add fn saturating_dist(self, other: Self) -> Self; /// Used in query methods to update the rd value. A saturating add for Fixed and an add for Float + #[deprecated(since = "5.3.0", note = "Use D::accumulate instead")] // TODO: change version number if adding this change - or better so: fully get rid off rd_update fn rd_update(rd: Self, delta: Self) -> Self; } @@ -73,7 +74,8 @@ impl #[inline] fn rd_update(rd: Self, delta: Self) -> Self { - rd.max(delta) + // DEPRECATED: Use D::accumulate instead + rd + delta } } diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 9bc8d638..381f3311 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -235,16 +235,24 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = D::dist( - &(0..K) - .map(|dim| remainder_points[dim][idx]) - .collect::>() - .try_into() - .unwrap(), - query, - ); - - if distance < radius { + let distance = match D::IS_MAX_BASED { + true => D::dist( + &(0..K) + .map(|dim| remainder_points[dim][idx]) + .collect::>() + .try_into() + .unwrap(), + query, + ), + false => { + let mut dist = A::zero(); + (0..K).step_by(1).for_each(|dim| { + dist += D::dist1(remainder_points[dim][idx], query[dim]); + }); + dist + } + }; + if distance <= radius { results.add(NearestNeighbour { distance, item: remainder_items[idx], @@ -273,16 +281,25 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = D::dist( - &(0..K) - .map(|dim| remainder_points[dim][idx]) - .collect::>() - .try_into() - .unwrap(), - query, - ); - - if distance < radius { + let distance = match D::IS_MAX_BASED { + true => D::dist( + &(0..K) + .map(|dim| remainder_points[dim][idx]) + .collect::>() + .try_into() + .unwrap(), + query, + ), + false => { + let mut distance = A::zero(); + (0..K).step_by(1).for_each(|dim| { + distance += D::dist1(remainder_points[dim][idx], query[dim]); + }); + distance + } + }; + + if distance <= radius { let item = remainder_items[idx]; if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -372,21 +389,34 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - // TODO: For each point in chunk, compute full distance using D::dist - // This is prob slower than SIMD, but works for all metrics - (0..C) - .map(|idx| { - let point: [Self; K] = (0..K) - .map(|dim| chunk[dim][idx]) - .collect::>() - .try_into() - .unwrap(); - D::dist(&point, query) - }) - .collect::>() - .try_into() - .unwrap() + if D::IS_MAX_BASED { + // Generic version for max-based metrics (Chebyshev) + // For each point in chunk, compute full distance using D::dist + (0..C) + .map(|idx| { + let point: [Self; K] = (0..K) + .map(|dim| chunk[dim][idx]) + .collect::>() + .try_into() + .unwrap(); + D::dist(&point, query) + }) + .collect::>() + .try_into() + .unwrap() + } else { + // SIMD-optimised version for sum-based metrics (Manhattan, SquaredEuclidean) + // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration + let mut acc = [0f64; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + + (0..C).step_by(1).for_each(|idx| { + acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + }); + }); + acc + } } } @@ -462,21 +492,34 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - // TODO: For each point in chunk, compute full distance using D::dist - // Same as above, optimisation to be recovered again - (0..C) - .map(|idx| { - let point: [Self; K] = (0..K) - .map(|dim| chunk[dim][idx]) - .collect::>() - .try_into() - .unwrap(); - D::dist(&point, query) - }) - .collect::>() - .try_into() - .unwrap() + if D::IS_MAX_BASED { + // Generic version for max-based metrics (Chebyshev) + // For each point in chunk, compute full distance using D::dist + (0..C) + .map(|idx| { + let point: [Self; K] = (0..K) + .map(|dim| chunk[dim][idx]) + .collect::>() + .try_into() + .unwrap(); + D::dist(&point, query) + }) + .collect::>() + .try_into() + .unwrap() + } else { + // SIMD-optimised version for sum-based metrics (Manhattan, SquaredEuclidean) + // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration + let mut acc = [0f32; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + + (0..C).step_by(1).for_each(|idx| { + acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + }); + }); + acc + } } } diff --git a/src/traits.rs b/src/traits.rs index 9561b804..da32a7c9 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -121,6 +121,29 @@ pub trait DistanceMetric { /// to extend the minimum acceptable distance for a node when recursing /// back up the tree) fn dist1(a: A, b: A) -> A; + + /// Accumulates a distance contribution for this metric. + /// + /// For sum-based metrics (Manhattan, SquaredEuclidean), this adds the contribution. + /// For max-based metrics (Chebyshev), this uses max aggregation. + /// + /// This is used in pruning logic to estimate lower bounds on distance to subtrees. + /// + /// # Migration from pre-v5.3.0 code // TODO: Update version number + /// + /// If you have custom distance metrics that worked before v5.3.0, implement this + /// method as `rd + delta` (or `rd.saturating_add(delta)` for fixed-point types) + /// to maintain backward compatible behaviour. + /// + /// Choose based on your distance metric: + /// - **Sum-based (L1, L2)**: Use `rd + delta` or `rd.saturating_add(delta)` + /// - **Max-based (L∞/Chebyshev)**: Use `rd.max(delta)` + fn accumulate(rd: A, delta: A) -> A; + + /// Whether this metric uses max-based aggregation (Chebyshev) instead of sum-based. + /// + /// Max-based metrics (L∞) do not use SIMD-optimised sum accumulation. + const IS_MAX_BASED: bool = false; } #[cfg(test)] From ec7fcb32419228c9f924f15ff88e28509208ad2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 21:38:42 +0100 Subject: [PATCH 07/17] feat: add fixed Chebyshev distance metric - integration `nearest_n` tests (Chebyshev, Manhattan, SquaredEuclidean). --- src/common/generate_nearest_n.rs | 16 +- src/fixed/distance.rs | 285 ++++++++++++++++++ src/hybrid/query/nearest_n.rs | 15 +- .../generate_immutable_best_n_within.rs | 4 +- .../generate_immutable_nearest_n_within.rs | 4 +- .../common/generate_immutable_nearest_one.rs | 4 +- ...generate_immutable_within_unsorted_iter.rs | 2 +- 7 files changed, 314 insertions(+), 16 deletions(-) diff --git a/src/common/generate_nearest_n.rs b/src/common/generate_nearest_n.rs index adccc0bf..f2e6c310 100644 --- a/src/common/generate_nearest_n.rs +++ b/src/common/generate_nearest_n.rs @@ -67,7 +67,7 @@ macro_rules! generate_nearest_n { rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if Self::dist_belongs_in_heap(rd, results) { + if Self::dist_belongs_in_heap(rd, results, D::IS_MAX_BASED) { off[split_dim] = new_off; self.nearest_n_recurse::( query, @@ -94,7 +94,7 @@ macro_rules! generate_nearest_n { .for_each(|(idx, entry)| { let distance: A = D::dist(query, transform(entry)); - if Self::dist_belongs_in_heap(distance, results) { + if Self::dist_belongs_in_heap(distance, results, D::IS_MAX_BASED) { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); @@ -112,9 +112,15 @@ macro_rules! generate_nearest_n { } } - #[inline] - fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { - heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() + #[inline(always)] + fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>, is_max_based: bool) -> bool { + heap.is_empty() || ( + if is_max_based { + dist <= heap.peek().unwrap().distance + } else { + dist < heap.peek().unwrap().distance + } + ) || heap.len() < heap.capacity() } }; } diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 6dc132db..c0b99655 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -64,6 +64,67 @@ impl DistanceMetric for Manhattan { const IS_MAX_BASED: bool = false; } +/// Returns the Chebyshev distance (L-infinity norm) between two points. +/// +/// This is the maximum of the absolute differences between coordinates of points. +/// +/// # Examples +/// +/// ```rust +/// use fixed::types::extra::U0; +/// use fixed::FixedU16; +/// use kiddo::traits::DistanceMetric; +/// use kiddo::fixed::distance::Chebyshev; +/// type Fxd = FixedU16; +/// +/// let ZERO = Fxd::from_num(0); +/// let ONE = Fxd::from_num(1); +/// let TWO = Fxd::from_num(2); +/// +/// assert_eq!(ZERO, Chebyshev::dist(&[ZERO, ZERO], &[ZERO, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ONE])); +/// assert_eq!(TWO, Chebyshev::dist(&[ZERO, ZERO], &[TWO, ONE])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| { + if a_val > b_val { + a_val - b_val + } else { + b_val - a_val + } + }) + .reduce(|a, b| if a > b { a } else { b }) + .unwrap_or(A::ZERO) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + if a > b { + a - b + } else { + b - a + } + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + if rd > delta { + rd + } else { + delta + } + } + + const IS_MAX_BASED: bool = true; +} + /// Returns the squared euclidean distance between two points. /// /// Faster than Euclidean distance due to not needing a square root, but still @@ -114,3 +175,227 @@ impl DistanceMetric for SquaredEuclidean { const IS_MAX_BASED: bool = false; } + +#[cfg(test)] +mod tests { + use super::*; + use fixed::types::extra::U0; + use rstest::rstest; + + type FxdU16 = fixed::FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], ONE)] + #[case([ZERO, ZERO], [TWO, ONE], TWO)] + #[case([ZERO, ZERO], [ONE, TWO], TWO)] + fn test_chebyshev_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, TWO, THREE], THREE)] + #[case([FIVE, FIVE, FIVE], [ONE, TWO, THREE], FOUR)] + fn test_chebyshev_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] + #[case([TWO, THREE], [ONE, ONE], THREE)] + fn test_manhattan_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::fixed::kdtree::KdTree; + use fixed::types::extra::U0; + use fixed::FixedU16; + use rstest::rstest; + + type FxdU16 = FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + enum DataScenario { + NoTies, + Ties, + } + + impl DataScenario { + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => { + vec![vec![ONE], vec![TWO], vec![THREE], vec![FOUR], vec![FIVE]] + } + (DataScenario::NoTies, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![TWO, ZERO], + vec![THREE, ZERO], + vec![FOUR, ZERO], + vec![FIVE, ZERO], + ], + (DataScenario::NoTies, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![TWO, ZERO, ZERO], + vec![THREE, ZERO, ZERO], + vec![FOUR, ZERO, ZERO], + vec![FIVE, ZERO, ZERO], + ], + (DataScenario::Ties, 1) => vec![ + vec![ZERO], + vec![ONE], + vec![ONE], + vec![TWO], + vec![THREE], + vec![THREE], + ], + (DataScenario::Ties, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![ZERO, ONE], + vec![TWO, ZERO], + vec![ZERO, TWO], + vec![TWO, TWO], + ], + (DataScenario::Ties, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![ZERO, ONE, ZERO], + vec![ZERO, ZERO, ONE], + vec![TWO, ZERO, ZERO], + vec![ZERO, TWO, ZERO], + ], + _ => panic!("Unsupported dimension"), + } + } + } + + fn run_test_helper>(dim: usize, scenario: DataScenario, n: usize) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[FxdU16; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [ZERO; 6]; + for (i, &val) in row.iter().enumerate() { + if i < 6 { + p[i] = val; + } + } + points.push(p); + } + + let mut query_arr = [ZERO; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + let expected: Vec<(usize, FxdU16)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + let mut tree: KdTree = KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u32); + } + + let results = tree.nearest_n::(&query_arr, n); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, ZERO, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u32, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + #[rstest] + fn test_nearest_n_chebyshev( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } +} diff --git a/src/hybrid/query/nearest_n.rs b/src/hybrid/query/nearest_n.rs index f6e24fda..710d0a0c 100644 --- a/src/hybrid/query/nearest_n.rs +++ b/src/hybrid/query/nearest_n.rs @@ -93,7 +93,7 @@ where // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if Self::dist_belongs_in_heap(rd, results) { + if Self::dist_belongs_in_heap(rd, results, D::IS_MAX_BASED) { off[split_dim] = new_off; self.nearest_n_recurse( query, @@ -118,7 +118,7 @@ where .enumerate() .for_each(|(idx, entry)| { let distance: A = distance_fn(query, entry); - if Self::dist_belongs_in_heap(distance, results) { + if Self::dist_belongs_in_heap(distance, results, D::IS_MAX_BASED) { let item = unsafe { *leaf_node.content_items.get_unchecked(idx) }; let element = Neighbour { distance, item }; if results.len() < results.capacity() { @@ -134,8 +134,15 @@ where } } - fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { - heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() + #[inline(always)] + fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>, is_max_based: bool) -> bool { + heap.is_empty() || ( + if is_max_based { + dist <= heap.peek().unwrap().distance + } else { + dist < heap.peek().unwrap().distance + } + ) || heap.len() < heap.capacity() } } diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index 3c4ad3bf..e5b58371 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -109,7 +109,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -190,7 +190,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 398697fd..3d3d8435 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -112,7 +112,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; @@ -189,7 +189,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; diff --git a/src/immutable/common/generate_immutable_nearest_one.rs b/src/immutable/common/generate_immutable_nearest_one.rs index 274da9e2..e72007ed 100644 --- a/src/immutable/common/generate_immutable_nearest_one.rs +++ b/src/immutable/common/generate_immutable_nearest_one.rs @@ -109,7 +109,7 @@ macro_rules! generate_immutable_nearest_one { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; @@ -177,7 +177,7 @@ macro_rules! generate_immutable_nearest_one { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; diff --git a/src/immutable/common/generate_immutable_within_unsorted_iter.rs b/src/immutable/common/generate_immutable_within_unsorted_iter.rs index a203d670..94f673fd 100644 --- a/src/immutable/common/generate_immutable_within_unsorted_iter.rs +++ b/src/immutable/common/generate_immutable_within_unsorted_iter.rs @@ -83,7 +83,7 @@ macro_rules! generate_immutable_within_unsorted_iter { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; From ccecab0590ac39525a8f6880a17e3e575395426b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 21:45:15 +0100 Subject: [PATCH 08/17] refactor: in-loop accumulation for max-based metrics --- src/float_leaf_slice/leaf_slice.rs | 103 ++++++++++++----------------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 381f3311..5eed2b90 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -235,23 +235,20 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = match D::IS_MAX_BASED { - true => D::dist( - &(0..K) - .map(|dim| remainder_points[dim][idx]) - .collect::>() - .try_into() - .unwrap(), - query, - ), - false => { - let mut dist = A::zero(); - (0..K).step_by(1).for_each(|dim| { - dist += D::dist1(remainder_points[dim][idx], query[dim]); - }); - dist - } + let distance = if D::IS_MAX_BASED { + let mut dist = A::zero(); + (0..K).step_by(1).for_each(|dim| { + dist = dist.max(D::dist1(remainder_points[dim][idx], query[dim])); + }); + dist + } else { + let mut dist = A::zero(); + (0..K).step_by(1).for_each(|dim| { + dist += D::dist1(remainder_points[dim][idx], query[dim]); + }); + dist }; + if distance <= radius { results.add(NearestNeighbour { distance, @@ -281,22 +278,18 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = match D::IS_MAX_BASED { - true => D::dist( - &(0..K) - .map(|dim| remainder_points[dim][idx]) - .collect::>() - .try_into() - .unwrap(), - query, - ), - false => { - let mut distance = A::zero(); - (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); - }); - distance - } + let distance = if D::IS_MAX_BASED { + let mut dist = A::zero(); + (0..K).step_by(1).for_each(|dim| { + dist = dist.max(D::dist1(remainder_points[dim][idx], query[dim])); + }); + dist + } else { + let mut dist = A::zero(); + (0..K).step_by(1).for_each(|dim| { + dist += D::dist1(remainder_points[dim][idx], query[dim]); + }); + dist }; if distance <= radius { @@ -390,22 +383,15 @@ where Self: Sized, { if D::IS_MAX_BASED { - // Generic version for max-based metrics (Chebyshev) - // For each point in chunk, compute full distance using D::dist - (0..C) - .map(|idx| { - let point: [Self; K] = (0..K) - .map(|dim| chunk[dim][idx]) - .collect::>() - .try_into() - .unwrap(); - D::dist(&point, query) - }) - .collect::>() - .try_into() - .unwrap() + let mut acc = [0f64; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + (0..C).step_by(1).for_each(|idx| { + acc[idx] = acc[idx].max(D::dist1(chunk[dim][idx], qd[idx])); + }); + }); + acc } else { - // SIMD-optimised version for sum-based metrics (Manhattan, SquaredEuclidean) // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f64; C]; (0..K).step_by(1).for_each(|dim| { @@ -493,22 +479,15 @@ where Self: Sized, { if D::IS_MAX_BASED { - // Generic version for max-based metrics (Chebyshev) - // For each point in chunk, compute full distance using D::dist - (0..C) - .map(|idx| { - let point: [Self; K] = (0..K) - .map(|dim| chunk[dim][idx]) - .collect::>() - .try_into() - .unwrap(); - D::dist(&point, query) - }) - .collect::>() - .try_into() - .unwrap() + let mut acc = [0f32; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + (0..C).step_by(1).for_each(|idx| { + acc[idx] = acc[idx].max(D::dist1(chunk[dim][idx], qd[idx])); + }); + }); + acc } else { - // SIMD-optimised version for sum-based metrics (Manhattan, SquaredEuclidean) // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f32; C]; (0..K).step_by(1).for_each(|dim| { From f9feccd1afd744d3f216cdb41e9a37c75e8fee11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sat, 7 Feb 2026 23:42:39 +0100 Subject: [PATCH 09/17] refactor: unify distance accumulation logic with `D::accumulate` --- src/float_leaf_slice/leaf_slice.rs | 91 ++++++++---------------------- 1 file changed, 24 insertions(+), 67 deletions(-) diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 5eed2b90..4da31f83 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -235,20 +235,11 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = if D::IS_MAX_BASED { - let mut dist = A::zero(); - (0..K).step_by(1).for_each(|dim| { - dist = dist.max(D::dist1(remainder_points[dim][idx], query[dim])); - }); - dist - } else { - let mut dist = A::zero(); - (0..K).step_by(1).for_each(|dim| { - dist += D::dist1(remainder_points[dim][idx], query[dim]); - }); - dist - }; - + let mut distance = A::zero(); + (0..K).step_by(1).for_each(|dim| { + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); + }); if distance <= radius { results.add(NearestNeighbour { distance, @@ -278,19 +269,11 @@ where #[allow(clippy::needless_range_loop)] for idx in 0..remainder_items.len() { - let distance = if D::IS_MAX_BASED { - let mut dist = A::zero(); - (0..K).step_by(1).for_each(|dim| { - dist = dist.max(D::dist1(remainder_points[dim][idx], query[dim])); - }); - dist - } else { - let mut dist = A::zero(); - (0..K).step_by(1).for_each(|dim| { - dist += D::dist1(remainder_points[dim][idx], query[dim]); - }); - dist - }; + let mut distance = A::zero(); + (0..K).step_by(1).for_each(|dim| { + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); + }); if distance <= radius { let item = remainder_items[idx]; @@ -382,27 +365,14 @@ where D: DistanceMetric, Self: Sized, { - if D::IS_MAX_BASED { - let mut acc = [0f64; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] = acc[idx].max(D::dist1(chunk[dim][idx], qd[idx])); - }); - }); - acc - } else { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - let mut acc = [0f64; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); - }); + let mut acc = [0f64; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + (0..C).step_by(1).for_each(|idx| { + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); - acc - } + }); + acc } } @@ -478,27 +448,14 @@ where D: DistanceMetric, Self: Sized, { - if D::IS_MAX_BASED { - let mut acc = [0f32; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] = acc[idx].max(D::dist1(chunk[dim][idx], qd[idx])); - }); - }); - acc - } else { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration - let mut acc = [0f32; C]; - (0..K).step_by(1).for_each(|dim| { - let qd = [query[dim]; C]; - - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); - }); + let mut acc = [0f32; C]; + (0..K).step_by(1).for_each(|dim| { + let qd = [query[dim]; C]; + (0..C).step_by(1).for_each(|idx| { + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); - acc - } + }); + acc } } From 53759e27fac701a77a791a5a291d5ed85a57b507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sun, 8 Feb 2026 00:26:21 +0100 Subject: [PATCH 10/17] refactor: remove `D::IS_MAX_BASED`, unify heap logic, doc & test - improve `DistanceMetric` doc - add Gaussian scenario to tests --- src/common/generate_nearest_n.rs | 14 ++----- src/fixed/distance.rs | 6 --- src/float/distance.rs | 43 ++++++++++++--------- src/hybrid/query/nearest_n.rs | 14 ++----- src/traits.rs | 64 ++++++++++++++++++++------------ 5 files changed, 74 insertions(+), 67 deletions(-) diff --git a/src/common/generate_nearest_n.rs b/src/common/generate_nearest_n.rs index f2e6c310..f4efac07 100644 --- a/src/common/generate_nearest_n.rs +++ b/src/common/generate_nearest_n.rs @@ -67,7 +67,7 @@ macro_rules! generate_nearest_n { rd = D::accumulate(rd, D::dist1(new_off, old_off)); - if Self::dist_belongs_in_heap(rd, results, D::IS_MAX_BASED) { + if Self::dist_belongs_in_heap(rd, results) { off[split_dim] = new_off; self.nearest_n_recurse::( query, @@ -94,7 +94,7 @@ macro_rules! generate_nearest_n { .for_each(|(idx, entry)| { let distance: A = D::dist(query, transform(entry)); - if Self::dist_belongs_in_heap(distance, results, D::IS_MAX_BASED) { + if Self::dist_belongs_in_heap(distance, results) { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); @@ -113,14 +113,8 @@ macro_rules! generate_nearest_n { } #[inline(always)] - fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>, is_max_based: bool) -> bool { - heap.is_empty() || ( - if is_max_based { - dist <= heap.peek().unwrap().distance - } else { - dist < heap.peek().unwrap().distance - } - ) || heap.len() < heap.capacity() + fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { + heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } }; } diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index c0b99655..bac0d179 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -60,8 +60,6 @@ impl DistanceMetric for Manhattan { fn accumulate(rd: A, delta: A) -> A { rd.saturating_add(delta) } - - const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -121,8 +119,6 @@ impl DistanceMetric for Chebyshev { delta } } - - const IS_MAX_BASED: bool = true; } /// Returns the squared euclidean distance between two points. @@ -172,8 +168,6 @@ impl DistanceMetric for SquaredEuclidean { fn accumulate(rd: A, delta: A) -> A { rd.saturating_add(delta) } - - const IS_MAX_BASED: bool = false; } #[cfg(test)] diff --git a/src/float/distance.rs b/src/float/distance.rs index 8053c812..5453d103 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -43,8 +43,6 @@ impl DistanceMetric for Manhattan { fn accumulate(rd: A, delta: A) -> A { rd + delta } - - const IS_MAX_BASED: bool = false; } /// Returns the Chebyshev / L-infinity distance between two points. @@ -84,8 +82,6 @@ impl DistanceMetric for Chebyshev { fn accumulate(rd: A, delta: A) -> A { rd.max(delta) } - - const IS_MAX_BASED: bool = true; } /// Returns the squared euclidean distance between two points. @@ -125,8 +121,6 @@ impl DistanceMetric for SquaredEuclidean { fn accumulate(rd: A, delta: A) -> A { rd + delta } - - const IS_MAX_BASED: bool = false; } #[cfg(test)] @@ -490,13 +484,16 @@ mod tests { mod integration_tests { use super::*; - use crate::{ImmutableKdTree, KdTree}; + use crate::KdTree; use rstest::rstest; + use rand::prelude::*; + use rand_distr::Normal; #[derive(Debug, Clone, Copy)] enum DataScenario { NoTies, Ties, + Gaussian, } #[derive(Debug, Clone, Copy)] @@ -508,7 +505,6 @@ mod tests { impl DataScenario { /// Get data scenario /// - /// Data is ordered to appear in increasing distance to the 0-th point. /// Predefined data has input dimension (`dim`) and either /// with `DataScenario::NoTies` or `DataScenario::Ties`. /// @@ -584,6 +580,18 @@ mod tests { vec![0.0, 0.0, 0.0, 1.0], vec![-1.0, 0.0, 0.0, 0.0], ], + (DataScenario::Gaussian, d) => { + let mut rng = StdRng::seed_from_u64(8757); + let normal = Normal::new(1.0, 10.0).unwrap(); + let n_samples = 200; + let mut data = vec![vec![0.0; d]; n_samples]; + for i in 0..n_samples { + for j in 0..d { + data[i][j] = normal.sample(&mut rng); + } + } + data + } _ => panic!("Unsupported dimension {} for scenario {:?}", dim, self), } } @@ -634,7 +642,7 @@ mod tests { } // Calculate ground truth with brute-force approach - let expected: Vec<(usize, f64)> = points + let mut expected: Vec<(usize, f64)> = points .iter() .enumerate() .map(|(i, &point)| { @@ -643,6 +651,8 @@ mod tests { }) .collect(); + expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); println!( @@ -653,14 +663,15 @@ mod tests { // Query based on tree type let results = match tree_type { TreeType::Mutable => { - let mut tree: KdTree = KdTree::new(); + let mut tree: crate::float::kdtree::KdTree = crate::float::kdtree::KdTree::new(); for (i, point) in points.iter().enumerate() { tree.add(point, i as u64); } tree.nearest_n::(&query_arr, n) } TreeType::Immutable => { - let tree: ImmutableKdTree = ImmutableKdTree::new_from_slice(&points); + let tree: crate::immutable::float::kdtree::ImmutableKdTree = + crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) } }; @@ -710,7 +721,7 @@ mod tests { #[rstest] fn test_nearest_n_chebyshev( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -720,7 +731,7 @@ mod tests { #[rstest] fn test_nearest_n_squared_euclidean( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -730,7 +741,7 @@ mod tests { #[rstest] fn test_nearest_n_manhattan( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -1068,10 +1079,8 @@ mod tests { assert!(found_indices.contains(&0)); assert!(found_indices.contains(&1)); - // This assert FAILS - demonstrates the bug - assert!(found_indices.contains(&2)); // currently not included, but should! + assert!(found_indices.contains(&2)); assert!(found_indices.contains(&3)); - // Should NOT include points with Chebyshev distance > 1 assert!(!found_indices.contains(&4)); assert!(!found_indices.contains(&5)); diff --git a/src/hybrid/query/nearest_n.rs b/src/hybrid/query/nearest_n.rs index 710d0a0c..9d28118a 100644 --- a/src/hybrid/query/nearest_n.rs +++ b/src/hybrid/query/nearest_n.rs @@ -93,7 +93,7 @@ where // TODO: switch from dist_fn to a dist trait that can apply to 1D as well as KD // so that updating rd is not hardcoded to sq euclidean rd = rd + new_off * new_off - old_off * old_off; - if Self::dist_belongs_in_heap(rd, results, D::IS_MAX_BASED) { + if Self::dist_belongs_in_heap(rd, results) { off[split_dim] = new_off; self.nearest_n_recurse( query, @@ -118,7 +118,7 @@ where .enumerate() .for_each(|(idx, entry)| { let distance: A = distance_fn(query, entry); - if Self::dist_belongs_in_heap(distance, results, D::IS_MAX_BASED) { + if Self::dist_belongs_in_heap(distance, results) { let item = unsafe { *leaf_node.content_items.get_unchecked(idx) }; let element = Neighbour { distance, item }; if results.len() < results.capacity() { @@ -135,14 +135,8 @@ where } #[inline(always)] - fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>, is_max_based: bool) -> bool { - heap.is_empty() || ( - if is_max_based { - dist <= heap.peek().unwrap().distance - } else { - dist < heap.peek().unwrap().distance - } - ) || heap.len() < heap.capacity() + fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { + heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } } diff --git a/src/traits.rs b/src/traits.rs index da32a7c9..ea12ace0 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -107,43 +107,59 @@ pub(crate) fn is_stem_index>(x: IDX) -> bool { x < ::leaf_offset() } -/// Trait that needs to be implemented by any potential distance -/// metric to be used within queries +/// Defines how distances are measured and compared for k-d tree queries. +/// +/// Implement this trait to use custom distance metrics with [`kiddo:KdTree`](crate::KdTree). +/// +/// # Distance Metrics in k-d Trees +/// +/// **Distance aggregation**: How to combine per-dimension distances into a total distance +/// - Sum-based: `dist(p,q) = Σ |p[i] - q[i]|` (Manhattan, SquaredEuclidean) +/// - Max-based: `dist(p,q) = max_i |p[i] - q[i]|` (Chebyshev/L∞) +/// +/// # Required Methods +/// +/// - [`dist()`]: Compute total distance between two points +/// - [`dist1()`]: Compute per-dimension distance component +/// - [`accumulate()`]: Aggregate distance components (add or max) +/// +/// # Migration from pre-v5.3.0 Code // TODO: update version number for breaking change +/// +/// For custom distance metrics that worked before v5.3.0: +/// - Implement `accumulate()` as `rd + delta` (or `rd.saturating_add(delta)` for fixed-point) +/// pub trait DistanceMetric { - /// returns the distance between two K-d points, as measured - /// by a particular distance metric + /// Returns the distance between two K-d points, as measured by this metric. fn dist(a: &[A; K], b: &[A; K]) -> A; - /// returns the distance between two points along a single axis, - /// as measured by a particular distance metric. + /// Returns the distance between two points along a single dimension. /// - /// (needs to be implemented as it is used by the NN query implementations - /// to extend the minimum acceptable distance for a node when recursing - /// back up the tree) + /// Used internally by NN query implementations to extend the minimum + /// acceptable distance for a node when recursing back up the tree. fn dist1(a: A, b: A) -> A; - /// Accumulates a distance contribution for this metric. + /// Aggregates a distance contribution into a running total. /// - /// For sum-based metrics (Manhattan, SquaredEuclidean), this adds the contribution. - /// For max-based metrics (Chebyshev), this uses max aggregation. + /// This defines how per-dimension distances combine into a total distance. + /// Choose based on your distance metric: + /// + /// - **Sum-based (L1, L2)**: Use `rd + delta` or `rd.saturating_add(delta)` for fixed-point types + /// - **Max-based (L∞/Chebyshev)**: Use `rd.max(delta)` /// - /// This is used in pruning logic to estimate lower bounds on distance to subtrees. + /// The implementation should match the mathematical definition of your metric: + /// - Manhattan: `dist(p,q) = Σ |p[i] - q[i]|` -> accumulate by adding + /// - SquaredEuclidean: `dist(p,q) = Σ (p[i] - q[i])²` -> accumulate by adding + /// - Chebyshev: `dist(p,q) = max_i |p[i] - q[i]|` -> accumulate by taking max + /// - Generalised Minkowski (L_p): `dist(p,q) = (Σ |p[i] - q[i]|^p)^(1/p)`. + /// For k-d tree pruning, use the sum of powers: accumulate by adding. + /// Only the limit p → ∞ (Chebyshev) uses `max`. /// - /// # Migration from pre-v5.3.0 code // TODO: Update version number + /// # Migration from pre-v5.3.0 Code // TODO: update version number for breaking change /// - /// If you have custom distance metrics that worked before v5.3.0, implement this + /// If you have custom distance metrics that worked before v5.3.0, implement this // TODO: update version number for breaking change /// method as `rd + delta` (or `rd.saturating_add(delta)` for fixed-point types) /// to maintain backward compatible behaviour. - /// - /// Choose based on your distance metric: - /// - **Sum-based (L1, L2)**: Use `rd + delta` or `rd.saturating_add(delta)` - /// - **Max-based (L∞/Chebyshev)**: Use `rd.max(delta)` fn accumulate(rd: A, delta: A) -> A; - - /// Whether this metric uses max-based aggregation (Chebyshev) instead of sum-based. - /// - /// Max-based metrics (L∞) do not use SIMD-optimised sum accumulation. - const IS_MAX_BASED: bool = false; } #[cfg(test)] From 8239e39668ee99f7b7d94c39934c2db0e57c3aae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sun, 8 Feb 2026 00:32:24 +0100 Subject: [PATCH 11/17] chore: change test comment & lint --- src/float/distance.rs | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/float/distance.rs b/src/float/distance.rs index 5453d103..76fb8faf 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -485,9 +485,9 @@ mod tests { mod integration_tests { use super::*; use crate::KdTree; - use rstest::rstest; use rand::prelude::*; use rand_distr::Normal; + use rstest::rstest; #[derive(Debug, Clone, Copy)] enum DataScenario { @@ -663,7 +663,8 @@ mod tests { // Query based on tree type let results = match tree_type { TreeType::Mutable => { - let mut tree: crate::float::kdtree::KdTree = crate::float::kdtree::KdTree::new(); + let mut tree: crate::float::kdtree::KdTree = + crate::float::kdtree::KdTree::new(); for (i, point) in points.iter().enumerate() { tree.add(point, i as u64); } @@ -721,7 +722,8 @@ mod tests { #[rstest] fn test_nearest_n_chebyshev( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -731,7 +733,8 @@ mod tests { #[rstest] fn test_nearest_n_squared_euclidean( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -741,7 +744,8 @@ mod tests { #[rstest] fn test_nearest_n_manhattan( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, - #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] scenario: DataScenario, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, #[values(1, 2, 3, 4, 5, 6)] n: usize, #[values(1, 2, 3, 4)] dim: usize, ) { @@ -1068,13 +1072,8 @@ mod tests { .unwrap_or(std::cmp::Ordering::Equal) }); - // NOTE: This test demonstrates a known limitation with Chebyshev distance: - // The k-d tree query logic uses dist1 for pruning, which is incorrect for Chebyshev. - // Point at [1.0, 0.0] (index 2) has Chebyshev distance exactly 1.0 but is NOT found. - - // Should include points with Chebyshev distance <= 1 // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) - // But ACTUALLY FINDS: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) + // For `<=5.2.4` found: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) let found_indices: Vec = results.iter().map(|r| r.item).collect(); assert!(found_indices.contains(&0)); From 142a74e43a9043d29df9d5e39a0566c067abe815 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Sun, 8 Feb 2026 00:53:34 +0100 Subject: [PATCH 12/17] refactor: make metric property tests reusable --- src/float/distance.rs | 125 +++++++++++++++++++----------------------- 1 file changed, 57 insertions(+), 68 deletions(-) diff --git a/src/float/distance.rs b/src/float/distance.rs index 76fb8faf..932b73bb 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -128,6 +128,63 @@ mod tests { use super::*; use rstest::rstest; + mod common_metric_tests { + use super::*; + + #[rstest] + #[case::zeros_1d([0.0f32], [0.0f32])] + #[case::normal_1d([1.0f32], [2.0f32])] + #[case::neg_1d([-1.0f32], [1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32], [0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32], [3.0f32, 4.0f32])] + #[case::large_2d([1e30f32, 1e30f32], [-1e30f32, -1e30f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32], [0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32])] + #[case::zeros_4d([0.0f32; 4], [0.0f32; 4])] + #[case::normal_4d([1.0f32; 4], [2.0f32; 4])] + #[case::zeros_5d([0.0f32; 5], [0.0f32; 5])] + #[case::normal_5d([1.0f32; 5], [2.0f32; 5])] + fn test_metric_non_negativity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + let distance = D::dist(&a, &b); + assert!(distance >= A::zero()); + } + + #[rstest] + #[case::zeros_1d([0.0f32])] + #[case::normal_1d([1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32])] + #[case::zeros_4d([0.0f32; 4])] + #[case::zeros_5d([0.0f32; 5])] + fn test_metric_identity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + ) { + assert_eq!(D::dist(&a, &a), A::zero()); + } + + #[rstest] + #[case::normal_1d([1.0f64], [2.0f64])] + #[case::neg_1d([-1.0f64], [1.0f64])] + #[case::normal_2d([1.0f64, 2.0f64], [3.0f64, 4.0f64])] + #[case::normal_3d([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64])] + #[case::normal_4d([1.0f64; 4], [2.0f64; 4])] + #[case::normal_5d([1.0f64; 5], [2.0f64; 5])] + fn test_metric_symmetry>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + assert_eq!(D::dist(&a, &b), D::dist(&b, &a)); + } + } + mod manhattan_tests { use super::*; @@ -210,28 +267,6 @@ mod tests { 2000.0f32 ); // large values } - - #[test] - fn test_manhattan_symmetry() { - let a = [1.0f64, 2.0f64, 3.0f64]; - let b = [4.0f64, 5.0f64, 6.0f64]; - - assert_eq!(Manhattan::dist(&a, &b), Manhattan::dist(&b, &a)); - } - - #[test] - fn test_manhattan_identity() { - let a = [1.0f32, 2.0f32, 3.0f32]; - assert_eq!(Manhattan::dist(&a, &a), 0.0f32); - } - - #[test] - fn test_manhattan_non_negativity() { - let a = [1.0f32, 2.0f32]; - let b = [3.0f32, 4.0f32]; - let distance = Manhattan::dist(&a, &b); - assert!(distance >= 0.0f32); - } } mod squared_euclidean_tests { @@ -302,31 +337,6 @@ mod tests { ); // large values (20^2) } - #[test] - fn test_squared_euclidean_symmetry() { - let a = [1.0f64, 2.0f64, 3.0f64]; - let b = [4.0f64, 5.0f64, 6.0f64]; - - assert_eq!( - SquaredEuclidean::dist(&a, &b), - SquaredEuclidean::dist(&b, &a) - ); - } - - #[test] - fn test_squared_euclidean_identity() { - let a = [1.0f32, 2.0f32, 3.0f32]; - assert_eq!(SquaredEuclidean::dist(&a, &a), 0.0f32); - } - - #[test] - fn test_squared_euclidean_non_negativity() { - let a = [1.0f32, 2.0f32]; - let b = [3.0f32, 4.0f32]; - let distance = SquaredEuclidean::dist(&a, &b); - assert!(distance >= 0.0f32); - } - #[test] fn test_squared_euclidean_triangle_inequality_property() { // Test that squared Euclidean distance preserves ordering @@ -416,27 +426,6 @@ mod tests { assert_eq!(>::dist1(a, b), expected); } - #[test] - fn test_chebyshev_symmetry() { - let a = [1.0f64, 2.0f64, 3.0f64]; - let b = [4.0f64, 5.0f64, 6.0f64]; - assert_eq!(Chebyshev::dist(&a, &b), Chebyshev::dist(&b, &a)); - } - - #[test] - fn test_chebyshev_identity() { - let a = [1.0f32, 2.0f32, 3.0f32]; - assert_eq!(Chebyshev::dist(&a, &a), 0.0f32); - } - - #[test] - fn test_chebyshev_non_negativity() { - let a = [1.0f32, 2.0f32]; - let b = [3.0f32, 4.0f32]; - let distance = Chebyshev::dist(&a, &b); - assert!(distance >= 0.0f32); - } - #[test] fn test_chebyshev_max_property() { // Test that Chebyshev correctly finds the maximum difference From faa6dab401d8a8c69c9e1be2cec5e27234a5d2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 12 Feb 2026 23:54:25 +0100 Subject: [PATCH 13/17] chore: add default implementation of `accumulate` to `DistanceMetric` trait - improve test coverage --- src/fixed/distance.rs | 95 ++++++++++++++++++++++++++++++++++++++----- src/float/distance.rs | 10 ----- src/traits.rs | 63 +++++++++++++++++++++++----- 3 files changed, 137 insertions(+), 31 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index bac0d179..b6275a3e 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -55,11 +55,6 @@ impl DistanceMetric for Manhattan { b - a } } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -163,11 +158,6 @@ impl DistanceMetric for SquaredEuclidean { let diff: A = a.dist(b); diff * diff } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd.saturating_add(delta) - } } #[cfg(test)] @@ -225,6 +215,91 @@ mod tests { ) { assert_eq!(Manhattan::dist(&a, &b), expected); } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], FxdU16::lit("2"))] + #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] + #[case([ONE, TWO], [TWO, ONE], FxdU16::lit("2"))] + fn test_squared_euclidean_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, ZERO, ZERO], ONE)] + #[case([ONE, ONE, ONE], [TWO, TWO, TWO], THREE)] + fn test_squared_euclidean_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::diff(THREE, ONE, TWO)] + fn test_manhattan_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_chebyshev_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_squared_euclidean_dist1( + #[case] a: FxdU16, + #[case] b: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero_one(ZERO, ONE, ONE)] + #[case::one_zero(ONE, ZERO, ONE)] + #[case::first_larger(ONE, TWO, TWO)] + #[case::second_larger(TWO, ONE, TWO)] + fn test_chebyshev_accumulate( + #[case] rd: FxdU16, + #[case] delta: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } } #[cfg(test)] diff --git a/src/float/distance.rs b/src/float/distance.rs index 932b73bb..37beb03d 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -38,11 +38,6 @@ impl DistanceMetric for Manhattan { fn dist1(a: A, b: A) -> A { (a - b).abs() } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd + delta - } } /// Returns the Chebyshev / L-infinity distance between two points. @@ -116,11 +111,6 @@ impl DistanceMetric for SquaredEuclidean { fn dist1(a: A, b: A) -> A { (a - b) * (a - b) } - - #[inline] - fn accumulate(rd: A, delta: A) -> A { - rd + delta - } } #[cfg(test)] diff --git a/src/traits.rs b/src/traits.rs index ea12ace0..81da5347 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -123,11 +123,6 @@ pub(crate) fn is_stem_index>(x: IDX) -> bool { /// - [`dist1()`]: Compute per-dimension distance component /// - [`accumulate()`]: Aggregate distance components (add or max) /// -/// # Migration from pre-v5.3.0 Code // TODO: update version number for breaking change -/// -/// For custom distance metrics that worked before v5.3.0: -/// - Implement `accumulate()` as `rd + delta` (or `rd.saturating_add(delta)` for fixed-point) -/// pub trait DistanceMetric { /// Returns the distance between two K-d points, as measured by this metric. fn dist(a: &[A; K], b: &[A; K]) -> A; @@ -154,18 +149,23 @@ pub trait DistanceMetric { /// For k-d tree pruning, use the sum of powers: accumulate by adding. /// Only the limit p → ∞ (Chebyshev) uses `max`. /// - /// # Migration from pre-v5.3.0 Code // TODO: update version number for breaking change - /// - /// If you have custom distance metrics that worked before v5.3.0, implement this // TODO: update version number for breaking change - /// method as `rd + delta` (or `rd.saturating_add(delta)` for fixed-point types) - /// to maintain backward compatible behaviour. - fn accumulate(rd: A, delta: A) -> A; + /// The default implementation uses regular addition (`rd + delta`), which works for + /// both integer and floating-point types. For fixed-point types where overflow is a + /// concern, override this with `rd.saturating_add(delta)`. + fn accumulate(rd: A, delta: A) -> A + where + A: std::ops::Add, + { + rd + delta + } } #[cfg(test)] mod tests { + use super::DistanceMetric; use crate::traits::Index; + use rstest::rstest; #[test] fn test_u16() { @@ -197,4 +197,45 @@ mod tests { (u32::MAX - u32::MAX.overflowing_shr(1).0).saturating_mul(bucket_size); assert_eq!(capacity_with_bucket_size, u32::MAX); } + + struct TestMetricU32; + struct TestMetricI64; + + impl DistanceMetric for TestMetricU32 { + fn dist(_a: &[u32; K], _b: &[u32; K]) -> u32 { + 0 + } + fn dist1(a: u32, _b: u32) -> u32 { + a + } + } + + impl DistanceMetric for TestMetricI64 { + fn dist(_a: &[i64; K], _b: &[i64; K]) -> i64 { + 0 + } + fn dist1(a: i64, _b: i64) -> i64 { + a + } + } + + #[rstest] + #[case(5u32, 3u32, 8u32)] + #[case(10u32, 20u32, 30u32)] + fn test_default_accumulate_u32(#[case] rd: u32, #[case] delta: u32, #[case] expected: u32) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } + + #[rstest] + #[case(10i64, 20i64, 30i64)] + #[case(100i64, 200i64, 300i64)] + fn test_default_accumulate_i64(#[case] rd: i64, #[case] delta: i64, #[case] expected: i64) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } } From 2b4dc673a611057ba2bf9fcf42a6cbcb38bf7a92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Fri, 13 Feb 2026 01:04:21 +0100 Subject: [PATCH 14/17] chore: saturating add for fixed metrics --- src/fixed/distance.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index b6275a3e..c2a6bb9d 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -55,6 +55,11 @@ impl DistanceMetric for Manhattan { b - a } } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } } /// Returns the Chebyshev distance (L-infinity norm) between two points. @@ -158,6 +163,11 @@ impl DistanceMetric for SquaredEuclidean { let diff: A = a.dist(b); diff * diff } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } } #[cfg(test)] From 0de47af473f557e8ee7f3ce29d3a47fba992a268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Thu, 26 Feb 2026 15:30:21 +0100 Subject: [PATCH 15/17] chore: remove outdated docstring test information --- src/float/distance.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/float/distance.rs b/src/float/distance.rs index 37beb03d..6818cc10 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -688,16 +688,6 @@ mod tests { /// /// Test matrix covering all combinations of mutable/immutable trees, /// data scenarios (with/out ties), dimensions, and neighbor query counts. - /// - /// Currently passing tests: - /// - All MutableKdTree tests pass - /// - ImmutableKdTree with NoTies: - /// - Pass for when just querying the root n=1 or dim=1 - /// - ImmutableKdTree with Ties: Several pass (one edge case failure for n=6, dim=2) - /// - /// Currently failing tests (16 of 96): - /// - ImmutableKdTree + NoTies: fails for dim>=2 AND n>=2 (15 failures) - /// - ImmutableKdTree + Ties: 1 failure (n=6, dim=2) #[rstest] fn test_nearest_n_chebyshev( #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, From f33b6e9149d1766f7925e61cbc17d8dee1aa1dac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Fri, 27 Feb 2026 09:40:44 +0100 Subject: [PATCH 16/17] chore: replace `FxdU16::lit("2")` with `TWO` --- src/fixed/distance.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index c2a6bb9d..f992a132 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -230,9 +230,9 @@ mod tests { #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] #[case([ZERO, ZERO], [ONE, ZERO], ONE)] #[case([ZERO, ZERO], [ZERO, ONE], ONE)] - #[case([ZERO, ZERO], [ONE, ONE], FxdU16::lit("2"))] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] - #[case([ONE, TWO], [TWO, ONE], FxdU16::lit("2"))] + #[case([ONE, TWO], [TWO, ONE], TWO)] fn test_squared_euclidean_distance_2d( #[case] a: [FxdU16; 2], #[case] b: [FxdU16; 2], From c4e8db94c83bbf6be23abfa99a952985205796d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlson=20B=C3=BCth?= Date: Fri, 27 Feb 2026 09:54:24 +0100 Subject: [PATCH 17/17] fix: use saturating arithmetic for fixed-point distances - `A::dist` - `saturating_mul` --- src/fixed/distance.rs | 32 ++++++-------------------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index f992a132..5fb3693d 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -37,23 +37,13 @@ impl DistanceMetric for Manhattan { fn dist(a: &[A; K], b: &[A; K]) -> A { a.iter() .zip(b.iter()) - .map(|(&a_val, &b_val)| { - if a_val > b_val { - a_val - b_val - } else { - b_val - a_val - } - }) + .map(|(&a_val, &b_val)| a_val.dist(b_val)) .fold(A::ZERO, |a, b| a.saturating_add(b)) } #[inline] fn dist1(a: A, b: A) -> A { - if a > b { - a - b - } else { - b - a - } + a.dist(b) } #[inline] @@ -91,24 +81,14 @@ impl DistanceMetric for Chebyshev { fn dist(a: &[A; K], b: &[A; K]) -> A { a.iter() .zip(b.iter()) - .map(|(&a_val, &b_val)| { - if a_val > b_val { - a_val - b_val - } else { - b_val - a_val - } - }) + .map(|(&a_val, &b_val)| a_val.dist(b_val)) .reduce(|a, b| if a > b { a } else { b }) .unwrap_or(A::ZERO) } #[inline] fn dist1(a: A, b: A) -> A { - if a > b { - a - b - } else { - b - a - } + a.dist(b) } #[inline] @@ -153,7 +133,7 @@ impl DistanceMetric for SquaredEuclidean { .zip(b.iter()) .map(|(&a_val, &b_val)| { let diff: A = a_val.dist(b_val); - diff * diff + diff.saturating_mul(diff) }) .fold(A::ZERO, |a, b| a.saturating_add(b)) } @@ -161,7 +141,7 @@ impl DistanceMetric for SquaredEuclidean { #[inline] fn dist1(a: A, b: A) -> A { let diff: A = a.dist(b); - diff * diff + diff.saturating_mul(diff) } #[inline]