Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions src/ddfacet_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def reorganise_convolution_filter(cf: npt.NDArray, oversampling: int) -> npt.NDA
for j in range(oversampling):
result[i, j, :, :] = cf[i::oversampling, j::oversampling]

return result.reshape(cf.shape)
return result


def find_max_support(radius: float, maxw: float, min_wave: float) -> int:
Expand Down Expand Up @@ -436,17 +436,13 @@ class FacetWKernelData:
# U′= U − W · l0
# V′= V − W · m0
clm: Tuple[float, float]
# The kernel support
support: int
# The kernel oversampling factor
oversampling: int
# W values for each plane
w_values: npt.NDArray[np.float64]
# Flattened W Kernels for each W plane
# Can be reshaped to (oversampling, oversampling, support, support)
# W Kernels for each W plane
# with shape (oversampling, oversampling, support, support)
w_kernels: List[npt.NDArray[np.complex64]]
# Flattened Conjugate W kernels for each W plane
# Can be reshaped to (oversampling, oversampling, support, support)
# with shape (oversampling, oversampling, support, support)
w_kernels_conj: List[npt.NDArray[np.complex64]]

@property
Expand Down Expand Up @@ -474,6 +470,34 @@ def nwplanes(self) -> int:
"""Number of w planes"""
return len(self.w_kernels)

@property
def oversampling(self) -> int:
"""Oversampling factor"""
assert len(self.w_kernels) > 0
return self.w_kernels[0].shape[0] # (o, o, s, s)

@property
def support(self) -> int:
"""Kernel Support"""
assert len(self.w_kernels) > 0
return self.w_kernels[0].shape[2] # (o, o, s, s)

@property
def w_kernels_ravel(self) -> List[npt.NDArray[np.complex64]]:
"""Return flattened versions of the w kernels"""
return [self._reshape_kernel(wk) for wk in self.w_kernels]

@property
def w_kernels_conj_ravel(self) -> List[npt.NDArray[np.complex64]]:
"""Return flattened versions of the conjugate w kernels"""
return [self._reshape_kernel(wkc) for wkc in self.w_kernels_conj]

def _reshape_kernel(
self, kernel: npt.NDArray[np.complex64]
) -> npt.NDArray[np.complex64]:
o, _, s, _ = kernel.shape
return kernel.reshape((o * s, o * s))


def facet_w_kernels(
nwplanes: int,
Expand Down Expand Up @@ -547,9 +571,7 @@ def facet_w_kernels(
fzw = np.require(fzw, dtype=np.complex64, requirements=["A", "C"])
fzw_conj = np.require(fzw_conj, dtype=np.complex64, requirements=["A", "C"])

return FacetWKernelData(
(l0, m0), (cl, cm), support, oversampling, w_values, [fzw], [fzw_conj]
)
return FacetWKernelData((l0, m0), (cl, cm), w_values, [fzw], [fzw_conj])

wkernels = []
wkernels_conj = []
Expand Down Expand Up @@ -591,6 +613,4 @@ def facet_w_kernels(
wkernels.append(fzw)
wkernels_conj.append(fzw_conj)

return FacetWKernelData(
(l0, m0), (cl, cm), support, oversampling, w_values, wkernels, wkernels_conj
)
return FacetWKernelData((l0, m0), (cl, cm), w_values, wkernels, wkernels_conj)
6 changes: 4 additions & 2 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def test_ddfacet_allclose(ddf_wkernel_data):
assert_array_almost_equal(wkernel_data.support, kw["Sup"], decimal=7)

assert len(wkernel_data.w_kernels) == len(ddf_wkernel_data["WPlanes"])
for this, ddf in zip(wkernel_data.w_kernels, ddf_wkernel_data["WPlanes"]):
for this, ddf in zip(wkernel_data.w_kernels_ravel, ddf_wkernel_data["WPlanes"]):
assert_array_almost_equal(this, ddf, decimal=7)

for this, ddf in zip(wkernel_data.w_kernels_conj, ddf_wkernel_data["WPlanesConj"]):
for this, ddf in zip(
wkernel_data.w_kernels_conj_ravel, ddf_wkernel_data["WPlanesConj"]
):
assert_array_almost_equal(this, ddf, decimal=7)