Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/example_data_collection_vx300s.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from common.transfer_cube import BIMANUAL_VIPERX_URDF_PATH, make_sim_env

import neuracore as nc
from neuracore.data_daemon.lifecycle.daemon_lifecycle import ensure_daemon_running


def main(args):
Expand All @@ -30,6 +31,8 @@ def main(args):
)
print("Created Dataset...")

ensure_daemon_running()

try:
for episode_idx in range(num_episodes):
print(f"Starting episode {episode_idx}")
Expand Down
1 change: 0 additions & 1 deletion examples/example_data_collection_vx300s_data_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def main(args: dict) -> None:
# start daemon process
ensure_daemon_running()
try:
print("Starting!!!")
for episode_idx in range(num_episodes):
print(f"Starting episode {episode_idx}")

Expand Down
76 changes: 51 additions & 25 deletions neuracore/data_daemon/communications_management/data_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DATA_TYPE_FIELD_SIZE,
DEFAULT_RING_BUFFER_SIZE,
HEARTBEAT_TIMEOUT_SECS,
NEVER_OPENED_TIMEOUT_SECS,
TRACE_ID_FIELD_SIZE,
)
from neuracore.data_daemon.event_emitter import Emitter, get_emitter
Expand Down Expand Up @@ -58,6 +59,7 @@ class ChannelState:
last_heartbeat: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
reader: ChannelMessageReader | None = None
trace_id: str | None = None
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))

def touch(self) -> None:
"""Update the last heartbeat time for the channel.
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
self.comm = comm_manager or CommunicationsManager()
self.recording_disk_manager = recording_disk_manager
self.channels: dict[str, ChannelState] = {}
self._closed_producers: set[str] = set()
self._recording_traces: dict[str, set[str]] = {}
self._trace_recordings: dict[str, str] = {}
self._trace_metadata: dict[str, dict[str, str | int | None]] = {}
Expand All @@ -104,24 +107,25 @@ def __init__(
CommandType.DATA_CHUNK: self._handle_write_data_chunk,
CommandType.HEARTBEAT: self._handle_heartbeat,
CommandType.TRACE_END: self._handle_end_trace,
CommandType.RECORDING_STOPPED: self._handle_recording_stopped,
}

self._emitter = get_emitter()
self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_stopped_channels)
self._running = False
self._emitter.on(Emitter.TRACE_WRITTEN, self.cleanup_channel_on_trace_written)

def run(self) -> None:
"""Run the daemon main loop.
"""Starts the daemon and begins accepting messages from producers.

This starts the consumer socket, and then enters an infinite loop where it:
- Receives ManagementMessages from producers over ZMQ
- Handles messages from producers using the `handle_message` function
- Cleans up expired channels using the `_cleanup_expired_channels` function
- Drains channel messages using the `_drain_channel_messages` function
This function blocks until the daemon is shutdown via Ctrl-C.

The loop will exit on a KeyboardInterrupt (e.g. Ctrl+C), and will then call
`cleanup_daemon` on the communications manager to clean up resources.
It is responsible for:

- Starting the ZMQ consumer and publisher sockets.
- Receiving and processing management messages from producers.
- Periodically cleaning up expired channels.
- Draining full messages from the ring buffer.

:return: None
"""
if self._running:
raise RuntimeError("Daemon is already running")
Expand Down Expand Up @@ -185,17 +189,36 @@ def handle_message(self, message: MessageEnvelope) -> None:
cmd = message.command

if producer_id is None:
# Stop recording commands are sent without a producer_id / channel
if cmd != CommandType.RECORDING_STOPPED:
logger.warning("Missing producer_id for command %s", cmd)
return
channel = ChannelState(producer_id="recording-context")
else:
existing = self.channels.get(producer_id)
if existing is None:
existing = ChannelState(producer_id=producer_id)
self.channels[producer_id] = existing
logger.info("Created new channel for producer_id=%s", producer_id)
channel = existing
self._handle_recording_stopped(message)
return

if (
producer_id in self._closed_producers
and cmd != CommandType.OPEN_RING_BUFFER
):
logger.warning(
"Ignoring command %s from closed producer_id=%s",
cmd,
producer_id,
)
return

if (
cmd == CommandType.OPEN_RING_BUFFER
and producer_id in self._closed_producers
):
self._closed_producers.discard(producer_id)

existing = self.channels.get(producer_id)
if existing is None:
existing = ChannelState(producer_id=producer_id)
self.channels[producer_id] = existing
logger.info("Created new channel for producer_id=%s", producer_id)
channel = existing
channel.touch()

handler = self._command_handlers.get(cmd)
Expand Down Expand Up @@ -250,6 +273,7 @@ def _handle_open_ring_buffer(
def _drain_channel_messages(self) -> None:
"""Poll all channels for completed messages and handle them."""
for channel in self.channels.values():
# guard against uninitialised channels
if channel.reader is None or channel.ring_buffer is None:
continue
# Loop to receive full message
Expand Down Expand Up @@ -564,9 +588,7 @@ def _handle_end_trace(

self._remove_trace(str(recording_id), str(trace_id))

def _handle_recording_stopped(
self, _: ChannelState, message: MessageEnvelope
) -> None:
def _handle_recording_stopped(self, message: MessageEnvelope) -> None:
"""Handle a RECORDING_STOPPED message from a producer.

This function is called when a producer sends a RECORDING_STOPPED message to
Expand Down Expand Up @@ -599,7 +621,7 @@ def _finalize_pending_closes(self) -> None:
self._closed_recordings.update(self._pending_close_recordings)
self._pending_close_recordings.clear()

def cleanup_stopped_channels(
def cleanup_channel_on_trace_written(
self,
trace_id: str,
_: str | None = None,
Expand Down Expand Up @@ -631,12 +653,16 @@ def cleanup_stopped_channels(
def _cleanup_expired_channels(self) -> None:
"""Remove channels whose heartbeat has not been seen within the timeout."""
now = datetime.now(timezone.utc)
timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS)
heartbeat_timeout = timedelta(seconds=HEARTBEAT_TIMEOUT_SECS)
never_opened_timeout = timedelta(seconds=NEVER_OPENED_TIMEOUT_SECS)

to_remove = [
producer_id
for producer_id, state in self.channels.items()
if now - state.last_heartbeat > timeout
# Missed heart beat
if (now - state.last_heartbeat > heartbeat_timeout)
# Never opened and timed out
or (now - state.created_at > never_opened_timeout)
]

for producer_id in to_remove:
Expand All @@ -661,5 +687,5 @@ def _cleanup_expired_channels(self) -> None:
},
),
)
# Here is where you would also clean up any shared memory segments.
del self.channels[producer_id]
self._closed_producers.add(producer_id)
1 change: 1 addition & 0 deletions neuracore/data_daemon/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

HEARTBEAT_TIMEOUT_SECS = 10
NEVER_OPENED_TIMEOUT_SECS = 20
API_URL = os.getenv("NEURACORE_API_URL", "https://api.neuracore.app/api")

TRACE_ID_FIELD_SIZE = 36 # bytes allocated for the trace_id string in chunk headers
Expand Down
1 change: 1 addition & 0 deletions neuracore/data_daemon/state_management/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@

Index("idx_traces_trace_id", traces.c.trace_id)
Index("idx_traces_status", traces.c.status)
Index("idx_traces_next_retry_at", traces.c.next_retry_at)
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,6 @@ def test_zmq_commands_and_message_flow(daemon_runtime) -> None:

payload = json.dumps({"seq": 1}).encode("utf-8")
active_trace_id = producer.trace_id
producer.send_data(
payload,
data_type=DataType.CUSTOM_1D,
data_type_name="custom",
robot_instance=1,
robot_id="robot-1",
dataset_id="dataset-1",
)
producer.end_trace()

trace_written: list[int] = []

def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None:
Expand All @@ -236,6 +226,18 @@ def on_trace_written(trace_id: str, _: str, bytes_written: int) -> None:

get_emitter().on(Emitter.TRACE_WRITTEN, on_trace_written)
try:
producer.send_data(
payload,
data_type=DataType.CUSTOM_1D,
data_type_name="custom",
robot_instance=1,
robot_id="robot-1",
dataset_id="dataset-1",
)
assert _wait_for(
lambda: active_trace_id in daemon._trace_recordings, timeout=0.5
)
producer.end_trace()
assert _wait_for(lambda: trace_written, timeout=1)
finally:
get_emitter().remove_listener(Emitter.TRACE_WRITTEN, on_trace_written)
Expand Down Expand Up @@ -420,7 +422,7 @@ def on_stop_recording(rec_id: str) -> None:
assert recording_id not in daemon._pending_close_recordings
assert recording_id not in daemon._closed_recordings

daemon._handle_recording_stopped(None, msg)
daemon._handle_recording_stopped(msg)

assert recording_id in daemon._pending_close_recordings
assert recording_id not in daemon._closed_recordings
Expand Down