diff --git a/src/ddfacet_kernels/__init__.py b/src/ddfacet_kernels/__init__.py index 0e958b9..1eaed62 100644 --- a/src/ddfacet_kernels/__init__.py +++ b/src/ddfacet_kernels/__init__.py @@ -361,7 +361,7 @@ def reorganise_convolution_filter(cf: npt.NDArray, oversampling: int) -> npt.NDA for j in range(oversampling): result[i, j, :, :] = cf[i::oversampling, j::oversampling] - return result.reshape(cf.shape) + return result def find_max_support(radius: float, maxw: float, min_wave: float) -> int: @@ -436,17 +436,13 @@ class FacetWKernelData: # U′= U − W · l0 # V′= V − W · m0 clm: Tuple[float, float] - # The kernel support - support: int - # The kernel oversampling factor - oversampling: int # W values for each plane w_values: npt.NDArray[np.float64] - # Flattened W Kernels for each W plane - # Can be reshaped to (oversampling, oversampling, support, support) + # W Kernels for each W plane + # with shape (oversampling, oversampling, support, support) w_kernels: List[npt.NDArray[np.complex64]] # Flattened Conjugate W kernels for each W plane - # Can be reshaped to (oversampling, oversampling, support, support) + # with shape (oversampling, oversampling, support, support) w_kernels_conj: List[npt.NDArray[np.complex64]] @property @@ -474,6 +470,34 @@ def nwplanes(self) -> int: """Number of w planes""" return len(self.w_kernels) + @property + def oversampling(self) -> int: + """Oversampling factor""" + assert len(self.w_kernels) > 0 + return self.w_kernels[0].shape[0] # (o, o, s, s) + + @property + def support(self) -> int: + """Kernel Support""" + assert len(self.w_kernels) > 0 + return self.w_kernels[0].shape[2] # (o, o, s, s) + + @property + def w_kernels_ravel(self) -> List[npt.NDArray[np.complex64]]: + """Return flattened versions of the w kernels""" + return [self._reshape_kernel(wk) for wk in self.w_kernels] + + @property + def w_kernels_conj_ravel(self) -> List[npt.NDArray[np.complex64]]: + """Return flattened versions of the conjugate w kernels""" + return [self._reshape_kernel(wkc) for wkc in self.w_kernels_conj] + + def _reshape_kernel( + self, kernel: npt.NDArray[np.complex64] + ) -> npt.NDArray[np.complex64]: + o, _, s, _ = kernel.shape + return kernel.reshape((o * s, o * s)) + def facet_w_kernels( nwplanes: int, @@ -547,9 +571,7 @@ def facet_w_kernels( fzw = np.require(fzw, dtype=np.complex64, requirements=["A", "C"]) fzw_conj = np.require(fzw_conj, dtype=np.complex64, requirements=["A", "C"]) - return FacetWKernelData( - (l0, m0), (cl, cm), support, oversampling, w_values, [fzw], [fzw_conj] - ) + return FacetWKernelData((l0, m0), (cl, cm), w_values, [fzw], [fzw_conj]) wkernels = [] wkernels_conj = [] @@ -591,6 +613,4 @@ def facet_w_kernels( wkernels.append(fzw) wkernels_conj.append(fzw_conj) - return FacetWKernelData( - (l0, m0), (cl, cm), support, oversampling, w_values, wkernels, wkernels_conj - ) + return FacetWKernelData((l0, m0), (cl, cm), w_values, wkernels, wkernels_conj) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index e0f91ab..38b4e45 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -80,8 +80,10 @@ def test_ddfacet_allclose(ddf_wkernel_data): assert_array_almost_equal(wkernel_data.support, kw["Sup"], decimal=7) assert len(wkernel_data.w_kernels) == len(ddf_wkernel_data["WPlanes"]) - for this, ddf in zip(wkernel_data.w_kernels, ddf_wkernel_data["WPlanes"]): + for this, ddf in zip(wkernel_data.w_kernels_ravel, ddf_wkernel_data["WPlanes"]): assert_array_almost_equal(this, ddf, decimal=7) - for this, ddf in zip(wkernel_data.w_kernels_conj, ddf_wkernel_data["WPlanesConj"]): + for this, ddf in zip( + wkernel_data.w_kernels_conj_ravel, ddf_wkernel_data["WPlanesConj"] + ): assert_array_almost_equal(this, ddf, decimal=7)