Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,156 @@ def test_compute_midband_power_consistency():

expected_focus_slice = np.argmax(manual_powers)
assert focus_slice == expected_focus_slice


def test_subpixel_precision():
"""Test that sub-pixel precision returns float values when enabled."""
# Test parameters
ps = 6.5 / 100
lambda_ill = 0.532
NA_det = 1.4

# Create synthetic test data with a clear peak between slices
z_size, y_size, x_size = 11, 64, 64
x = np.linspace(-1, 1, x_size)
y = np.linspace(-1, 1, y_size)
z = np.linspace(-5, 5, z_size)

# Create a 3D Gaussian that peaks between slice indices
test_data = np.zeros((z_size, y_size, x_size))
true_peak_z = 5.3 # Peak between slices 5 and 6

for i, z_val in enumerate(z):
# Create Gaussian centered at true_peak_z position in physical space
gaussian_2d = np.exp(
-(
(x[None, :] ** 2 + y[:, None] ** 2)
+ (z_val - (true_peak_z - 5)) ** 2
)
)
test_data[i] = gaussian_2d

# Test without sub-pixel precision (should return integer)
focus_slice_int = focus.focus_from_transverse_band(
test_data,
NA_det,
lambda_ill,
ps,
polynomial_fit_order=4,
enable_subpixel_precision=False,
)
assert isinstance(focus_slice_int, (int, np.integer))

# Test with sub-pixel precision (should return float)
focus_slice_float = focus.focus_from_transverse_band(
test_data,
NA_det,
lambda_ill,
ps,
polynomial_fit_order=4,
enable_subpixel_precision=True,
)

# Should return a float
assert isinstance(focus_slice_float, float)

# Should be close to the true peak position
assert abs(focus_slice_float - true_peak_z) < 1.0 # Within 1 slice

# Sub-pixel result should be different from integer result
assert focus_slice_float != focus_slice_int


def test_subpixel_precision_backward_compatibility():
"""Test that default behavior (integer results) is preserved."""
ps = 6.5 / 100
lambda_ill = 0.532
NA_det = 1.4

# Create simple test data
test_data = np.random.random((5, 32, 32)).astype(np.float32)

# Test default behavior (should return integer)
focus_slice = focus.focus_from_transverse_band(
test_data,
NA_det,
lambda_ill,
ps,
polynomial_fit_order=4,
)

assert isinstance(focus_slice, (int, np.integer))


def test_subpixel_precision_with_plotting(tmp_path):
"""Test that sub-pixel precision works with plotting."""
ps = 6.5 / 100
lambda_ill = 0.532
NA_det = 1.4

# Create test data
test_data = np.random.random((7, 32, 32)).astype(np.float32)
plot_path = tmp_path / "subpixel_test.pdf"

# Should work without errors
focus_slice = focus.focus_from_transverse_band(
test_data,
NA_det,
lambda_ill,
ps,
polynomial_fit_order=4,
enable_subpixel_precision=True,
plot_path=str(plot_path),
)

assert isinstance(focus_slice, float)
assert plot_path.exists()


def test_z_focus_offset_float_type():
"""Test that z_focus_offset can accept float values in settings."""
from waveorder.cli.settings import FourierTransferFunctionSettings

# Test that float values are accepted
settings = FourierTransferFunctionSettings(z_focus_offset=1.5)
assert settings.z_focus_offset == 1.5
assert isinstance(settings.z_focus_offset, float)

# Test that "auto" still works
settings_auto = FourierTransferFunctionSettings(z_focus_offset="auto")
assert settings_auto.z_focus_offset == "auto"

# Test that integers are converted to float
settings_int = FourierTransferFunctionSettings(z_focus_offset=2)
assert settings_int.z_focus_offset == 2
assert isinstance(settings_int.z_focus_offset, (int, float))


def test_position_list_with_float_offset():
"""Test that _position_list_from_shape_scale_offset works correctly with float offsets."""
from waveorder.cli.compute_transfer_function import (
_position_list_from_shape_scale_offset,
)

# Test integer offset
pos_int = _position_list_from_shape_scale_offset(5, 1.0, 0)
expected_int = [2.0, 1.0, 0.0, -1.0, -2.0]
assert pos_int == expected_int

# Test float offset
pos_float = _position_list_from_shape_scale_offset(5, 1.0, 0.5)
expected_float = [2.5, 1.5, 0.5, -0.5, -1.5]
assert pos_float == expected_float

# Verify the difference is exactly the offset
import numpy as np

diff = np.array(pos_float) - np.array(pos_int)
assert np.allclose(diff, 0.5)

# Test with different scale and offset
pos_scaled = _position_list_from_shape_scale_offset(4, 2.0, 0.3)
# shape=4, shape//2=2, so indices are [0,1,2,3],
# positions are [(-0+2+0.3)*2, (-1+2+0.3)*2, (-2+2+0.3)*2, (-3+2+0.3)*2] = [4.6, 2.6, 0.6, -1.4]
expected_scaled = [4.6, 2.6, 0.6, -1.4]
assert np.allclose(pos_scaled, expected_scaled)
2 changes: 1 addition & 1 deletion waveorder/cli/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class FourierTransferFunctionSettings(MyBaseModel):
yx_pixel_size: PositiveFloat = 6.5 / 20
z_pixel_size: PositiveFloat = 2.0
z_padding: NonNegativeInt = 0
z_focus_offset: Union[int, Literal["auto"]] = 0
z_focus_offset: Union[float, Literal["auto"]] = 0
index_of_refraction_media: PositiveFloat = 1.3
numerical_aperture_detection: PositiveFloat = 1.2

Expand Down
62 changes: 56 additions & 6 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def focus_from_transverse_band(
polynomial_fit_order: Optional[int] = None,
plot_path: Optional[str] = None,
threshold_FWHM: float = 0,
enable_subpixel_precision: bool = False,
):
"""Estimates the in-focus slice from a 3D stack by optimizing a transverse spatial frequency band.

Expand Down Expand Up @@ -91,12 +92,16 @@ def focus_from_transverse_band(
The default value, 0, applies no threshold, and the maximum midband power is always considered in focus.
For values > 0, the peak's FWHM must be greater than the threshold for the slice to be considered in focus.
If the peak does not meet this threshold, the function returns None.
enable_subpixel_precision: bool, optional
If True and polynomial_fit_order is provided, enables sub-pixel precision focus detection
by finding the continuous extremum of the polynomial fit. Default is False for backward compatibility.

Returns
------
slice : int or None
-------
slice : int, float, or None
If peak's FWHM > peak_width_threshold:
return the index of the in-focus slice
return the index of the in-focus slice (int if enable_subpixel_precision=False,
float if enable_subpixel_precision=True and polynomial_fit_order is not None)
else:
return None

Expand Down Expand Up @@ -140,9 +145,44 @@ def focus_from_transverse_band(
else:
x = np.arange(len(midband_sum))
coeffs = np.polyfit(x, midband_sum, polynomial_fit_order)
peak_index = minmaxfunc(np.poly1d(coeffs)(x))
poly_func = np.poly1d(coeffs)

if enable_subpixel_precision:
# Find the continuous extremum using derivative
poly_deriv = np.polyder(coeffs)
# Find roots of the derivative (critical points)
critical_points = np.roots(poly_deriv)

# Filter for real roots within the data range
real_critical_points = []
for cp in critical_points:
if np.isreal(cp) and 0 <= cp.real < len(midband_sum):
real_critical_points.append(cp.real)

if real_critical_points:
# Evaluate the polynomial at critical points to find extremum
critical_values = [
poly_func(cp) for cp in real_critical_points
]
if mode == "max":
best_idx = np.argmax(critical_values)
else: # mode == "min"
best_idx = np.argmin(critical_values)
peak_index = real_critical_points[best_idx]
else:
# Fall back to discrete maximum if no valid critical points
peak_index = float(minmaxfunc(poly_func(x)))
else:
peak_index = minmaxfunc(poly_func(x))

peak_results = peak_widths(midband_sum, [peak_index])
# For peak width calculation, use integer peak index
if enable_subpixel_precision and polynomial_fit_order is not None:
# Use the closest integer index for peak width calculation
integer_peak_index = int(np.round(peak_index))
else:
integer_peak_index = int(peak_index)

peak_results = peak_widths(midband_sum, [integer_peak_index])
peak_FWHM = peak_results[0][0]

if peak_FWHM >= threshold_FWHM:
Expand Down Expand Up @@ -215,9 +255,19 @@ def _plot_focus_metric(
):
_, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.plot(midband_sum, "-k")

# Handle floating-point peak_index for plotting
if isinstance(peak_index, float) and not peak_index.is_integer():
# Use interpolation to get the y-value at the floating-point x-position
peak_y_value = np.interp(
peak_index, np.arange(len(midband_sum)), midband_sum
)
else:
peak_y_value = midband_sum[int(peak_index)]

ax.plot(
peak_index,
midband_sum[peak_index],
peak_y_value,
"go" if in_focus_index is not None else "ro",
)
ax.hlines(*peak_results[1:], color="k", linestyles="dashed")
Expand Down
Loading