Skip to content

Commit b9e8259

Browse files
committed
Mid-refactor
1 parent 2e729e5 commit b9e8259

File tree

1 file changed

+103
-94
lines changed

1 file changed

+103
-94
lines changed

python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py

Lines changed: 103 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -73,34 +73,34 @@ class BrightStarCutoutConnections(
7373
):
7474
"""Connections for BrightStarCutoutTask."""
7575

76-
refCat = PrerequisiteInput(
76+
ref_cat = PrerequisiteInput(
7777
name="gaia_dr3_20230707",
7878
storageClass="SimpleCatalog",
7979
doc="Reference catalog that contains bright star positions.",
8080
dimensions=("skypix",),
8181
multiple=True,
8282
deferLoad=True,
8383
)
84-
inputExposure = Input(
85-
name="calexp",
84+
input_image = Input(
85+
name="preliminary_visit_image",
8686
storageClass="ExposureF",
8787
doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.",
8888
dimensions=("visit", "detector"),
8989
)
90-
inputBackground = Input(
91-
name="calexpBackground",
90+
input_background = Input(
91+
name="preliminary_visit_image_background",
9292
storageClass="Background",
9393
doc="Background model for the input exposure, to be added back on during processing.",
9494
dimensions=("visit", "detector"),
9595
)
96-
extendedPsf = Input(
97-
name="extendedPsf2",
96+
extended_psf = Input(
97+
name="extended_psf",
9898
storageClass="ImageF",
9999
doc="Extended PSF model, built from stacking bright star cutouts.",
100100
dimensions=("band",),
101101
)
102-
brightStarStamps = Output(
103-
name="brightStarStamps",
102+
bright_star_stamps = Output(
103+
name="bright_star_stamps",
104104
storageClass="BrightStarStamps",
105105
doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
106106
dimensions=("visit", "detector"),
@@ -110,7 +110,7 @@ def __init__(self, *, config: "BrightStarCutoutConfig | None" = None):
110110
super().__init__(config=config)
111111
assert config is not None
112112
if not config.useExtendedPsf:
113-
self.inputs.remove("extendedPsf")
113+
self.inputs.remove("extended_psf")
114114

115115

116116
class BrightStarCutoutConfig(
@@ -124,18 +124,22 @@ class BrightStarCutoutConfig(
124124
doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.",
125125
default=[0, 18],
126126
)
127-
excludeArcsecRadius = Field[float](
128-
doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.",
127+
excludeRadiusArcsec = Field[float](
128+
doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.",
129129
default=5,
130130
)
131131
excludeMagRange = ListField[float](
132-
doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeArcsecRadius`` are not used.",
132+
doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.",
133133
default=[0, 20],
134134
)
135135
minAreaFraction = Field[float](
136136
doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.",
137137
default=0.1,
138138
)
139+
# offFrameMagLim = Field[float](
140+
# doc="Stars fainter than this limit are only included if they appear within the frame boundaries.",
141+
# default=15.0,
142+
# )
139143
badMaskPlanes = ListField[str](
140144
doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, "
141145
"optionally, fitting of the PSF.",
@@ -151,6 +155,8 @@ class BrightStarCutoutConfig(
151155
NEIGHBOR_MASK_PLANE,
152156
],
153157
)
158+
159+
# Stamp configuration
154160
stampSize = ListField[int](
155161
doc="Size of the stamps to be extracted, in pixels.",
156162
default=(251, 251),
@@ -179,18 +185,18 @@ class BrightStarCutoutConfig(
179185
"lanczos5": "Lanczos kernel of order 5",
180186
},
181187
)
182-
scalePsfModel = Field[bool](
183-
doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.",
184-
default=True,
185-
)
188+
# scalePsfModel = Field[bool](
189+
# doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.",
190+
# default=True,
191+
# )
186192

187193
# PSF Fitting
188194
useExtendedPsf = Field[bool](
189-
doc="Use the extended PSF model to normalize bright star cutouts.",
195+
doc="Use the extended PSF model to estimate the bright star cutout normalization factor.",
190196
default=False,
191197
)
192198
doFitPsf = Field[bool](
193-
doc="Fit a scaled PSF and a pedestal to each bright star cutout.",
199+
doc="Fit a scaled PSF and a simple background to each bright star cutout.",
194200
default=True,
195201
)
196202
useMedianVariance = Field[bool](
@@ -202,13 +208,9 @@ class BrightStarCutoutConfig(
202208
default=0.97,
203209
)
204210
fitIterations = Field[int](
205-
doc="Number of iterations over pedestal-gradient and scaling fit.",
211+
doc="Number of iterations to constrain PSF fitting.",
206212
default=5,
207213
)
208-
offFrameMagLim = Field[float](
209-
doc="Stars fainter than this limit are only included if they appear within the frame boundaries.",
210-
default=15.0,
211-
)
212214

213215
# Misc
214216
loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig](
@@ -226,7 +228,7 @@ class BrightStarCutoutTask(PipelineTask):
226228
catalog and extracts a stamp around each.
227229
Second, it shifts and warps each stamp to remove optical distortions and
228230
sample all stars on the same pixel grid.
229-
Finally, it optionally fits a PSF plus plane flux model to the cutout.
231+
Finally, it optionally fits a PSF and a simple background model.
230232
This final fitting procedure may be used to normalize each bright star
231233
stamp prior to stacking when producing extended PSF models.
232234
"""
@@ -237,69 +239,72 @@ class BrightStarCutoutTask(PipelineTask):
237239

238240
def __init__(self, initInputs=None, *args, **kwargs):
239241
super().__init__(*args, **kwargs)
240-
stampSize = Extent2D(*self.config.stampSize.list())
241-
stampRadius = floor(stampSize / 2)
242-
self.stampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stampRadius)
243-
paddedStampSize = stampSize * self.config.stampSizePadding
244-
self.paddedStampRadius = floor(paddedStampSize / 2)
245-
self.paddedStampBBox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
246-
self.paddedStampRadius
242+
stamp_size = Extent2D(*self.config.stampSize.list())
243+
stamp_radius = floor(stamp_size / 2)
244+
self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius)
245+
padded_stamp_size = stamp_size * self.config.stampSizePadding
246+
self.padded_stamp_radius = floor(padded_stamp_size / 2)
247+
self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
248+
self.padded_stamp_radius
247249
)
248-
self.modelScale = 1
249-
250-
def runQuantum(self, butlerQC, inputRefs, outputRefs):
251-
inputs = butlerQC.get(inputRefs)
252-
inputs["dataId"] = butlerQC.quantum.dataId
253-
refObjLoader = ReferenceObjectLoader(
254-
dataIds=[ref.datasetRef.dataId for ref in inputRefs.refCat],
255-
refCats=inputs.pop("refCat"),
256-
name=self.config.connections.refCat,
250+
# self.modelScale = 1
251+
252+
def runQuantum(self, butlerQC, input_refs, output_refs):
253+
inputs = butlerQC.get(input_refs)
254+
inputs["data_id"] = butlerQC.quantum.dataId
255+
ref_obj_loader = ReferenceObjectLoader(
256+
dataIds=[ref.datasetRef.dataId for ref in input_refs.ref_cat],
257+
refCats=inputs.pop("ref_cat"),
258+
name=self.config.connections.ref_cat,
257259
config=self.config.loadReferenceObjectsConfig,
258260
)
259-
extendedPsf = inputs.pop("extendedPsf", None)
260-
output = self.run(**inputs, extendedPsf=extendedPsf, refObjLoader=refObjLoader)
261+
extended_psf = inputs.pop("extended_psf", None)
262+
output = self.run(**inputs, extended_psf=extended_psf, ref_obj_loader=ref_obj_loader)
261263
# Only ingest Stamp if it exists; prevents ingesting an empty FITS file
262264
if output:
263-
butlerQC.put(output, outputRefs)
265+
butlerQC.put(output, output_refs)
264266

265267
@timeMethod
266268
def run(
267269
self,
268-
inputExposure: ExposureF,
269-
inputBackground: BackgroundList,
270-
extendedPsf: ImageF | None,
271-
refObjLoader: ReferenceObjectLoader,
272-
dataId: dict[str, Any] | DataCoordinate,
270+
input_image: ExposureF,
271+
input_background: BackgroundList,
272+
extended_psf: ImageF | None,
273+
ref_obj_loader: ReferenceObjectLoader,
274+
data_id: dict[str, Any] | DataCoordinate,
273275
):
274276
"""Identify bright stars within an exposure using a reference catalog,
275277
extract stamps around each, warp/shift stamps onto a common frame and
276278
then optionally fit a PSF plus plane model.
277279
278280
Parameters
279281
----------
280-
inputExposure : `~lsst.afw.image.ExposureF`
281-
The background-subtracted image to extract bright star stamps.
282-
inputBackground : `~lsst.afw.math.BackgroundList`
282+
input_image : `~lsst.afw.image.ExposureF`
283+
The background-subtracted image to extract bright star stamps from.
284+
input_background : `~lsst.afw.math.BackgroundList`
283285
The background model associated with the input exposure.
284-
refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
286+
extended_psf : `~lsst.afw.image.ImageF` | `None`
287+
The extended PSF model, built from stacking bright star cutouts.
288+
ref_obj_loader :
289+
`~lsst.meas.algorithms.ReferenceObjectLoader`, optional
285290
Loader to find objects within a reference catalog.
286-
dataId : `dict` or `~lsst.daf.butler.DataCoordinate`
287-
The dataId of the exposure that bright stars are extracted from.
291+
data_id : `dict` or `~lsst.daf.butler.DataCoordinate`
292+
The data ID of the detector that bright stars are extracted from.
288293
Both 'visit' and 'detector' will be persisted in the output data.
289294
290295
Returns
291296
-------
292-
brightStarResults : `~lsst.pipe.base.Struct`
297+
bright_star_stamps_results : `~lsst.pipe.base.Struct`
293298
Results as a struct with attributes:
294299
295-
``brightStarStamps``
300+
``bright_star_stamps``
296301
(`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
297302
"""
298-
wcs = inputExposure.getWcs()
299-
bbox = inputExposure.getBBox()
300-
warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
303+
wcs = input_image.getWcs()
304+
bbox = input_image.getBBox()
301305

302-
refCatBright = self._getRefCatBright(refObjLoader, wcs, bbox)
306+
# Get reference catalog stars
307+
ref_cat = self._get_ref_cat(ref_obj_loader, wcs, bbox)
303308
zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians)
304309
spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec]
305310
pixCoords = wcs.skyToPixel(spherePoints)
@@ -323,6 +328,7 @@ def run(
323328
)
324329

325330
# Loop over each bright star
331+
warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
326332
stamps, goodFracs, stamps_fitPsfResults = [], [], []
327333
for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore
328334
# Excluding faint stars that are not within the frame.
@@ -454,57 +460,60 @@ def star_in_frame(self, pixCoord, inputExposureBBox):
454460
return False
455461
return True
456462

457-
def _getRefCatBright(self, refObjLoader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
458-
"""Get a bright star subset of the reference catalog.
463+
def _get_ref_cat(self, ref_obj_loader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
464+
"""Get a subset of the reference catalog.
459465
460466
Trim the reference catalog to only those objects within the exposure
461467
bounding box dilated by half the bright star stamp size.
462468
This ensures all stars that overlap the exposure are included.
463469
464470
Parameters
465471
----------
466-
refObjLoader : `~lsst.meas.algorithms.ReferenceObjectLoader`
472+
ref_obj_loader : `~lsst.meas.algorithms.ReferenceObjectLoader`
467473
Loader to find objects within a reference catalog.
468474
wcs : `~lsst.afw.geom.SkyWcs`
469475
World coordinate system.
470476
bbox : `~lsst.geom.Box2I`
471-
Bounding box of the exposure.
477+
Bounding box of the image.
472478
473479
Returns
474480
-------
475-
refCatBright : `~astropy.table.Table`
476-
Bright star subset of the reference catalog.
481+
ref_cat : `~astropy.table.Table`
482+
Subset of the reference catalog.
477483
"""
478-
dilatedBBox = bbox.dilatedBy(self.paddedStampRadius)
479-
withinExposure = refObjLoader.loadPixelBox(dilatedBBox, wcs, filterName="phot_g_mean")
480-
refCatFull = withinExposure.refCat
481-
fluxField: str = withinExposure.fluxField
482-
483-
proxFluxRange = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value())
484-
brightFluxRange = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value())
485-
486-
subsetStars = (refCatFull[fluxField] > np.min((proxFluxRange[0], brightFluxRange[0]))) & (
487-
refCatFull[fluxField] < np.max((proxFluxRange[1], brightFluxRange[1]))
484+
# Get all stars within a dilated bbox
485+
dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius)
486+
within_dilated_bbox = ref_obj_loader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean")
487+
ref_cat_full = within_dilated_bbox.refCat
488+
flux_field: str = within_dilated_bbox.fluxField
489+
490+
# Trim to stars within the desired magnitude range
491+
flux_range_nearby = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value())
492+
flux_range_bright = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value())
493+
stars_magnitude_limited = (
494+
ref_cat_full[flux_field] > np.min((flux_range_nearby[0], flux_range_bright[0]))
495+
) & (ref_cat_full[flux_field] < np.max((flux_range_nearby[1], flux_range_bright[1])))
496+
ref_cat_subset = Table(
497+
ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=stars_magnitude_limited)
488498
)
489-
refCatSubset = Table(refCatFull.extract("id", "coord_ra", "coord_dec", fluxField, where=subsetStars))
490-
491-
proxStars = (refCatSubset[fluxField] >= proxFluxRange[0]) & (
492-
refCatSubset[fluxField] <= proxFluxRange[1]
499+
stars_nearby = (ref_cat_subset[flux_field] >= flux_range_nearby[0]) & (
500+
ref_cat_subset[flux_field] <= flux_range_nearby[1]
493501
)
494-
brightStars = (refCatSubset[fluxField] >= brightFluxRange[0]) & (
495-
refCatSubset[fluxField] <= brightFluxRange[1]
502+
stars_bright = (ref_cat_subset[flux_field] >= flux_range_bright[0]) & (
503+
ref_cat_subset[flux_field] <= flux_range_bright[1]
496504
)
497505

498-
coords = SkyCoord(refCatSubset["coord_ra"], refCatSubset["coord_dec"], unit="rad")
499-
excludeArcsecRadius = self.config.excludeArcsecRadius * u.arcsec # type: ignore
500-
refCatBrightIsolated = []
501-
for coord in cast(Iterable[SkyCoord], coords[brightStars]):
502-
neighbors = coords[proxStars]
503-
seps = coord.separation(neighbors).to(u.arcsec)
504-
tooClose = (seps > 0) & (seps <= excludeArcsecRadius) # not self matched
505-
refCatBrightIsolated.append(not tooClose.any())
506-
507-
refCatBright = cast(Table, refCatSubset[brightStars][refCatBrightIsolated])
506+
# Exclude stars with bright enough neighbors in a specified radius
507+
coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad")
508+
exclude_radius_arcsec = self.config.excludeRadiusArcsec * u.arcsec
509+
ref_cat_bright_isolated = []
510+
for coord in cast(Iterable[SkyCoord], coords[stars_bright]):
511+
neighbors = coords[stars_nearby]
512+
separations = coord.separation(neighbors).to(u.arcsec)
513+
too_close = (separations > 0) & (separations <= exclude_radius_arcsec) # ensure not self matched
514+
ref_cat_bright_isolated.append(not too_close.any())
515+
ref_cat_bright = cast(Table, ref_cat_subset[stars_bright][ref_cat_bright_isolated])
516+
breakpoint()
508517

509518
fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore
510519
refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes
@@ -652,7 +661,7 @@ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str,
652661

653662
for i in range(self.config.fitIterations):
654663
# Gradient-pedestal fitting:
655-
if i:
664+
if i > 0:
656665
# if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can
657666
# remove the star using the scale factor.
658667
stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821

0 commit comments

Comments
 (0)