From cdef1e9b7f444a761114de8ae9031e835ffc796d Mon Sep 17 00:00:00 2001 From: Stephen Nicholas Swatman Date: Fri, 7 Nov 2025 16:11:58 +0100 Subject: [PATCH] =?UTF-8?q?Compute=20K=C3=A1lm=C3=A1n=20filter=20chi2=20fr?= =?UTF-8?q?om=20predicted=20state?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit migrates the computation of the chi2 value in our gain matrix updater to use the predicted state rather than the filtered state. This is mathematically equivalent but potentially more stable. --- .../kalman_filter/gain_matrix_updater.hpp | 49 +++++---- .../kalman_filter/two_filters_smoother.hpp | 104 +++++++++--------- 2 files changed, 76 insertions(+), 77 deletions(-) diff --git a/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp b/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp index df659449f7..b1c8eceb1e 100644 --- a/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp +++ b/core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp @@ -74,7 +74,6 @@ struct gain_matrix_updater { // Some identity matrices // @TODO: Make constexpr work const auto I66 = matrix::identity(); - const auto I_m = matrix::identity>(); // Measurement data on surface matrix_type meas_local; @@ -123,6 +122,31 @@ struct gain_matrix_updater { const matrix_type projected_cov = algebra::matrix::transposed_product(predicted_cov, H); + { + // Calculate the chi square + const matrix_type R = + V + (H * predicted_cov * algebra::matrix::transpose(H)); + // Residual between measurement and predicted vector + const matrix_type residual = meas_local - H * predicted_vec; + const matrix_type<1, 1> chi2_mat = + algebra::matrix::transposed_product( + residual, matrix::inverse(R)) * + residual; + const scalar chi2_val = getter::element(chi2_mat, 0, 0); + + if (chi2_val < 0.f) { + TRACCC_ERROR_HOST_DEVICE("Chi2 negative"); + return kalman_fitter_status::ERROR_UPDATER_CHI2_NEGATIVE; + } + + if (!std::isfinite(chi2_val)) { + TRACCC_ERROR_HOST_DEVICE("Chi2 infinite"); + return kalman_fitter_status::ERROR_UPDATER_CHI2_NOT_FINITE; + } + + trk_state.filtered_chi2() = chi2_val; + } + const matrix_type M = H * projected_cov + V; // Kalman gain matrix @@ -162,38 +186,15 @@ struct gain_matrix_updater { return kalman_fitter_status::ERROR_QOP_ZERO; } - // Residual between measurement and (projected) filtered vector - const matrix_type residual = meas_local - H * filtered_vec; - - // Calculate the chi square - const matrix_type R = (I_m - H * K) * V; - const matrix_type<1, 1> chi2 = - algebra::matrix::transposed_product( - residual, matrix::inverse(R)) * - residual; - - const scalar chi2_val{getter::element(chi2, 0, 0)}; - TRACCC_VERBOSE_HOST("Filtered residual: " << residual); TRACCC_DEBUG_HOST("R:\n" << R); TRACCC_DEBUG_HOST_DEVICE("det(R): %f", matrix::determinant(R)); TRACCC_DEBUG_HOST("R_inv:\n" << matrix::inverse(R)); TRACCC_VERBOSE_HOST_DEVICE("Chi2: %f", chi2_val); - if (chi2_val < 0.f) { - TRACCC_ERROR_HOST_DEVICE("Chi2 negative"); - return kalman_fitter_status::ERROR_UPDATER_CHI2_NEGATIVE; - } - - if (!std::isfinite(chi2_val)) { - TRACCC_ERROR_HOST_DEVICE("Chi2 infinite"); - return kalman_fitter_status::ERROR_UPDATER_CHI2_NOT_FINITE; - } - // Set the chi2 for this track and measurement trk_state.filtered_params().set_vector(filtered_vec); trk_state.filtered_params().set_covariance(filtered_cov); - trk_state.filtered_chi2() = chi2_val; if (math::fmod(trk_state.filtered_params().theta(), 2.f * constant::pi) == 0.f) { diff --git a/core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp b/core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp index f08fc9f8e4..e15c1a1f08 100644 --- a/core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp +++ b/core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp @@ -84,6 +84,27 @@ struct two_filters_smoother { const matrix_type& predicted_vec = bound_params.vector(); + matrix_type H = + measurements.at(trk_state.measurement_index()) + .subs.template projector(); + if (dim == 1) { + getter::element(H, 1u, 0u) = 0.f; + getter::element(H, 1u, 1u) = 0.f; + } + + // Spatial resolution (Measurement covariance) + matrix_type V; + edm::get_measurement_covariance( + measurements.at(trk_state.measurement_index()), V); + if (dim == 1) { + getter::element(V, 1u, 1u) = 1.f; + } + + TRACCC_DEBUG_HOST("Measurement position: " << meas_local); + TRACCC_DEBUG_HOST("Measurement variance:\n" << V); + TRACCC_DEBUG_HOST("Predicted residual: " << meas_local - + H * predicted_vec); + // Predicted covaraince of bound track parameters const matrix_type& predicted_cov = bound_params.covariance(); @@ -134,29 +155,8 @@ struct two_filters_smoother { // Wrap the phi and theta angles in their valid ranges normalize_angles(trk_state.smoothed_params()); - matrix_type H = - measurements.at(trk_state.measurement_index()) - .subs.template projector(); - if (dim == 1) { - getter::element(H, 1u, 0u) = 0.f; - getter::element(H, 1u, 1u) = 0.f; - } - const matrix_type residual_smt = meas_local - H * smoothed_vec; - // Spatial resolution (Measurement covariance) - matrix_type V; - edm::get_measurement_covariance( - measurements.at(trk_state.measurement_index()), V); - if (dim == 1) { - getter::element(V, 1u, 1u) = 1.f; - } - - TRACCC_DEBUG_HOST("Measurement position: " << meas_local); - TRACCC_DEBUG_HOST("Measurement variance:\n" << V); - TRACCC_DEBUG_HOST("Predicted residual: " << meas_local - - H * predicted_vec); - // Eq (3.39) of "Pattern Recognition, Tracking and Vertex // Reconstruction in Particle Detectors" const matrix_type R_smt = @@ -204,7 +204,6 @@ struct two_filters_smoother { const auto I66 = matrix::identity>(); - const auto I_m = matrix::identity>(); const matrix_type projected_cov = algebra::matrix::transposed_product(predicted_cov, H); @@ -252,39 +251,38 @@ struct two_filters_smoother { bound_params.set_covariance(filtered_cov); - // Residual between measurement and (projected) filtered vector - const matrix_type residual = meas_local - H * filtered_vec; - - // Calculate backward chi2 - const matrix_type R = (I_m - H * K) * V; - // assert(matrix::determinant(R) != 0.f); // @TODO: This fails - assert(std::isfinite(matrix::determinant(R))); - const matrix_type<1, 1> chi2 = - algebra::matrix::transposed_product( - residual, matrix::inverse(R)) * - residual; - - const scalar chi2_val{getter::element(chi2, 0, 0)}; - - TRACCC_VERBOSE_HOST("Filtered residual: " << residual); - TRACCC_DEBUG_HOST("R:\n" << R); - TRACCC_DEBUG_HOST_DEVICE("det(R): %f", matrix::determinant(R)); - TRACCC_DEBUG_HOST("R_inv:\n" << matrix::inverse(R)); - TRACCC_VERBOSE_HOST_DEVICE("Chi2: %f", chi2_val); - - if (chi2_val < 0.f) { - TRACCC_ERROR_HOST_DEVICE("Chi2 negative: %f", chi2_val); - return kalman_fitter_status::ERROR_SMOOTHER_CHI2_NEGATIVE; - } - - if (!std::isfinite(chi2_val)) { - TRACCC_ERROR_HOST_DEVICE("Chi2 infinite"); - return kalman_fitter_status::ERROR_SMOOTHER_CHI2_NOT_FINITE; + { + // Calculate the chi square + const matrix_type R = + V - (H * predicted_cov * algebra::matrix::transpose(H)); + // Residual between measurement and predicted vector + const matrix_type residual = meas_local - H * predicted_vec; + const matrix_type<1, 1> chi2_mat = + algebra::matrix::transposed_product( + residual, matrix::inverse(R)) * + residual; + const scalar chi2_val = getter::element(chi2_mat, 0, 0); + + TRACCC_VERBOSE_HOST("Predicted residual: " << residual); + TRACCC_DEBUG_HOST("R:\n" << R); + TRACCC_DEBUG_HOST_DEVICE("det(R): %f", matrix::determinant(R)); + TRACCC_DEBUG_HOST("R_inv:\n" << matrix::inverse(R)); + TRACCC_VERBOSE_HOST_DEVICE("Chi2: %f", chi2_val); + + if (chi2_val < 0.f) { + TRACCC_ERROR_HOST_DEVICE("Chi2 negative: %f", chi2_val); + return kalman_fitter_status::ERROR_SMOOTHER_CHI2_NEGATIVE; + } + + if (!std::isfinite(chi2_val)) { + TRACCC_ERROR_HOST_DEVICE("Chi2 infinite"); + return kalman_fitter_status::ERROR_SMOOTHER_CHI2_NOT_FINITE; + } + + // Set backward chi2 + trk_state.backward_chi2() = chi2_val; } - // Set backward chi2 - trk_state.backward_chi2() = chi2_val; - // Wrap the phi and theta angles in their valid ranges normalize_angles(bound_params);