diff --git a/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
new file mode 100644
index 000000000..2c612fe7e
--- /dev/null
+++ b/python/lsst/pipe/tasks/brightStarSubtraction/brightStarCutout.py
@@ -0,0 +1,885 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Extract bright star cutouts; normalize and warp, optionally fit the PSF."""
+
+__all__ = ["BrightStarCutoutConnections", "BrightStarCutoutConfig", "BrightStarCutoutTask"]
+
+import math
+from copy import deepcopy
+from typing import Any, Iterable, cast
+
+import astropy.units as u
+import numpy as np
+from astropy.coordinates import SkyCoord
+from astropy.table import Table
+from lsst.afw.cameraGeom import FIELD_ANGLE, FOCAL_PLANE, PIXELS
+from lsst.afw.detection import Footprint, FootprintSet, Threshold
+from lsst.afw.geom import SkyWcs, SpanSet, makeModifiedWcs
+from lsst.afw.geom.transformFactory import makeTransform
+from lsst.afw.image import ExposureF, ImageD, ImageF, MaskedImageF
+from lsst.afw.math import BackgroundList, FixedKernel, WarpingControl, warpImage
+from lsst.daf.butler import DataCoordinate
+from lsst.geom import (
+ AffineTransform,
+ Angle,
+ Box2I,
+ Extent2D,
+ Extent2I,
+ Point2D,
+ Point2I,
+ SpherePoint,
+ arcseconds,
+ floor,
+ radians,
+)
+from lsst.meas.algorithms import (
+ BrightStarStamp,
+ BrightStarStamps,
+ KernelPsf,
+ LoadReferenceObjectsConfig,
+ ReferenceObjectLoader,
+ WarpedPsf,
+)
+from lsst.pex.config import ChoiceField, ConfigField, Field, ListField
+from lsst.pipe.base import PipelineTask, PipelineTaskConfig, PipelineTaskConnections, Struct
+from lsst.pipe.base.connectionTypes import Input, Output, PrerequisiteInput
+from lsst.utils.timer import timeMethod
+
+NEIGHBOR_MASK_PLANE = "NEIGHBOR"
+
+
+class BrightStarCutoutConnections(
+ PipelineTaskConnections,
+ dimensions=("instrument", "visit", "detector"),
+):
+ """Connections for BrightStarCutoutTask."""
+
+ ref_cat = PrerequisiteInput(
+ name="gaia_dr3_20230707",
+ storageClass="SimpleCatalog",
+ doc="Reference catalog that contains bright star positions.",
+ dimensions=("skypix",),
+ multiple=True,
+ deferLoad=True,
+ )
+ input_image = Input(
+ name="preliminary_visit_image",
+ storageClass="ExposureF",
+ doc="Background-subtracted input exposure from which to extract bright star stamp cutouts.",
+ dimensions=("visit", "detector"),
+ )
+ input_background = Input(
+ name="preliminary_visit_image_background",
+ storageClass="Background",
+ doc="Background model for the input exposure, to be added back on during processing.",
+ dimensions=("visit", "detector"),
+ )
+ extended_psf = Input(
+ name="extended_psf",
+ storageClass="ImageF",
+ doc="Extended PSF model, built from stacking bright star cutouts.",
+ dimensions=("band",),
+ )
+ bright_star_stamps = Output(
+ name="bright_star_stamps",
+ storageClass="BrightStarStamps",
+ doc="Set of preprocessed postage stamp cutouts, each centered on a single bright star.",
+ dimensions=("visit", "detector"),
+ )
+
+ def __init__(self, *, config: "BrightStarCutoutConfig | None" = None):
+ super().__init__(config=config)
+ assert config is not None
+ if not config.useExtendedPsf:
+ self.inputs.remove("extended_psf")
+
+
+class BrightStarCutoutConfig(
+ PipelineTaskConfig,
+ pipelineConnections=BrightStarCutoutConnections,
+):
+ """Configuration parameters for BrightStarCutoutTask."""
+
+ # Star selection
+ magRange = ListField[float](
+ doc="Magnitude range in Gaia G. Cutouts will be made for all stars in this range.",
+ default=[0, 18],
+ )
+ excludeRadiusArcsec = Field[float](
+ doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.",
+ default=5,
+ )
+ excludeMagRange = ListField[float](
+ doc="Stars with a star in the range ``excludeMagRange`` mag in ``excludeRadiusArcsec`` are not used.",
+ default=[0, 20],
+ )
+ minAreaFraction = Field[float](
+ doc="Minimum fraction of the stamp area, post-masking, that must remain for a cutout to be retained.",
+ default=0.1,
+ )
+ # offFrameMagLim = Field[float](
+ # doc="Stars fainter than this limit are only included if they appear within the frame boundaries.",
+ # default=15.0,
+ # )
+ badMaskPlanes = ListField[str](
+ doc="Mask planes that identify excluded pixels for the calculation of ``minAreaFraction`` and, "
+ "optionally, fitting of the PSF.",
+ default=[
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ NEIGHBOR_MASK_PLANE,
+ ],
+ )
+
+ # Stamp configuration
+ stampSize = ListField[int](
+ doc="Size of the stamps to be extracted, in pixels.",
+ default=(251, 251),
+ )
+ stampSizePadding = Field[float](
+ doc="Multiplicative factor applied to the cutout stamp size, to guard against post-warp data loss.",
+ default=1.1,
+ )
+ warpingKernelName = ChoiceField[str](
+ doc="Warping kernel.",
+ default="lanczos5",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+ maskWarpingKernelName = ChoiceField[str](
+ doc="Warping kernel for mask.",
+ default="bilinear",
+ allowed={
+ "bilinear": "bilinear interpolation",
+ "lanczos3": "Lanczos kernel of order 3",
+ "lanczos4": "Lanczos kernel of order 4",
+ "lanczos5": "Lanczos kernel of order 5",
+ },
+ )
+ # scalePsfModel = Field[bool](
+ # doc="If True, uses a scale factor to bring the PSF model data to the same level as the star data.",
+ # default=True,
+ # )
+
+ # PSF Fitting
+ useExtendedPsf = Field[bool](
+ doc="Use the extended PSF model to estimate the bright star cutout normalization factor.",
+ default=False,
+ )
+ doFitPsf = Field[bool](
+ doc="Fit a scaled PSF and a simple background to each bright star cutout.",
+ default=True,
+ )
+ useMedianVariance = Field[bool](
+ doc="Use the median of the variance plane for PSF fitting.",
+ default=False,
+ )
+ psfMaskedFluxFracThreshold = Field[float](
+ doc="Maximum allowed fraction of masked PSF flux for PSF fitting to occur.",
+ default=0.97,
+ )
+ fitIterations = Field[int](
+ doc="Number of iterations to constrain PSF fitting.",
+ default=5,
+ )
+
+ # Misc
+ loadReferenceObjectsConfig = ConfigField[LoadReferenceObjectsConfig](
+ doc="Reference object loader for astrometric calibration.",
+ )
+
+
+class BrightStarCutoutTask(PipelineTask):
+ """Extract bright star cutouts; normalize and warp to the same pixel grid.
+
+ The BrightStarCutoutTask is used to extract, process, and store small image
+ cutouts (or "postage stamps") around bright stars.
+ This task essentially consists of three principal steps.
+ First, it identifies bright stars within an exposure using a reference
+ catalog and extracts a stamp around each.
+ Second, it shifts and warps each stamp to remove optical distortions and
+ sample all stars on the same pixel grid.
+ Finally, it optionally fits a PSF and a simple background model.
+ This final fitting procedure may be used to normalize each bright star
+ stamp prior to stacking when producing extended PSF models.
+ """
+
+ ConfigClass = BrightStarCutoutConfig
+ _DefaultName = "brightStarCutout"
+ config: BrightStarCutoutConfig
+
+ def __init__(self, initInputs=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ stamp_size = Extent2D(*self.config.stampSize.list())
+ stamp_radius = floor(stamp_size / 2)
+ self.stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(stamp_radius)
+ padded_stamp_size = stamp_size * self.config.stampSizePadding
+ self.padded_stamp_radius = floor(padded_stamp_size / 2)
+ self.padded_stamp_bbox = Box2I(corner=Point2I(0, 0), dimensions=Extent2I(1, 1)).dilatedBy(
+ self.padded_stamp_radius
+ )
+ # self.modelScale = 1
+
+ def runQuantum(self, butlerQC, input_refs, output_refs):
+ inputs = butlerQC.get(input_refs)
+ inputs["data_id"] = butlerQC.quantum.dataId
+ ref_obj_loader = ReferenceObjectLoader(
+ dataIds=[ref.datasetRef.dataId for ref in input_refs.ref_cat],
+ refCats=inputs.pop("ref_cat"),
+ name=self.config.connections.ref_cat,
+ config=self.config.loadReferenceObjectsConfig,
+ )
+ extended_psf = inputs.pop("extended_psf", None)
+ output = self.run(**inputs, extended_psf=extended_psf, ref_obj_loader=ref_obj_loader)
+ # Only ingest Stamp if it exists; prevents ingesting an empty FITS file
+ if output:
+ butlerQC.put(output, output_refs)
+
+ @timeMethod
+ def run(
+ self,
+ input_image: ExposureF,
+ input_background: BackgroundList,
+ extended_psf: ImageF | None,
+ ref_obj_loader: ReferenceObjectLoader,
+ data_id: dict[str, Any] | DataCoordinate,
+ ):
+ """Identify bright stars within an exposure using a reference catalog,
+ extract stamps around each, warp/shift stamps onto a common frame and
+ then optionally fit a PSF plus plane model.
+
+ Parameters
+ ----------
+ input_image : `~lsst.afw.image.ExposureF`
+ The background-subtracted image to extract bright star stamps from.
+ input_background : `~lsst.afw.math.BackgroundList`
+ The background model associated with the input exposure.
+ extended_psf : `~lsst.afw.image.ImageF` | `None`
+ The extended PSF model, built from stacking bright star cutouts.
+ ref_obj_loader :
+ `~lsst.meas.algorithms.ReferenceObjectLoader`, optional
+ Loader to find objects within a reference catalog.
+ data_id : `dict` or `~lsst.daf.butler.DataCoordinate`
+ The data ID of the detector that bright stars are extracted from.
+ Both 'visit' and 'detector' will be persisted in the output data.
+
+ Returns
+ -------
+ bright_star_stamps_results : `~lsst.pipe.base.Struct`
+ Results as a struct with attributes:
+
+ ``bright_star_stamps``
+ (`~lsst.meas.algorithms.brightStarStamps.BrightStarStamps`)
+ """
+ wcs = input_image.getWcs()
+ bbox = input_image.getBBox()
+
+ # Get reference catalog stars
+ ref_cat = self._get_ref_cat(ref_obj_loader, wcs, bbox)
+ zipRaDec = zip(refCatBright["coord_ra"] * radians, refCatBright["coord_dec"] * radians)
+ spherePoints = [SpherePoint(ra, dec) for ra, dec in zipRaDec]
+ pixCoords = wcs.skyToPixel(spherePoints)
+
+ # Restore original subtracted background
+ inputMI = inputExposure.getMaskedImage()
+ inputMI += inputBackground.getImage()
+
+ # Set up NEIGHBOR mask plane; associate footprints with stars
+ inputExposure.mask.addMaskPlane(NEIGHBOR_MASK_PLANE)
+ allFootprints, associations = self._associateFootprints(inputExposure, pixCoords, plane="DETECTED")
+
+ # TODO: If we eventually have better PhotoCalibs (eg FGCM), apply here
+ inputMI = inputExposure.getPhotoCalib().calibrateImage(inputMI, False)
+
+ # Set up transform
+ detector = inputExposure.detector
+ pixelScale = wcs.getPixelScale().asArcseconds() * arcseconds
+ pixToFocalPlaneTan = detector.getTransform(PIXELS, FIELD_ANGLE).then(
+ makeTransform(AffineTransform.makeScaling(1 / pixelScale.asRadians()))
+ )
+
+ # Loop over each bright star
+ warpingControl = WarpingControl(self.config.warpingKernelName, self.config.maskWarpingKernelName)
+ stamps, goodFracs, stamps_fitPsfResults = [], [], []
+ for starIndex, (obj, pixCoord) in enumerate(zip(refCatBright, pixCoords)): # type: ignore
+ # Excluding faint stars that are not within the frame.
+ if obj["mag"] > self.config.offFrameMagLim and not self.star_in_frame(pixCoord, bbox):
+ continue
+ footprintIndex = associations.get(starIndex, None)
+ stampMI = MaskedImageF(self.paddedStampBBox)
+
+ # Set NEIGHBOR footprints in the mask plane
+ if footprintIndex:
+ neighborFootprints = [fp for i, fp in enumerate(allFootprints) if i != footprintIndex]
+ self._setFootprints(inputMI, neighborFootprints, NEIGHBOR_MASK_PLANE)
+ else:
+ self._setFootprints(inputMI, allFootprints, NEIGHBOR_MASK_PLANE)
+
+ # Define linear shifting to recenter stamps
+ coordFocalPlaneTan = pixToFocalPlaneTan.applyForward(pixCoord) # center of warped star
+ shift = makeTransform(AffineTransform(Point2D(0, 0) - coordFocalPlaneTan))
+ angle = np.arctan2(coordFocalPlaneTan.getY(), coordFocalPlaneTan.getX()) * radians
+ rotation = makeTransform(AffineTransform.makeRotation(-angle))
+ pixToPolar = pixToFocalPlaneTan.then(shift).then(rotation)
+
+ # Apply the warp to the star stamp (in-place)
+ warpImage(stampMI, inputMI, pixToPolar, warpingControl)
+
+ # Trim to the base stamp size, check mask coverage, update metadata
+ stampMI = stampMI[self.stampBBox]
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+ goodFrac = np.sum(stampMI.mask.array & badMaskBitMask == 0) / stampMI.mask.array.size
+ goodFracs.append(goodFrac)
+ if goodFrac < self.config.minAreaFraction:
+ continue
+
+ # Fit a scaled PSF and a pedestal to each bright star cutout
+ psf = WarpedPsf(inputExposure.getPsf(), pixToPolar, warpingControl)
+ constantPsf = KernelPsf(FixedKernel(psf.computeKernelImage(Point2D(0, 0))))
+ if self.config.useExtendedPsf:
+ psfImage = deepcopy(extendedPsf) # Assumed to be warped, center at [0,0]
+ else:
+ psfImage = constantPsf.computeKernelImage(constantPsf.getAveragePosition())
+ # TODO: maybe we want to generate a smaller psf in case the following happens?
+ # The following could happen for when the user chooses small stampSize ~(50, 50)
+ if (
+ psfImage.array.shape[0] > stampMI.image.array.shape[0]
+ or psfImage.array.shape[1] > stampMI.image.array.shape[1]
+ ):
+ continue
+ # Computing an scale factor that brings the model to the similar level of the star.
+ self.computeModelScale(stampMI, psfImage)
+ psfImage.array *= self.modelScale # ####### model scale correction ########
+
+ fitPsfResults = {}
+
+ if self.config.doFitPsf:
+ fitPsfResults = self._fitPsf(stampMI, psfImage)
+ stamps_fitPsfResults.append(fitPsfResults)
+
+ # Save the stamp if the PSF fit was successful or no fit requested
+ if fitPsfResults or not self.config.doFitPsf:
+ distance_mm, theta_angle = self.star_location_on_focal(pixCoord, detector)
+
+ stamp = BrightStarStamp(
+ stamp_im=stampMI,
+ psf=constantPsf,
+ wcs=makeModifiedWcs(pixToPolar, wcs, False),
+ visit=cast(int, dataId["visit"]),
+ detector=cast(int, dataId["detector"]),
+ ref_id=obj["id"],
+ ref_mag=obj["mag"],
+ position=pixCoord,
+ focal_plane_radius=distance_mm,
+ focal_plane_angle=theta_angle, # TODO: add the lsst.geom.Angle here
+ scale=fitPsfResults.get("scale", None),
+ scale_err=fitPsfResults.get("scaleErr", None),
+ pedestal=fitPsfResults.get("pedestal", None),
+ pedestal_err=fitPsfResults.get("pedestalErr", None),
+ pedestal_scale_cov=fitPsfResults.get("pedestalScaleCov", None),
+ gradient_x=fitPsfResults.get("xGradient", None),
+ gradient_y=fitPsfResults.get("yGradient", None),
+ global_reduced_chi_squared=fitPsfResults.get("globalReducedChiSquared", None),
+ global_degrees_of_freedom=fitPsfResults.get("globalDegreesOfFreedom", None),
+ psf_reduced_chi_squared=fitPsfResults.get("psfReducedChiSquared", None),
+ psf_degrees_of_freedom=fitPsfResults.get("psfDegreesOfFreedom", None),
+ psf_masked_flux_fraction=fitPsfResults.get("psfMaskedFluxFrac", None),
+ )
+ print(
+ obj["mag"],
+ fitPsfResults.get("globalReducedChiSquared", None),
+ fitPsfResults.get("globalDegreesOfFreedom", None),
+ fitPsfResults.get("psfReducedChiSquared", None),
+ fitPsfResults.get("psfDegreesOfFreedom", None),
+ fitPsfResults.get("psfMaskedFluxFrac", None),
+ )
+ stamps.append(stamp)
+
+ self.log.info(
+ "Extracted %i bright star stamp%s. "
+ "Excluded %i star%s: insufficient area (%i), PSF fit failure (%i).",
+ len(stamps),
+ "" if len(stamps) == 1 else "s",
+ len(refCatBright) - len(stamps),
+ "" if len(refCatBright) - len(stamps) == 1 else "s",
+ np.sum(np.array(goodFracs) < self.config.minAreaFraction),
+ (
+ np.sum(np.isnan([x.get("pedestal", np.nan) for x in stamps_fitPsfResults]))
+ if self.config.doFitPsf
+ else 0
+ ),
+ )
+ brightStarStamps = BrightStarStamps(stamps)
+ return Struct(brightStarStamps=brightStarStamps)
+
+ def star_location_on_focal(self, pixCoord, detector):
+ star_focal_plane_coords = detector.transform(pixCoord, PIXELS, FOCAL_PLANE)
+ star_x_fp = star_focal_plane_coords.getX()
+ star_y_fp = star_focal_plane_coords.getY()
+ distance_mm = np.sqrt(star_x_fp**2 + star_y_fp**2)
+ theta_rad = math.atan2(star_y_fp, star_x_fp)
+ theta_angle = Angle(theta_rad, radians)
+ return distance_mm, theta_angle
+
+ def star_in_frame(self, pixCoord, inputExposureBBox):
+ if (
+ pixCoord[0] < 0
+ or pixCoord[1] < 0
+ or pixCoord[0] > inputExposureBBox.getDimensions()[0]
+ or pixCoord[1] > inputExposureBBox.getDimensions()[1]
+ ):
+ return False
+ return True
+
+ def _get_ref_cat(self, ref_obj_loader: ReferenceObjectLoader, wcs: SkyWcs, bbox: Box2I) -> Table:
+ """Get a subset of the reference catalog.
+
+ Trim the reference catalog to only those objects within the exposure
+ bounding box dilated by half the bright star stamp size.
+ This ensures all stars that overlap the exposure are included.
+
+ Parameters
+ ----------
+ ref_obj_loader : `~lsst.meas.algorithms.ReferenceObjectLoader`
+ Loader to find objects within a reference catalog.
+ wcs : `~lsst.afw.geom.SkyWcs`
+ World coordinate system.
+ bbox : `~lsst.geom.Box2I`
+ Bounding box of the image.
+
+ Returns
+ -------
+ ref_cat : `~astropy.table.Table`
+ Subset of the reference catalog.
+ """
+ # Get all stars within a dilated bbox
+ dilated_bbox = bbox.dilatedBy(self.padded_stamp_radius)
+ within_dilated_bbox = ref_obj_loader.loadPixelBox(dilated_bbox, wcs, filterName="phot_g_mean")
+ ref_cat_full = within_dilated_bbox.refCat
+ flux_field: str = within_dilated_bbox.fluxField
+
+ # Trim to stars within the desired magnitude range
+ flux_range_nearby = sorted(((self.config.excludeMagRange * u.ABmag).to(u.nJy)).to_value())
+ flux_range_bright = sorted(((self.config.magRange * u.ABmag).to(u.nJy)).to_value())
+ stars_magnitude_limited = (
+ ref_cat_full[flux_field] > np.min((flux_range_nearby[0], flux_range_bright[0]))
+ ) & (ref_cat_full[flux_field] < np.max((flux_range_nearby[1], flux_range_bright[1])))
+ ref_cat_subset = Table(
+ ref_cat_full.extract("id", "coord_ra", "coord_dec", flux_field, where=stars_magnitude_limited)
+ )
+ stars_nearby = (ref_cat_subset[flux_field] >= flux_range_nearby[0]) & (
+ ref_cat_subset[flux_field] <= flux_range_nearby[1]
+ )
+ stars_bright = (ref_cat_subset[flux_field] >= flux_range_bright[0]) & (
+ ref_cat_subset[flux_field] <= flux_range_bright[1]
+ )
+
+ # Exclude stars with bright enough neighbors in a specified radius
+ coords = SkyCoord(ref_cat_subset["coord_ra"], ref_cat_subset["coord_dec"], unit="rad")
+ exclude_radius_arcsec = self.config.excludeRadiusArcsec * u.arcsec
+ ref_cat_bright_isolated = []
+ for coord in cast(Iterable[SkyCoord], coords[stars_bright]):
+ neighbors = coords[stars_nearby]
+ separations = coord.separation(neighbors).to(u.arcsec)
+ too_close = (separations > 0) & (separations <= exclude_radius_arcsec) # ensure not self matched
+ ref_cat_bright_isolated.append(not too_close.any())
+ ref_cat_bright = cast(Table, ref_cat_subset[stars_bright][ref_cat_bright_isolated])
+ breakpoint()
+
+ fluxNanojansky = refCatBright[fluxField][:] * u.nJy # type: ignore
+ refCatBright["mag"] = fluxNanojansky.to(u.ABmag).to_value() # AB magnitudes
+
+ self.log.info(
+ "Identified %i of %i star%s which satisfy: frame overlap; in the range %s mag; no neighboring "
+ "stars within %s arcsec.",
+ len(refCatBright),
+ len(refCatFull),
+ "" if len(refCatFull) == 1 else "s",
+ self.config.magRange,
+ self.config.excludeArcsecRadius,
+ )
+
+ return refCatBright
+
+ def _associateFootprints(
+ self, inputExposure: ExposureF, pixCoords: list[Point2D], plane: str
+ ) -> tuple[list[Footprint], dict[int, int]]:
+ """Associate footprints from a given mask plane with specific objects.
+
+ Footprints from the given mask plane are associated with objects at the
+ coordinates provided, where possible.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure with a mask plane.
+ pixCoords : `list` [`~lsst.geom.Point2D`]
+ The pixel coordinates of the objects.
+ plane : `str`
+ The mask plane used to identify masked pixels.
+
+ Returns
+ -------
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints from the input exposure.
+ associations : `dict`[int, int]
+ Association indices between objects (key) and footprints (value).
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(plane), Threshold.BITMASK)
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+ footprints = footprintSet.getFootprints()
+ associations = {}
+ for starIndex, pixCoord in enumerate(pixCoords):
+ for footprintIndex, footprint in enumerate(footprints):
+ if footprint.contains(Point2I(pixCoord)):
+ associations[starIndex] = footprintIndex
+ break
+ self.log.debug(
+ "Associated %i of %i star%s to one each of the %i %s footprint%s.",
+ len(associations),
+ len(pixCoords),
+ "" if len(pixCoords) == 1 else "s",
+ len(footprints),
+ plane,
+ "" if len(footprints) == 1 else "s",
+ )
+ return footprints, associations
+
+ def _setFootprints(self, inputExposure: ExposureF, footprints: list, maskPlane: str):
+ """Set footprints in a given mask plane.
+
+ Parameters
+ ----------
+ inputExposure : `~lsst.afw.image.ExposureF`
+ The input exposure to modify.
+ footprints : `list` [`~lsst.afw.detection.Footprint`]
+ The footprints to set in the mask plane.
+ maskPlane : `str`
+ The mask plane to set the footprints in.
+
+ Notes
+ -----
+ This method modifies the ``inputExposure`` object in-place.
+ """
+ detThreshold = Threshold(inputExposure.mask.getPlaneBitMask(maskPlane), Threshold.BITMASK)
+ detThresholdValue = int(detThreshold.getValue())
+ footprintSet = FootprintSet(inputExposure.mask, detThreshold)
+
+ # Wipe any existing footprints in the mask plane
+ inputExposure.mask.clearMaskPlane(int(np.log2(detThresholdValue)))
+
+ # Set the footprints in the mask plane
+ footprintSet.setFootprints(footprints)
+ footprintSet.setMask(inputExposure.mask, maskPlane)
+
+ def _fitPsf(self, stampMI: MaskedImageF, psfImage: ImageD | ImageF) -> dict[str, Any]:
+ """Fit a scaled PSF and a pedestal to each bright star cutout.
+
+ Parameters
+ ----------
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The PSF model to fit.
+
+ Returns
+ -------
+ fitPsfResults : `dict`[`str`, `float`]
+ The result of the PSF fitting, with keys:
+
+ ``scale`` : `float`
+ The scale factor.
+ ``scaleErr`` : `float`
+ The error on the scale factor.
+ ``pedestal`` : `float`
+ The pedestal value.
+ ``pedestalErr`` : `float`
+ The error on the pedestal value.
+ ``pedestalScaleCov`` : `float`
+ The covariance between the pedestal and scale factor.
+ ``xGradient`` : `float`
+ The gradient in the x-direction.
+ ``yGradient`` : `float`
+ The gradient in the y-direction.
+ ``globalReducedChiSquared`` : `float`
+ The global reduced chi-squared goodness-of-fit.
+ ``globalDegreesOfFreedom`` : `int`
+ The global number of degrees of freedom.
+ ``psfReducedChiSquared`` : `float`
+ The PSF BBox reduced chi-squared goodness-of-fit.
+ ``psfDegreesOfFreedom`` : `int`
+ The PSF BBox number of degrees of freedom.
+ ``psfMaskedFluxFrac`` : `float`
+ The fraction of the PSF image flux masked by bad pixels.
+ """
+ badMaskBitMask = stampMI.mask.getPlaneBitMask(self.config.badMaskPlanes)
+
+ # Calculate the fraction of the PSF image flux masked by bad pixels
+ psfMaskedPixels = ImageF(psfImage.getBBox())
+ psfMaskedPixels.array[:, :] = (stampMI.mask[psfImage.getBBox()].array & badMaskBitMask).astype(bool)
+ psfMaskedFluxFrac = (
+ np.dot(psfImage.array.flat, psfMaskedPixels.array.flat).astype(np.float64) / psfImage.array.sum()
+ )
+ if psfMaskedFluxFrac > self.config.psfMaskedFluxFracThreshold:
+ return {} # Handle cases where the PSF image is mostly masked
+
+ # Generating good spans for gradient-pedestal fitting (including the star DETECTED mask).
+ gradientGoodSpans = self.generate_gradient_spans(stampMI, badMaskBitMask)
+ varianceData = gradientGoodSpans.flatten(stampMI.variance.array, stampMI.getXY0())
+ if self.config.useMedianVariance:
+ varianceData = np.median(varianceData)
+ sigmaData = np.sqrt(varianceData)
+
+ for i in range(self.config.fitIterations):
+ # Gradient-pedestal fitting:
+ if i > 0:
+ # if i > 0, there should be scale factor from the previous fit iteration. Therefore, we can
+ # remove the star using the scale factor.
+ stamp = self.remove_star(stampMI, scale, paddedPsfImage) # noqa: F821
+ else:
+ stamp = deepcopy(stampMI.image.array)
+
+ imageDataGr = gradientGoodSpans.flatten(stamp, stampMI.getXY0()) / sigmaData # B
+ nData = len(imageDataGr)
+ coefficientMatrix = np.ones((nData, 3), dtype=float) # A
+ coefficientMatrix[:, 0] /= sigmaData
+ coefficientMatrix[:, 1:] = gradientGoodSpans.indices().T
+ coefficientMatrix[:, 1] /= sigmaData
+ coefficientMatrix[:, 2] /= sigmaData
+
+ try:
+ grSolutions, grSumSquaredResiduals, *_ = np.linalg.lstsq(
+ coefficientMatrix, imageDataGr, rcond=None
+ )
+ covarianceMatrix = np.linalg.inv(
+ np.dot(coefficientMatrix.transpose(), coefficientMatrix)
+ ) # C
+ except np.linalg.LinAlgError:
+ return {} # Handle singular matrix errors
+ if grSumSquaredResiduals.size == 0:
+ return {} # Handle cases where sum of the squared residuals are empty
+
+ pedestal = grSolutions[0]
+ pedestalErr = np.sqrt(covarianceMatrix[0, 0])
+ scalePedestalCov = None
+ xGradient = grSolutions[2]
+ yGradient = grSolutions[1]
+
+ # Scale fitting:
+ updatedStampMI = deepcopy(stampMI)
+ self._removePedestalAndGradient(updatedStampMI, pedestal, xGradient, yGradient)
+
+ # Create a padded version of the input constant PSF image
+ paddedPsfImage = ImageF(updatedStampMI.getBBox())
+ paddedPsfImage[psfImage.getBBox()] = psfImage.convertF()
+
+ # Generating a mask plane while considering bad pixels in the psf model.
+ mask = self.add_psf_mask(paddedPsfImage, updatedStampMI)
+ # Create consistently masked data
+ scaleGoodSpans = self.generate_good_spans(mask, updatedStampMI.getBBox(), badMaskBitMask)
+
+ imageData = scaleGoodSpans.flatten(updatedStampMI.image.array, updatedStampMI.getXY0())
+ psfData = scaleGoodSpans.flatten(paddedPsfImage.array, paddedPsfImage.getXY0())
+ scaleCoefficientMatrix = psfData.reshape(psfData.shape[0], 1)
+
+ try:
+ scaleSolution, scaleSumSquaredResiduals, *_ = np.linalg.lstsq(
+ scaleCoefficientMatrix, imageData, rcond=None
+ )
+ except np.linalg.LinAlgError:
+ return {} # Handle singular matrix errors
+ if scaleSumSquaredResiduals.size == 0:
+ return {} # Handle cases where sum of the squared residuals are empty
+ scale = scaleSolution[0]
+ if scale <= 0:
+ return {} # Handle cases where the PSF scale fit has failed
+ # TODO: calculate scale error and store it.
+ scaleErr = None
+
+ scale *= self.modelScale # ####### model scale correction ########
+ nData = len(imageData)
+
+ # Calculate global (whole image) reduced chi-squared (scaling fit is assumed as the main fitting
+ # process here.)
+ globalChiSquared = np.sum(scaleSumSquaredResiduals)
+ globalDegreesOfFreedom = nData - 1
+ globalReducedChiSquared = np.float64(globalChiSquared / globalDegreesOfFreedom)
+
+ # Calculate PSF BBox reduced chi-squared
+ psfBBoxscaleGoodSpans = scaleGoodSpans.clippedTo(psfImage.getBBox())
+ psfBBoxscaleGoodSpansX, psfBBoxscaleGoodSpansY = psfBBoxscaleGoodSpans.indices()
+ psfBBoxData = psfBBoxscaleGoodSpans.flatten(stampMI.image.array, stampMI.getXY0())
+ paddedPsfImage.array /= self.modelScale # ####### model scale correction ########
+ psfBBoxModel = (
+ psfBBoxscaleGoodSpans.flatten(paddedPsfImage.array, stampMI.getXY0()) * scale
+ + pedestal
+ + psfBBoxscaleGoodSpansX * xGradient
+ + psfBBoxscaleGoodSpansY * yGradient
+ )
+ psfBBoxResiduals = (psfBBoxData - psfBBoxModel) ** 2 # / psfBBoxVariance
+ psfBBoxChiSquared = np.sum(psfBBoxResiduals)
+ psfBBoxDegreesOfFreedom = len(psfBBoxData) - 1
+ psfBBoxReducedChiSquared = psfBBoxChiSquared / psfBBoxDegreesOfFreedom
+ return dict(
+ scale=scale,
+ scaleErr=scaleErr,
+ pedestal=pedestal,
+ pedestalErr=pedestalErr,
+ xGradient=xGradient,
+ yGradient=yGradient,
+ pedestalScaleCov=scalePedestalCov,
+ globalReducedChiSquared=globalReducedChiSquared,
+ globalDegreesOfFreedom=globalDegreesOfFreedom,
+ psfReducedChiSquared=psfBBoxReducedChiSquared,
+ psfDegreesOfFreedom=psfBBoxDegreesOfFreedom,
+ psfMaskedFluxFrac=psfMaskedFluxFrac,
+ )
+
+ def add_psf_mask(self, psfImage, stampMI, maskZeros=True):
+ """
+ Creates a new mask by adding PSF bad pixels to an existing stamp mask.
+
+ This method identifies "bad" pixels in the PSF image (NaNs and
+ optionally zeros/non-positives) and adds them to a deep copy
+ of the input stamp's mask.
+
+ Args:
+ psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The PSF image object.
+ stampMI: `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ maskZeros (bool, optional): If True (default), mask pixels
+ where the PSF is <= 0. If False, only mask pixels < 0.
+
+ Returns:
+ Any: A new mask object (deep copy) with the PSF mask planes added.
+ """
+ cond = np.isnan(psfImage.array)
+ if maskZeros:
+ cond |= psfImage.array <= 0
+ else:
+ cond |= psfImage.array < 0
+ mask = deepcopy(stampMI.mask)
+ mask.array[cond] = np.bitwise_or(mask.array[cond], 1)
+ return mask
+
+ def _removePedestalAndGradient(self, stampMI, pedestal, xGradient, yGradient):
+ """Apply fitted pedestal and gradients to a single bright star stamp."""
+ stampBBox = stampMI.getBBox()
+ xGrid, yGrid = np.meshgrid(stampBBox.getX().arange(), stampBBox.getY().arange())
+ xPlane = ImageF((xGrid * xGradient).astype(np.float32), xy0=stampMI.getXY0())
+ yPlane = ImageF((yGrid * yGradient).astype(np.float32), xy0=stampMI.getXY0())
+ stampMI -= pedestal
+ stampMI -= xPlane
+ stampMI -= yPlane
+
+ def remove_star(self, stampMI, scale, psfImage):
+ """
+ Subtracts a scaled PSF model from a star image.
+
+ This performs a simple subtraction: `image - (psf * scale)`.
+
+ Args:
+ stampMI: `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ scale (float): The scaling factor to apply to the PSF.
+ psfImage: `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The PSF image object.
+
+ Returns:
+ np.ndarray: A new 2D numpy array containing the star-subtracted
+ image.
+ """
+ star_removed_cutout = stampMI.image.array - psfImage.array * scale
+ return star_removed_cutout
+
+ def computeModelScale(self, stampMI, psfImage):
+ """
+ Computes the scaling factor of the given model against a star.
+
+ Args:
+ stampMI : `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ psfImage : `~lsst.afw.image.ImageD` | `~lsst.afw.image.ImageF`
+ The given PSF model.
+ """
+ cond = stampMI.mask.array == 0
+ self.starMedianValue = np.median(stampMI.image.array[cond]).astype(np.float64)
+
+ psfPos = psfImage.array > 0
+
+ imageArray = stampMI.image.array - self.starMedianValue
+ imageArrayPos = imageArray > 0
+ self.modelScale = np.nanmean(imageArray[imageArrayPos]) / np.nanmean(psfImage.array[psfPos])
+
+ def generate_gradient_spans(self, stampMI, badMaskBitMask):
+ """
+ Generates spans of "good" pixels for gradient fitting.
+
+ This method creates a combined bitmask by OR-ing the provided
+ `badMaskBitMask` with the "DETECTED" plane from the stamp's mask.
+ It then calls `self.generate_good_spans` to find all pixel spans
+ not covered by this combined mask.
+
+ Args:
+ stampMI: `~lsst.afw.image.MaskedImageF`
+ The masked image of the bright star cutout.
+ badMaskBitMask (int): A bitmask representing planes to be
+ considered "bad" for gradient fitting.
+
+ Returns:
+ gradientGoodSpans: A SpanSet object containing the "good" spans.
+ """
+ detectedMaskBitMask = stampMI.mask.getPlaneBitMask("DETECTED")
+ gradientBitMask = np.bitwise_or(badMaskBitMask, detectedMaskBitMask)
+
+ gradientGoodSpans = self.generate_good_spans(stampMI.mask, stampMI.getBBox(), gradientBitMask)
+ return gradientGoodSpans
+
+ def generate_good_spans(self, mask, bBox, badBitMask):
+ """
+ Generates a SpanSet of "good" pixels from a mask.
+
+ This method identifies all spans within a given bounding box (`bBox`)
+ that are *not* flagged by the `badBitMask` in the provided `mask`.
+
+ Args:
+ mask (lsst.afw.image.MaskedImageF.mask): The mask object (e.g., `stampMI.mask`).
+ bBox (lsst.geom.Box2I): The bounding box of the image (e.g., `stampMI.getBBox()`).
+ badBitMask (int): The combined bitmask of planes to exclude.
+
+ Returns:
+ goodSpans: A SpanSet object representing all "good" spans.
+ """
+ badSpans = SpanSet.fromMask(mask, badBitMask)
+ goodSpans = SpanSet(bBox).intersectNot(badSpans)
+ return goodSpans
diff --git a/tests/test_brightStarCutout.py b/tests/test_brightStarCutout.py
new file mode 100644
index 000000000..88da41440
--- /dev/null
+++ b/tests/test_brightStarCutout.py
@@ -0,0 +1,103 @@
+# This file is part of pipe_tasks.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+
+import lsst.afw.cameraGeom.testUtils
+import lsst.afw.image
+import lsst.utils.tests
+import numpy as np
+from lsst.afw.image import ImageD, ImageF, MaskedImageF
+from lsst.afw.math import FixedKernel
+from lsst.geom import Point2I
+from lsst.meas.algorithms import KernelPsf
+from lsst.pipe.tasks.brightStarSubtraction import BrightStarCutoutConfig, BrightStarCutoutTask
+
+
+class BrightStarCutoutTestCase(lsst.utils.tests.TestCase):
+ def setUp(self):
+ # Fit values
+ self.scale = 2.34e5
+ self.pedestal = 3210.1
+ self.xGradient = 5.432
+ self.yGradient = 10.987
+
+ # Create a pedestal + 2D plane
+ xCoords = np.linspace(-50, 50, 101)
+ yCoords = np.linspace(-50, 50, 101)
+ xPlane, yPlane = np.meshgrid(xCoords, yCoords)
+ pedestal = np.ones_like(xPlane) * self.pedestal
+
+ # Create a pseudo-PSF
+ dist_from_center = np.sqrt(xPlane**2 + yPlane**2)
+ psfArray = np.exp(-dist_from_center / 5)
+ psfArray /= np.sum(psfArray)
+ fixedKernel = FixedKernel(ImageD(psfArray))
+ psf = KernelPsf(fixedKernel)
+ self.psf = psf.computeKernelImage(psf.getAveragePosition())
+
+ # Bring everything together to construct a stamp masked image
+ stampArray = psfArray * self.scale + pedestal + xPlane * self.xGradient + yPlane * self.yGradient
+ stampIm = ImageF((stampArray).astype(np.float32))
+ stampVa = ImageF(stampIm.getBBox(), 654.321)
+ self.stampMI = MaskedImageF(image=stampIm, variance=stampVa)
+ self.stampMI.setXY0(Point2I(-50, -50))
+
+ # Ensure that all mask planes required by the task are in-place;
+ # new mask plane entries will be created as necessary
+ badMaskPlanes = [
+ "BAD",
+ "CR",
+ "CROSSTALK",
+ "EDGE",
+ "NO_DATA",
+ "SAT",
+ "SUSPECT",
+ "UNMASKEDNAN",
+ "NEIGHBOR",
+ ]
+ _ = [self.stampMI.mask.addMaskPlane(mask) for mask in badMaskPlanes]
+
+ def test_fitPsf(self):
+ """Test the PSF fitting method."""
+ brightStarCutoutConfig = BrightStarCutoutConfig()
+ brightStarCutoutTask = BrightStarCutoutTask(config=brightStarCutoutConfig)
+ fitPsfResults = brightStarCutoutTask._fitPsf(
+ self.stampMI,
+ self.psf,
+ )
+ assert abs(fitPsfResults["scale"] - self.scale) / self.scale < 1e-6
+ assert abs(fitPsfResults["pedestal"] - self.pedestal) / self.pedestal < 1e-6
+ assert abs(fitPsfResults["xGradient"] - self.xGradient) / self.xGradient < 1e-6
+ assert abs(fitPsfResults["yGradient"] - self.yGradient) / self.yGradient < 1e-6
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+class MemoryTestCase(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()