Skip to content
29 changes: 23 additions & 6 deletions src/common/generate_nearest_n_within_unsorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,32 @@ macro_rules! generate_nearest_n_within_unsorted {
where
D: DistanceMetric<A, K>,
{
self.nearest_n_within_exclusive::<D>(query, dist, max_items, sorted, true)
}

#[doc = concat!$comments]
#[inline]
pub fn nearest_n_within_exclusive<D>(&self, query: &[A; K], dist: A, max_items: std::num::NonZero<usize>, sorted: bool, inclusive: bool) -> Vec<NearestNeighbour<A, T>>
where
D: DistanceMetric<A, K>,
{
// Like [`nearest_n_within`] but allows controlling boundary inclusiveness.
//
// When `inclusive` is true, points at exactly the maximum distance are included.
// When false, only points strictly less than the maximum distance are included.
if sorted || max_items < std::num::NonZero::new(usize::MAX).unwrap() {
if max_items <= std::num::NonZero::new(MAX_VEC_RESULT_SIZE).unwrap() {
self.nearest_n_within_stub::<D, SortedVec<NearestNeighbour<A, T>>>(query, dist, max_items.get(), sorted)
self.nearest_n_within_stub::<D, SortedVec<NearestNeighbour<A, T>>>(query, dist, max_items.get(), sorted, inclusive)
} else {
self.nearest_n_within_stub::<D, BinaryHeap<NearestNeighbour<A, T>>>(query, dist, max_items.get(), sorted)
self.nearest_n_within_stub::<D, BinaryHeap<NearestNeighbour<A, T>>>(query, dist, max_items.get(), sorted, inclusive)
}
} else {
self.nearest_n_within_stub::<D, Vec<NearestNeighbour<A,T>>>(query, dist, 0, sorted)
self.nearest_n_within_stub::<D, Vec<NearestNeighbour<A,T>>>(query, dist, 0, sorted, inclusive)
}
}

fn nearest_n_within_stub<D: DistanceMetric<A, K>, H: ResultCollection<A, T>>(
&self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool
&self, query: &[A; K], dist: A, res_capacity: usize, sorted: bool, inclusive: bool
) -> Vec<NearestNeighbour<A, T>> {
let mut matching_items = H::new_with_capacity(res_capacity);
let mut off = [A::zero(); K];
Expand All @@ -35,6 +48,7 @@ macro_rules! generate_nearest_n_within_unsorted {
&mut matching_items,
&mut off,
A::zero(),
inclusive,
);
}

Expand All @@ -55,6 +69,7 @@ macro_rules! generate_nearest_n_within_unsorted {
matching_items: &mut R,
off: &mut [A; K],
rd: A,
inclusive: bool,
) where
D: DistanceMetric<A, K>,
{
Expand Down Expand Up @@ -84,11 +99,12 @@ macro_rules! generate_nearest_n_within_unsorted {
matching_items,
off,
rd,
inclusive,
);

rd = D::accumulate(rd, D::dist1(new_off, old_off));

if rd <= radius {
if if inclusive { rd <= radius } else { rd < radius } {
off[split_dim] = new_off;
self.nearest_n_within_unsorted_recurse::<D, R>(
query,
Expand All @@ -98,6 +114,7 @@ macro_rules! generate_nearest_n_within_unsorted {
matching_items,
off,
rd,
inclusive,
);
off[split_dim] = old_off;
}
Expand All @@ -116,7 +133,7 @@ macro_rules! generate_nearest_n_within_unsorted {
.for_each(|(idx, entry)| {
let distance = D::dist(query, transform(entry));

if distance <= radius {
if if inclusive { distance <= radius } else { distance < radius } {
let item = unsafe { leaf_node.content_items.get_unchecked(idx) };
let item = *transform(item);

Expand Down
15 changes: 14 additions & 1 deletion src/common/generate_within.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,20 @@ macro_rules! generate_within {
where
D: DistanceMetric<A, K>,
{
let mut matching_items = self.within_unsorted::<D>(query, dist);
self.within_exclusive::<D>(query, dist, true)
}

#[doc = concat!$comments]
#[inline]
pub fn within_exclusive<D>(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec<NearestNeighbour<A, T>>
where
D: DistanceMetric<A, K>,
{
// Like [`within`] but allows controlling boundary inclusiveness.
//
// When `inclusive` is true, points at exactly the maximum distance are included.
// When false, only points strictly less than the maximum distance are included.
let mut matching_items = self.within_unsorted_exclusive::<D>(query, dist, inclusive);
matching_items.sort();
matching_items
}
Expand Down
21 changes: 19 additions & 2 deletions src/common/generate_within_unsorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ macro_rules! generate_within_unsorted {
where
D: DistanceMetric<A, K>,
{
self.within_unsorted_exclusive::<D>(query, dist, true)
}

#[doc = concat!$comments]
#[inline]
pub fn within_unsorted_exclusive<D>(&self, query: &[A; K], dist: A, inclusive: bool) -> Vec<NearestNeighbour<A, T>>
where
D: DistanceMetric<A, K>,
{
// Like [`within_unsorted`] but allows controlling boundary inclusiveness.
//
// When `inclusive` is true, points at exactly the maximum distance are included.
// When false, only points strictly less than the maximum distance are included.
let mut off = [A::zero(); K];
let mut matching_items = Vec::new();
let root_index: IDX = *transform(&self.root_index);
Expand All @@ -21,6 +34,7 @@ macro_rules! generate_within_unsorted {
&mut matching_items,
&mut off,
A::zero(),
inclusive,
);
}

Expand All @@ -37,6 +51,7 @@ macro_rules! generate_within_unsorted {
matching_items: &mut Vec<NearestNeighbour<A, T>>,
off: &mut [A; K],
rd: A,
inclusive: bool,
) where
D: DistanceMetric<A, K>,
{
Expand Down Expand Up @@ -66,11 +81,12 @@ macro_rules! generate_within_unsorted {
matching_items,
off,
rd,
inclusive,
);

rd = D::accumulate(rd, D::dist1(new_off, old_off));

if rd <= radius {
if if inclusive { rd <= radius } else { rd < radius } {
off[split_dim] = new_off;
self.within_unsorted_recurse::<D>(
query,
Expand All @@ -80,6 +96,7 @@ macro_rules! generate_within_unsorted {
matching_items,
off,
rd,
inclusive,
);
off[split_dim] = old_off;
}
Expand All @@ -98,7 +115,7 @@ macro_rules! generate_within_unsorted {
.for_each(|(idx, entry)| {
let distance = D::dist(query, transform(entry));

if distance <= radius {
if if inclusive { distance <= radius } else { distance < radius } {
let item = unsafe { leaf_node.content_items.get_unchecked(idx) };
let item = *transform(item);

Expand Down
26 changes: 24 additions & 2 deletions src/common/generate_within_unsorted_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ macro_rules! generate_within_unsorted_iter {
where
D: DistanceMetric<A, K>,
{
self.within_unsorted_iter_exclusive::<D>(query, dist, true)
}

#[doc = concat!$comments]
#[inline]
pub fn within_unsorted_iter_exclusive<D>(
&'a self,
query: &'query [A; K],
dist: A,
inclusive: bool,
) -> WithinUnsortedIter<'a, A, T>
where
D: DistanceMetric<A, K>,
{
// Like [`within_unsorted_iter`] but allows controlling boundary inclusiveness.
//
// When `inclusive` is true, points at exactly the maximum distance are included.
// When false, only points strictly less than the maximum distance are included.
let mut off = [A::zero(); K];
let root_index: IDX = *transform(&self.root_index);

Expand All @@ -27,6 +45,7 @@ macro_rules! generate_within_unsorted_iter {
gen_scope,
&mut off,
A::zero(),
inclusive,
);
}

Expand All @@ -46,6 +65,7 @@ macro_rules! generate_within_unsorted_iter {
mut gen_scope: Scope<'scope, 'a, (), NearestNeighbour<A, T>>,
off: &mut [A; K],
rd: A,
inclusive: bool,
) -> Scope<'scope, 'a, (), NearestNeighbour<A, T>>
where
D: DistanceMetric<A, K>,
Expand Down Expand Up @@ -76,11 +96,12 @@ macro_rules! generate_within_unsorted_iter {
gen_scope,
off,
rd,
inclusive,
);

rd = D::accumulate(rd, D::dist1(new_off, old_off));

if rd <= radius {
if if inclusive { rd <= radius } else { rd < radius } {
off[split_dim] = new_off;
gen_scope = self.within_unsorted_iter_recurse::<D>(
query,
Expand All @@ -90,6 +111,7 @@ macro_rules! generate_within_unsorted_iter {
gen_scope,
off,
rd,
inclusive,
);
off[split_dim] = old_off;
}
Expand All @@ -108,7 +130,7 @@ macro_rules! generate_within_unsorted_iter {
.for_each(|(idx, entry)| {
let distance = D::dist(query, transform(entry));

if distance <= radius {
if if inclusive { distance <= radius } else { distance < radius } {
let item = unsafe { leaf_node.content_items.get_unchecked(idx) };
let item = *transform(item);

Expand Down
20 changes: 20 additions & 0 deletions src/fixed/query/within.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ mod tests {
use fixed::types::extra::U14;
use fixed::FixedU16;
use rand::Rng;
use rstest::rstest;
use std::cmp::Ordering;

type Fxd = FixedU16<U14>;
Expand Down Expand Up @@ -154,6 +155,25 @@ mod tests {
}
}

#[rstest]
#[case(true, 1)]
#[case(false, 0)]
fn test_within_boundary_inclusiveness(#[case] inclusive: bool, #[case] expected_len: usize) {
let mut kdtree: KdTree<Fxd, u32, 2, 5, u32> = KdTree::new();
kdtree.add(&[n(1.0), n(0.0)], 1);
kdtree.add(&[n(2.0), n(0.0)], 2);

let query = [n(0.0), n(0.0)];
let radius = n(1.0);

let results = kdtree.within_exclusive::<Manhattan>(&query, radius, inclusive);
assert_eq!(results.len(), expected_len);
if expected_len > 0 {
assert_eq!(results[0].item, 1);
assert_eq!(results[0].distance, n(1.0));
}
}

fn linear_search<A: Axis, const K: usize>(
content: &[([A; K], u32)],
query_point: &[A; K],
Expand Down
Loading
Loading