From 7f3c4b90ded9fe7f35446bb245b1a08d476e2037 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Wed, 27 Aug 2025 11:41:15 -0700 Subject: [PATCH 1/5] refactor `compute_midband_power` --- waveorder/focus.py | 70 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/waveorder/focus.py b/waveorder/focus.py index 38128f3c..38fcb521 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -3,11 +3,53 @@ import matplotlib.pyplot as plt import numpy as np +import torch from scipy.signal import peak_widths from waveorder import util +def compute_midband_power( + yx_array: torch.Tensor, + NA_det: float, + lambda_ill: float, + pixel_size: float, + midband_fractions: tuple[float, float] = (0.125, 0.25), +) -> torch.Tensor: + """Compute midband spatial frequency power by summing over a 2D midband donut. + + Parameters + ---------- + yx_array : torch.Tensor + 2D tensor in (Y, X) order. + NA_det : float + Detection NA. + lambda_ill : float + Illumination wavelength. + Units are arbitrary, but must match [pixel_size]. + pixel_size : float + Object-space pixel size = camera pixel size / magnification. + Units are arbitrary, but must match [lambda_ill]. + midband_fractions : tuple[float, float], optional + The minimum and maximum fraction of the cutoff frequency that define the midband. + Default is (0.125, 0.25). + + Returns + ------- + torch.Tensor + Sum of absolute FFT values in the midband region. + """ + _, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size) + frr = torch.tensor(np.sqrt(fxx**2 + fyy**2)) + xy_abs_fft = torch.abs(torch.fft.fftn(yx_array)) + cutoff = 2 * NA_det / lambda_ill + mask = torch.logical_and( + frr > cutoff * midband_fractions[0], + frr < cutoff * midband_fractions[1], + ) + return torch.sum(xy_abs_fft[mask]) + + def focus_from_transverse_band( zyx_array, NA_det, @@ -79,24 +121,20 @@ def focus_from_transverse_band( ) return 0 - # Calculate coordinates - _, Y, X = zyx_array.shape - _, _, fxx, fyy = util.gen_coordinate((Y, X), pixel_size) - frr = np.sqrt(fxx**2 + fyy**2) - - # Calculate fft - xy_abs_fft = np.abs(np.fft.fftn(zyx_array, axes=(1, 2))) - - # Calculate midband mask - cutoff = 2 * NA_det / lambda_ill - midband_mask = np.logical_and( - frr > cutoff * midband_fractions[0], - frr < cutoff * midband_fractions[1], + # Calculate midband power for each slice + midband_sum = np.array( + [ + compute_midband_power( + torch.from_numpy(zyx_array[z]), + NA_det, + lambda_ill, + pixel_size, + midband_fractions, + ).numpy() + for z in range(zyx_array.shape[0]) + ] ) - # Find slice index with min/max power in midband - midband_sum = np.sum(xy_abs_fft[:, midband_mask], axis=1) - if polynomial_fit_order is None: peak_index = minmaxfunc(midband_sum) else: From 19c094cb39726a951979aeaff6e014a2a9ea6332 Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Wed, 27 Aug 2025 11:46:27 -0700 Subject: [PATCH 2/5] test compute_midband_power --- tests/test_focus_estimator.py | 71 +++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/tests/test_focus_estimator.py b/tests/test_focus_estimator.py index 5e4442df..8c093195 100644 --- a/tests/test_focus_estimator.py +++ b/tests/test_focus_estimator.py @@ -1,5 +1,6 @@ import numpy as np import pytest +import torch from waveorder import focus @@ -85,3 +86,73 @@ def test_focus_estimator_snr(tmp_path): assert plot_path.exists() if slice is not None: assert np.abs(slice - 10) <= 2 + + +def test_compute_midband_power(): + """Test the compute_midband_power function with torch tensors.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + midband_fractions = (0.125, 0.25) + + # Create test data + np.random.seed(42) + test_2d_np = np.random.random((64, 64)).astype(np.float32) + test_2d_torch = torch.from_numpy(test_2d_np) + + # Test the compute_midband_power function + result = focus.compute_midband_power( + test_2d_torch, NA_det, lambda_ill, ps, midband_fractions + ) + + # Check result properties + assert isinstance(result, torch.Tensor) + assert result.shape == torch.Size([]) # scalar tensor + assert result.item() > 0 # should be positive + + # Test with different midband fractions + result2 = focus.compute_midband_power( + test_2d_torch, NA_det, lambda_ill, ps, (0.1, 0.3) + ) + assert isinstance(result2, torch.Tensor) + assert result2.item() > 0 + + # Results should be different for different bands + assert abs(result.item() - result2.item()) > 1e-6 + + +def test_compute_midband_power_consistency(): + """Test that compute_midband_power is consistent with focus_from_transverse_band.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + midband_fractions = (0.125, 0.25) + + # Create 3D test data + np.random.seed(42) + test_3d = np.random.random((3, 32, 32)).astype(np.float32) + + # Test focus_from_transverse_band still works + focus_slice = focus.focus_from_transverse_band( + test_3d, NA_det, lambda_ill, ps, midband_fractions + ) + + assert isinstance(focus_slice, (int, np.integer)) + assert 0 <= focus_slice < test_3d.shape[0] + + # Manually compute midband power for each slice + manual_powers = [] + for z in range(test_3d.shape[0]): + power = focus.compute_midband_power( + torch.from_numpy(test_3d[z]), + NA_det, + lambda_ill, + ps, + midband_fractions, + ) + manual_powers.append(power.item()) + + expected_focus_slice = np.argmax(manual_powers) + assert focus_slice == expected_focus_slice From 079fa09febbfcf436a4a7ace3c0a98657bb5ff3e Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Tue, 2 Sep 2025 11:07:11 -0700 Subject: [PATCH 3/5] Allow float values for z_focus_offset in settings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Change z_focus_offset type from Union[int, Literal["auto"]] to Union[float, Literal["auto"]] - This enables sub-pixel precision for focus offset values in 2D phase reconstruction - Addresses issue #470 for coarsely sampled slice improvements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- waveorder/cli/settings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/waveorder/cli/settings.py b/waveorder/cli/settings.py index e28e3a15..2a885f25 100644 --- a/waveorder/cli/settings.py +++ b/waveorder/cli/settings.py @@ -62,7 +62,7 @@ class FourierTransferFunctionSettings(MyBaseModel): yx_pixel_size: PositiveFloat = 6.5 / 20 z_pixel_size: PositiveFloat = 2.0 z_padding: NonNegativeInt = 0 - z_focus_offset: Union[int, Literal["auto"]] = 0 + z_focus_offset: Union[float, Literal["auto"]] = 0 index_of_refraction_media: PositiveFloat = 1.3 numerical_aperture_detection: PositiveFloat = 1.2 From 51d438bc8454184b5a62b8518e711c83cf0747dd Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Tue, 2 Sep 2025 11:07:56 -0700 Subject: [PATCH 4/5] Add sub-pixel precision to focus_from_transverse_band MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add enable_subpixel_precision parameter (default False for backward compatibility) - Use polynomial derivative analysis to find continuous extrema when enabled - Return float focus indices when sub-pixel precision is enabled - Update plotting function to handle float indices via interpolation - Enhance docstring with new parameter and return type information This enables more accurate focus detection for coarsely sampled data by finding focus positions between discrete slice indices. Addresses issue #470. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- waveorder/focus.py | 62 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 56 insertions(+), 6 deletions(-) diff --git a/waveorder/focus.py b/waveorder/focus.py index 38fcb521..c4401d8d 100644 --- a/waveorder/focus.py +++ b/waveorder/focus.py @@ -60,6 +60,7 @@ def focus_from_transverse_band( polynomial_fit_order: Optional[int] = None, plot_path: Optional[str] = None, threshold_FWHM: float = 0, + enable_subpixel_precision: bool = False, ): """Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band. @@ -91,12 +92,16 @@ def focus_from_transverse_band( The default value, 0, applies no threshold, and the maximum midband power is always considered in focus. For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus. If the peak does not meet this threshold, the function returns None. + enable_subpixel_precision: bool, optional + If True and polynomial_fit_order is provided, enables sub-pixel precision focus detection + by finding the continuous extremum of the polynomial fit. Default is False for backward compatibility. Returns - ------ - slice : int or None + ------- + slice : int, float, or None If peak's FWHM > peak_width_threshold: - return the index of the in-focus slice + return the index of the in-focus slice (int if enable_subpixel_precision=False, + float if enable_subpixel_precision=True and polynomial_fit_order is not None) else: return None @@ -140,9 +145,44 @@ def focus_from_transverse_band( else: x = np.arange(len(midband_sum)) coeffs = np.polyfit(x, midband_sum, polynomial_fit_order) - peak_index = minmaxfunc(np.poly1d(coeffs)(x)) + poly_func = np.poly1d(coeffs) + + if enable_subpixel_precision: + # Find the continuous extremum using derivative + poly_deriv = np.polyder(coeffs) + # Find roots of the derivative (critical points) + critical_points = np.roots(poly_deriv) + + # Filter for real roots within the data range + real_critical_points = [] + for cp in critical_points: + if np.isreal(cp) and 0 <= cp.real < len(midband_sum): + real_critical_points.append(cp.real) + + if real_critical_points: + # Evaluate the polynomial at critical points to find extremum + critical_values = [ + poly_func(cp) for cp in real_critical_points + ] + if mode == "max": + best_idx = np.argmax(critical_values) + else: # mode == "min" + best_idx = np.argmin(critical_values) + peak_index = real_critical_points[best_idx] + else: + # Fall back to discrete maximum if no valid critical points + peak_index = float(minmaxfunc(poly_func(x))) + else: + peak_index = minmaxfunc(poly_func(x)) - peak_results = peak_widths(midband_sum, [peak_index]) + # For peak width calculation, use integer peak index + if enable_subpixel_precision and polynomial_fit_order is not None: + # Use the closest integer index for peak width calculation + integer_peak_index = int(np.round(peak_index)) + else: + integer_peak_index = int(peak_index) + + peak_results = peak_widths(midband_sum, [integer_peak_index]) peak_FWHM = peak_results[0][0] if peak_FWHM >= threshold_FWHM: @@ -215,9 +255,19 @@ def _plot_focus_metric( ): _, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.plot(midband_sum, "-k") + + # Handle floating-point peak_index for plotting + if isinstance(peak_index, float) and not peak_index.is_integer(): + # Use interpolation to get the y-value at the floating-point x-position + peak_y_value = np.interp( + peak_index, np.arange(len(midband_sum)), midband_sum + ) + else: + peak_y_value = midband_sum[int(peak_index)] + ax.plot( peak_index, - midband_sum[peak_index], + peak_y_value, "go" if in_focus_index is not None else "ro", ) ax.hlines(*peak_results[1:], color="k", linestyles="dashed") From 3af6422a2c4cbf79686f0cc07215500482c4621d Mon Sep 17 00:00:00 2001 From: Talon Chandler Date: Tue, 2 Sep 2025 11:08:40 -0700 Subject: [PATCH 5/5] Add comprehensive tests for non-integer focus support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_subpixel_precision: Validates float focus detection with synthetic data - test_subpixel_precision_backward_compatibility: Ensures default behavior unchanged - test_subpixel_precision_with_plotting: Tests plotting with float indices - test_z_focus_offset_float_type: Validates settings accept float z_focus_offset - test_position_list_with_float_offset: Tests position calculation pipeline All tests verify both functionality and backward compatibility. Ensures robust implementation of issue #470 requirements. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- tests/test_focus_estimator.py | 153 ++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) diff --git a/tests/test_focus_estimator.py b/tests/test_focus_estimator.py index 8c093195..2460917c 100644 --- a/tests/test_focus_estimator.py +++ b/tests/test_focus_estimator.py @@ -156,3 +156,156 @@ def test_compute_midband_power_consistency(): expected_focus_slice = np.argmax(manual_powers) assert focus_slice == expected_focus_slice + + +def test_subpixel_precision(): + """Test that sub-pixel precision returns float values when enabled.""" + # Test parameters + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create synthetic test data with a clear peak between slices + z_size, y_size, x_size = 11, 64, 64 + x = np.linspace(-1, 1, x_size) + y = np.linspace(-1, 1, y_size) + z = np.linspace(-5, 5, z_size) + + # Create a 3D Gaussian that peaks between slice indices + test_data = np.zeros((z_size, y_size, x_size)) + true_peak_z = 5.3 # Peak between slices 5 and 6 + + for i, z_val in enumerate(z): + # Create Gaussian centered at true_peak_z position in physical space + gaussian_2d = np.exp( + -( + (x[None, :] ** 2 + y[:, None] ** 2) + + (z_val - (true_peak_z - 5)) ** 2 + ) + ) + test_data[i] = gaussian_2d + + # Test without sub-pixel precision (should return integer) + focus_slice_int = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=False, + ) + assert isinstance(focus_slice_int, (int, np.integer)) + + # Test with sub-pixel precision (should return float) + focus_slice_float = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=True, + ) + + # Should return a float + assert isinstance(focus_slice_float, float) + + # Should be close to the true peak position + assert abs(focus_slice_float - true_peak_z) < 1.0 # Within 1 slice + + # Sub-pixel result should be different from integer result + assert focus_slice_float != focus_slice_int + + +def test_subpixel_precision_backward_compatibility(): + """Test that default behavior (integer results) is preserved.""" + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create simple test data + test_data = np.random.random((5, 32, 32)).astype(np.float32) + + # Test default behavior (should return integer) + focus_slice = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + ) + + assert isinstance(focus_slice, (int, np.integer)) + + +def test_subpixel_precision_with_plotting(tmp_path): + """Test that sub-pixel precision works with plotting.""" + ps = 6.5 / 100 + lambda_ill = 0.532 + NA_det = 1.4 + + # Create test data + test_data = np.random.random((7, 32, 32)).astype(np.float32) + plot_path = tmp_path / "subpixel_test.pdf" + + # Should work without errors + focus_slice = focus.focus_from_transverse_band( + test_data, + NA_det, + lambda_ill, + ps, + polynomial_fit_order=4, + enable_subpixel_precision=True, + plot_path=str(plot_path), + ) + + assert isinstance(focus_slice, float) + assert plot_path.exists() + + +def test_z_focus_offset_float_type(): + """Test that z_focus_offset can accept float values in settings.""" + from waveorder.cli.settings import FourierTransferFunctionSettings + + # Test that float values are accepted + settings = FourierTransferFunctionSettings(z_focus_offset=1.5) + assert settings.z_focus_offset == 1.5 + assert isinstance(settings.z_focus_offset, float) + + # Test that "auto" still works + settings_auto = FourierTransferFunctionSettings(z_focus_offset="auto") + assert settings_auto.z_focus_offset == "auto" + + # Test that integers are converted to float + settings_int = FourierTransferFunctionSettings(z_focus_offset=2) + assert settings_int.z_focus_offset == 2 + assert isinstance(settings_int.z_focus_offset, (int, float)) + + +def test_position_list_with_float_offset(): + """Test that _position_list_from_shape_scale_offset works correctly with float offsets.""" + from waveorder.cli.compute_transfer_function import ( + _position_list_from_shape_scale_offset, + ) + + # Test integer offset + pos_int = _position_list_from_shape_scale_offset(5, 1.0, 0) + expected_int = [2.0, 1.0, 0.0, -1.0, -2.0] + assert pos_int == expected_int + + # Test float offset + pos_float = _position_list_from_shape_scale_offset(5, 1.0, 0.5) + expected_float = [2.5, 1.5, 0.5, -0.5, -1.5] + assert pos_float == expected_float + + # Verify the difference is exactly the offset + import numpy as np + + diff = np.array(pos_float) - np.array(pos_int) + assert np.allclose(diff, 0.5) + + # Test with different scale and offset + pos_scaled = _position_list_from_shape_scale_offset(4, 2.0, 0.3) + # shape=4, shape//2=2, so indices are [0,1,2,3], + # positions are [(-0+2+0.3)*2, (-1+2+0.3)*2, (-2+2+0.3)*2, (-3+2+0.3)*2] = [4.6, 2.6, 0.6, -1.4] + expected_scaled = [4.6, 2.6, 0.6, -1.4] + assert np.allclose(pos_scaled, expected_scaled)