From be55967840cb8199b8677dd18767cdf1daba7497 Mon Sep 17 00:00:00 2001 From: Michael Langmayr Date: Thu, 23 Oct 2025 16:37:24 -0700 Subject: [PATCH] fix structure, add b-spline fitting to normalize flat function --- .gitignore | 4 +- config/config.yaml | 5 +- src/core/flat.py | 186 ++++++++++++++++-- src/core/flows.py | 53 ----- src/core/qa.py | 2 - src/core/tracing.py | 4 +- src/keck_primitives/create_master_flat.py | 58 ++++++ .../{normalize_flat.py => load_flat.py} | 13 +- src/keck_primitives/qa_plot.py | 17 ++ src/keck_primitives/save_corrected.py | 18 ++ src/keck_primitives/save_trace.py | 16 ++ src/keck_primitives/trace_slits.py | 9 +- src/main.py | 45 ++++- src/workflows/flows/batch_flat_flow.py | 41 +++- .../prefect_tasks/create_correction.py | 6 - .../prefect_tasks/create_master_flat.py | 45 +++++ src/workflows/prefect_tasks/load_flat.py | 14 +- src/workflows/prefect_tasks/normalize_flat.py | 16 -- src/workflows/prefect_tasks/qa_plot.py | 16 +- src/workflows/prefect_tasks/save_corrected.py | 17 +- src/workflows/prefect_tasks/save_trace.py | 15 +- 21 files changed, 477 insertions(+), 123 deletions(-) delete mode 100644 src/core/flows.py create mode 100644 src/keck_primitives/create_master_flat.py rename src/keck_primitives/{normalize_flat.py => load_flat.py} (50%) create mode 100644 src/keck_primitives/qa_plot.py create mode 100644 src/keck_primitives/save_corrected.py create mode 100644 src/keck_primitives/save_trace.py delete mode 100644 src/workflows/prefect_tasks/create_correction.py create mode 100644 src/workflows/prefect_tasks/create_master_flat.py delete mode 100644 src/workflows/prefect_tasks/normalize_flat.py diff --git a/.gitignore b/.gitignore index 7d20e5d..8ebe52b 100644 --- a/.gitignore +++ b/.gitignore @@ -174,4 +174,6 @@ cython_debug/ # Output files output/* -.python-version \ No newline at end of file +.python-version + +.DS_Store \ No newline at end of file diff --git a/config/config.yaml b/config/config.yaml index 77d6629..17f7b35 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,3 +1,6 @@ # config.yaml input_dir: data/lris2_flats -output_dir: output \ No newline at end of file +output_dir: output + +# Set to true to start Prefect UI (at http://127.0.0.1:4200) +use_prefect_server: true diff --git a/src/core/flat.py b/src/core/flat.py index 3055d01..1c88f19 100644 --- a/src/core/flat.py +++ b/src/core/flat.py @@ -1,38 +1,51 @@ import os -from prefect import task import numpy as np from astropy.io import fits -from typing import Tuple +from typing import Tuple, List, Optional +from scipy.interpolate import splrep, BSpline +from .tracing import trace_slits_1d + -@task(name="Load FITS Frame", description="Load a flat field FITS file", tags=["load"]) def load_flat_frame(filepath: str) -> Tuple[np.ndarray, dict]: """Load a FITS file and return its data and header.""" with fits.open(filepath) as hdul: - data = hdul[0].data + data = hdul[0].data.astype(np.float64) header = hdul[0].header return data, header -@task(name="Normalize Flat", description="Normalize the flat field by median", tags=["normalize"]) def normalize_flat(data: np.ndarray) -> np.ndarray: - """Normalize the flat field data by dividing by the median value.""" - median = np.median(data[data > 0]) + """ + Normalize the flat field data by dividing by the median of the illuminated area. + """ + data = data.astype(np.float64) + + # Define illuminated area as pixels above 10% of maximum value + threshold = 0.1 * np.max(data) + illuminated_mask = data > threshold + + # Compute median from illuminated area only + if np.any(illuminated_mask): + median = np.median(data[illuminated_mask]) + else: + # Fallback: use median of all positive values + median = np.median(data[data > 0]) + + # Normalize - result should be ~1.0 in illuminated areas return data / median -@task(name="Create Flat Correction", description="Invert normalized flat to create correction map", tags=["correction"]) def create_flat_correction(norm_data: np.ndarray) -> np.ndarray: """Create a flat correction map by inverting the normalized flat.""" - correction = 1.0 / (norm_data + 1e-8) # Avoid division by zero + correction = 1.0 / (norm_data + 1e-8) # Avoid division by zero correction[np.isnan(correction)] = 1.0 correction[np.isinf(correction)] = 1.0 return correction -@task(name="Save Corrected FITS", description="Apply flat correction and write corrected FITS file", tags=["output", "fits"]) def save_corrected_fits(original_data: np.ndarray, correction: np.ndarray, header: dict, output_path: str) -> str: """Apply the flat correction to the original data and save as a new FITS file.""" - corrected_data = original_data * correction + corrected_data = original_data.astype(np.float64) * correction # Add DRP history to header header.add_history("DRP: Flat field correction applied") @@ -43,3 +56,154 @@ def save_corrected_fits(original_data: np.ndarray, correction: np.ndarray, heade os.makedirs(os.path.dirname(output_path), exist_ok=True) hdul.writeto(output_path, overwrite=True) return output_path + + +def fit_bspline_1d(x: np.ndarray, y: np.ndarray, n_knots: int = 100, k: int = 3) -> Tuple[np.ndarray, np.ndarray]: + """ + Fit a B-spline to 1D data and return the fit. + + Simplified version of Bspline.iterfit + + Args: + x: x-coordinates (must be sorted) + y: y-values + n_knots: Number of internal knots for the spline + k: Spline order (3 = cubic) + + Returns: + Tuple of (fitted_y, knots_used) + """ + # Remove NaN and inf values + valid = np.isfinite(x) & np.isfinite(y) + x_valid = x[valid] + y_valid = y[valid] + + if len(x_valid) < k + 1: + # Not enough points for spline, return mean + return np.full_like(y, np.mean(y_valid)), np.array([]) + + # Create knots evenly spaced across the data range + x_min, x_max = np.min(x_valid), np.max(x_valid) + # Adjust n_knots if we don't have enough data points + n_knots = min(n_knots, len(x_valid) // (k + 1)) + if n_knots < 1: + n_knots = 1 + + knots = np.linspace(x_min, x_max, n_knots + 2)[1:-1] + + # Fit the spline + try: + tck = splrep(x_valid, y_valid, t=knots, k=k, s=0) + fitted = BSpline(*tck)(x) + except Exception: + # Fallback to polynomial fit if spline fails + coeffs = np.polyfit(x_valid, y_valid, min(3, len(x_valid) - 1)) + fitted = np.polyval(coeffs, x) + knots = np.array([]) + + return fitted, knots + + +def normalize_flat_spectroscopic( + data: np.ndarray, + slit_positions: Optional[List[int]] = None, + slit_width: int = 50, + n_knots_spectral: int = 100, + low_signal_threshold: float = 30.0, + edge_trim_pixels: int = 5 +) -> np.ndarray: + """ + Spectroscopic flat normalization using B-spline fitting. + (simplified from KCWI) + + 1. Identifies slit traces + 2. For each slit, fits a smooth B-spline along the spectral direction + 3. Creates a smooth model of the illumination pattern + 4. Normalizes the flat to create a ratio map + + Args: + data: 2D flat field image (spatial x spectral) + slit_positions: Optional list of slit center positions. If None, auto-detect. + slit_width: Width around each slit center to extract (pixels) + n_knots_spectral: Number of knots for B-spline fit along spectral direction + low_signal_threshold: Pixels below this value get no correction + edge_trim_pixels: Number of pixels to trim from edges of each slit + + Returns: + Normalized flat field (ratio map for correction) + """ + data = data.astype(np.float64) + ny, nx = data.shape + + # Auto-detect slits if not provided + if slit_positions is None: + slit_positions = trace_slits_1d(data) + + if len(slit_positions) == 0: + # Fallback to simple normalization if no slits found + print("Warning: No slits detected, using simple normalization") + return normalize_flat(data) + + print(f"Processing {len(slit_positions)} slits") + + # Create smooth model of the flat field + flat_model = np.zeros_like(data) + + # Process each slit + for slit_idx, slit_center in enumerate(slit_positions): + # Define slit region + y_start = max(0, slit_center - slit_width // 2) + y_end = min(ny, slit_center + slit_width // 2) + + # Extract slit region + slit_data = data[y_start:y_end, :] + + # For each column (spectral pixel), get median across slit + spectral_profile = np.median(slit_data, axis=0) + + # Fit B-spline along spectral direction + x_coords = np.arange(nx) + spectral_fit, _ = fit_bspline_1d(x_coords, spectral_profile, n_knots=n_knots_spectral) + + # Replicate the fit across the slit width + for i in range(y_start, y_end): + # Apply edge trimming + if i < y_start + edge_trim_pixels or i >= y_end - edge_trim_pixels: + flat_model[i, :] = 0.0 # Will be masked later + else: + flat_model[i, :] = spectral_fit + + # Create ratio map + ratio = np.ones_like(data) + + # Only correct where we have valid model and good signal + valid = (flat_model > 0) & (data > low_signal_threshold) + ratio[valid] = flat_model[valid] / data[valid] + + # Trim unreasonable values + ratio[ratio < 0] = 1.0 # Negative ratios (line 906-908) + ratio[ratio > 3.0] = 1.0 # High ratios at edges (line 911-915) + ratio[~np.isfinite(ratio)] = 1.0 # NaN/inf values + + # Low signal regions get no correction + ratio[data < low_signal_threshold] = 1.0 + + return ratio + + +def create_master_flat(flat_data: np.ndarray, **kwargs) -> np.ndarray: + """ + Create a master flat correction map. + + Args: + flat_data: 2D flat field image + **kwargs: Additional arguments + + Returns: + Master flat correction map (multiply science data by this) + """ + ratio = normalize_flat_spectroscopic(flat_data, **kwargs) + # The ratio is already the correction map + correction = ratio + + return correction diff --git a/src/core/flows.py b/src/core/flows.py deleted file mode 100644 index 9e07c7a..0000000 --- a/src/core/flows.py +++ /dev/null @@ -1,53 +0,0 @@ -from prefect import flow, task, get_run_logger -from prefect.task_runners import ConcurrentTaskRunner -import os -from core.flat import ( - load_flat_frame, - normalize_flat, - create_flat_correction, - save_corrected_fits, -) -from core.tracing import trace_slits_1d, save_trace_solution -from core.qa import generate_qa_plot - -@task(name="Process Single LRIS2 Flat Frame", description="Run all DRP steps on a single flat FITS file") -def process_flat_frame(filepath: str, output_dir: str): - """Process a single LRIS2 flat field FITS file through all DRP steps.""" - filename = os.path.splitext(os.path.basename(filepath))[0] - - data, header = load_flat_frame(filepath) - norm = normalize_flat(data) - correction = create_flat_correction(norm) - slit_positions = trace_slits_1d(data) - - trace_path = os.path.join(output_dir, filename, "slit_trace.txt") - plot_path = os.path.join(output_dir, filename, "flat_norm_qa.png") - corrected_path = os.path.join(output_dir, filename, "flat_corrected.fits") - - save_trace_solution(slit_positions, trace_path) - generate_qa_plot(norm, plot_path, title=f"Normalized Flat: {filename}") - save_corrected_fits(data, correction, header, corrected_path) - - -@flow( - name="Batch Process LRIS2 Flats", - description="Batch process all LRIS2 flats in parallel", - task_runner=ConcurrentTaskRunner(max_workers=2), # Adjust concurrency here -) -def batch_process_all_flats(input_dir: str, output_dir: str): - """Batch process all LRIS2 flat field FITS files in the input directory.""" - logger = get_run_logger() - - fits_files = [ - os.path.join(input_dir, f) - for f in os.listdir(input_dir) - if f.lower().endswith(".fits") - ] - - logger.info(f"Found {len(fits_files)} FITS files.") - - futures = [process_flat_frame.submit(fp, output_dir) for fp in fits_files] - - # Wait for all to complete - for future in futures: - future.result() diff --git a/src/core/qa.py b/src/core/qa.py index dfa0f26..7b3d638 100644 --- a/src/core/qa.py +++ b/src/core/qa.py @@ -1,4 +1,3 @@ -from prefect import task import os import matplotlib matplotlib.use('Agg') # Use non-interactive backend for saving plots @@ -6,7 +5,6 @@ import numpy as np -@task(name="Generate QA Plot", description="Save normalized flat as PNG", tags=["qa", "plot"]) def generate_qa_plot(data: np.ndarray, output_path: str, title: str = "Flat QA") -> str: """Generate a QA plot for the normalized flat field data.""" diff --git a/src/core/tracing.py b/src/core/tracing.py index 5829297..fe15f5f 100644 --- a/src/core/tracing.py +++ b/src/core/tracing.py @@ -1,10 +1,9 @@ -from prefect import task from typing import List import os import numpy as np from scipy.signal import find_peaks -@task(name="Trace Slits", description="Find slit center peaks by collapsing along dispersion axis", tags=["trace"]) + def trace_slits_1d(data: np.ndarray) -> List[int]: """Trace slit positions by finding peaks in the 1D profile of the flat field data.""" profile = np.median(data, axis=1) @@ -12,7 +11,6 @@ def trace_slits_1d(data: np.ndarray) -> List[int]: return list(peaks) -@task(name="Save Trace Solution", description="Write slit positions to file", tags=["save", "trace"]) def save_trace_solution(slit_positions: List[int], output_path: str) -> str: """Save the traced slit positions to a text file.""" os.makedirs(os.path.dirname(output_path), exist_ok=True) diff --git a/src/keck_primitives/create_master_flat.py b/src/keck_primitives/create_master_flat.py new file mode 100644 index 0000000..a7fb8cf --- /dev/null +++ b/src/keck_primitives/create_master_flat.py @@ -0,0 +1,58 @@ +from keckdrpframework.primitives.base_img import BaseImg +from core.flat import create_master_flat + + +class CreateMasterFlat(BaseImg): + """ + Create a master flat correction + + This primitive wraps the create_master_flat function from core.flat. + """ + + def __init__(self, action, context): + BaseImg.__init__(self, action, context) + self.logger = context.pipeline_logger if hasattr(context, 'pipeline_logger') else None + + def _perform(self, args, config=None): + """ + Create master flat correction. + + Args: + args: Arguments object with: + - flat_data: 2D numpy array of flat field data + - slit_positions: Optional list of slit positions (for spectroscopic method) + - slit_width: Width around slit centers (default: 50) + - n_knots_spectral: Number of B-spline knots (default: 100) + - low_signal_threshold: Threshold for masking low signal (default: 30.0) + - edge_trim_pixels: Pixels to trim from slit edges (default: 5) + + Returns: + Arguments object with: + - correction: Master flat correction array + """ + if config is None: + config = {} + + flat_data = args["flat_data"] + method = args["method"] if "method" in args else "spectroscopic" + + # Get optional parameters for spectroscopic method + kwargs = {} + if method == "spectroscopic": + if "slit_positions" in args: + kwargs["slit_positions"] = args["slit_positions"] + kwargs["slit_width"] = args["slit_width"] if "slit_width" in args else 50 + kwargs["n_knots_spectral"] = args["n_knots_spectral"] if "n_knots_spectral" in args else 100 + kwargs["low_signal_threshold"] = args["low_signal_threshold"] if "low_signal_threshold" in args else 30.0 + kwargs["edge_trim_pixels"] = args["edge_trim_pixels"] if "edge_trim_pixels" in args else 5 + + if self.logger: + self.logger.info(f"Creating master flat using {method} method") + + # Create the master flat correction + correction = create_master_flat(flat_data, **kwargs) + + if self.logger: + self.logger.info(f"Master flat correction created successfully") + + return {"correction": correction} diff --git a/src/keck_primitives/normalize_flat.py b/src/keck_primitives/load_flat.py similarity index 50% rename from src/keck_primitives/normalize_flat.py rename to src/keck_primitives/load_flat.py index b702946..d1bc566 100644 --- a/src/keck_primitives/normalize_flat.py +++ b/src/keck_primitives/load_flat.py @@ -1,13 +1,14 @@ -import numpy as np from keckdrpframework.models.arguments import Arguments from keckdrpframework.primitives.base_primitive import BasePrimitive +from core.flat import load_flat_frame -class NormalizeFlat(BasePrimitive): + +class LoadFlat(BasePrimitive): def __init__(self, action, context): super().__init__(action, context) def _perform(self, input_args: Arguments, config: dict) -> dict: - data = input_args["flat_data"] - median = np.median(data[data > 0]) - norm = data / median - return {"norm_data": norm} + """Load a FITS file and return its data and header.""" + filepath = input_args["filepath"] + data, header = load_flat_frame(filepath) + return {"flat_data": data, "header": header} diff --git a/src/keck_primitives/qa_plot.py b/src/keck_primitives/qa_plot.py new file mode 100644 index 0000000..bc16027 --- /dev/null +++ b/src/keck_primitives/qa_plot.py @@ -0,0 +1,17 @@ +from keckdrpframework.models.arguments import Arguments +from keckdrpframework.primitives.base_primitive import BasePrimitive +from core.qa import generate_qa_plot + + +class GenerateQAPlot(BasePrimitive): + def __init__(self, action, context): + super().__init__(action, context) + + def _perform(self, input_args: Arguments, config: dict) -> dict: + """Generate a QA plot for the normalized flat field data.""" + data = input_args["data"] + output_path = input_args["output_path"] + title = input_args["title"] if "title" in input_args else "Flat QA" + + result_path = generate_qa_plot(data, output_path, title) + return {"output_path": result_path} diff --git a/src/keck_primitives/save_corrected.py b/src/keck_primitives/save_corrected.py new file mode 100644 index 0000000..66da9e0 --- /dev/null +++ b/src/keck_primitives/save_corrected.py @@ -0,0 +1,18 @@ +from keckdrpframework.models.arguments import Arguments +from keckdrpframework.primitives.base_primitive import BasePrimitive +from core.flat import save_corrected_fits + + +class SaveCorrectedFits(BasePrimitive): + def __init__(self, action, context): + super().__init__(action, context) + + def _perform(self, input_args: Arguments, config: dict) -> dict: + """Apply the flat correction to the original data and save as a new FITS file.""" + original_data = input_args["original_data"] + correction = input_args["correction"] + header = input_args["header"] + output_path = input_args["output_path"] + + result_path = save_corrected_fits(original_data, correction, header, output_path) + return {"output_path": result_path} diff --git a/src/keck_primitives/save_trace.py b/src/keck_primitives/save_trace.py new file mode 100644 index 0000000..b1760ce --- /dev/null +++ b/src/keck_primitives/save_trace.py @@ -0,0 +1,16 @@ +from keckdrpframework.models.arguments import Arguments +from keckdrpframework.primitives.base_primitive import BasePrimitive +from core.tracing import save_trace_solution + + +class SaveTraceSolution(BasePrimitive): + def __init__(self, action, context): + super().__init__(action, context) + + def _perform(self, input_args: Arguments, config: dict) -> dict: + """Save the traced slit positions to a text file.""" + slit_positions = input_args["slit_positions"] + output_path = input_args["output_path"] + + result_path = save_trace_solution(slit_positions, output_path) + return {"output_path": result_path} diff --git a/src/keck_primitives/trace_slits.py b/src/keck_primitives/trace_slits.py index 195d7cc..3b58469 100644 --- a/src/keck_primitives/trace_slits.py +++ b/src/keck_primitives/trace_slits.py @@ -1,7 +1,7 @@ -import numpy as np -from scipy.signal import find_peaks from keckdrpframework.models.arguments import Arguments from keckdrpframework.primitives.base_primitive import BasePrimitive +from core.tracing import trace_slits_1d + class TraceSlits1D(BasePrimitive): def __init__(self, action, context): @@ -9,6 +9,5 @@ def __init__(self, action, context): def _perform(self, input_args: Arguments, config: dict) -> dict: data = input_args["flat_data"] - profile = np.median(data, axis=1) - peaks, _ = find_peaks(profile, distance=20, prominence=0.05) - return {"slit_positions": list(peaks)} + slit_positions = trace_slits_1d(data) + return {"slit_positions": slit_positions} diff --git a/src/main.py b/src/main.py index ee46364..ad7bc5f 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,8 @@ """ import yaml import os +import subprocess +import time from workflows.flows.batch_flat_flow import batch_process_all_flats def load_config(config_path="config/config.yaml"): @@ -16,8 +18,47 @@ def load_config(config_path="config/config.yaml"): config = load_config() input_dir = config["input_dir"] output_dir = config["output_dir"] + use_prefect_server = config.get("use_prefect_server", True) os.makedirs(output_dir, exist_ok=True) - print(f"🟢 Starting batch processing of FITS files in {input_dir}") - batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + server_process = None + if use_prefect_server: + print("šŸš€ Starting Prefect server with UI...") + # Start Prefect server in background + server_process = subprocess.Popen( + ["prefect", "server", "start"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + print("āœ… Prefect server started!") + print("🌐 Dashboard available at: http://127.0.0.1:4200") + print() + + try: + print(f"🟢 Starting batch processing of FITS files in {input_dir}") + batch_process_all_flats(input_dir=input_dir, output_dir=output_dir) + + if use_prefect_server: + print("\nāœ… Pipeline completed!") + print("🌐 View the dashboard at: http://127.0.0.1:4200") + print("šŸ“Š Press Ctrl+C to stop the server and exit.\n") + + try: + # Keep running so server stays up + import signal + signal.pause() + + except KeyboardInterrupt: + print("\nšŸ›‘ Received interrupt signal...") + raise + except KeyboardInterrupt: + print("\nšŸ›‘ Shutting down...") + finally: + if server_process: + print("Stopping Prefect server...") + server_process.terminate() + server_process.wait() + print("Server stopped.") diff --git a/src/workflows/flows/batch_flat_flow.py b/src/workflows/flows/batch_flat_flow.py index 2703fd9..e75aebe 100644 --- a/src/workflows/flows/batch_flat_flow.py +++ b/src/workflows/flows/batch_flat_flow.py @@ -2,8 +2,7 @@ from prefect import flow, task, get_run_logger from prefect.task_runners import ConcurrentTaskRunner from workflows.prefect_tasks.load_flat import load_flat_frame_task -from workflows.prefect_tasks.normalize_flat import normalize_flat_task -from workflows.prefect_tasks.create_correction import create_flat_correction_task +from workflows.prefect_tasks.create_master_flat import create_master_flat_task from workflows.prefect_tasks.trace_slits import trace_slits_task from workflows.prefect_tasks.save_corrected import save_corrected_fits_task from workflows.prefect_tasks.save_trace import save_trace_solution_task @@ -12,7 +11,13 @@ @task(name="Process Single Flat Frame") def process_single_flat_frame(flat_fits_path: str, output_dir: str): - """Process a single LRIS2 flat FITS file through all DRP steps.""" + """ + Process a single LRIS2 flat FITS file through all DRP steps. + + Args: + flat_fits_path: Path to input FITS file + output_dir: Output directory for results + """ logger = get_run_logger() filename = os.path.splitext(os.path.basename(flat_fits_path))[0] @@ -25,27 +30,45 @@ def process_single_flat_frame(flat_fits_path: str, output_dir: str): os.makedirs(os.path.dirname(corrected_output), exist_ok=True) # Load FITS + logger.info(f"Loading {flat_fits_path}") data, header = load_flat_frame_task(flat_fits_path) - # DRP steps - norm = normalize_flat_task(data) - correction = create_flat_correction_task(norm) + # Trace slits + logger.info("Tracing slits") slit_positions = trace_slits_task(data) + # Create master flat + logger.info("Creating master flat correction") + correction = create_master_flat_task( + data, + slit_positions=slit_positions, + slit_width=50, + n_knots_spectral=100, + low_signal_threshold=30.0, + edge_trim_pixels=5 + ) + # Save outputs + logger.info("Saving results") save_corrected_fits_task(data, correction, header, corrected_output) save_trace_solution_task(slit_positions, trace_output) - generate_qa_plot_task(norm, qa_output) + generate_qa_plot_task(correction, qa_output) logger.info(f"Finished processing {flat_fits_path}") @flow( name="Batch Process LRIS2 Flats", - description="Process all flat frames concurrently using Prefect", + description="Process all flat frames using spectroscopic flat fielding", task_runner=ConcurrentTaskRunner(max_workers=2), # You can adjust this ) def batch_process_all_flats(input_dir: str, output_dir: str): - """Process all FITS files in a directory using concurrent subflows.""" + """ + Process all FITS files in a directory using spectroscopic flat fielding. + + Args: + input_dir: Directory containing input FITS files + output_dir: Directory for output files + """ logger = get_run_logger() fits_files = [ diff --git a/src/workflows/prefect_tasks/create_correction.py b/src/workflows/prefect_tasks/create_correction.py deleted file mode 100644 index 349e050..0000000 --- a/src/workflows/prefect_tasks/create_correction.py +++ /dev/null @@ -1,6 +0,0 @@ -from prefect import task -from core.flat import create_flat_correction - -@task(name="Create Flat Correction") -def create_flat_correction_task(norm_data): - return create_flat_correction.fn(norm_data) \ No newline at end of file diff --git a/src/workflows/prefect_tasks/create_master_flat.py b/src/workflows/prefect_tasks/create_master_flat.py new file mode 100644 index 0000000..75d542b --- /dev/null +++ b/src/workflows/prefect_tasks/create_master_flat.py @@ -0,0 +1,45 @@ +from prefect import task +from keckdrpframework.models.arguments import Arguments +from keck_primitives.create_master_flat import CreateMasterFlat +from keck_primitives.utils import DummyAction, DummyContext + + +@task(name="Create Master Flat") +def create_master_flat_task( + flat_data, + slit_positions=None, + slit_width=50, + n_knots_spectral=100, + low_signal_threshold=30.0, + edge_trim_pixels=5 +): + """ + Create master flat correction using spectroscopic method. + + Args: + flat_data: 2D numpy array of flat field data + slit_positions: Optional list of slit positions (None = auto-detect) + slit_width: Width around slit centers in pixels + n_knots_spectral: Number of B-spline knots for spectral fitting + low_signal_threshold: Pixels below this value get no correction + edge_trim_pixels: Number of pixels to trim from slit edges + + Returns: + Master flat correction array (multiply science data by this) + """ + args = Arguments() + args["flat_data"] = flat_data + args["method"] = "spectroscopic" + + if slit_positions is not None: + args["slit_positions"] = slit_positions + args["slit_width"] = slit_width + args["n_knots_spectral"] = n_knots_spectral + args["low_signal_threshold"] = low_signal_threshold + args["edge_trim_pixels"] = edge_trim_pixels + + action = DummyAction(args=args) + context = DummyContext() + + result = CreateMasterFlat(action, context)._perform(args, config={}) + return result["correction"] diff --git a/src/workflows/prefect_tasks/load_flat.py b/src/workflows/prefect_tasks/load_flat.py index 62089c8..4bd7764 100644 --- a/src/workflows/prefect_tasks/load_flat.py +++ b/src/workflows/prefect_tasks/load_flat.py @@ -1,6 +1,16 @@ from prefect import task -from core.flat import load_flat_frame +from keckdrpframework.models.arguments import Arguments +from keck_primitives.load_flat import LoadFlat +from keck_primitives.utils import DummyAction, DummyContext + @task(name="Load Flat Frame") def load_flat_frame_task(filepath: str): - return load_flat_frame.fn(filepath) \ No newline at end of file + args = Arguments() + args["filepath"] = filepath + + action = DummyAction(args=args) + context = DummyContext() + + result = LoadFlat(action, context)._perform(args, config={}) + return result["flat_data"], result["header"] \ No newline at end of file diff --git a/src/workflows/prefect_tasks/normalize_flat.py b/src/workflows/prefect_tasks/normalize_flat.py deleted file mode 100644 index 6599645..0000000 --- a/src/workflows/prefect_tasks/normalize_flat.py +++ /dev/null @@ -1,16 +0,0 @@ -from prefect import task -from keckdrpframework.models.arguments import Arguments -from keck_primitives.normalize_flat import NormalizeFlat -from keck_primitives.utils import DummyAction, DummyContext - - -@task(name="Normalize Flat") -def normalize_flat_task(data): - args = Arguments() - args["flat_data"] = data - - action = DummyAction(args=args) - context = DummyContext() - - result = NormalizeFlat(action, context)._perform(args, config={}) - return result["norm_data"] diff --git a/src/workflows/prefect_tasks/qa_plot.py b/src/workflows/prefect_tasks/qa_plot.py index fc23cad..8a94f69 100644 --- a/src/workflows/prefect_tasks/qa_plot.py +++ b/src/workflows/prefect_tasks/qa_plot.py @@ -1,6 +1,18 @@ from prefect import task -from core.qa import generate_qa_plot +from keckdrpframework.models.arguments import Arguments +from keck_primitives.qa_plot import GenerateQAPlot +from keck_primitives.utils import DummyAction, DummyContext + @task(name="Generate QA Plot") def generate_qa_plot_task(data, output_path: str, title: str = "Flat QA"): - return generate_qa_plot.fn(data, output_path, title) \ No newline at end of file + args = Arguments() + args["data"] = data + args["output_path"] = output_path + args["title"] = title + + action = DummyAction(args=args) + context = DummyContext() + + result = GenerateQAPlot(action, context)._perform(args, config={}) + return result["output_path"] \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_corrected.py b/src/workflows/prefect_tasks/save_corrected.py index 61d1cb3..4f12796 100644 --- a/src/workflows/prefect_tasks/save_corrected.py +++ b/src/workflows/prefect_tasks/save_corrected.py @@ -1,6 +1,19 @@ from prefect import task -from core.flat import save_corrected_fits +from keckdrpframework.models.arguments import Arguments +from keck_primitives.save_corrected import SaveCorrectedFits +from keck_primitives.utils import DummyAction, DummyContext + @task(name="Save Corrected FITS") def save_corrected_fits_task(original_data, correction, header, output_path: str): - return save_corrected_fits.fn(original_data, correction, header, output_path) \ No newline at end of file + args = Arguments() + args["original_data"] = original_data + args["correction"] = correction + args["header"] = header + args["output_path"] = output_path + + action = DummyAction(args=args) + context = DummyContext() + + result = SaveCorrectedFits(action, context)._perform(args, config={}) + return result["output_path"] \ No newline at end of file diff --git a/src/workflows/prefect_tasks/save_trace.py b/src/workflows/prefect_tasks/save_trace.py index d78b3bc..c32a8d2 100644 --- a/src/workflows/prefect_tasks/save_trace.py +++ b/src/workflows/prefect_tasks/save_trace.py @@ -1,6 +1,17 @@ from prefect import task -from core.tracing import save_trace_solution +from keckdrpframework.models.arguments import Arguments +from keck_primitives.save_trace import SaveTraceSolution +from keck_primitives.utils import DummyAction, DummyContext + @task(name="Save Trace Solution") def save_trace_solution_task(slit_positions, output_path: str): - return save_trace_solution.fn(slit_positions, output_path) \ No newline at end of file + args = Arguments() + args["slit_positions"] = slit_positions + args["output_path"] = output_path + + action = DummyAction(args=args) + context = DummyContext() + + result = SaveTraceSolution(action, context)._perform(args, config={}) + return result["output_path"] \ No newline at end of file