Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion banzai/cache/replication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from sqlalchemy import text, create_engine
from banzai import dbs
from banzai.logs import get_logger

# PostgreSQL logical replication is managed via server-level DDL commands
Expand Down
74 changes: 73 additions & 1 deletion banzai/dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from dateutil.parser import parse
import requests
from sqlalchemy import create_engine, pool, func, make_url
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, CHAR, JSON, UniqueConstraint, Float, Text
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, CHAR, JSON, UniqueConstraint, Float
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql.expression import true
from contextlib import contextmanager
Expand Down Expand Up @@ -116,6 +117,24 @@ class ProcessedImage(Base):
tries = Column(Integer, default=0)


class StackFrame(Base):
__tablename__ = 'stack_frames'
id = Column(Integer, primary_key=True, autoincrement=True)
moluid = Column(String(100), nullable=False, index=True)
stack_num = Column(Integer, nullable=False)
frmtotal = Column(Integer, nullable=False)
camera = Column(String(50), nullable=False, index=True)
filepath = Column(String(255), nullable=True)
is_last = Column(Boolean, default=False)
status = Column(String(20), default='active', nullable=False)
dateobs = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
__table_args__ = (
UniqueConstraint('moluid', 'stack_num', name='uq_stack_moluid_num'),
)


def parse_configdb(configdb_address):
"""
Parse the contents of the configdb.
Expand Down Expand Up @@ -580,3 +599,56 @@ def replicate_instrument(instrument_record, db_address):

add_or_update_record(db_session, Instrument, equivalence_criteria, record_attributes)
db_session.commit()


def insert_stack_frame(db_address, moluid, stack_num, frmtotal, camera, filepath, is_last, dateobs):
"""Insert a stack frame record into the database. Duplicate (moluid, stack_num) is a no-op."""
try:
with get_session(db_address) as session:
session.add(StackFrame(
moluid=moluid,
stack_num=stack_num,
frmtotal=frmtotal,
camera=camera,
filepath=filepath,
is_last=is_last,
dateobs=dateobs,
))
except IntegrityError:
pass


def get_stack_frames(db_address, moluid):
"""Get all stack frame records for a given moluid."""
with get_session(db_address) as session:
return session.query(StackFrame).filter(
StackFrame.moluid == moluid
).all()


def mark_stack_complete(db_address, moluid, status='complete'):
"""Mark all frames for a moluid as complete (or timeout)."""
now = datetime.datetime.utcnow()
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.moluid == moluid
).update({'status': status, 'completed_at': now})


def update_stack_frame_filepath(db_address, moluid, stack_num, filepath):
"""Set the reduced filepath on an existing stack frame record."""
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.moluid == moluid,
StackFrame.stack_num == stack_num,
).update({'filepath': filepath})


def cleanup_old_records(db_address, retention_days):
"""Delete completed stack frame records older than retention_days."""
cutoff = datetime.datetime.utcnow() - datetime.timedelta(days=retention_days)
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.status != 'active',
StackFrame.completed_at < cutoff,
).delete()
53 changes: 52 additions & 1 deletion banzai/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from banzai import settings, dbs, logs, calibrations
from banzai.context import Context
from banzai.utils import date_utils, stage_utils, import_utils, image_utils, fits_utils, file_utils
from banzai.scheduling import process_image, app, schedule_calibration_stacking
from banzai.scheduling import process_image, process_subframe, app, schedule_calibration_stacking
from banzai.stacking import validate_message
from banzai.data import DataProduct
from celery.schedules import crontab
import celery
Expand Down Expand Up @@ -228,6 +229,56 @@ def start_listener(runtime_context):
logger.info('Shutting down pipeline listener.')


class SubframeListener(ConsumerMixin):
def __init__(self, runtime_context):
self.runtime_context = runtime_context

def on_connection_error(self, exc, interval):
logger.error('{0}. Retrying connection in {1} seconds...'.format(exc, interval))
self.connection = self.connection.clone()
self.connection.ensure_connection(max_retries=10)

def get_consumers(self, Consumer, channel):
"""Bind to banzai_stack_queue."""
queue = Queue(self.runtime_context.STACK_QUEUE_NAME)
consumer = Consumer(queues=[queue], callbacks=[self.on_message])
consumer.qos(prefetch_count=1)
return [consumer]

def on_message(self, body, message):
"""Validate and dispatch to Celery for processing."""
if not validate_message(body):
logger.error('Invalid message received, missing required fields')
message.ack()
return

process_subframe.apply_async(
args=(body, vars(self.runtime_context)),
queue=self.runtime_context.SUBFRAME_TASK_QUEUE_NAME,
)
message.ack()


def run_subframe_worker():
"""Entry point for the subframe listener."""
runtime_context = parse_args(settings)

logging.getLogger('amqp').setLevel(logging.WARNING)
logger.info('Starting subframe listener')

listener = SubframeListener(runtime_context)

with Connection(runtime_context.broker_url) as connection:
listener.connection = connection.clone()
try:
listener.run()
except listener.connection.connection_errors:
listener.connection = connection.clone()
listener.ensure_connection(max_retries=10)
except KeyboardInterrupt:
logger.info('Shutting down subframe listener.')


def mark_frame(mark_as):
parser = argparse.ArgumentParser(description="Set the is_bad flag to mark the frame as {mark_as}"
"for a calibration frame in the database ".format(mark_as=mark_as))
Expand Down
60 changes: 58 additions & 2 deletions banzai/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
from datetime import datetime, timedelta, timezone
from dateutil.parser import parse

import redis
from celery import Celery
from kombu import Queue
from celery.exceptions import Retry
from banzai import dbs, calibrations, logs
from banzai.utils import date_utils, realtime_utils, stage_utils
from banzai.utils import date_utils, fits_utils, realtime_utils, stage_utils
from banzai.metrics import add_telemetry_span_attribute, add_telemetry_span_event
from celery.signals import worker_process_init
from banzai.context import Context
from banzai.utils.observation_utils import filter_calibration_blocks_for_type, get_calibration_blocks_for_time_range
from banzai.utils.date_utils import get_stacking_date_range
from banzai.utils.date_utils import get_stacking_date_range, parse_date_obs
from banzai.dbs import insert_stack_frame, update_stack_frame_filepath
from banzai.stacking import push_notification
try:
from opentelemetry.instrumentation.celery import CeleryInstrumentor
OPENTELEMETRY_AVAILABLE = True
Expand Down Expand Up @@ -247,3 +250,56 @@ def process_image(self, file_info: dict, runtime_context: dict):
except Exception:
logger.error("Exception processing frame: {error}".format(error=logs.format_exception()),
extra_tags={'file_info': file_info})


@app.task(name='celery.process_subframe', bind=True, reject_on_worker_lost=True, max_retries=5)
def process_subframe(self, body: dict, runtime_context: dict):
"""Reduce a subframe, record it in the DB, and notify the stacking worker."""
try:
runtime_context = Context(runtime_context)
filepath = body['fits_file']

header = fits_utils.get_primary_header(filepath)

camera = header.get('INSTRUME', '').strip()
dateobs_str = header.get('DATE-OBS', '')
dateobs = parse_date_obs(dateobs_str) if dateobs_str else None

# Phase 1: Insert DB record before reduction so stacking worker can see it
insert_stack_frame(
runtime_context.db_address,
moluid=header['MOLUID'],
stack_num=header['MOLFRNUM'],
frmtotal=header['FRMTOTAL'],
camera=camera,
filepath=None,
is_last=body.get('last_frame', False),
dateobs=dateobs,
)

# Phase 2: Run reduction pipeline
images = stage_utils.run_pipeline_stages([{'path': filepath}], runtime_context)

# Phase 3: Update DB record with reduced filepath
if images:
reduced_path = os.path.join(
images[0].get_output_directory(runtime_context),
images[0].get_output_filename(runtime_context),
)
update_stack_frame_filepath(
runtime_context.db_address,
header['MOLUID'],
header['MOLFRNUM'],
reduced_path,
)

# Phase 4: Notify the stack worker that a new subframe is available
redis_url = getattr(runtime_context, 'REDIS_URL', None)
if redis_url:
redis_client = redis.Redis.from_url(redis_url)
push_notification(redis_client, camera, header['MOLUID'])

except Retry:
raise
except Exception:
logger.error("Error processing subframe: {error}".format(error=logs.format_exception()))
4 changes: 4 additions & 0 deletions banzai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,7 @@
LARGE_WORKER_QUEUE = os.getenv('CELERY_LARGE_TASK_QUEUE_NAME', 'celery_large')

REFERENCE_CATALOG_URL = os.getenv('REFERENCE_CATALOG_URL', 'http://phot-catalog.lco.gtn/')

SUBFRAME_TASK_QUEUE_NAME = os.getenv('SUBFRAME_TASK_QUEUE_NAME', 'subframe_tasks')
STACK_QUEUE_NAME = os.getenv('STACK_QUEUE_NAME', 'banzai_stack_queue')
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379/0')
Loading
Loading