diff --git a/banzai/cache/replication.py b/banzai/cache/replication.py index 0961c04b..6bfc74d1 100644 --- a/banzai/cache/replication.py +++ b/banzai/cache/replication.py @@ -1,5 +1,4 @@ from sqlalchemy import text, create_engine -from banzai import dbs from banzai.logs import get_logger # PostgreSQL logical replication is managed via server-level DDL commands diff --git a/banzai/dbs.py b/banzai/dbs.py index 230d6f81..ac68e51a 100755 --- a/banzai/dbs.py +++ b/banzai/dbs.py @@ -12,8 +12,9 @@ from dateutil.parser import parse import requests from sqlalchemy import create_engine, pool, func, make_url +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker -from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, CHAR, JSON, UniqueConstraint, Float, Text +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, CHAR, JSON, UniqueConstraint, Float from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql.expression import true from contextlib import contextmanager @@ -116,6 +117,24 @@ class ProcessedImage(Base): tries = Column(Integer, default=0) +class StackFrame(Base): + __tablename__ = 'stack_frames' + id = Column(Integer, primary_key=True, autoincrement=True) + moluid = Column(String(100), nullable=False, index=True) + stack_num = Column(Integer, nullable=False) + frmtotal = Column(Integer, nullable=False) + camera = Column(String(50), nullable=False, index=True) + filepath = Column(String(255), nullable=True) + is_last = Column(Boolean, default=False) + status = Column(String(20), default='active', nullable=False) + dateobs = Column(DateTime, nullable=True) + created_at = Column(DateTime, default=datetime.datetime.utcnow) + completed_at = Column(DateTime, nullable=True) + __table_args__ = ( + UniqueConstraint('moluid', 'stack_num', name='uq_stack_moluid_num'), + ) + + def parse_configdb(configdb_address): """ Parse the contents of the configdb. @@ -580,3 +599,56 @@ def replicate_instrument(instrument_record, db_address): add_or_update_record(db_session, Instrument, equivalence_criteria, record_attributes) db_session.commit() + + +def insert_stack_frame(db_address, moluid, stack_num, frmtotal, camera, filepath, is_last, dateobs): + """Insert a stack frame record into the database. Duplicate (moluid, stack_num) is a no-op.""" + try: + with get_session(db_address) as session: + session.add(StackFrame( + moluid=moluid, + stack_num=stack_num, + frmtotal=frmtotal, + camera=camera, + filepath=filepath, + is_last=is_last, + dateobs=dateobs, + )) + except IntegrityError: + pass + + +def get_stack_frames(db_address, moluid): + """Get all stack frame records for a given moluid.""" + with get_session(db_address) as session: + return session.query(StackFrame).filter( + StackFrame.moluid == moluid + ).all() + + +def mark_stack_complete(db_address, moluid, status='complete'): + """Mark all frames for a moluid as complete (or timeout).""" + now = datetime.datetime.utcnow() + with get_session(db_address) as session: + session.query(StackFrame).filter( + StackFrame.moluid == moluid + ).update({'status': status, 'completed_at': now}) + + +def update_stack_frame_filepath(db_address, moluid, stack_num, filepath): + """Set the reduced filepath on an existing stack frame record.""" + with get_session(db_address) as session: + session.query(StackFrame).filter( + StackFrame.moluid == moluid, + StackFrame.stack_num == stack_num, + ).update({'filepath': filepath}) + + +def cleanup_old_records(db_address, retention_days): + """Delete completed stack frame records older than retention_days.""" + cutoff = datetime.datetime.utcnow() - datetime.timedelta(days=retention_days) + with get_session(db_address) as session: + session.query(StackFrame).filter( + StackFrame.status != 'active', + StackFrame.completed_at < cutoff, + ).delete() diff --git a/banzai/main.py b/banzai/main.py index 0b91db6a..edd6e918 100755 --- a/banzai/main.py +++ b/banzai/main.py @@ -22,7 +22,8 @@ from banzai import settings, dbs, logs, calibrations from banzai.context import Context from banzai.utils import date_utils, stage_utils, import_utils, image_utils, fits_utils, file_utils -from banzai.scheduling import process_image, app, schedule_calibration_stacking +from banzai.scheduling import process_image, process_subframe, app, schedule_calibration_stacking +from banzai.stacking import validate_message from banzai.data import DataProduct from celery.schedules import crontab import celery @@ -228,6 +229,56 @@ def start_listener(runtime_context): logger.info('Shutting down pipeline listener.') +class SubframeListener(ConsumerMixin): + def __init__(self, runtime_context): + self.runtime_context = runtime_context + + def on_connection_error(self, exc, interval): + logger.error('{0}. Retrying connection in {1} seconds...'.format(exc, interval)) + self.connection = self.connection.clone() + self.connection.ensure_connection(max_retries=10) + + def get_consumers(self, Consumer, channel): + """Bind to banzai_stack_queue.""" + queue = Queue(self.runtime_context.STACK_QUEUE_NAME) + consumer = Consumer(queues=[queue], callbacks=[self.on_message]) + consumer.qos(prefetch_count=1) + return [consumer] + + def on_message(self, body, message): + """Validate and dispatch to Celery for processing.""" + if not validate_message(body): + logger.error('Invalid message received, missing required fields') + message.ack() + return + + process_subframe.apply_async( + args=(body, vars(self.runtime_context)), + queue=self.runtime_context.SUBFRAME_TASK_QUEUE_NAME, + ) + message.ack() + + +def run_subframe_worker(): + """Entry point for the subframe listener.""" + runtime_context = parse_args(settings) + + logging.getLogger('amqp').setLevel(logging.WARNING) + logger.info('Starting subframe listener') + + listener = SubframeListener(runtime_context) + + with Connection(runtime_context.broker_url) as connection: + listener.connection = connection.clone() + try: + listener.run() + except listener.connection.connection_errors: + listener.connection = connection.clone() + listener.ensure_connection(max_retries=10) + except KeyboardInterrupt: + logger.info('Shutting down subframe listener.') + + def mark_frame(mark_as): parser = argparse.ArgumentParser(description="Set the is_bad flag to mark the frame as {mark_as}" "for a calibration frame in the database ".format(mark_as=mark_as)) diff --git a/banzai/scheduling.py b/banzai/scheduling.py index b37d6701..464d6856 100644 --- a/banzai/scheduling.py +++ b/banzai/scheduling.py @@ -2,16 +2,19 @@ from datetime import datetime, timedelta, timezone from dateutil.parser import parse +import redis from celery import Celery from kombu import Queue from celery.exceptions import Retry from banzai import dbs, calibrations, logs -from banzai.utils import date_utils, realtime_utils, stage_utils +from banzai.utils import date_utils, fits_utils, realtime_utils, stage_utils from banzai.metrics import add_telemetry_span_attribute, add_telemetry_span_event from celery.signals import worker_process_init from banzai.context import Context from banzai.utils.observation_utils import filter_calibration_blocks_for_type, get_calibration_blocks_for_time_range -from banzai.utils.date_utils import get_stacking_date_range +from banzai.utils.date_utils import get_stacking_date_range, parse_date_obs +from banzai.dbs import insert_stack_frame, update_stack_frame_filepath +from banzai.stacking import push_notification try: from opentelemetry.instrumentation.celery import CeleryInstrumentor OPENTELEMETRY_AVAILABLE = True @@ -247,3 +250,56 @@ def process_image(self, file_info: dict, runtime_context: dict): except Exception: logger.error("Exception processing frame: {error}".format(error=logs.format_exception()), extra_tags={'file_info': file_info}) + + +@app.task(name='celery.process_subframe', bind=True, reject_on_worker_lost=True, max_retries=5) +def process_subframe(self, body: dict, runtime_context: dict): + """Reduce a subframe, record it in the DB, and notify the stacking worker.""" + try: + runtime_context = Context(runtime_context) + filepath = body['fits_file'] + + header = fits_utils.get_primary_header(filepath) + + camera = header.get('INSTRUME', '').strip() + dateobs_str = header.get('DATE-OBS', '') + dateobs = parse_date_obs(dateobs_str) if dateobs_str else None + + # Phase 1: Insert DB record before reduction so stacking worker can see it + insert_stack_frame( + runtime_context.db_address, + moluid=header['MOLUID'], + stack_num=header['MOLFRNUM'], + frmtotal=header['FRMTOTAL'], + camera=camera, + filepath=None, + is_last=body.get('last_frame', False), + dateobs=dateobs, + ) + + # Phase 2: Run reduction pipeline + images = stage_utils.run_pipeline_stages([{'path': filepath}], runtime_context) + + # Phase 3: Update DB record with reduced filepath + if images: + reduced_path = os.path.join( + images[0].get_output_directory(runtime_context), + images[0].get_output_filename(runtime_context), + ) + update_stack_frame_filepath( + runtime_context.db_address, + header['MOLUID'], + header['MOLFRNUM'], + reduced_path, + ) + + # Phase 4: Notify the stack worker that a new subframe is available + redis_url = getattr(runtime_context, 'REDIS_URL', None) + if redis_url: + redis_client = redis.Redis.from_url(redis_url) + push_notification(redis_client, camera, header['MOLUID']) + + except Retry: + raise + except Exception: + logger.error("Error processing subframe: {error}".format(error=logs.format_exception())) diff --git a/banzai/settings.py b/banzai/settings.py index 758d1d49..2cce2eed 100644 --- a/banzai/settings.py +++ b/banzai/settings.py @@ -163,3 +163,7 @@ LARGE_WORKER_QUEUE = os.getenv('CELERY_LARGE_TASK_QUEUE_NAME', 'celery_large') REFERENCE_CATALOG_URL = os.getenv('REFERENCE_CATALOG_URL', 'http://phot-catalog.lco.gtn/') + +SUBFRAME_TASK_QUEUE_NAME = os.getenv('SUBFRAME_TASK_QUEUE_NAME', 'subframe_tasks') +STACK_QUEUE_NAME = os.getenv('STACK_QUEUE_NAME', 'banzai_stack_queue') +REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379/0') diff --git a/banzai/stacking.py b/banzai/stacking.py new file mode 100644 index 00000000..0064224a --- /dev/null +++ b/banzai/stacking.py @@ -0,0 +1,195 @@ +"""Smart stacking: worker, supervisor, and helper functions.""" +import datetime +import multiprocessing +import os +import signal +import time + +import redis as redis_lib + +from banzai import dbs +from banzai.logs import get_logger + +logger = get_logger() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +REQUIRED_MESSAGE_FIELDS = ('fits_file', 'last_frame', 'instrument_enqueue_timestamp') + + +def validate_message(body): + """Check that body contains fits_file, last_frame, instrument_enqueue_timestamp.""" + return all(field in body for field in REQUIRED_MESSAGE_FIELDS) + + +def check_stack_complete(frames, frmtotal): + """Return True if the stack is ready to finalize. + + A stack is complete when all received frames have been reduced and either + all expected frames are present or the instrument signalled is_last. + """ + all_reduced = all(f.filepath is not None for f in frames) + all_arrived = len(frames) == frmtotal + has_last = any(f.is_last for f in frames) + return all_reduced and (all_arrived or has_last) + + +# --------------------------------------------------------------------------- +# Notifications +# --------------------------------------------------------------------------- + +REDIS_KEY_PREFIX = 'stack:notify:' + + +def push_notification(redis_client, camera, moluid): + """Push a moluid notification onto the Redis list for a camera.""" + redis_client.lpush(f'{REDIS_KEY_PREFIX}{camera}', moluid) + + +def drain_notifications(redis_client, camera): + """Drain and return a deduplicated set of moluids from the Redis list for a camera.""" + key = f'{REDIS_KEY_PREFIX}{camera}' + drain_key = f'{key}:draining' + # Atomic rename so notifications pushed between read and delete aren't lost + try: + redis_client.rename(key, drain_key) + except redis_lib.exceptions.ResponseError: + return set() + raw = redis_client.lrange(drain_key, 0, -1) + redis_client.delete(drain_key) + return {item.decode() if isinstance(item, bytes) else item for item in raw} + + +# --------------------------------------------------------------------------- +# Worker +# --------------------------------------------------------------------------- + +def run_worker_loop(camera, db_address, redis_url, timeout_minutes=20, retention_days=30, poll_interval=5): + """Main loop: drain notifications, query DB, check completion, finalize.""" + redis_client = redis_lib.Redis.from_url(redis_url) + while True: + process_notifications(db_address, redis_client, camera) + check_timeout(db_address, camera, timeout_minutes) + dbs.cleanup_old_records(db_address, retention_days) + time.sleep(poll_interval) + + +def process_notifications(db_address, redis_client, camera): + """Drain, deduplicate, and process latest state for each moluid.""" + moluids = drain_notifications(redis_client, camera) + for moluid in moluids: + frames = dbs.get_stack_frames(db_address, moluid) + if not frames: + continue + frmtotal = frames[0].frmtotal + if check_stack_complete(frames, frmtotal): + finalize_stack(db_address, moluid, status='complete') + + +def finalize_stack(db_address, moluid, status='complete'): + """Mark stack complete and log mock stacking/JPEG/ingester operations.""" + dbs.mark_stack_complete(db_address, moluid, status=status) + logger.info(f'Mock stacking complete for {moluid}', extra_tags={'moluid': moluid}) + logger.info(f'Mock JPEG generation for {moluid}', extra_tags={'moluid': moluid}) + logger.info(f'Mock ingester upload for {moluid}', extra_tags={'moluid': moluid}) + + +def check_timeout(db_address, camera, timeout_minutes): + """Find stale active stacks and finalize them with status='timeout'.""" + cutoff = datetime.datetime.utcnow() - datetime.timedelta(minutes=timeout_minutes) + with dbs.get_session(db_address) as session: + stale_moluids = session.query(dbs.StackFrame.moluid).filter( + dbs.StackFrame.camera == camera, + dbs.StackFrame.status == 'active', + dbs.StackFrame.dateobs < cutoff, + ).distinct().all() + for (moluid,) in stale_moluids: + finalize_stack(db_address, moluid, status='timeout') + + +# --------------------------------------------------------------------------- +# Supervisor +# --------------------------------------------------------------------------- + +def discover_cameras(db_address, site_id): + """Query the Instrument table for cameras at a site.""" + with dbs.get_session(db_address) as session: + instruments = session.query(dbs.Instrument).filter( + dbs.Instrument.site == site_id + ).all() + return [inst.camera for inst in instruments] + + +class StackingSupervisor: + def __init__(self, site_id, db_address, redis_url, timeout_minutes=20, retention_days=30): + self.site_id = site_id + self.db_address = db_address + self.redis_url = redis_url + self.timeout_minutes = timeout_minutes + self.retention_days = retention_days + self.workers = {} + + def _worker_args(self, camera): + return (camera, self.db_address, self.redis_url, self.timeout_minutes, self.retention_days) + + def start(self): + """Discover cameras and spawn one worker process per camera.""" + cameras = discover_cameras(self.db_address, self.site_id) + for camera in cameras: + proc = multiprocessing.Process( + target=run_worker_loop, + args=self._worker_args(camera), + name=f'stacking-worker-{camera}', + ) + proc.start() + self.workers[camera] = proc + logger.info(f'Started stacking worker for camera {camera}') + + def monitor(self, check_interval=10): + """Check worker health and restart crashed workers.""" + while True: + for camera, proc in list(self.workers.items()): + if not proc.is_alive(): + logger.warning(f'Worker for {camera} died, restarting') + new_proc = multiprocessing.Process( + target=run_worker_loop, + args=self._worker_args(camera), + name=f'stacking-worker-{camera}', + ) + new_proc.start() + self.workers[camera] = new_proc + time.sleep(check_interval) + + def shutdown(self): + """Graceful shutdown of all workers.""" + for camera, proc in self.workers.items(): + proc.terminate() + proc.join(timeout=10) + logger.info(f'Stopped stacking worker for camera {camera}') + self.workers.clear() + + +def run_supervisor(): + """Entry point for the stacking supervisor.""" + site_id = os.environ['SITE_ID'] + db_address = os.environ['DB_ADDRESS'] + redis_url = os.environ.get('REDIS_URL', 'redis://redis:6379/0') + timeout_minutes = int(os.environ.get('STACK_TIMEOUT_MINUTES', '20')) + retention_days = int(os.environ.get('STACK_RETENTION_DAYS', '30')) + + supervisor = StackingSupervisor(site_id, db_address, redis_url, + timeout_minutes=timeout_minutes, + retention_days=retention_days) + + def handle_signal(signum, frame): + supervisor.shutdown() + raise SystemExit(0) + + signal.signal(signal.SIGTERM, handle_signal) + signal.signal(signal.SIGINT, handle_signal) + + supervisor.start() + supervisor.monitor() diff --git a/banzai/tests/site_e2e/conftest.py b/banzai/tests/site_e2e/conftest.py index 01871bb9..9ba73060 100644 --- a/banzai/tests/site_e2e/conftest.py +++ b/banzai/tests/site_e2e/conftest.py @@ -6,6 +6,8 @@ import subprocess import time from pathlib import Path +from sqlalchemy import create_engine, text as sa_text +from kombu import Connection, Queue import pytest @@ -65,6 +67,15 @@ def poll_until(predicate, timeout, interval=5): return result +def publish_to_queue(queue_name, body, broker_url='amqp://localhost:5672'): + """Publish a JSON message to a RabbitMQ queue.""" + with Connection(broker_url) as conn: + queue = Queue(queue_name, channel=conn.channel()) + queue.declare() + with conn.Producer() as producer: + producer.publish(body, routing_key=queue_name, serializer='json') + + def run_docker_compose(compose_file, *args, cwd=None, env=None): """Run a docker compose command and return the CompletedProcess result.""" cmd = ["docker", "compose", "-f", str(compose_file)] + list(args) @@ -132,7 +143,6 @@ def check(): def wait_for_subscription_active(timeout=60): """Wait for the replication subscription to be enabled with an active worker.""" - from sqlalchemy import create_engine, text as sa_text engine = create_engine(LOCAL_DB_ADDRESS) diff --git a/banzai/tests/site_e2e/test_site_e2e.py b/banzai/tests/site_e2e/test_site_e2e.py index fd570910..07503840 100644 --- a/banzai/tests/site_e2e/test_site_e2e.py +++ b/banzai/tests/site_e2e/test_site_e2e.py @@ -1,12 +1,16 @@ """End-to-end tests for site deployment caching system.""" +import datetime import os +import shutil import subprocess +import time from pathlib import Path import pytest import requests from sqlalchemy import create_engine, text +from astropy.io import fits from banzai import dbs from banzai.tests.site_e2e.utils import populate_publication @@ -14,6 +18,7 @@ PUBLICATION_DB_ADDRESS, LOCAL_DB_ADDRESS, CACHE_DIR, DATA_DIR, ARCHIVE_API_URL, REPO_ROOT, wait_for_subscription_active, poll_until, run_site_compose, + publish_to_queue, ) @@ -162,7 +167,6 @@ def test_06_queue_raw_frame(self, site_deployment, auth_token): @pytest.mark.e2e_site_reduction def test_07_reduction_completes(self, site_deployment): """Verify reduction completed by checking for processed output file.""" - from astropy.io import fits raw_dir = DATA_DIR / 'raw' assert raw_dir.exists(), f"Raw directory not found: {raw_dir}" @@ -205,8 +209,47 @@ def test_07_reduction_completes(self, site_deployment): if failures: pytest.fail(f"Reduction failed for {len(failures)}/{len(raw_files)} files:\n" + "\n".join(failures)) + @pytest.mark.e2e_site_reduction + def test_08_reduction_used_cached_calibrations(self, site_deployment): + """Verify reduced frames used calibrations that exist in the local cache and DB.""" + + output_dir = DATA_DIR / 'output' + reduced_files = list(output_dir.rglob('*-e91.fits.fz')) + assert reduced_files, f"No reduced files found under {output_dir}" + + cal_header_keys = {'L1IDBIAS': 'bias', 'L1IDDARK': 'dark', 'L1IDFLAT': 'flat'} + cached_files = {p.name for p in CACHE_DIR.rglob('*.fits.fz')} + errors = [] + + for reduced_path in reduced_files: + with fits.open(str(reduced_path)) as hdul: + ext = 'SCI' if 'SCI' in hdul else 0 + header = hdul[ext].header + + for key, cal_type in cal_header_keys.items(): + val = header.get(key, '') + if not val or val == 'N/A': + continue + basename = os.path.basename(val) + + if basename not in cached_files: + errors.append( + f"{reduced_path.name}: {cal_type} file '{basename}' not found in cache" + ) + + with dbs.get_session(LOCAL_DB_ADDRESS) as session: + cal = session.query(dbs.CalibrationImage).filter( + dbs.CalibrationImage.filename == basename + ).first() + if not cal or not cal.filepath: + errors.append( + f"{reduced_path.name}: {cal_type} file '{basename}' missing or NULL filepath in DB" + ) + + assert not errors, "Cached calibration verification failed:\n" + "\n".join(errors) + @pytest.mark.e2e_site_cache - def test_08_add_older_calibrations(self): + def test_09_add_older_calibrations(self): """Insert older calibrations to test cache updates.""" populate_publication.insert_additional_calibrations(PUBLICATION_DB_ADDRESS) @@ -222,7 +265,7 @@ def test_08_add_older_calibrations(self): ) @pytest.mark.e2e_site_cache - def test_09_older_calibrations_replicated(self): + def test_10_older_calibrations_replicated(self): """Verify new calibrations replicated (now 13 total in DB).""" def check(): with dbs.get_session(LOCAL_DB_ADDRESS) as session: @@ -236,10 +279,89 @@ def check(): "Expected 13 calibrations in local DB after replication" @pytest.mark.e2e_site_cache - def test_10_cache_updated(self): + def test_11_cache_updated(self): """Verify cache settled to exactly 7 files after older calibrations added. The download worker keeps only the top 2 per config, so the 6 older calibrations should not persist in the cache. """ _assert_cache_matches(PHASE1_EXPECTED_FILES, timeout=120) + + @pytest.mark.e2e_site_reduction + def test_12_subframe_stack_completes(self, site_deployment): + """Verify subframe stacking processes a frame end-to-end.""" + + raw_dir = DATA_DIR / 'raw' + src_path = raw_dir / RAW_FRAME_FILENAME + subframe_path = raw_dir / 'subframe_test.fits.fz' + + assert src_path.exists(), f"Raw frame not found: {src_path}" + shutil.copy2(str(src_path), str(subframe_path)) + + with fits.open(str(subframe_path), mode='update') as hdul: + hdul['SCI'].header['MOLUID'] = 'mol-e2e-test' + hdul['SCI'].header['MOLFRNUM'] = 1 + hdul['SCI'].header['FRMTOTAL'] = 1 + hdul['SCI'].header['STACK'] = 'T' + + body = { + 'fits_file': '/raw/subframe_test.fits.fz', + 'last_frame': True, + 'instrument_enqueue_timestamp': int(time.time() * 1000), + } + publish_to_queue('banzai_stack_queue', body) + + def check(): + with dbs.get_session(LOCAL_DB_ADDRESS) as session: + frames = session.query(dbs.StackFrame).filter( + dbs.StackFrame.moluid == 'mol-e2e-test' + ).all() + if frames and all(f.status == 'complete' for f in frames): + return [f.filepath for f in frames] + return None + + filepaths = poll_until(check, timeout=300) + assert filepaths, "Subframe stack did not complete within timeout" + + # Verify the reduced output file exists on disk. + # The container path (e.g. /reduced/lsc/sq34/.../file.fits.fz) maps to DATA_DIR/output/... + container_path = filepaths[0] + assert container_path, "StackFrame has no filepath after completion" + relative_path = container_path.removeprefix('/reduced/') + expected_path = DATA_DIR / 'output' / relative_path + + found = poll_until( + lambda p=expected_path: p.exists() and p.stat().st_size > 0, + timeout=60, interval=5 + ) + assert found, f"Reduced subframe output not found: {expected_path}" + + @pytest.mark.e2e_site_cache + def test_13_stack_timeout(self, site_deployment): + """Verify stacking supervisor times out incomplete stacks.""" + stale_dateobs = datetime.datetime.utcnow() - datetime.timedelta(minutes=25) + + for stack_num in [1, 2]: + dbs.insert_stack_frame( + LOCAL_DB_ADDRESS, + moluid='mol-e2e-timeout', + stack_num=stack_num, + frmtotal=3, + camera='sq34', + filepath='/tmp/fake.fits', + is_last=False, + dateobs=stale_dateobs, + ) + + def check(): + with dbs.get_session(LOCAL_DB_ADDRESS) as session: + frames = session.query(dbs.StackFrame).filter( + dbs.StackFrame.moluid == 'mol-e2e-timeout' + ).all() + if frames and all(f.status == 'timeout' for f in frames): + return frames + return None + + result = poll_until(check, timeout=60) + assert result, "Stacking supervisor did not timeout the stale stack" + assert len(result) == 2, f"Expected 2 timed-out frames, found {len(result)}" diff --git a/banzai/tests/test_download_worker.py b/banzai/tests/test_download_worker.py index e6e6f5e5..6d42bc9c 100644 --- a/banzai/tests/test_download_worker.py +++ b/banzai/tests/test_download_worker.py @@ -128,7 +128,7 @@ def _attrs_for_type(cal_type, **overrides): return attrs -def test_returns_top_2_per_config(db_address, tmp_path): +def test_returns_top_2_per_config(db_address): inst_id = _seed_db(db_address) bias_attrs = _attrs_for_type('BIAS', configuration_mode='default', binning='1x1') with dbs.get_session(db_address) as session: @@ -140,7 +140,7 @@ def test_returns_top_2_per_config(db_address, tmp_path): assert filenames == {'mid.fits', 'new.fits'} -def test_partitions_independently_by_config(db_address, tmp_path): +def test_partitions_independently_by_config(db_address): inst_id = _seed_db(db_address) with dbs.get_session(db_address) as session: for i, binning in enumerate(['1x1', '2x2']): @@ -154,7 +154,7 @@ def test_partitions_independently_by_config(db_address, tmp_path): 'bias_2x2_1.fits', 'bias_2x2_2.fits'} -def test_filters_by_instrument_type(db_address, tmp_path): +def test_filters_by_instrument_type(db_address): sinistro_id = _seed_db(db_address, camera='fa01', inst_type='1m0-SciCam-Sinistro') floyds_id = _seed_db(db_address, camera='en01', inst_type='2m0-FLOYDS-SciCam') bias_attrs = _attrs_for_type('BIAS', configuration_mode='default', binning='1x1') @@ -167,7 +167,7 @@ def test_filters_by_instrument_type(db_address, tmp_path): assert filenames == {'sinistro.fits'} -def test_wildcard_returns_all_instrument_types(db_address, tmp_path): +def test_wildcard_returns_all_instrument_types(db_address): sinistro_id = _seed_db(db_address, camera='fa01', inst_type='1m0-SciCam-Sinistro') floyds_id = _seed_db(db_address, camera='en01', inst_type='2m0-FLOYDS-SciCam') bias_attrs = _attrs_for_type('BIAS', configuration_mode='default', binning='1x1') @@ -179,7 +179,7 @@ def test_wildcard_returns_all_instrument_types(db_address, tmp_path): assert filenames == {'sinistro.fits', 'floyds.fits'} -def test_biases_ignore_filter(db_address, tmp_path): +def test_biases_ignore_filter(db_address): """Biases taken with different filters should be grouped together.""" inst_id = _seed_db(db_address) bias_attrs = _attrs_for_type('BIAS', configuration_mode='default', binning='1x1') @@ -192,7 +192,7 @@ def test_biases_ignore_filter(db_address, tmp_path): assert filenames == {'bias_B.fits', 'bias_R.fits'} -def test_darks_partitioned_by_temperature(db_address, tmp_path): +def test_darks_partitioned_by_temperature(db_address): inst_id = _seed_db(db_address) with dbs.get_session(db_address) as session: for i, temp in enumerate(['5', '10']): @@ -207,7 +207,7 @@ def test_darks_partitioned_by_temperature(db_address, tmp_path): 'dark_t10_1.fits', 'dark_t10_2.fits'} -def test_skyflats_partitioned_by_filter(db_address, tmp_path): +def test_skyflats_partitioned_by_filter(db_address): inst_id = _seed_db(db_address) with dbs.get_session(db_address) as session: for i, filt in enumerate(['V', 'B']): diff --git a/banzai/tests/test_smart_stacking.py b/banzai/tests/test_smart_stacking.py new file mode 100644 index 00000000..15c29cb5 --- /dev/null +++ b/banzai/tests/test_smart_stacking.py @@ -0,0 +1,397 @@ +"""Unit tests for the smart stacking feature.""" +import datetime +from unittest.mock import MagicMock, patch + +import pytest +from astropy.io.fits import Header + +from sqlalchemy import text + +from banzai import dbs +from banzai.dbs import insert_stack_frame, get_stack_frames, mark_stack_complete, cleanup_old_records, update_stack_frame_filepath +from banzai.stacking import (validate_message, check_stack_complete, + push_notification, drain_notifications, REDIS_KEY_PREFIX, + process_notifications, finalize_stack, check_timeout, + discover_cameras, StackingSupervisor) +from banzai.scheduling import process_subframe +from banzai.main import SubframeListener + +pytestmark = pytest.mark.smart_stacking + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def db_address(tmp_path): + """Create a fresh SQLite DB per test with a site and two instruments.""" + addr = f'sqlite:///{tmp_path}/test.db' + dbs.create_db(addr) + with dbs.get_session(addr) as session: + session.add(dbs.Site(id='tst', timezone=0, latitude=0, longitude=0, elevation=0)) + session.add(dbs.Instrument(site='tst', camera='cam1', name='cam1', type='1m0-SciCam-Sinistro', nx=4096, ny=4096)) + session.add(dbs.Instrument(site='tst', camera='cam2', name='cam2', type='1m0-SciCam-Sinistro', nx=4096, ny=4096)) + return addr + + +@pytest.fixture +def mock_redis(): + """Return a MagicMock standing in for a Redis client.""" + r = MagicMock() + r.lpush = MagicMock() + r.lrange = MagicMock(return_value=[]) + r.delete = MagicMock() + return r + + +# --------------------------------------------------------------------------- +# Database operations +# --------------------------------------------------------------------------- + +class TestDBOperations: + + def test_insert_and_query(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + insert_stack_frame( + db_address, moluid='mol-001', stack_num=1, frmtotal=5, + camera='cam1', filepath='/data/frame1.fits', is_last=False, dateobs=dateobs, + ) + frames = get_stack_frames(db_address, moluid='mol-001') + assert len(frames) == 1 + frame = frames[0] + assert frame.moluid == 'mol-001' + assert frame.stack_num == 1 + assert frame.frmtotal == 5 + assert frame.camera == 'cam1' + assert frame.filepath == '/data/frame1.fits' + assert frame.is_last is False + assert frame.dateobs == dateobs + + def test_duplicate_is_noop(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + insert_stack_frame( + db_address, moluid='mol-dup', stack_num=1, frmtotal=3, + camera='cam1', filepath='/data/dup1.fits', is_last=False, dateobs=dateobs, + ) + insert_stack_frame( + db_address, moluid='mol-dup', stack_num=1, frmtotal=3, + camera='cam1', filepath='/data/dup2.fits', is_last=False, dateobs=dateobs, + ) + frames = get_stack_frames(db_address, 'mol-dup') + assert len(frames) == 1 + assert frames[0].filepath == '/data/dup1.fits' + + def test_update_stack_frame_filepath(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + insert_stack_frame( + db_address, moluid='mol-upd', stack_num=1, frmtotal=3, + camera='cam1', filepath=None, is_last=False, dateobs=dateobs, + ) + frames = get_stack_frames(db_address, 'mol-upd') + assert frames[0].filepath is None + + update_stack_frame_filepath(db_address, 'mol-upd', 1, '/data/reduced.fits') + frames = get_stack_frames(db_address, 'mol-upd') + assert frames[0].filepath == '/data/reduced.fits' + + +# --------------------------------------------------------------------------- +# Status transitions +# --------------------------------------------------------------------------- + +class TestStatusTransitions: + + def test_status_active_to_complete(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + for i in range(3): + insert_stack_frame( + db_address, moluid='mol-comp', stack_num=i + 1, frmtotal=3, + camera='cam1', filepath=f'/data/comp{i}.fits', is_last=(i == 2), dateobs=dateobs, + ) + mark_stack_complete(db_address, 'mol-comp', 'complete') + frames = get_stack_frames(db_address, 'mol-comp') + for f in frames: + assert f.status == 'complete' + assert f.completed_at is not None + + def test_status_active_to_timeout(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + for i in range(2): + insert_stack_frame( + db_address, moluid='mol-to', stack_num=i + 1, frmtotal=5, + camera='cam1', filepath=f'/data/to{i}.fits', is_last=False, dateobs=dateobs, + ) + mark_stack_complete(db_address, 'mol-to', 'timeout') + frames = get_stack_frames(db_address, 'mol-to') + for f in frames: + assert f.status == 'timeout' + assert f.completed_at is not None + + +# --------------------------------------------------------------------------- +# Timeout +# --------------------------------------------------------------------------- + +class TestTimeout: + + def test_timeout_finalizes_stale_stacks(self, db_address): + old_dateobs = datetime.datetime.utcnow() - datetime.timedelta(hours=2) + for i in range(3): + insert_stack_frame( + db_address, moluid='mol-stale', stack_num=i + 1, frmtotal=5, + camera='cam1', filepath=f'/data/stale{i}.fits', is_last=False, dateobs=old_dateobs, + ) + check_timeout(db_address, 'cam1', timeout_minutes=60) + frames = get_stack_frames(db_address, 'mol-stale') + for f in frames: + assert f.status == 'timeout' + + +# --------------------------------------------------------------------------- +# Redis notifications +# --------------------------------------------------------------------------- + +class TestRedisNotifications: + + def test_push_notification(self, mock_redis): + push_notification(mock_redis, 'cam1', 'mol-abc') + mock_redis.lpush.assert_called_once_with(f'{REDIS_KEY_PREFIX}cam1', 'mol-abc') + + def test_drain_for_camera(self, mock_redis): + mock_redis.lrange.return_value = [b'mol-a', b'mol-a', b'mol-b'] + result = drain_notifications(mock_redis, 'cam1') + assert result == {'mol-a', 'mol-b'} + mock_redis.rename.assert_called_once_with( + f'{REDIS_KEY_PREFIX}cam1', f'{REDIS_KEY_PREFIX}cam1:draining') + mock_redis.lrange.assert_called_once_with(f'{REDIS_KEY_PREFIX}cam1:draining', 0, -1) + mock_redis.delete.assert_called_once_with(f'{REDIS_KEY_PREFIX}cam1:draining') + + +# --------------------------------------------------------------------------- +# Multiple concurrent stacks +# --------------------------------------------------------------------------- + +class TestConcurrentStacks: + + def test_concurrent_stacks_same_camera(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + for i in range(3): + insert_stack_frame( + db_address, moluid='mol-A', stack_num=i + 1, frmtotal=3, + camera='cam1', filepath=f'/data/a{i}.fits', is_last=(i == 2), dateobs=dateobs, + ) + for i in range(2): + insert_stack_frame( + db_address, moluid='mol-B', stack_num=i + 1, frmtotal=5, + camera='cam1', filepath=f'/data/b{i}.fits', is_last=False, dateobs=dateobs, + ) + + frames_a = get_stack_frames(db_address, 'mol-A') + frames_b = get_stack_frames(db_address, 'mol-B') + assert len(frames_a) == 3 + assert len(frames_b) == 2 + assert check_stack_complete(frames_a, frmtotal=3) is True + assert check_stack_complete(frames_b, frmtotal=5) is False + + +# --------------------------------------------------------------------------- +# check_stack_complete +# --------------------------------------------------------------------------- + +class TestCheckStackComplete: + + @staticmethod + def _frame(filepath='/data/f.fits', is_last=False): + f = MagicMock() + f.filepath = filepath + f.is_last = is_last + return f + + def test_all_frames_arrived_and_reduced(self): + frames = [self._frame() for _ in range(3)] + assert check_stack_complete(frames, frmtotal=3) is True + + def test_partial_without_is_last(self): + frames = [self._frame() for _ in range(3)] + assert check_stack_complete(frames, frmtotal=5) is False + + def test_partial_with_is_last(self): + frames = [self._frame() for _ in range(2)] + [self._frame(is_last=True)] + assert check_stack_complete(frames, frmtotal=5) is True + + def test_is_last_waits_for_unreduced_frames(self): + frames = [self._frame(), self._frame(filepath=None, is_last=True)] + assert check_stack_complete(frames, frmtotal=5) is False + + def test_empty_frames(self): + assert check_stack_complete([], frmtotal=5) is False + + +# --------------------------------------------------------------------------- +# Retention / cleanup +# --------------------------------------------------------------------------- + +class TestRetention: + + def test_cleanup_old_records(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + for i in range(3): + insert_stack_frame( + db_address, moluid='mol-old', stack_num=i + 1, frmtotal=3, + camera='cam1', filepath=f'/data/old{i}.fits', is_last=(i == 2), dateobs=dateobs, + ) + mark_stack_complete(db_address, 'mol-old', 'complete') + + with dbs.get_session(db_address) as session: + session.execute( + text("UPDATE stack_frames SET completed_at = :old_date WHERE moluid = :mol"), + {'old_date': datetime.datetime.utcnow() - datetime.timedelta(days=30), 'mol': 'mol-old'}, + ) + + cleanup_old_records(db_address, retention_days=7) + frames = get_stack_frames(db_address, 'mol-old') + assert len(frames) == 0 + + def test_cleanup_preserves_recent(self, db_address): + dateobs = datetime.datetime(2024, 6, 15, 12, 0, 0) + for i in range(3): + insert_stack_frame( + db_address, moluid='mol-recent', stack_num=i + 1, frmtotal=3, + camera='cam1', filepath=f'/data/recent{i}.fits', is_last=(i == 2), dateobs=dateobs, + ) + mark_stack_complete(db_address, 'mol-recent', 'complete') + cleanup_old_records(db_address, retention_days=7) + frames = get_stack_frames(db_address, 'mol-recent') + assert len(frames) == 3 + + +# --------------------------------------------------------------------------- +# SubframeListener on_message +# --------------------------------------------------------------------------- + +class TestSubframeListenerOnMessage: + """on_message dispatches to Celery; no FITS I/O or DB work here.""" + + @patch('banzai.main.process_subframe') + def test_on_message_dispatches_valid(self, mock_task): + ctx = MagicMock(SUBFRAME_TASK_QUEUE_NAME='subframe_tasks') + listener = SubframeListener(ctx) + + body = { + 'fits_file': '/path/to/frame.fits', + 'last_frame': False, + 'instrument_enqueue_timestamp': 1771023918500, + } + mock_message = MagicMock() + + listener.on_message(body, mock_message) + + mock_task.apply_async.assert_called_once_with( + args=(body, vars(ctx)), + queue='subframe_tasks', + ) + mock_message.ack.assert_called_once() + + @patch('banzai.main.process_subframe') + def test_on_message_invalid_no_dispatch(self, mock_task): + listener = SubframeListener(MagicMock()) + + body = { + 'last_frame': True, + # missing fits_file and instrument_enqueue_timestamp + } + mock_message = MagicMock() + + listener.on_message(body, mock_message) + + mock_task.apply_async.assert_not_called() + mock_message.ack.assert_called_once() + + +# --------------------------------------------------------------------------- +# process_subframe Celery task +# --------------------------------------------------------------------------- + +class TestProcessSubframe: + """Test the Celery task that does the actual subframe processing.""" + + @staticmethod + def _make_fits_header(**overrides): + """Build a FITS header with the standard stack keys.""" + h = Header() + h['INSTRUME'] = 'cam1' + h['DATE-OBS'] = '2024-01-01T00:00:00' + h['STACK'] = 'T' + h['MOLFRNUM'] = 1 + h['FRMTOTAL'] = 5 + h['MOLUID'] = 'mol-xyz' + for k, v in overrides.items(): + h[k] = v + return h + + @staticmethod + def _make_mock_image(output_dir='/data/processed', output_filename='frame-e09.fits'): + """Build a mock image returned by run_pipeline_stages.""" + img = MagicMock() + img.get_output_directory.return_value = output_dir + img.get_output_filename.return_value = output_filename + return img + + @pytest.mark.parametrize('last_frame_val, expected_is_last', [ + (False, False), + (True, True), + ]) + @patch('banzai.scheduling.stage_utils.run_pipeline_stages') + def test_process_subframe(self, mock_run_stages, last_frame_val, expected_is_last, db_address, mock_redis): + + mock_image = self._make_mock_image() + mock_run_stages.return_value = [mock_image] + + header = self._make_fits_header() + body = { + 'fits_file': '/path/to/frame.fits', + 'last_frame': last_frame_val, + 'instrument_enqueue_timestamp': 1771023918500, + } + runtime_context = {'db_address': db_address, 'REDIS_URL': 'redis://localhost:6379/0'} + + with patch('banzai.scheduling.fits_utils.get_primary_header', return_value=header), \ + patch('banzai.scheduling.redis.Redis.from_url', return_value=mock_redis): + process_subframe(body, runtime_context) + + mock_run_stages.assert_called_once() + + frames = get_stack_frames(db_address, 'mol-xyz') + assert len(frames) == 1 + assert frames[0].stack_num == 1 + assert frames[0].frmtotal == 5 + assert frames[0].camera == 'cam1' + assert frames[0].is_last is expected_is_last + assert frames[0].filepath == '/data/processed/frame-e09.fits' + mock_redis.lpush.assert_called_once() + + +# --------------------------------------------------------------------------- +# Supervisor +# --------------------------------------------------------------------------- + +class TestSupervisor: + + def test_discover_cameras(self, db_address): + cameras = discover_cameras(db_address, 'tst') + assert 'cam1' in cameras + assert 'cam2' in cameras + assert len(cameras) == 2 + + @patch('banzai.stacking.discover_cameras', return_value=['cam1', 'cam2', 'cam3']) + @patch('banzai.stacking.multiprocessing.Process') + def test_supervisor_spawns_per_camera(self, mock_process_cls, mock_discover): + supervisor = StackingSupervisor( + site_id='tst', + db_address='sqlite:///fake.db', + redis_url='redis://localhost:6379', + ) + supervisor.start() + assert mock_process_cls.call_count == 3 + assert mock_process_cls.return_value.start.call_count == 3 diff --git a/banzai/utils/stage_utils.py b/banzai/utils/stage_utils.py index ccd17315..74ea4e0a 100644 --- a/banzai/utils/stage_utils.py +++ b/banzai/utils/stage_utils.py @@ -63,3 +63,4 @@ def run_pipeline_stages(image_paths: list, runtime_context: Context, calibration for image in images: image.write(runtime_context) + return images diff --git a/docker-compose-site.yml b/docker-compose-site.yml index 2611a1ed..5bbdc736 100644 --- a/docker-compose-site.yml +++ b/docker-compose-site.yml @@ -159,3 +159,69 @@ services: - DOWNLOAD_WORKER_POLL_INTERVAL=${DOWNLOAD_WORKER_POLL_INTERVAL:-10} command: ["banzai_download_worker"] restart: unless-stopped + + banzai-subframe-listener: + build: . + container_name: banzai-subframe-listener + depends_on: + banzai-cache-init: + condition: service_completed_successfully + rabbitmq: + condition: service_started + environment: + - DB_ADDRESS=${DB_ADDRESS} + - SUBFRAME_TASK_QUEUE_NAME=subframe_reduction_task_queue + - TASK_HOST=redis://redis:6379/0 + command: ["banzai_subframe_worker", "--broker-url=amqp://rabbitmq:5672", + "--db-address=${DB_ADDRESS}", "--rlevel=9", + "--processed-path=/reduced"] + restart: unless-stopped + + banzai-subframe-worker: + build: . + container_name: banzai-subframe-worker + depends_on: + banzai-cache-init: + condition: service_completed_successfully + redis: + condition: service_started + rabbitmq: + condition: service_started + volumes: + - ${HOST_RAW_DIR}:/raw + - ${HOST_CALS_DIR}:/calibrations + - ${HOST_REDUCED_DIR}:/reduced + environment: + - DB_ADDRESS=${DB_ADDRESS} + - CAL_DB_ADDRESS=${CAL_DB_ADDRESS:-${DB_ADDRESS}} + - TASK_HOST=redis://redis:6379/0 + - SUBFRAME_TASK_QUEUE_NAME=subframe_reduction_task_queue + - REDIS_URL=redis://redis:6379/0 + - BANZAI_WORKER_LOGLEVEL=${BANZAI_WORKER_LOGLEVEL} + - OMP_NUM_THREADS=${OMP_NUM_THREADS} + - API_ROOT=${API_ROOT} + - AUTH_TOKEN=${AUTH_TOKEN} + - DATA_ROOT=/reduced + command: ["celery", "-A", "banzai.scheduling", "worker", + "--hostname", "banzai-subframe-worker-local", + "-l", "${BANZAI_WORKER_LOGLEVEL}", + "-c", "4", + "-Q", "subframe_reduction_task_queue"] + restart: unless-stopped + + banzai-stacking-supervisor: + build: . + container_name: banzai-stacking-supervisor + depends_on: + banzai-cache-init: + condition: service_completed_successfully + redis: + condition: service_started + environment: + - SITE_ID=${SITE_ID} + - DB_ADDRESS=${DB_ADDRESS} + - REDIS_URL=redis://redis:6379/0 + - STACK_TIMEOUT_MINUTES=${STACK_TIMEOUT_MINUTES:-20} + - STACK_RETENTION_DAYS=${STACK_RETENTION_DAYS:-30} + command: ["banzai_stacking_supervisor"] + restart: unless-stopped diff --git a/pyproject.toml b/pyproject.toml index b9eba032..947d85e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,6 +144,8 @@ zip-safe = false banzai_create_local_db = "banzai.main:create_local_db" banzai_download_worker = "banzai.cache.download_worker:run_download_worker_daemon" banzai_cache_init = "banzai.cache.init:run_initialization" + banzai_stacking_supervisor = "banzai.stacking:run_supervisor" + banzai_subframe_worker = "banzai.main:run_subframe_worker" [tool.coverage.run] source = ["banzai"] diff --git a/pytest.ini b/pytest.ini index 686a0197..eb8f72bf 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,6 @@ [pytest] minversion = 3.5 -norecursedirs = build docs/_build .direnv site_e2e +norecursedirs = build docs/_build .direnv site_e2e smart_stacking_integration doctest_plus = enabled addopts = -p no:warnings log_cli = True @@ -63,3 +63,5 @@ markers = stacking stats thousands_qc + smart_stacking : Smart stacking unit tests + integration_smart_stacking : Smart stacking integration tests diff --git a/site-banzai-env.default b/site-banzai-env.default index e2e8ae58..df7ee367 100644 --- a/site-banzai-env.default +++ b/site-banzai-env.default @@ -15,13 +15,19 @@ BANZAI_WORKER_LOGLEVEL=debug OMP_NUM_THREADS=2 DOWNLOAD_WORKER_POLL_INTERVAL=10 +# Smart stacking +STACK_TIMEOUT_MINUTES=20 +STACK_RETENTION_DAYS=30 +STACK_QUEUE_NAME=banzai_stack_queue + # Database DB_ADDRESS=postgresql://banzai@postgresql:5432/banzai_local PUBLICATION_NAME=banzai_calibrations # CAL_DB_ADDRESS= -# Replication -AWS_DB_ADDRESS= # postgresql://user:pass@aws-host.rds.amazonaws.com:5432/calibrations +# Replication (cache-init handles these) +# AWS_DB_ADDRESS= # postgresql://user:pass@aws-host.rds.amazonaws.com:5432/calibrations +# REPLICATION_SLOT_NAME= # banzai_${SITE_ID}_slot # API API_ROOT=https://archive-api.lco.global/