Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions banzai/dbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dateutil.parser import parse
import requests
from sqlalchemy import create_engine, pool, func, make_url
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean, CHAR, JSON, UniqueConstraint, Float, Text
from sqlalchemy.ext.declarative import declarative_base
Expand Down Expand Up @@ -116,6 +117,25 @@ class ProcessedImage(Base):
tries = Column(Integer, default=0)



class StackFrame(Base):
__tablename__ = 'stack_frames'
id = Column(Integer, primary_key=True, autoincrement=True)
moluid = Column(String(100), nullable=False, index=True)
stack_num = Column(Integer, nullable=False)
frmtotal = Column(Integer, nullable=False)
camera = Column(String(50), nullable=False, index=True)
filepath = Column(String(255), nullable=True)
is_last = Column(Boolean, default=False)
status = Column(String(20), default='active', nullable=False)
dateobs = Column(DateTime, nullable=True)
created_at = Column(DateTime, default=datetime.datetime.utcnow)
completed_at = Column(DateTime, nullable=True)
__table_args__ = (
UniqueConstraint('moluid', 'stack_num', name='uq_stack_moluid_num'),
)


def parse_configdb(configdb_address):
"""
Parse the contents of the configdb.
Expand Down Expand Up @@ -580,3 +600,56 @@ def replicate_instrument(instrument_record, db_address):

add_or_update_record(db_session, Instrument, equivalence_criteria, record_attributes)
db_session.commit()


def insert_stack_frame(db_address, moluid, stack_num, frmtotal, camera, filepath, is_last, dateobs):
"""Insert a stack frame record into the database. Duplicate (moluid, stack_num) is a no-op."""
try:
with get_session(db_address) as session:
session.add(StackFrame(
moluid=moluid,
stack_num=stack_num,
frmtotal=frmtotal,
camera=camera,
filepath=filepath,
is_last=is_last,
dateobs=dateobs,
))
except IntegrityError:
pass


def get_stack_frames(db_address, moluid):
"""Get all stack frame records for a given moluid."""
with get_session(db_address) as session:
return session.query(StackFrame).filter(
StackFrame.moluid == moluid
).all()


def mark_stack_complete(db_address, moluid, status='complete'):
"""Mark all frames for a moluid as complete (or timeout)."""
now = datetime.datetime.utcnow()
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.moluid == moluid
).update({'status': status, 'completed_at': now})


def update_stack_frame_filepath(db_address, moluid, stack_num, filepath):
"""Set the reduced filepath on an existing stack frame record."""
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.moluid == moluid,
StackFrame.stack_num == stack_num,
).update({'filepath': filepath})


def cleanup_old_records(db_address, retention_days):
"""Delete completed stack frame records older than retention_days."""
cutoff = datetime.datetime.utcnow() - datetime.timedelta(days=retention_days)
with get_session(db_address) as session:
session.query(StackFrame).filter(
StackFrame.status != 'active',
StackFrame.completed_at < cutoff,
).delete()
4 changes: 4 additions & 0 deletions banzai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,7 @@
LARGE_WORKER_QUEUE = os.getenv('CELERY_LARGE_TASK_QUEUE_NAME', 'celery_large')

REFERENCE_CATALOG_URL = os.getenv('REFERENCE_CATALOG_URL', 'http://phot-catalog.lco.gtn/')

SUBFRAME_TASK_QUEUE_NAME = os.getenv('SUBFRAME_TASK_QUEUE_NAME', 'subframe_tasks')
STACK_QUEUE_NAME = os.getenv('STACK_QUEUE_NAME', 'banzai_stack_queue')
REDIS_URL = os.getenv('REDIS_URL', 'redis://redis:6379/0')
195 changes: 195 additions & 0 deletions banzai/stacking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Smart stacking: worker, supervisor, and helper functions."""
import datetime
import multiprocessing
import os
import signal
import time

import redis as redis_lib

from banzai import dbs
from banzai.logs import get_logger

logger = get_logger()


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

REQUIRED_MESSAGE_FIELDS = ('fits_file', 'last_frame', 'instrument_enqueue_timestamp')


def validate_message(body):
"""Check that body contains fits_file, last_frame, instrument_enqueue_timestamp."""
return all(field in body for field in REQUIRED_MESSAGE_FIELDS)


def check_stack_complete(frames, frmtotal):
"""Return True if the stack is ready to finalize.

A stack is complete when all received frames have been reduced and either
all expected frames are present or the instrument signalled is_last.
"""
all_reduced = all(f.filepath is not None for f in frames)
all_arrived = len(frames) == frmtotal
has_last = any(f.is_last for f in frames)
return all_reduced and (all_arrived or has_last)


# ---------------------------------------------------------------------------
# Notifications
# ---------------------------------------------------------------------------

REDIS_KEY_PREFIX = 'stack:notify:'


def push_notification(redis_client, camera, moluid):
"""Push a moluid notification onto the Redis list for a camera."""
redis_client.lpush(f'{REDIS_KEY_PREFIX}{camera}', moluid)


def drain_notifications(redis_client, camera):
"""Drain and return a deduplicated set of moluids from the Redis list for a camera."""
key = f'{REDIS_KEY_PREFIX}{camera}'
drain_key = f'{key}:draining'
# Atomic rename so notifications pushed between read and delete aren't lost
try:
redis_client.rename(key, drain_key)
except redis_lib.exceptions.ResponseError:
return set()
raw = redis_client.lrange(drain_key, 0, -1)
redis_client.delete(drain_key)
return {item.decode() if isinstance(item, bytes) else item for item in raw}


# ---------------------------------------------------------------------------
# Worker
# ---------------------------------------------------------------------------

def run_worker_loop(camera, db_address, redis_url, timeout_minutes=20, retention_days=30, poll_interval=5):
"""Main loop: drain notifications, query DB, check completion, finalize."""
redis_client = redis_lib.Redis.from_url(redis_url)
while True:
process_notifications(db_address, redis_client, camera)
check_timeout(db_address, camera, timeout_minutes)
dbs.cleanup_old_records(db_address, retention_days)
time.sleep(poll_interval)


def process_notifications(db_address, redis_client, camera):
"""Drain, deduplicate, and process latest state for each moluid."""
moluids = drain_notifications(redis_client, camera)
for moluid in moluids:
frames = dbs.get_stack_frames(db_address, moluid)
if not frames:
continue
frmtotal = frames[0].frmtotal
if check_stack_complete(frames, frmtotal):
finalize_stack(db_address, moluid, status='complete')


def finalize_stack(db_address, moluid, status='complete'):
"""Mark stack complete and log mock stacking/JPEG/ingester operations."""
dbs.mark_stack_complete(db_address, moluid, status=status)
logger.info(f'Mock stacking complete for {moluid}', extra_tags={'moluid': moluid})
logger.info(f'Mock JPEG generation for {moluid}', extra_tags={'moluid': moluid})
logger.info(f'Mock ingester upload for {moluid}', extra_tags={'moluid': moluid})


def check_timeout(db_address, camera, timeout_minutes):
"""Find stale active stacks and finalize them with status='timeout'."""
cutoff = datetime.datetime.utcnow() - datetime.timedelta(minutes=timeout_minutes)
with dbs.get_session(db_address) as session:
stale_moluids = session.query(dbs.StackFrame.moluid).filter(
dbs.StackFrame.camera == camera,
dbs.StackFrame.status == 'active',
dbs.StackFrame.dateobs < cutoff,
).distinct().all()
for (moluid,) in stale_moluids:
finalize_stack(db_address, moluid, status='timeout')


# ---------------------------------------------------------------------------
# Supervisor
# ---------------------------------------------------------------------------

def discover_cameras(db_address, site_id):
"""Query the Instrument table for cameras at a site."""
with dbs.get_session(db_address) as session:
instruments = session.query(dbs.Instrument).filter(
dbs.Instrument.site == site_id
).all()
return [inst.camera for inst in instruments]


class StackingSupervisor:
def __init__(self, site_id, db_address, redis_url, timeout_minutes=20, retention_days=30):
self.site_id = site_id
self.db_address = db_address
self.redis_url = redis_url
self.timeout_minutes = timeout_minutes
self.retention_days = retention_days
self.workers = {}

def _worker_args(self, camera):
return (camera, self.db_address, self.redis_url, self.timeout_minutes, self.retention_days)

def start(self):
"""Discover cameras and spawn one worker process per camera."""
cameras = discover_cameras(self.db_address, self.site_id)
for camera in cameras:
proc = multiprocessing.Process(
target=run_worker_loop,
args=self._worker_args(camera),
name=f'stacking-worker-{camera}',
)
proc.start()
self.workers[camera] = proc
logger.info(f'Started stacking worker for camera {camera}')

def monitor(self, check_interval=10):
"""Check worker health and restart crashed workers."""
while True:
for camera, proc in list(self.workers.items()):
if not proc.is_alive():
logger.warning(f'Worker for {camera} died, restarting')
new_proc = multiprocessing.Process(
target=run_worker_loop,
args=self._worker_args(camera),
name=f'stacking-worker-{camera}',
)
new_proc.start()
self.workers[camera] = new_proc
time.sleep(check_interval)

def shutdown(self):
"""Graceful shutdown of all workers."""
for camera, proc in self.workers.items():
proc.terminate()
proc.join(timeout=10)
logger.info(f'Stopped stacking worker for camera {camera}')
self.workers.clear()


def run_supervisor():
"""Entry point for the stacking supervisor."""
site_id = os.environ['SITE_ID']
db_address = os.environ['DB_ADDRESS']
redis_url = os.environ.get('REDIS_URL', 'redis://redis:6379/0')
timeout_minutes = int(os.environ.get('STACK_TIMEOUT_MINUTES', '20'))
retention_days = int(os.environ.get('STACK_RETENTION_DAYS', '30'))

supervisor = StackingSupervisor(site_id, db_address, redis_url,
timeout_minutes=timeout_minutes,
retention_days=retention_days)

def handle_signal(signum, frame):
supervisor.shutdown()
raise SystemExit(0)

signal.signal(signal.SIGTERM, handle_signal)
signal.signal(signal.SIGINT, handle_signal)

supervisor.start()
supervisor.monitor()
2 changes: 1 addition & 1 deletion banzai/tests/site_e2e/test_site_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import requests
from sqlalchemy import create_engine, text
from astropy.io import fits

from banzai import dbs
from banzai.tests.site_e2e.utils import populate_publication
Expand Down Expand Up @@ -162,7 +163,6 @@ def test_06_queue_raw_frame(self, site_deployment, auth_token):
@pytest.mark.e2e_site_reduction
def test_07_reduction_completes(self, site_deployment):
"""Verify reduction completed by checking for processed output file."""
from astropy.io import fits

raw_dir = DATA_DIR / 'raw'
assert raw_dir.exists(), f"Raw directory not found: {raw_dir}"
Expand Down
Loading
Loading