diff --git a/src/common/generate_best_n_within.rs b/src/common/generate_best_n_within.rs index 8a116688..21c2e576 100644 --- a/src/common/generate_best_n_within.rs +++ b/src/common/generate_best_n_within.rs @@ -76,7 +76,7 @@ macro_rules! generate_best_n_within { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/common/generate_nearest_n.rs b/src/common/generate_nearest_n.rs index 828aec15..f4efac07 100644 --- a/src/common/generate_nearest_n.rs +++ b/src/common/generate_nearest_n.rs @@ -65,7 +65,7 @@ macro_rules! generate_nearest_n { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if Self::dist_belongs_in_heap(rd, results) { off[split_dim] = new_off; @@ -112,7 +112,7 @@ macro_rules! generate_nearest_n { } } - #[inline] + #[inline(always)] fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } diff --git a/src/common/generate_nearest_n_within_unsorted.rs b/src/common/generate_nearest_n_within_unsorted.rs index c3654cd8..09d56c4e 100644 --- a/src/common/generate_nearest_n_within_unsorted.rs +++ b/src/common/generate_nearest_n_within_unsorted.rs @@ -86,7 +86,7 @@ macro_rules! generate_nearest_n_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -116,7 +116,7 @@ macro_rules! generate_nearest_n_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_nearest_one.rs b/src/common/generate_nearest_one.rs index e4d5c2f8..c5838b20 100644 --- a/src/common/generate_nearest_one.rs +++ b/src/common/generate_nearest_one.rs @@ -67,7 +67,7 @@ macro_rules! generate_nearest_one { nearest = nearest_neighbour; } - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim] = new_off; diff --git a/src/common/generate_within_unsorted.rs b/src/common/generate_within_unsorted.rs index bc9442cc..656ab033 100644 --- a/src/common/generate_within_unsorted.rs +++ b/src/common/generate_within_unsorted.rs @@ -68,7 +68,7 @@ macro_rules! generate_within_unsorted { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -98,7 +98,7 @@ macro_rules! generate_within_unsorted { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/common/generate_within_unsorted_iter.rs b/src/common/generate_within_unsorted_iter.rs index f5c94ea9..a44ad861 100644 --- a/src/common/generate_within_unsorted_iter.rs +++ b/src/common/generate_within_unsorted_iter.rs @@ -78,7 +78,7 @@ macro_rules! generate_within_unsorted_iter { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -108,7 +108,7 @@ macro_rules! generate_within_unsorted_iter { .for_each(|(idx, entry)| { let distance = D::dist(query, transform(entry)); - if distance < radius { + if distance <= radius { let item = unsafe { leaf_node.content_items.get_unchecked(idx) }; let item = *transform(item); diff --git a/src/fixed/distance.rs b/src/fixed/distance.rs index 5e901432..5fb3693d 100644 --- a/src/fixed/distance.rs +++ b/src/fixed/distance.rs @@ -37,22 +37,66 @@ impl DistanceMetric for Manhattan { fn dist(a: &[A; K], b: &[A; K]) -> A { a.iter() .zip(b.iter()) - .map(|(&a_val, &b_val)| { - if a_val > b_val { - a_val - b_val - } else { - b_val - a_val - } - }) + .map(|(&a_val, &b_val)| a_val.dist(b_val)) .fold(A::ZERO, |a, b| a.saturating_add(b)) } #[inline] fn dist1(a: A, b: A) -> A { - if a > b { - a - b + a.dist(b) + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } +} + +/// Returns the Chebyshev distance (L-infinity norm) between two points. +/// +/// This is the maximum of the absolute differences between coordinates of points. +/// +/// # Examples +/// +/// ```rust +/// use fixed::types::extra::U0; +/// use fixed::FixedU16; +/// use kiddo::traits::DistanceMetric; +/// use kiddo::fixed::distance::Chebyshev; +/// type Fxd = FixedU16; +/// +/// let ZERO = Fxd::from_num(0); +/// let ONE = Fxd::from_num(1); +/// let TWO = Fxd::from_num(2); +/// +/// assert_eq!(ZERO, Chebyshev::dist(&[ZERO, ZERO], &[ZERO, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ZERO])); +/// assert_eq!(ONE, Chebyshev::dist(&[ZERO, ZERO], &[ONE, ONE])); +/// assert_eq!(TWO, Chebyshev::dist(&[ZERO, ZERO], &[TWO, ONE])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| a_val.dist(b_val)) + .reduce(|a, b| if a > b { a } else { b }) + .unwrap_or(A::ZERO) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + a.dist(b) + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + if rd > delta { + rd } else { - b - a + delta } } } @@ -89,7 +133,7 @@ impl DistanceMetric for SquaredEuclidean { .zip(b.iter()) .map(|(&a_val, &b_val)| { let diff: A = a_val.dist(b_val); - diff * diff + diff.saturating_mul(diff) }) .fold(A::ZERO, |a, b| a.saturating_add(b)) } @@ -97,6 +141,320 @@ impl DistanceMetric for SquaredEuclidean { #[inline] fn dist1(a: A, b: A) -> A { let diff: A = a.dist(b); - diff * diff + diff.saturating_mul(diff) + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.saturating_add(delta) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use fixed::types::extra::U0; + use rstest::rstest; + + type FxdU16 = fixed::FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], ONE)] + #[case([ZERO, ZERO], [TWO, ONE], TWO)] + #[case([ZERO, ZERO], [ONE, TWO], TWO)] + fn test_chebyshev_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, TWO, THREE], THREE)] + #[case([FIVE, FIVE, FIVE], [ONE, TWO, THREE], FOUR)] + fn test_chebyshev_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] + #[case([TWO, THREE], [ONE, ONE], THREE)] + fn test_manhattan_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO], [ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO], [ONE, ZERO], ONE)] + #[case([ZERO, ZERO], [ZERO, ONE], ONE)] + #[case([ZERO, ZERO], [ONE, ONE], TWO)] + #[case([TWO, TWO], [ZERO, ZERO], FxdU16::lit("8"))] + #[case([ONE, TWO], [TWO, ONE], TWO)] + fn test_squared_euclidean_distance_2d( + #[case] a: [FxdU16; 2], + #[case] b: [FxdU16; 2], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([ZERO, ZERO, ZERO], [ZERO, ZERO, ZERO], ZERO)] + #[case([ZERO, ZERO, ZERO], [ONE, ZERO, ZERO], ONE)] + #[case([ONE, ONE, ONE], [TWO, TWO, TWO], THREE)] + fn test_squared_euclidean_distance_3d( + #[case] a: [FxdU16; 3], + #[case] b: [FxdU16; 3], + #[case] expected: FxdU16, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::diff(THREE, ONE, TWO)] + fn test_manhattan_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_chebyshev_dist1(#[case] a: FxdU16, #[case] b: FxdU16, #[case] expected: FxdU16) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero(ZERO, ZERO, ZERO)] + #[case::pos(ONE, ZERO, ONE)] + #[case::neg(ZERO, ONE, ONE)] + #[case::a_larger(TWO, ONE, ONE)] + #[case::b_larger(ONE, TWO, ONE)] + fn test_squared_euclidean_dist1( + #[case] a: FxdU16, + #[case] b: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::dist1(a, b), + expected + ); + } + + #[rstest] + #[case::zero_one(ZERO, ONE, ONE)] + #[case::one_zero(ONE, ZERO, ONE)] + #[case::first_larger(ONE, TWO, TWO)] + #[case::second_larger(TWO, ONE, TWO)] + fn test_chebyshev_accumulate( + #[case] rd: FxdU16, + #[case] delta: FxdU16, + #[case] expected: FxdU16, + ) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } +} + +#[cfg(test)] +mod integration_tests { + use super::*; + use crate::fixed::kdtree::KdTree; + use fixed::types::extra::U0; + use fixed::FixedU16; + use rstest::rstest; + + type FxdU16 = FixedU16; + + const ZERO: FxdU16 = FxdU16::ZERO; + const ONE: FxdU16 = FxdU16::lit("1"); + const TWO: FxdU16 = FxdU16::lit("2"); + const THREE: FxdU16 = FxdU16::lit("3"); + const FOUR: FxdU16 = FxdU16::lit("4"); + const FIVE: FxdU16 = FxdU16::lit("5"); + + enum DataScenario { + NoTies, + Ties, + } + + impl DataScenario { + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => { + vec![vec![ONE], vec![TWO], vec![THREE], vec![FOUR], vec![FIVE]] + } + (DataScenario::NoTies, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![TWO, ZERO], + vec![THREE, ZERO], + vec![FOUR, ZERO], + vec![FIVE, ZERO], + ], + (DataScenario::NoTies, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![TWO, ZERO, ZERO], + vec![THREE, ZERO, ZERO], + vec![FOUR, ZERO, ZERO], + vec![FIVE, ZERO, ZERO], + ], + (DataScenario::Ties, 1) => vec![ + vec![ZERO], + vec![ONE], + vec![ONE], + vec![TWO], + vec![THREE], + vec![THREE], + ], + (DataScenario::Ties, 2) => vec![ + vec![ZERO, ZERO], + vec![ONE, ZERO], + vec![ZERO, ONE], + vec![TWO, ZERO], + vec![ZERO, TWO], + vec![TWO, TWO], + ], + (DataScenario::Ties, 3) => vec![ + vec![ZERO, ZERO, ZERO], + vec![ONE, ZERO, ZERO], + vec![ZERO, ONE, ZERO], + vec![ZERO, ZERO, ONE], + vec![TWO, ZERO, ZERO], + vec![ZERO, TWO, ZERO], + ], + _ => panic!("Unsupported dimension"), + } + } + } + + fn run_test_helper>(dim: usize, scenario: DataScenario, n: usize) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[FxdU16; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [ZERO; 6]; + for (i, &val) in row.iter().enumerate() { + if i < 6 { + p[i] = val; + } + } + points.push(p); + } + + let mut query_arr = [ZERO; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + let expected: Vec<(usize, FxdU16)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + let mut tree: KdTree = KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u32); + } + + let results = tree.nearest_n::(&query_arr, n); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, ZERO, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u32, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + #[rstest] + fn test_nearest_n_chebyshev( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(DataScenario::NoTies, DataScenario::Ties)] scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3)] dim: usize, + ) { + run_test_helper::(dim, scenario, n); } } diff --git a/src/fixed/query/within.rs b/src/fixed/query/within.rs index 01848e88..094211af 100644 --- a/src/fixed/query/within.rs +++ b/src/fixed/query/within.rs @@ -163,7 +163,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted.rs b/src/fixed/query/within_unsorted.rs index 3075cc15..b5a1dcb8 100644 --- a/src/fixed/query/within_unsorted.rs +++ b/src/fixed/query/within_unsorted.rs @@ -167,7 +167,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/fixed/query/within_unsorted_iter.rs b/src/fixed/query/within_unsorted_iter.rs index 72cf5d4b..de22be63 100644 --- a/src/fixed/query/within_unsorted_iter.rs +++ b/src/fixed/query/within_unsorted_iter.rs @@ -178,7 +178,7 @@ mod tests { for &(p, item) in content { let dist = Manhattan::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/distance.rs b/src/float/distance.rs index 3784f946..6818cc10 100644 --- a/src/float/distance.rs +++ b/src/float/distance.rs @@ -40,6 +40,45 @@ impl DistanceMetric for Manhattan { } } +/// Returns the Chebyshev / L-infinity distance between two points. +/// +/// Chebyshev distance is the maximum absolute difference along any axis. +/// Also known as chessboard distance or L-infinity norm. +/// +/// re-exported as `kiddo::Chebyshev` for convenience +/// +/// # Examples +/// +/// ```rust +/// use kiddo::traits::DistanceMetric; +/// use kiddo::Chebyshev; +/// +/// assert_eq!(0f32, Chebyshev::dist(&[0f32, 0f32], &[0f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 0f32])); +/// assert_eq!(1f32, Chebyshev::dist(&[0f32, 0f32], &[1f32, 1f32])); +/// ``` +pub struct Chebyshev {} + +impl DistanceMetric for Chebyshev { + #[inline] + fn dist(a: &[A; K], b: &[A; K]) -> A { + a.iter() + .zip(b.iter()) + .map(|(&a_val, &b_val)| (a_val - b_val).abs()) + .fold(A::zero(), |acc, val| acc.max(val)) + } + + #[inline] + fn dist1(a: A, b: A) -> A { + (a - b).abs() + } + + #[inline] + fn accumulate(rd: A, delta: A) -> A { + rd.max(delta) + } +} + /// Returns the squared euclidean distance between two points. /// /// Faster than Euclidean distance due to not needing a square root, but still @@ -73,3 +112,997 @@ impl DistanceMetric for SquaredEuclidean { (a - b) * (a - b) } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + mod common_metric_tests { + use super::*; + + #[rstest] + #[case::zeros_1d([0.0f32], [0.0f32])] + #[case::normal_1d([1.0f32], [2.0f32])] + #[case::neg_1d([-1.0f32], [1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32], [0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32], [3.0f32, 4.0f32])] + #[case::large_2d([1e30f32, 1e30f32], [-1e30f32, -1e30f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32], [0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32], [4.0f32, 5.0f32, 6.0f32])] + #[case::zeros_4d([0.0f32; 4], [0.0f32; 4])] + #[case::normal_4d([1.0f32; 4], [2.0f32; 4])] + #[case::zeros_5d([0.0f32; 5], [0.0f32; 5])] + #[case::normal_5d([1.0f32; 5], [2.0f32; 5])] + fn test_metric_non_negativity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + let distance = D::dist(&a, &b); + assert!(distance >= A::zero()); + } + + #[rstest] + #[case::zeros_1d([0.0f32])] + #[case::normal_1d([1.0f32])] + #[case::zeros_2d([0.0f32, 0.0f32])] + #[case::normal_2d([1.0f32, 2.0f32])] + #[case::zeros_3d([0.0f32, 0.0f32, 0.0f32])] + #[case::normal_3d([1.0f32, 2.0f32, 3.0f32])] + #[case::zeros_4d([0.0f32; 4])] + #[case::zeros_5d([0.0f32; 5])] + fn test_metric_identity>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + ) { + assert_eq!(D::dist(&a, &a), A::zero()); + } + + #[rstest] + #[case::normal_1d([1.0f64], [2.0f64])] + #[case::neg_1d([-1.0f64], [1.0f64])] + #[case::normal_2d([1.0f64, 2.0f64], [3.0f64, 4.0f64])] + #[case::normal_3d([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64])] + #[case::normal_4d([1.0f64; 4], [2.0f64; 4])] + #[case::normal_5d([1.0f64; 5], [2.0f64; 5])] + fn test_metric_symmetry>( + #[values(Manhattan {}, SquaredEuclidean {}, Chebyshev {})] _metric: D, + #[case] a: [A; K], + #[case] b: [A; K], + ) { + assert_eq!(D::dist(&a, &b), D::dist(&b, &a)); + } + } + + mod manhattan_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 4.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 4.0f32)] // fractional values + fn test_manhattan_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 6.0f64)] // 3D diagonal + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 9.0f64)] // 3D offset + fn test_manhattan_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_manhattan_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 16.0f32; // |5-1| + |6-2| + |7-3| + |8-4| = 4+4+4+4 = 16 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 25.0f64; // |5-0| + |6-1| + |7-2| + |8-3| + |9-4| = 5+5+5+5+5 = 25 + assert_eq!(Manhattan::dist(&a, &b), expected); + } + + #[test] + fn test_manhattan_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 6.0f32 + ); // fractional negative to positive + assert_eq!( + >::dist1(1000.0f32, -1000.0f32), + 2000.0f32 + ); // large values + } + } + + mod squared_euclidean_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 2.0f32)] // diagonal (1^2 + 1^2) + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 8.0f32)] // negative to positive (2^2 + 2^2) + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 8.0f32)] // fractional values (2^2 + 2^2) + #[case([0.0f32, 0.0f32], [3.0f32, 4.0f32], 25.0f32)] // 3-4-5 triangle + fn test_squared_euclidean_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 2.0f64], 9.0f64)] // 3D (1^2 + 2^2 + 2^2) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 27.0f64)] // 3D offset (3^2 + 3^2 + 3^2) + fn test_squared_euclidean_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 25.0f32)] // 1D positive (5^2) + #[case([5.0f32], [0.0f32], 25.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 100.0f32)] // 1D negative to positive (10^2) + fn test_squared_euclidean_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(SquaredEuclidean::dist(&a, &b), expected); + } + + #[test] + fn test_squared_euclidean_dist1() { + assert_eq!( + >::dist1(0.0f32, 0.0f32), + 0.0f32 + ); // zero difference + assert_eq!( + >::dist1(1.0f32, 0.0f32), + 1.0f32 + ); // positive difference + assert_eq!( + >::dist1(0.0f32, 1.0f32), + 1.0f32 + ); // negative difference (reversed) + assert_eq!( + >::dist1(-2.5f32, 3.5f32), + 36.0f32 + ); // fractional negative to positive (6^2) + assert_eq!( + >::dist1(10.0f32, -10.0f32), + 400.0f32 + ); // large values (20^2) + } + + #[test] + fn test_squared_euclidean_triangle_inequality_property() { + // Test that squared Euclidean distance preserves ordering + let a = [0.0f32, 0.0f32]; + let b = [1.0f32, 0.0f32]; + let c = [1.0f32, 1.0f32]; + + let dist_ab = SquaredEuclidean::dist(&a, &b); + let dist_ac = SquaredEuclidean::dist(&a, &c); + let dist_bc = SquaredEuclidean::dist(&b, &c); + + // For these points: dist(a,b) = 1, dist(b,c) = 1, dist(a,c) = 2 + assert_eq!(dist_ab, 1.0f32); + assert_eq!(dist_bc, 1.0f32); + assert_eq!(dist_ac, 2.0f32); + } + } + + mod chebyshev_tests { + use super::*; + + #[rstest] + #[case([0.0f32, 0.0f32], [0.0f32, 0.0f32], 0.0f32)] // identical points + #[case([0.0f32, 0.0f32], [1.0f32, 0.0f32], 1.0f32)] // single axis difference + #[case([0.0f32, 0.0f32], [0.0f32, 1.0f32], 1.0f32)] // single axis difference (other axis) + #[case([0.0f32, 0.0f32], [1.0f32, 1.0f32], 1.0f32)] // diagonal + #[case([-1.0f32, -1.0f32], [1.0f32, 1.0f32], 2.0f32)] // negative to positive + #[case([1.5f32, 2.5f32], [3.5f32, 4.5f32], 2.0f32)] // fractional values + #[case([0.0f32, 0.0f32], [2.0f32, 1.0f32], 2.0f32)] // max on first axis + #[case([0.0f32, 0.0f32], [1.0f32, 2.0f32], 2.0f32)] // max on second axis + fn test_chebyshev_distance_2d( + #[case] a: [f32; 2], + #[case] b: [f32; 2], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f64, 0.0f64, 0.0f64], [0.0f64, 0.0f64, 0.0f64], 0.0f64)] // identical points 3D + #[case([0.0f64, 0.0f64, 0.0f64], [1.0f64, 2.0f64, 3.0f64], 3.0f64)] // 3D diagonal (max is 3) + #[case([1.0f64, 2.0f64, 3.0f64], [4.0f64, 5.0f64, 6.0f64], 3.0f64)] // 3D offset (max is 3) + fn test_chebyshev_distance_3d( + #[case] a: [f64; 3], + #[case] b: [f64; 3], + #[case] expected: f64, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case([0.0f32], [0.0f32], 0.0f32)] // 1D identical + #[case([0.0f32], [5.0f32], 5.0f32)] // 1D positive + #[case([5.0f32], [0.0f32], 5.0f32)] // 1D negative (reversed) + #[case([-3.0f32], [7.0f32], 10.0f32)] // 1D negative to positive + fn test_chebyshev_distance_1d( + #[case] a: [f32; 1], + #[case] b: [f32; 1], + #[case] expected: f32, + ) { + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_4d() { + let a = [1.0f32, 2.0f32, 3.0f32, 4.0f32]; + let b = [5.0f32, 6.0f32, 7.0f32, 8.0f32]; + let expected = 4.0f32; // max(|5-1|, |6-2|, |7-3|, |8-4|) = max(4, 4, 4, 4) = 4 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[test] + fn test_chebyshev_distance_5d() { + let a = [0.0f64, 1.0f64, 2.0f64, 3.0f64, 4.0f64]; + let b = [5.0f64, 6.0f64, 7.0f64, 8.0f64, 9.0f64]; + let expected = 5.0f64; // max(|5-0|, |6-1|, |7-2|, |8-3|, |9-4|) = max(5, 5, 5, 5, 5) = 5 + assert_eq!(Chebyshev::dist(&a, &b), expected); + } + + #[rstest] + #[case(0.0f32, 0.0f32, 0.0f32)] // zero difference + #[case(1.0f32, 0.0f32, 1.0f32)] // positive difference + #[case(0.0f32, 1.0f32, 1.0f32)] // negative difference (reversed) + #[case(-2.5f32, 3.5f32, 6.0f32)] // fractional negative to positive + #[case(1000.0f32, -1000.0f32, 2000.0f32)] // large values + fn test_chebyshev_dist1(#[case] a: f32, #[case] b: f32, #[case] expected: f32) { + assert_eq!(>::dist1(a, b), expected); + } + + #[test] + fn test_chebyshev_max_property() { + // Test that Chebyshev correctly finds the maximum difference + let a = [0.0, 0.0]; + let b = [3.0, 1.0]; + + let result = Chebyshev::dist(&a, &b); + + // max(|0-3|, |0-1|) = max(3, 1) = 3 + assert_eq!(result, 3.0); + + // Verify it's not Manhattan (which would be 4) or Euclidean (sqrt(10)) + assert_ne!(result, 4.0); + assert_ne!(result, (10.0_f64).sqrt()); + } + } + + #[cfg(feature = "f16")] + mod f16_tests { + use super::*; + use half::f16; + + #[test] + fn test_manhattan_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = Manhattan::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + + #[test] + fn test_squared_euclidean_f16() { + let a = [f16::from_f32(0.0), f16::from_f32(0.0)]; + let b = [f16::from_f32(1.0), f16::from_f32(1.0)]; + + let result = SquaredEuclidean::dist(&a, &b); + let expected = f16::from_f32(2.0); + + assert_eq!(result, expected); + } + } + + mod integration_tests { + use super::*; + use crate::KdTree; + use rand::prelude::*; + use rand_distr::Normal; + use rstest::rstest; + + #[derive(Debug, Clone, Copy)] + enum DataScenario { + NoTies, + Ties, + Gaussian, + } + + #[derive(Debug, Clone, Copy)] + enum TreeType { + Mutable, + Immutable, + } + + impl DataScenario { + /// Get data scenario + /// + /// Predefined data has input dimension (`dim`) and either + /// with `DataScenario::NoTies` or `DataScenario::Ties`. + /// + /// # Parameters + /// - `dim`: The dimensionality of the data to retrieve. + /// Must be a value between 1 and 4 (inclusive). + /// + /// # Returns + /// - `Vec>`: A 2D vector where each inner vector represents a data point. + fn get(&self, dim: usize) -> Vec> { + match (self, dim) { + (DataScenario::NoTies, 1) => vec![ + vec![1.0], + vec![2.0], + vec![4.0], + vec![7.0], + vec![-9.0], + vec![16.0], + ], + (DataScenario::NoTies, 2) => vec![ + vec![0.0, 0.0], + vec![1.1, 0.1], + vec![2.3, 0.4], + vec![3.6, 0.9], + vec![5.0, 1.6], + vec![6.5, 2.5], + ], + (DataScenario::NoTies, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.1, 0.1, 0.01], + vec![2.3, 0.4, 0.08], + vec![-3.6, -0.9, -0.27], + vec![5.0, 1.6, 0.64], + vec![6.5, 2.5, 1.25], + ], + (DataScenario::NoTies, 4) => vec![ + vec![0.0, 0.0, 0.0, 1000.0], + vec![1.1, 0.1, 0.01, 1000.001], + vec![2.3, 0.4, 0.08, 1000.008], + vec![3.6, 0.9, 0.27, 1000.027], + vec![5.0, 1.6, 0.64, 1000.256], + vec![6.5, 2.5, 1.25, 1000.625], + ], + (DataScenario::Ties, 1) => vec![ + vec![0.0], + vec![1.0], + vec![1.0], + vec![2.0], + vec![2.0], + vec![3.0], + ], + (DataScenario::Ties, 2) => vec![ + vec![0.0, 0.0], + vec![1.0, 0.0], + vec![0.0, 1.0], + vec![-1.0, 0.0], + vec![0.0, -1.0], + vec![1.0, 1.0], + ], + (DataScenario::Ties, 3) => vec![ + vec![0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0], + vec![0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0], + vec![0.0, -1.0, 0.0], + ], + (DataScenario::Ties, 4) => vec![ + vec![0.0, 0.0, 0.0, 0.0], + vec![1.0, 0.0, 0.0, 0.0], + vec![0.0, 1.0, 0.0, 0.0], + vec![0.0, 0.0, 1.0, 0.0], + vec![0.0, 0.0, 0.0, 1.0], + vec![-1.0, 0.0, 0.0, 0.0], + ], + (DataScenario::Gaussian, d) => { + let mut rng = StdRng::seed_from_u64(8757); + let normal = Normal::new(1.0, 10.0).unwrap(); + let n_samples = 200; + let mut data = vec![vec![0.0; d]; n_samples]; + for i in 0..n_samples { + for j in 0..d { + data[i][j] = normal.sample(&mut rng); + } + } + data + } + _ => panic!("Unsupported dimension {} for scenario {:?}", dim, self), + } + } + } + + /// Helper function to test nearest_n queries for `D: DistanceMetric` + /// + /// Tests KD-tree Chebyshev distance queries across different tree types and + /// data scenarios. This simplifies testing across different combinations. + /// + /// # What this function does + /// 1. Get test data points based on a scenario (NoTies/Ties) and dimensionality + /// 2. Builds either MutableKdTree (incremental) or ImmutableKdTree (bulk construction) + /// 3. Performs nearest_n query with Chebyshev distance from point 0 + /// 4. Compares results against Brute-force distances, + /// calculated from `>::dist`. + /// + /// # Choices + /// - Fixed-size array `[f64; 6]`. For `dim<6` a subspace/padding is used for practicality + /// + /// # Assertions + /// - Point 0 is always the query point (distance 0, index 0 expected first result) + /// - NoTies scenario: checks distances and item IDs for points with unique distances + /// - Ties scenario: checks distances (order among ties is non-deterministic) + fn run_test_helper>( + dim: usize, + tree_type: TreeType, + scenario: DataScenario, + n: usize, + ) { + let data = scenario.get(dim); + let query_point = &data[0]; + + let mut points: Vec<[f64; 6]> = Vec::with_capacity(data.len()); + for row in &data { + let mut p = [0.0; 6]; + for (i, &val) in row.iter().enumerate() { + p[i] = val; + } + points.push(p); + } + + let mut query_arr = [0.0; 6]; + for (i, &val) in query_point.iter().enumerate() { + if i < 6 { + query_arr[i] = val; + } + } + + // Calculate ground truth with brute-force approach + let mut expected: Vec<(usize, f64)> = points + .iter() + .enumerate() + .map(|(i, &point)| { + let dist = D::dist(&query_arr, &point); + (i, dist) + }) + .collect(); + + expected.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); + + let expected_distances: Vec = expected.iter().map(|(_, d)| *d).collect(); + + println!( + "Query: {:?}, TreeType: {:?}, Scenario: {:?}, dim={}, n={}", + query_point, tree_type, scenario, dim, n + ); + + // Query based on tree type + let results = match tree_type { + TreeType::Mutable => { + let mut tree: crate::float::kdtree::KdTree = + crate::float::kdtree::KdTree::new(); + for (i, point) in points.iter().enumerate() { + tree.add(point, i as u64); + } + tree.nearest_n::(&query_arr, n) + } + TreeType::Immutable => { + let tree: crate::immutable::float::kdtree::ImmutableKdTree = + crate::immutable::float::kdtree::ImmutableKdTree::new_from_slice(&points); + tree.nearest_n::(&query_arr, std::num::NonZero::new(n).unwrap()) + } + }; + + println!("Results (len: {}):", results.len()); + + assert_eq!(results[0].item, 0, "First result should be the query point"); + assert_eq!( + results[0].distance, 0.0, + "First result distance should be 0.0" + ); + + for (i, result) in results.iter().enumerate() { + assert_eq!( + result.distance, expected_distances[i], + "Distance at index {} should be {}, but was {}", + i, expected_distances[i], result.distance + ); + } + + if matches!(scenario, DataScenario::NoTies) { + for (i, result) in results.iter().enumerate() { + let expected_id = expected[i].0; + assert_eq!( + result.item, expected_id as u64, + "Result {}: item ID mismatch. Expected {}, got {}", + i, expected_id, result.item + ); + } + } + } + + /// Chebyshev distance nearest-neighbor query tests. + /// + /// Test matrix covering all combinations of mutable/immutable trees, + /// data scenarios (with/out ties), dimensions, and neighbor query counts. + #[rstest] + fn test_nearest_n_chebyshev( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_squared_euclidean( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[rstest] + fn test_nearest_n_manhattan( + #[values(TreeType::Mutable, TreeType::Immutable)] tree_type: TreeType, + #[values(DataScenario::NoTies, DataScenario::Ties, DataScenario::Gaussian)] + scenario: DataScenario, + #[values(1, 2, 3, 4, 5, 6)] n: usize, + #[values(1, 2, 3, 4)] dim: usize, + ) { + run_test_helper::(dim, tree_type, scenario, n); + } + + #[test] + fn test_nearest_n_manhattan_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a simple pattern + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // distance 1 from query point + ([0.0f32, 1.0f32], 2), // distance 1 from query point + ([2.0f32, 0.0f32], 3), // distance 2 from query point + ([0.0f32, 2.0f32], 4), // distance 2 from query point + ([3.0f32, 3.0f32], 5), // distance 6 from query point + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 4); + + // Expected order: [0], [1], [2], [3], [4] + // Distances: 0, 1, 1, 2, 2 + // But we only ask for 4 nearest + assert_eq!(results.len(), 4); + + // First result should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Manhattan distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Fourth should be one of the points at distance 2 + assert!(results[3].item == 3 || results[3].item == 4); + assert_eq!(results[3].distance, 2.0); + } + + #[test] + fn test_nearest_n_squared_euclidean_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in a pattern where Euclidean and Manhattan differ + let points = [ + ([0.0, 0.0], 0), // distance 0 from query point + ([1.0, 0.0], 1), // Euclidean: 1, Manhattan: 1 + ([0.0, 1.0], 2), // Euclidean: 1, Manhattan: 1 + ([1.0, 1.0], 3), // Euclidean: 2, Manhattan: 2 + ([2.0, 0.0], 4), // Euclidean: 4, Manhattan: 2 + ([0.0, 2.0], 5), // Euclidean: 4, Manhattan: 2 + ([3.0, 4.0], 6), // Euclidean: 25, Manhattan: 7 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + let results = kdtree.nearest_n::(&query_point, 5); + + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next two should be the points at Euclidean distance 1 + assert_eq!(results[1].item, 1); + assert_eq!(results[1].distance, 1.0); + assert_eq!(results[2].item, 2); + assert_eq!(results[2].distance, 1.0); + + // Next two should be the points at Euclidean distance 2 + assert_eq!(results[3].item, 3); + assert_eq!(results[3].distance, 2.0); + assert_eq!(results[4].item, 4); + assert_eq!(results[4].distance, 4.0); + + // Verify that points at squared Euclidean distance 4 are indeed farther + // than points at squared Euclidean distance 2 + assert!(results[4].distance > results[3].distance); + } + + #[test] + fn test_nearest_n_different_metrics_produce_different_orderings() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points where Manhattan and Euclidean give different orderings + let points = [ + ([0.0, 0.0], 0), // origin + ([2.0, 1.0], 1), // Manhattan: 3, Euclidean^2: 5 + ([1.0, 2.0], 2), // Manhattan: 3, Euclidean^2: 5 + ([3.0, 0.0], 3), // Manhattan: 3, Euclidean^2: 9 + ([0.0, 3.0], 4), // Manhattan: 3, Euclidean^2: 9 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0, 0.0]; + + let manhattan_results = kdtree.nearest_n::(&query_point, 3); + let euclidean_results = kdtree.nearest_n::(&query_point, 3); + + // Both should include the origin as first result + assert_eq!(manhattan_results[0].item, 0); + assert_eq!(euclidean_results[0].item, 0); + + // For Manhattan: points 1, 2, 3, 4 are all at distance 3 + // The ordering among ties depends on tree structure, but they should all have same distance + assert_eq!(manhattan_results[1].distance, 3.0); + assert_eq!(manhattan_results[2].distance, 3.0); + + // For Euclidean: points 1 and 2 are at distance sqrt(5) ≈ 2.236 (squared: 5) + // Points 3 and 4 are at distance 3 (squared: 9) + assert_eq!(euclidean_results[1].distance, 5.0); + assert_eq!(euclidean_results[2].distance, 5.0); + + // Verify that Euclidean ordering puts points 1 and 2 before 3 and 4 + let euclidean_items: Vec = euclidean_results + .iter() + .skip(1) // skip origin + .take(2) // take next 2 + .map(|nn| nn.item) + .collect(); + + assert!(euclidean_items.contains(&1) || euclidean_items.contains(&2)); + + // Calculate actual distances to verify our understanding + let p1 = [2.0, 1.0]; + let p2 = [1.0, 2.0]; + let p3 = [3.0, 0.0]; + + let manhattan_p1 = Manhattan::dist(&query_point, &p1); + let manhattan_p2 = Manhattan::dist(&query_point, &p2); + let manhattan_p3 = Manhattan::dist(&query_point, &p3); + + let euclidean_p1 = SquaredEuclidean::dist(&query_point, &p1); + let euclidean_p2 = SquaredEuclidean::dist(&query_point, &p2); + let euclidean_p3 = SquaredEuclidean::dist(&query_point, &p3); + + assert_eq!(manhattan_p1, 3.0); + assert_eq!(manhattan_p2, 3.0); + assert_eq!(manhattan_p3, 3.0); + + assert_eq!(euclidean_p1, 5.0); + assert_eq!(euclidean_p2, 5.0); + assert_eq!(euclidean_p3, 9.0); + } + + #[test] + fn test_nearest_n_3d_different_metrics() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points in 3D space + let points = [ + ([1.0, 1.0, 1.0], 0), // origin + ([2.0, 1.0, 1.0], 1), // 1 unit away on x-axis + ([1.0, 2.0, 1.0], 2), // 1 unit away on y-axis + ([1.0, 1.0, 2.0], 3), // 1 unit away on z-axis + ([3.0, 1.0, 1.0], 4), // 2 units away on x-axis + ([0.0, 0.0, 0.0], 5), // sqrt(3) ≈ 1.732 from origin + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [1.0, 1.0, 1.0]; + let results = kdtree.nearest_n::(&query_point, 4); + + assert_eq!(results.len(), 4); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next three should be the points at Manhattan distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(3) // take next 3 + .map(|nn| nn.item) + .collect(); + + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&3)); + + // All nearby points should have distance 1 + for result in results.iter().skip(1).take(3) { + assert_eq!(result.distance, 1.0); + } + + // Point 4 should be farther (distance 2) and not in top 4 + let all_items: Vec = results.iter().map(|nn| nn.item).collect(); + assert!(!all_items.contains(&4)); + + // Point 5 has Manhattan distance 3, so definitely not in top 4 + assert!(!all_items.contains(&5)); + } + + #[test] + fn test_nearest_n_large_scale() { + let mut kdtree: KdTree = KdTree::new(); + + // Create a grid of points + let mut index = 0; + for x in 0i32..10 { + for y in 0i32..10 { + let point = [x as f32, y as f32]; + kdtree.add(&point, index); + index += 1; + } + } + + // Query from center of grid + let query_point = [5.0f32, 5.0f32]; + let results = kdtree.nearest_n::(&query_point, 10); + + assert_eq!(results.len(), 10); + + // First result should be the center point itself (index 55) + assert_eq!(results[0].item, 55); + assert_eq!(results[0].distance, 0.0); + + // Results should be ordered by increasing distance + for i in 1..10 { + assert!(results[i].distance >= results[i - 1].distance); + } + + // Verify distances make sense for a grid + // The nearest points should be at squared distances: 0, 1, 1, 1, 1, 2, 2, 4, 4, 5... + let expected_distances = [0.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32]; + + for (i, &expected_dist) in expected_distances.iter().enumerate() { + if i < results.len() { + assert_eq!(results[i].distance, expected_dist); + } + } + } + + #[test] + fn test_nearest_n_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points that show Chebyshev behavior + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 from query point + ([1.0f32, 0.0f32], 1), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([0.0f32, 1.0f32], 2), // Chebyshev: 1, Manhattan: 1, Euclidean^2: 1 + ([2.0f32, 0.0f32], 3), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([0.0f32, 2.0f32], 4), // Chebyshev: 2, Manhattan: 2, Euclidean^2: 4 + ([1.0f32, 1.0f32], 5), // Chebyshev: 1, Manhattan: 2, Euclidean^2: 2 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let results = kdtree.nearest_n::(&query_point, 5); + + // With Chebyshev, points at (1,0), (0,1), and (1,1) all have distance 1 + // Points at (2,0) and (0,2) have distance 2 + assert_eq!(results.len(), 5); + + // First should be the query point itself + assert_eq!(results[0].item, 0); + assert_eq!(results[0].distance, 0.0); + + // Next should all be at Chebyshev distance 1 + let nearby_items: Vec = results + .iter() + .skip(1) // skip origin + .take(4) // take next 4 + .filter(|r| (r.distance - 1.0).abs() < 0.001) // check for distance 1 (with some float tolerance) + .map(|nn| nn.item) + .collect(); + + // All of these should be in the results: 1, 2, 5 + assert!(nearby_items.contains(&1)); + assert!(nearby_items.contains(&2)); + assert!(nearby_items.contains(&5)); + } + + #[test] + fn test_within_chebyshev_distance() { + let mut kdtree: KdTree = KdTree::new(); + + // Add points with varying Chebyshev distances + let points = [ + ([0.0f32, 0.0f32], 0), // distance 0 + ([0.5f32, 0.5f32], 1), // Chebyshev: 0.5 + ([1.0f32, 0.0f32], 2), // Chebyshev: 1.0 + ([0.8f32, 0.9f32], 3), // Chebyshev: 0.9 + ([2.0f32, 0.0f32], 4), // Chebyshev: 2.0 + ([0.0f32, 2.0f32], 5), // Chebyshev: 2.0 + ([1.5f32, 1.5f32], 6), // Chebyshev: 1.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + let radius = 1.0; // radius 1 (not squared for Chebyshev) + let mut results = kdtree.within::(&query_point, radius); + + // Sort by distance for easier verification + results.sort_by(|a, b| { + a.distance + .partial_cmp(&b.distance) + .unwrap_or(std::cmp::Ordering::Equal) + }); + + // These SHOULD be: 0, 1, 2, 3 (distances: 0, 0.5, 1.0, 0.9) + // For `<=5.2.4` found: 0, 1, 3 (index 2 is missing due to dist1 pruning issue) + let found_indices: Vec = results.iter().map(|r| r.item).collect(); + + assert!(found_indices.contains(&0)); + assert!(found_indices.contains(&1)); + assert!(found_indices.contains(&2)); + assert!(found_indices.contains(&3)); + // Should NOT include points with Chebyshev distance > 1 + assert!(!found_indices.contains(&4)); + assert!(!found_indices.contains(&5)); + assert!(!found_indices.contains(&6)); + + // Verify distances + for result in results { + assert!(result.distance <= 1.0 || (result.distance - 1.0).abs() < 0.001); + } + } + + #[test] + fn test_chebyshev_vs_manhattan_ordering() { + let mut kdtree: KdTree = KdTree::new(); + + // Points where Chebyshev and Manhattan differ significantly + let points = [ + ([0.0f32, 0.0f32], 0), // origin + ([3.0f32, 1.0f32], 1), // Chebyshev: 3, Manhattan: 4 + ([1.0f32, 3.0f32], 2), // Chebyshev: 3, Manhattan: 4 + ([2.0f32, 2.0f32], 3), // Chebyshev: 2, Manhattan: 4 + ([4.0f32, 0.5f32], 4), // Chebyshev: 4, Manhattan: 4.5 + ]; + + for (point, index) in points { + kdtree.add(&point, index); + } + + let query_point = [0.0f32, 0.0f32]; + + let chebyshev_results = kdtree.nearest_n::(&query_point, 4); + let manhattan_results = kdtree.nearest_n::(&query_point, 4); + + // Both should include the origin first + assert_eq!(chebyshev_results[0].item, 0); + assert_eq!(manhattan_results[0].item, 0); + + // With Chebyshev, nearest should be point 3 (distance 2) + // With Manhattan, nearest should be points 1 and 2 (distance 4) + assert_eq!(chebyshev_results[1].item, 3); + assert_eq!(chebyshev_results[1].distance, 2.0); + + // With Manhattan, points 1 and 2 should come before point 3 (which is distance 4) + let manhattan_items: Vec = manhattan_results + .iter() + .skip(1) + .take(3) + .map(|r| r.item) + .collect(); + assert!(manhattan_items.contains(&1) || manhattan_items.contains(&2)); + + // Verify the distance calculations are correct + assert_eq!(chebyshev_results[1].distance, 2.0); // Chebyshev: max(|2-0|, |2-0|) = 2 + assert_eq!(manhattan_results[1].distance, 4.0); // Manhattan: |3-0| + |1-0| = 4 + } + } +} diff --git a/src/float/kdtree.rs b/src/float/kdtree.rs index 4757cb3c..0ea74d6e 100644 --- a/src/float/kdtree.rs +++ b/src/float/kdtree.rs @@ -63,6 +63,7 @@ pub trait Axis: FloatCore + Default + Debug + Copy + Sync + Send + std::ops::Add fn saturating_dist(self, other: Self) -> Self; /// Used in query methods to update the rd value. A saturating add for Fixed and an add for Float + #[deprecated(since = "5.3.0", note = "Use D::accumulate instead")] // TODO: change version number if adding this change - or better so: fully get rid off rd_update fn rd_update(rd: Self, delta: Self) -> Self; } @@ -73,6 +74,7 @@ impl #[inline] fn rd_update(rd: Self, delta: Self) -> Self { + // DEPRECATED: Use D::accumulate instead rd + delta } } diff --git a/src/float/query/nearest_n_within.rs b/src/float/query/nearest_n_within.rs index 6a2c70c9..6fc78994 100644 --- a/src/float/query/nearest_n_within.rs +++ b/src/float/query/nearest_n_within.rs @@ -353,7 +353,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float/query/within_unsorted.rs b/src/float/query/within_unsorted.rs index 5d0bab97..226270b9 100644 --- a/src/float/query/within_unsorted.rs +++ b/src/float/query/within_unsorted.rs @@ -209,7 +209,7 @@ mod tests { for &(p, item) in content { let dist = SquaredEuclidean::dist(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/float_leaf_slice/leaf_slice.rs b/src/float_leaf_slice/leaf_slice.rs index 7c104c98..4da31f83 100644 --- a/src/float_leaf_slice/leaf_slice.rs +++ b/src/float_leaf_slice/leaf_slice.rs @@ -237,10 +237,10 @@ where for idx in 0..remainder_items.len() { let mut distance = A::zero(); (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - - if distance < radius { + if distance <= radius { results.add(NearestNeighbour { distance, item: remainder_items[idx], @@ -271,10 +271,11 @@ where for idx in 0..remainder_items.len() { let mut distance = A::zero(); (0..K).step_by(1).for_each(|dim| { - distance += D::dist1(remainder_points[dim][idx], query[dim]); + distance = + D::accumulate(distance, D::dist1(remainder_points[dim][idx], query[dim])); }); - if distance < radius { + if distance <= radius { let item = remainder_items[idx]; if results.len() < max_qty { results.push(BestNeighbour { distance, item }); @@ -364,16 +365,13 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f64; C]; (0..K).step_by(1).for_each(|dim| { let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); }); - acc } } @@ -450,16 +448,13 @@ where D: DistanceMetric, Self: Sized, { - // AVX512: 4 loops of 32 iterations, each 4x unrolled, 5 instructions per pre-unrolled iteration let mut acc = [0f32; C]; (0..K).step_by(1).for_each(|dim| { let qd = [query[dim]; C]; - (0..C).step_by(1).for_each(|idx| { - acc[idx] += D::dist1(chunk[dim][idx], qd[idx]); + acc[idx] = D::accumulate(acc[idx], D::dist1(chunk[dim][idx], qd[idx])); }); }); - acc } } diff --git a/src/hybrid/query/nearest_n.rs b/src/hybrid/query/nearest_n.rs index f6e24fda..9d28118a 100644 --- a/src/hybrid/query/nearest_n.rs +++ b/src/hybrid/query/nearest_n.rs @@ -134,6 +134,7 @@ where } } + #[inline(always)] fn dist_belongs_in_heap(dist: A, heap: &BinaryHeap>) -> bool { heap.is_empty() || dist < heap.peek().unwrap().distance || heap.len() < heap.capacity() } diff --git a/src/hybrid/query/within.rs b/src/hybrid/query/within.rs index be5b956b..76505046 100644 --- a/src/hybrid/query/within.rs +++ b/src/hybrid/query/within.rs @@ -254,7 +254,7 @@ mod tests { for &(p, item) in content { let dist = manhattan(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/hybrid/query/within_unsorted.rs b/src/hybrid/query/within_unsorted.rs index 683c58fe..fb93473f 100644 --- a/src/hybrid/query/within_unsorted.rs +++ b/src/hybrid/query/within_unsorted.rs @@ -256,7 +256,7 @@ mod tests { for &(p, item) in content { let dist = squared_euclidean(query_point, &p); - if dist < radius { + if dist <= radius { matching_items.push((dist, item)); } } diff --git a/src/immutable/common/generate_immutable_best_n_within.rs b/src/immutable/common/generate_immutable_best_n_within.rs index 3c4ad3bf..e5b58371 100644 --- a/src/immutable/common/generate_immutable_best_n_within.rs +++ b/src/immutable/common/generate_immutable_best_n_within.rs @@ -109,7 +109,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; @@ -190,7 +190,7 @@ macro_rules! generate_immutable_best_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/immutable/common/generate_immutable_nearest_n_within.rs b/src/immutable/common/generate_immutable_nearest_n_within.rs index 398697fd..3d3d8435 100644 --- a/src/immutable/common/generate_immutable_nearest_n_within.rs +++ b/src/immutable/common/generate_immutable_nearest_n_within.rs @@ -112,7 +112,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; @@ -189,7 +189,7 @@ macro_rules! generate_immutable_nearest_n_within { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius && rd < matching_items.max_dist() { off[split_dim] = new_off; diff --git a/src/immutable/common/generate_immutable_nearest_one.rs b/src/immutable/common/generate_immutable_nearest_one.rs index 274da9e2..e72007ed 100644 --- a/src/immutable/common/generate_immutable_nearest_one.rs +++ b/src/immutable/common/generate_immutable_nearest_one.rs @@ -109,7 +109,7 @@ macro_rules! generate_immutable_nearest_one { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; @@ -177,7 +177,7 @@ macro_rules! generate_immutable_nearest_one { rd, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= nearest.distance { off[split_dim as usize] = new_off; diff --git a/src/immutable/common/generate_immutable_within_unsorted_iter.rs b/src/immutable/common/generate_immutable_within_unsorted_iter.rs index a203d670..94f673fd 100644 --- a/src/immutable/common/generate_immutable_within_unsorted_iter.rs +++ b/src/immutable/common/generate_immutable_within_unsorted_iter.rs @@ -83,7 +83,7 @@ macro_rules! generate_immutable_within_unsorted_iter { closer_leaf_idx, ); - rd = Axis::rd_update(rd, D::dist1(new_off, old_off)); + rd = D::accumulate(rd, D::dist1(new_off, old_off)); if rd <= radius { off[split_dim] = new_off; diff --git a/src/immutable/float/query/nearest_n_within.rs b/src/immutable/float/query/nearest_n_within.rs index 7b3cfece..d8b1603a 100644 --- a/src/immutable/float/query/nearest_n_within.rs +++ b/src/immutable/float/query/nearest_n_within.rs @@ -231,7 +231,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within.rs b/src/immutable/float/query/within.rs index a04c9f7a..535d56ee 100644 --- a/src/immutable/float/query/within.rs +++ b/src/immutable/float/query/within.rs @@ -206,7 +206,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = Manhattan::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted.rs b/src/immutable/float/query/within_unsorted.rs index 28310c5a..67b4d6c2 100644 --- a/src/immutable/float/query/within_unsorted.rs +++ b/src/immutable/float/query/within_unsorted.rs @@ -204,7 +204,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/immutable/float/query/within_unsorted_iter.rs b/src/immutable/float/query/within_unsorted_iter.rs index 3f3cceba..dc8aa5bc 100644 --- a/src/immutable/float/query/within_unsorted_iter.rs +++ b/src/immutable/float/query/within_unsorted_iter.rs @@ -201,7 +201,7 @@ mod tests { for (idx, p) in content.iter().enumerate() { let dist = SquaredEuclidean::dist(query_point, p); - if dist < radius { + if dist <= radius { matching_items.push((dist, idx as u32)); } } diff --git a/src/lib.rs b/src/lib.rs index 557c9a2b..102bd044 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -135,6 +135,7 @@ pub type ImmutableKdTree = immutable::float::kdtree::ImmutableKdTree; pub use best_neighbour::BestNeighbour; +pub use float::distance::Chebyshev; pub use float::distance::Manhattan; pub use float::distance::SquaredEuclidean; pub use nearest_neighbour::NearestNeighbour; diff --git a/src/traits.rs b/src/traits.rs index 9561b804..81da5347 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -107,26 +107,65 @@ pub(crate) fn is_stem_index>(x: IDX) -> bool { x < ::leaf_offset() } -/// Trait that needs to be implemented by any potential distance -/// metric to be used within queries +/// Defines how distances are measured and compared for k-d tree queries. +/// +/// Implement this trait to use custom distance metrics with [`kiddo:KdTree`](crate::KdTree). +/// +/// # Distance Metrics in k-d Trees +/// +/// **Distance aggregation**: How to combine per-dimension distances into a total distance +/// - Sum-based: `dist(p,q) = Σ |p[i] - q[i]|` (Manhattan, SquaredEuclidean) +/// - Max-based: `dist(p,q) = max_i |p[i] - q[i]|` (Chebyshev/L∞) +/// +/// # Required Methods +/// +/// - [`dist()`]: Compute total distance between two points +/// - [`dist1()`]: Compute per-dimension distance component +/// - [`accumulate()`]: Aggregate distance components (add or max) +/// pub trait DistanceMetric { - /// returns the distance between two K-d points, as measured - /// by a particular distance metric + /// Returns the distance between two K-d points, as measured by this metric. fn dist(a: &[A; K], b: &[A; K]) -> A; - /// returns the distance between two points along a single axis, - /// as measured by a particular distance metric. + /// Returns the distance between two points along a single dimension. /// - /// (needs to be implemented as it is used by the NN query implementations - /// to extend the minimum acceptable distance for a node when recursing - /// back up the tree) + /// Used internally by NN query implementations to extend the minimum + /// acceptable distance for a node when recursing back up the tree. fn dist1(a: A, b: A) -> A; + + /// Aggregates a distance contribution into a running total. + /// + /// This defines how per-dimension distances combine into a total distance. + /// Choose based on your distance metric: + /// + /// - **Sum-based (L1, L2)**: Use `rd + delta` or `rd.saturating_add(delta)` for fixed-point types + /// - **Max-based (L∞/Chebyshev)**: Use `rd.max(delta)` + /// + /// The implementation should match the mathematical definition of your metric: + /// - Manhattan: `dist(p,q) = Σ |p[i] - q[i]|` -> accumulate by adding + /// - SquaredEuclidean: `dist(p,q) = Σ (p[i] - q[i])²` -> accumulate by adding + /// - Chebyshev: `dist(p,q) = max_i |p[i] - q[i]|` -> accumulate by taking max + /// - Generalised Minkowski (L_p): `dist(p,q) = (Σ |p[i] - q[i]|^p)^(1/p)`. + /// For k-d tree pruning, use the sum of powers: accumulate by adding. + /// Only the limit p → ∞ (Chebyshev) uses `max`. + /// + /// The default implementation uses regular addition (`rd + delta`), which works for + /// both integer and floating-point types. For fixed-point types where overflow is a + /// concern, override this with `rd.saturating_add(delta)`. + fn accumulate(rd: A, delta: A) -> A + where + A: std::ops::Add, + { + rd + delta + } } #[cfg(test)] mod tests { + use super::DistanceMetric; use crate::traits::Index; + use rstest::rstest; #[test] fn test_u16() { @@ -158,4 +197,45 @@ mod tests { (u32::MAX - u32::MAX.overflowing_shr(1).0).saturating_mul(bucket_size); assert_eq!(capacity_with_bucket_size, u32::MAX); } + + struct TestMetricU32; + struct TestMetricI64; + + impl DistanceMetric for TestMetricU32 { + fn dist(_a: &[u32; K], _b: &[u32; K]) -> u32 { + 0 + } + fn dist1(a: u32, _b: u32) -> u32 { + a + } + } + + impl DistanceMetric for TestMetricI64 { + fn dist(_a: &[i64; K], _b: &[i64; K]) -> i64 { + 0 + } + fn dist1(a: i64, _b: i64) -> i64 { + a + } + } + + #[rstest] + #[case(5u32, 3u32, 8u32)] + #[case(10u32, 20u32, 30u32)] + fn test_default_accumulate_u32(#[case] rd: u32, #[case] delta: u32, #[case] expected: u32) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } + + #[rstest] + #[case(10i64, 20i64, 30i64)] + #[case(100i64, 200i64, 300i64)] + fn test_default_accumulate_i64(#[case] rd: i64, #[case] delta: i64, #[case] expected: i64) { + assert_eq!( + >::accumulate(rd, delta), + expected + ); + } }