From 0c56dbcc37505ac0edc7b110ffb47b1d196ef0a2 Mon Sep 17 00:00:00 2001 From: Alexander Hans Date: Tue, 28 Oct 2025 18:31:09 +0100 Subject: [PATCH] Make bin calculation more robust --- cpp/dbscan.cpp | 9 +++++---- python/dbscan_test.py | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/cpp/dbscan.cpp b/cpp/dbscan.cpp index 6011473..d892aa0 100644 --- a/cpp/dbscan.cpp +++ b/cpp/dbscan.cpp @@ -45,16 +45,17 @@ auto Dbscan::fit_predict(std::vector const& points) -> std::vecto float const range_x{max[0] - min[0]}; float const range_y{max[1] - min[1]}; // add 1e-7 to handle the case where range is exactly divisible by eps_ - auto const num_bins_x{static_cast(std::ceil((range_x + 1e-7f) / eps_))}; - auto const num_bins_y{static_cast(std::ceil((range_y + 1e-7f) / eps_))}; + auto const num_bins_x{static_cast(std::ceil((range_x + eps_ / 2) / eps_))}; + auto const num_bins_y{static_cast(std::ceil((range_y + eps_ / 2) / eps_))}; // count number of points in every bin counts_.assign(num_bins_x * num_bins_y, 0); // FIRST PASS OVER THE POINTS for (auto const& pt : points) { - auto const bin_x{static_cast(std::floor((pt[0] - min[0]) / eps_))}; - auto const bin_y{static_cast(std::floor((pt[1] - min[1]) / eps_))}; + auto const bin_x{std::min(static_cast(std::floor((pt[0] - min[0]) / eps_)), num_bins_x - 1)}; + auto const bin_y{std::min(static_cast(std::floor((pt[1] - min[1]) / eps_)), num_bins_y - 1)}; + auto const index{bin_y * num_bins_x + bin_x}; counts_[index] += 1; } diff --git a/python/dbscan_test.py b/python/dbscan_test.py index 634792b..528e214 100644 --- a/python/dbscan_test.py +++ b/python/dbscan_test.py @@ -72,5 +72,42 @@ def test_points_on_border(): np.testing.assert_equal(y_pred, np.array([-1, -1])) +def test_points_on_border2(): + """Another test with points on the edge of the eps grid.""" + X = np.array( + [ + [1.0000000e01, -2.5000000e-01], + [1.0000000e01, -2.0000000e-01], + [1.0000000e01, -1.5000001e-01], + [1.0000000e01, -1.0000000e-01], + [1.0000000e01, -5.0000001e-02], + [1.0000000e01, -1.3877788e-17], + [1.0000000e01, 5.0000001e-02], + [1.0000000e01, 1.0000000e-01], + [1.0000000e01, 1.5000001e-01], + [1.0000000e01, 2.0000000e-01], + [1.5000000e01, -2.5000000e-01], + [1.5000000e01, -2.0000000e-01], + [1.5000000e01, -1.5000001e-01], + [1.5000000e01, -1.0000000e-01], + [1.5000000e01, -5.0000001e-02], + [1.5000000e01, -1.3877788e-17], + [1.5000000e01, 5.0000001e-02], + [1.5000000e01, 1.0000000e-01], + [1.5000000e01, 1.5000001e-01], + [1.5000000e01, 2.0000000e-01], + ] + ) + + dbscan = py_dbscan.DBSCAN(0.5, 2) + y_pred = dbscan.fit_predict(X) + + label_a = y_pred[0] + label_b = y_pred[10] + assert label_a != label_b + assert np.all(y_pred[:10] == label_a) + assert np.all(y_pred[10:] == label_b) + + if __name__ == "__main__": - sys.exit(pytest.main([__file__, "-rP"])) + sys.exit(pytest.main([__file__, "-rP"] + sys.argv[1:]))