From 830ea77336d00078aec268294a0523fc62eccf1e Mon Sep 17 00:00:00 2001 From: Muneeb Amer Date: Wed, 4 Feb 2026 18:11:59 +0500 Subject: [PATCH] fix: added upload fail points (#355) --- examples/example_data_collection_vx300s.py | 3 + ...mple_data_collection_vx300s_data_daemon.py | 1 - .../communications_management/data_bridge.py | 76 +++++++++++++------ neuracore/data_daemon/const.py | 1 + .../data_daemon/state_management/tables.py | 1 + .../test_zmq_sockets.py | 24 +++--- 6 files changed, 69 insertions(+), 37 deletions(-) diff --git a/examples/example_data_collection_vx300s.py b/examples/example_data_collection_vx300s.py index 646308e7..2fbff2bb 100644 --- a/examples/example_data_collection_vx300s.py +++ b/examples/example_data_collection_vx300s.py @@ -6,6 +6,7 @@ from common.transfer_cube import BIMANUAL_VIPERX_URDF_PATH, make_sim_env import neuracore as nc +from neuracore.data_daemon.lifecycle.daemon_lifecycle import ensure_daemon_running def main(args): @@ -30,6 +31,8 @@ def main(args): ) print("Created Dataset...") + ensure_daemon_running() + try: for episode_idx in range(num_episodes): print(f"Starting episode {episode_idx}") diff --git a/examples/example_data_collection_vx300s_data_daemon.py b/examples/example_data_collection_vx300s_data_daemon.py index 9de7d423..623e0740 100644 --- a/examples/example_data_collection_vx300s_data_daemon.py +++ b/examples/example_data_collection_vx300s_data_daemon.py @@ -43,7 +43,6 @@ def main(args: dict) -> None: # start daemon process ensure_daemon_running() try: - print("Starting!!!") for episode_idx in range(num_episodes): print(f"Starting episode {episode_idx}") diff --git a/neuracore/data_daemon/communications_management/data_bridge.py b/neuracore/data_daemon/communications_management/data_bridge.py index b34663ee..84f49e4b 100644 --- a/neuracore/data_daemon/communications_management/data_bridge.py +++ b/neuracore/data_daemon/communications_management/data_bridge.py @@ -22,6 +22,7 @@ DATA_TYPE_FIELD_SIZE, DEFAULT_RING_BUFFER_SIZE, HEARTBEAT_TIMEOUT_SECS, + NEVER_OPENED_TIMEOUT_SECS, TRACE_ID_FIELD_SIZE, ) from neuracore.data_daemon.event_emitter import Emitter, get_emitter @@ -58,6 +59,7 @@ class ChannelState: last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) reader: ChannelMessageReader | None = None trace_id: str | None = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) def touch(self) -> None: """Update the last heartbeat time for the channel. @@ -94,6 +96,7 @@ def __init__( self.comm = comm_manager or CommunicationsManager() self.recording_disk_manager = recording_disk_manager self.channels: dict[str, ChannelState] = {} + self._closed_producers: set[str] = set() self._recording_traces: dict[str, set[str]] = {} self._trace_recordings: dict[str, str] = {} self._trace_metadata: dict[str, dict[str, str | int | None]] = {} @@ -104,24 +107,25 @@ def __init__( CommandType.DATA_CHUNK: self._handle_write_data_chunk, CommandType.HEARTBEAT: self._handle_heartbeat, CommandType.TRACE_END: self._handle_end_trace, - CommandType.RECORDING_STOPPED: self._handle_recording_stopped, } self._emitter = get_emitter() - self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_stopped_channels) self._running = False + self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_channel_on_trace_written) def run(self) -> None: - """Run the daemon main loop. + """Starts the daemon and begins accepting messages from producers. - This starts the consumer socket, and then enters an infinite loop where it: - - Receives ManagementMessages from producers over ZMQ - - Handles messages from producers using the `handle_message` function - - Cleans up expired channels using the `_cleanup_expired_channels` function - - Drains channel messages using the `_drain_channel_messages` function + This function blocks until the daemon is shutdown via Ctrl-C. - The loop will exit on a KeyboardInterrupt (e.g. Ctrl+C), and will then call - `cleanup_daemon` on the communications manager to clean up resources. + It is responsible for: + + - Starting the ZMQ consumer and publisher sockets. + - Receiving and processing management messages from producers. + - Periodically cleaning up expired channels. + - Draining full messages from the ring buffer. + + :return: None """ if self._running: raise RuntimeError("Daemon is already running") @@ -185,17 +189,36 @@ def handle_message(self, message: MessageEnvelope) -> None: cmd = message.command if producer_id is None: + # Stop recording commands are sent without a producer_id / channel if cmd != CommandType.RECORDING_STOPPED: logger.warning("Missing producer_id for command %s", cmd) return - channel = ChannelState(producer_id="recording-context") - else: - existing = self.channels.get(producer_id) - if existing is None: - existing = ChannelState(producer_id=producer_id) - self.channels[producer_id] = existing - logger.info("Created new channel for producer_id=%s", producer_id) - channel = existing + self._handle_recording_stopped(message) + return + + if ( + producer_id in self._closed_producers + and cmd != CommandType.OPEN_RING_BUFFER + ): + logger.warning( + "Ignoring command %s from closed producer_id=%s", + cmd, + producer_id, + ) + return + + if ( + cmd == CommandType.OPEN_RING_BUFFER + and producer_id in self._closed_producers + ): + self._closed_producers.discard(producer_id) + + existing = self.channels.get(producer_id) + if existing is None: + existing = ChannelState(producer_id=producer_id) + self.channels[producer_id] = existing + logger.info("Created new channel for producer_id=%s", producer_id) + channel = existing channel.touch() handler = self._command_handlers.get(cmd) @@ -250,6 +273,7 @@ def _handle_open_ring_buffer( def _drain_channel_messages(self) -> None: """Poll all channels for completed messages and handle them.""" for channel in self.channels.values(): + # guard against uninitialised channels if channel.reader is None or channel.ring_buffer is None: continue # Loop to receive full message @@ -564,9 +588,7 @@ def _handle_end_trace( self._remove_trace(str(recording_id), str(trace_id)) - def _handle_recording_stopped( - self, _: ChannelState, message: MessageEnvelope - ) -> None: + def _handle_recording_stopped(self, message: MessageEnvelope) -> None: """Handle a RECORDING_STOPPED message from a producer. This function is called when a producer sends a RECORDING_STOPPED message to @@ -599,7 +621,7 @@ def _finalize_pending_closes(self) -> None: self._closed_recordings.update(self._pending_close_recordings) self._pending_close_recordings.clear() - def cleanup_stopped_channels( + def cleanup_channel_on_trace_written( self, trace_id: str, _: str | None = None, @@ -631,12 +653,16 @@ def cleanup_stopped_channels( def _cleanup_expired_channels(self) -> None: """Remove channels whose heartbeat has not been seen within the timeout.""" now = datetime.now(timezone.utc) - timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS) + heartbeat_timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS) + never_opened_timeout = timedelta(seconds=NEVER_OPENED_TIMEOUT_SECS) to_remove = [ producer_id for producer_id, state in self.channels.items() - if now - state.last_heartbeat > timeout + # Missed heart beat + if (now - state.last_heartbeat > heartbeat_timeout) + # Never opened and timed out + or (now - state.created_at > never_opened_timeout) ] for producer_id in to_remove: @@ -661,5 +687,5 @@ def _cleanup_expired_channels(self) -> None: }, ), ) - # Here is where you would also clean up any shared memory segments. del self.channels[producer_id] + self._closed_producers.add(producer_id) diff --git a/neuracore/data_daemon/const.py b/neuracore/data_daemon/const.py index 2596038c..6a6ad56c 100644 --- a/neuracore/data_daemon/const.py +++ b/neuracore/data_daemon/const.py @@ -6,6 +6,7 @@ from pathlib import Path HEARTBEAT_TIMEOUT_SECS = 10 +NEVER_OPENED_TIMEOUT_SECS = 20 API_URL = os.getenv("NEURACORE_API_URL", "https://api.neuracore.app/api") TRACE_ID_FIELD_SIZE = 36 # bytes allocated for the trace_id string in chunk headers diff --git a/neuracore/data_daemon/state_management/tables.py b/neuracore/data_daemon/state_management/tables.py index 9c0e7ed0..9f350931 100644 --- a/neuracore/data_daemon/state_management/tables.py +++ b/neuracore/data_daemon/state_management/tables.py @@ -80,3 +80,4 @@ Index("idx_traces_trace_id", traces.c.trace_id) Index("idx_traces_status", traces.c.status) +Index("idx_traces_next_retry_at", traces.c.next_retry_at) diff --git a/tests/unit/data_daemon/communications_management/test_zmq_sockets.py b/tests/unit/data_daemon/communications_management/test_zmq_sockets.py index 5718499c..a189178c 100644 --- a/tests/unit/data_daemon/communications_management/test_zmq_sockets.py +++ b/tests/unit/data_daemon/communications_management/test_zmq_sockets.py @@ -218,16 +218,6 @@ def test_zmq_commands_and_message_flow(daemon_runtime) -> None: payload = json.dumps({"seq": 1}).encode("utf-8") active_trace_id = producer.trace_id - producer.send_data( - payload, - data_type=DataType.CUSTOM_1D, - data_type_name="custom", - robot_instance=1, - robot_id="robot-1", - dataset_id="dataset-1", - ) - producer.end_trace() - trace_written: list[int] = [] def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None: @@ -236,6 +226,18 @@ def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None: get_emitter().on(Emitter.TRACE_WRITTEN, on_trace_written) try: + producer.send_data( + payload, + data_type=DataType.CUSTOM_1D, + data_type_name="custom", + robot_instance=1, + robot_id="robot-1", + dataset_id="dataset-1", + ) + assert _wait_for( + lambda: active_trace_id in daemon._trace_recordings, timeout=0.5 + ) + producer.end_trace() assert _wait_for(lambda: trace_written, timeout=1) finally: get_emitter().remove_listener(Emitter.TRACE_WRITTEN, on_trace_written) @@ -420,7 +422,7 @@ def on_stop_recording(rec_id: str) -> None: assert recording_id not in daemon._pending_close_recordings assert recording_id not in daemon._closed_recordings - daemon._handle_recording_stopped(None, msg) + daemon._handle_recording_stopped(msg) assert recording_id in daemon._pending_close_recordings assert recording_id not in daemon._closed_recordings