diff --git a/neuracore/data_daemon/connection_management/connection_manager.py b/neuracore/data_daemon/connection_management/connection_manager.py index 21724723..dce81024 100644 --- a/neuracore/data_daemon/connection_management/connection_manager.py +++ b/neuracore/data_daemon/connection_management/connection_manager.py @@ -25,26 +25,33 @@ class ConnectionManager: def __init__( self, + *, timeout: float = 5.0, check_interval: float = 10.0, + offline_mode: bool = False, ) -> None: """Initialize the connection manager. Args: - emitter: Event emitter for broadcasting connection state + config_manager: Config to resolve for Connection Manager timeout: Timeout in seconds for connectivity checks check_interval: Seconds between connectivity checks + offline_mode: Daemon is in offline mode; skip connectivity checks """ self._timeout = timeout self._check_interval = check_interval self._is_connected = False self._running = False self._checker_thread: threading.Thread | None = None + self._offline_mode = offline_mode emitter.emit(Emitter.IS_CONNECTED, self._is_connected) def start(self) -> None: """Start the connection monitoring thread.""" + if self._offline_mode: + logger.info("ConnectionManager in offline mode") + return if self._running: logger.warning("ConnectionManager already running") return @@ -64,6 +71,9 @@ def stop(self, timeout: float = 5.0) -> None: Args: timeout: Maximum time to wait for thread to stop """ + if self._offline_mode: + logger.info("ConnectionManager in offline mode") + return if not self._running: logger.warning("ConnectionManager not running") return diff --git a/neuracore/data_daemon/upload_management/upload_manager.py b/neuracore/data_daemon/upload_management/upload_manager.py index adfc230c..735d5040 100644 --- a/neuracore/data_daemon/upload_management/upload_manager.py +++ b/neuracore/data_daemon/upload_management/upload_manager.py @@ -10,7 +10,6 @@ from neuracore_types import DataType, RecordingDataTraceStatus -from neuracore.data_daemon.config_manager.daemon_config import DaemonConfig from neuracore.data_daemon.event_emitter import Emitter, emitter from neuracore.data_daemon.models import TraceErrorCode, TraceStatus, get_content_type from neuracore.data_daemon.upload_management.trace_manager import TraceManager @@ -32,12 +31,14 @@ class UploadManager(TraceManager): Uploads are triggered via READY_FOR_UPLOAD events from state manager. """ - def __init__(self, config: DaemonConfig): - """Initialize the upload manager.""" - self._config = config + def __init__(self, num_threads: int = 4): + """Initialize the upload manager. + Args: + num_threads: Number of concurrent upload threads + """ # Threading - self._num_threads = self._config.num_threads or 4 + self._num_threads = num_threads self._executor = ThreadPoolExecutor( max_workers=self._num_threads, thread_name_prefix="uploader", diff --git a/tests/unit/data_daemon/connection_manager/test_connection_manager.py b/tests/unit/data_daemon/connection_manager/test_connection_manager.py index 22da991c..a2d1b30d 100644 --- a/tests/unit/data_daemon/connection_manager/test_connection_manager.py +++ b/tests/unit/data_daemon/connection_manager/test_connection_manager.py @@ -3,6 +3,7 @@ from __future__ import annotations import time +from collections.abc import Callable import pytest @@ -12,23 +13,85 @@ from neuracore.data_daemon.event_emitter import Emitter, emitter +class IsConnectedCapture: + """Captures IS_CONNECTED events for testing.""" + + def __init__(self) -> None: + self.received: list[bool] = [] + + def handler(self, is_connected: bool) -> None: + self.received.append(is_connected) + + def __eq__(self, other: object) -> bool: + return self.received == other + + +@pytest.fixture +def is_connected_capture(): + """Fixture that captures IS_CONNECTED events and cleans up after test.""" + capture = IsConnectedCapture() + emitter.on(Emitter.IS_CONNECTED, capture.handler) + yield capture + emitter.remove_listener(Emitter.IS_CONNECTED, capture.handler) + + @pytest.fixture -def manager() -> ConnectionManager: - """Create a ConnectionManager instance for testing.""" - return ConnectionManager( - timeout=2.0, - check_interval=1.0, - ) +def manager_factory() -> Callable[..., ConnectionManager]: + """Factory to create ConnectionManager instances with custom settings.""" + + def _make( + *, offline_mode: bool = False, timeout: float = 2.0, check_interval: float = 1.0 + ): + return ConnectionManager( + timeout=timeout, + check_interval=check_interval, + offline_mode=offline_mode, + ) + + return _make -def test_connection_manager_initializes_correctly(manager: ConnectionManager) -> None: +def test_connection_manager_initializes_correctly( + is_connected_capture, manager_factory +) -> None: """Test that ConnectionManager initializes with correct defaults.""" + manager = manager_factory() + assert manager._running is False + assert manager._checker_thread is None + assert manager._offline_mode is False + assert is_connected_capture == [False] + + +def test_connection_manager_initializes_offline_mode( + is_connected_capture, manager_factory +) -> None: + """Test that ConnectionManager in offline mode does not emit IS_CONNECTED.""" + manager = manager_factory(offline_mode=True) assert manager._running is False assert manager._checker_thread is None + assert manager._offline_mode is True + assert is_connected_capture == [False] -def test_connection_manager_start_stop(manager: ConnectionManager) -> None: +def test_connection_manager_offline_mode_start_does_nothing(manager_factory) -> None: + """Test that start() in offline mode does not start the checker thread.""" + manager = manager_factory(offline_mode=True) + manager.start() + assert manager._running is False + assert manager._checker_thread is None + + +def test_connection_manager_offline_mode_stop_does_nothing(manager_factory) -> None: + """Test that stop() in offline mode does nothing.""" + manager = manager_factory(offline_mode=True) + manager.stop() + assert manager._running is False + assert manager._checker_thread is None + + +def test_connection_manager_start_stop(manager_factory) -> None: """Test basic start and stop functionality.""" + manager = manager_factory() # Start manager manager.start() assert manager._running is True @@ -42,67 +105,75 @@ def test_connection_manager_start_stop(manager: ConnectionManager) -> None: assert manager._running is False -def test_connection_manager_emits_events_on_state_change() -> None: - """Test that events are emitted when connection state changes.""" - received: list[bool] = [] +def test_connection_manager_emits_true_when_connected( + is_connected_capture, manager_factory +) -> None: + """Test that IS_CONNECTED event is emitted with True when connectivity succeeds.""" + manager = manager_factory(check_interval=0.1) - def handler(is_connected: bool) -> None: - received.append(is_connected) + manager._check_connectivity = lambda: True + manager.start() + time.sleep(0.5) - emitter.on(Emitter.IS_CONNECTED, handler) - try: - manager = ConnectionManager( - timeout=2.0, - check_interval=0.5, - ) + assert is_connected_capture == [False, True] + + manager.stop() - manager.start() - time.sleep(2) - manager.stop() - assert len(received) > 0 - finally: - emitter.remove_listener(Emitter.IS_CONNECTED, handler) +def test_connection_manager_emits_false_when_disconnected( + is_connected_capture, manager_factory +) -> None: + """Test that IS_CONNECTED event is emitted with False when connectivity fails.""" + manager = manager_factory(check_interval=0.1) + manager._check_connectivity = lambda: False + manager.start() + time.sleep(0.5) + manager.stop() -def test_connection_manager_tracks_state_changes() -> None: - """Test that connection state changes are tracked correctly.""" - received: list[bool] = [] + assert is_connected_capture == [False] - def handler(is_connected: bool) -> None: - received.append(is_connected) - emitter.on(Emitter.IS_CONNECTED, handler) - try: - manager = ConnectionManager( - timeout=2.0, - check_interval=0.3, - ) +def test_connection_manager_offline_mode_never_emits_true( + is_connected_capture, manager_factory +) -> None: + """Test that offline mode never emits IS_CONNECTED as True.""" + manager = manager_factory(offline_mode=True, check_interval=0.1) - connection_states = [True, True, False, False, True] - state_index = [0] + manager._check_connectivity = lambda: True + manager.start() + time.sleep(0.5) + manager.stop() + + assert is_connected_capture == [False] - def mock_check_connectivity() -> bool: - state = connection_states[state_index[0] % len(connection_states)] - state_index[0] += 1 - return state - manager._check_connectivity = mock_check_connectivity +def test_connection_manager_tracks_state_changes( + is_connected_capture, manager_factory +) -> None: + """Test that connection state changes are tracked correctly.""" + manager = manager_factory(check_interval=0.3) - manager.start() - time.sleep(2) - manager.stop() + connection_states = [True, True, False, False, True] + state_index = [0] - assert len(received) >= 2 + def mock_check_connectivity() -> bool: + state = connection_states[state_index[0] % len(connection_states)] + state_index[0] += 1 + return state - assert True in received - assert False in received - finally: - emitter.remove_listener(Emitter.IS_CONNECTED, handler) + manager._check_connectivity = mock_check_connectivity + manager.start() + time.sleep(2) + manager.stop() -def test_connection_manager_is_connected_method(manager: ConnectionManager) -> None: + assert is_connected_capture == [False, True, False, True] + + +def test_connection_manager_is_connected_method(manager_factory) -> None: """Test the is_connected() method returns current state.""" + manager = manager_factory() current_state = manager.is_connected() assert isinstance(current_state, bool) @@ -115,8 +186,9 @@ def test_connection_manager_is_connected_method(manager: ConnectionManager) -> N manager.stop() -def test_connection_manager_double_start_is_safe(manager: ConnectionManager) -> None: +def test_connection_manager_double_start_is_safe(manager_factory) -> None: """Test that calling start twice is handled gracefully.""" + manager = manager_factory() manager.start() assert manager._running is True @@ -126,10 +198,9 @@ def test_connection_manager_double_start_is_safe(manager: ConnectionManager) -> manager.stop() -def test_connection_manager_stop_without_start_is_safe( - manager: ConnectionManager, -) -> None: +def test_connection_manager_stop_without_start_is_safe(manager_factory) -> None: """Test that calling stop without start is handled gracefully.""" + manager = manager_factory() assert manager._running is False manager.stop() @@ -138,15 +209,17 @@ def test_connection_manager_stop_without_start_is_safe( def test_connection_manager_get_available_bandwidth_returns_none( - manager: ConnectionManager, + manager_factory, ) -> None: """Test that get_available_bandwidth returns None (placeholder).""" + manager = manager_factory() bandwidth = manager.get_available_bandwidth() assert bandwidth is None -def test_connection_manager_stops_thread_on_stop(manager: ConnectionManager) -> None: +def test_connection_manager_stops_thread_on_stop(manager_factory) -> None: """Test that the checking thread actually stops.""" + manager = manager_factory() manager.start() thread = manager._checker_thread @@ -160,55 +233,22 @@ def test_connection_manager_stops_thread_on_stop(manager: ConnectionManager) -> assert thread.is_alive() is False -def test_connection_manager_handles_check_exceptions() -> None: +def test_connection_manager_handles_check_exceptions(manager_factory) -> None: """Test that exceptions in connectivity check are handled gracefully.""" - received: list[bool] = [] - - def handler(is_connected: bool) -> None: - received.append(is_connected) - - emitter.on(Emitter.IS_CONNECTED, handler) - try: - manager = ConnectionManager(timeout=2.0, check_interval=0.3) - - check_count = [0] - - def mock_check_that_raises() -> bool: - check_count[0] += 1 - if check_count[0] == 2: - raise RuntimeError("Test exception") - return True - - manager._check_connectivity = mock_check_that_raises + manager = manager_factory(check_interval=0.3) - manager.start() - time.sleep(1.5) - manager.stop() + check_count = [0] - assert check_count[0] >= 3 - finally: - emitter.remove_listener(Emitter.IS_CONNECTED, handler) + def mock_check_that_raises() -> bool: + check_count[0] += 1 + if check_count[0] == 2: + raise RuntimeError("Test exception") + return True + manager._check_connectivity = mock_check_that_raises -def test_connection_manager_only_emits_on_state_change() -> None: - """Test that events are only emitted when state actually changes.""" - received: list[bool] = [] - - def handler(is_connected: bool) -> None: - received.append(is_connected) - - emitter.on(Emitter.IS_CONNECTED, handler) - try: - manager = ConnectionManager(timeout=2.0, check_interval=0.3) - - manager._check_connectivity = True - - manager.start() - time.sleep(1.5) - manager.stop() - assert len(received) <= 2 + manager.start() + time.sleep(1.5) + manager.stop() - if len(received) > 1: - assert all(received[1:]) - finally: - emitter.remove_listener(Emitter.IS_CONNECTED, handler) + assert check_count[0] >= 3 diff --git a/tests/unit/data_daemon/helpers/mock_config_manager.py b/tests/unit/data_daemon/helpers/mock_config_manager.py index ae258c77..3c27bca5 100644 --- a/tests/unit/data_daemon/helpers/mock_config_manager.py +++ b/tests/unit/data_daemon/helpers/mock_config_manager.py @@ -1,12 +1,13 @@ from __future__ import annotations -from dataclasses import dataclass, replace -from pathlib import Path +from dataclasses import dataclass from typing import Any @dataclass(frozen=True) -class MockConfigManager: +class MockDaemonConfig: + """Mock version of DaemonConfig for testing.""" + storage_limit: int | None = None bandwidth_limit: int | None = None path_to_store_record: str | None = None @@ -16,8 +17,42 @@ class MockConfigManager: api_key: str | None = None current_org_id: str | None = None - def path_to_store_record_from(self, path: Path) -> MockConfigManager: - return replace(self, path_to_store_record=str(path)) + @classmethod + def with_defaults(cls) -> MockDaemonConfig: + """Return a MockDaemonConfig with sensible defaults for testing.""" + return cls( + storage_limit=None, + bandwidth_limit=None, + path_to_store_record=None, + num_threads=1, + keep_wakelock_while_upload=False, + offline=False, + api_key=None, + current_org_id=None, + ) + + +class MockConfigManager: + """Mock version of ConfigManager for testing. + + Mirrors real ConfigManager by merging overrides onto defaults. + """ + + def __init__(self, config: MockDaemonConfig | None = None, **kwargs: Any) -> None: + if config is not None: + self._overrides = config + else: + self._overrides = MockDaemonConfig(**kwargs) - def resolve_effective_config(self, *args: Any, **kwargs: Any) -> MockConfigManager: - return self + def resolve_effective_config(self, *args: Any, **kwargs: Any) -> MockDaemonConfig: + """Merge overrides onto defaults, returning effective config.""" + defaults = MockDaemonConfig.with_defaults() + merged = { + field: ( + getattr(self._overrides, field) + if getattr(self._overrides, field) is not None + else getattr(defaults, field) + ) + for field in defaults.__dataclass_fields__ + } + return MockDaemonConfig(**merged) diff --git a/tests/unit/data_daemon/test_daemon_rdm_integration.py b/tests/unit/data_daemon/test_daemon_rdm_integration.py index dc294ddf..1bd8329a 100644 --- a/tests/unit/data_daemon/test_daemon_rdm_integration.py +++ b/tests/unit/data_daemon/test_daemon_rdm_integration.py @@ -75,7 +75,7 @@ class TestDaemonInit: def test_daemon_accepts_config_manager(self, tmp_path: Any) -> None: """Daemon should accept config_manager parameter.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -104,7 +104,10 @@ def test_daemon_creates_rdm_with_config_if_not_provided( self, tmp_path: Any ) -> None: """Daemon should create RDM using config_manager if RDM not provided.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_path, mock_storage = str(tmp_path), 1000 + mock_config = MockConfigManager( + path_to_store_record=mock_path, storage_limit=mock_storage + ) mock_comm = MockComm() with patch( @@ -118,7 +121,10 @@ def test_daemon_creates_rdm_with_config_if_not_provided( config_manager=mock_config, ) - mock_rdm_class.assert_called_once_with(mock_config) + mock_rdm_class.assert_called_once_with( + path_to_store_record=mock_path, + storage_limit_bytes=mock_storage, + ) assert daemon.recording_disk_manager is mock_rdm_instance @@ -132,7 +138,7 @@ class TestOnCompleteMessage: def test_on_complete_message_enqueues_to_rdm(self, tmp_path: Any) -> None: """_on_complete_message should construct CompleteMessage and enqueue to RDM.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -162,7 +168,7 @@ def test_on_complete_message_enqueues_to_rdm(self, tmp_path: Any) -> None: def test_on_complete_message_with_final_chunk(self, tmp_path: Any) -> None: """_on_complete_message should set final_chunk=True when specified.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -188,7 +194,7 @@ def test_on_complete_message_with_final_chunk(self, tmp_path: Any) -> None: def test_on_complete_message_uses_trace_metadata(self, tmp_path: Any) -> None: """_on_complete_message should use metadata from _trace_metadata.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -229,7 +235,7 @@ def test_on_complete_message_uses_trace_metadata(self, tmp_path: Any) -> None: def test_on_complete_message_handles_missing_metadata(self, tmp_path: Any) -> None: """_on_complete_message should handle missing metadata gracefully.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -263,7 +269,7 @@ def test_on_complete_message_handles_empty_recording_id( self, tmp_path: Any ) -> None: """_on_complete_message should use empty string if recording_id is None.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -298,7 +304,7 @@ class TestHandleEndTrace: def test_handle_end_trace_sends_final_chunk_message(self, tmp_path: Any) -> None: """_handle_end_trace should send final_chunk=True message to RDM.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -343,7 +349,7 @@ def test_handle_end_trace_uses_custom_1d_for_unknown_data_type( self, tmp_path: Any ) -> None: """_handle_end_trace should default to CUSTOM_1D for unknown data_type.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -384,7 +390,7 @@ def test_handle_end_trace_uses_custom_1d_for_missing_metadata( self, tmp_path: Any ) -> None: """_handle_end_trace should default to CUSTOM_1D if no metadata exists.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -420,7 +426,7 @@ def test_handle_end_trace_uses_custom_1d_for_missing_metadata( def test_handle_end_trace_removes_trace_after_sending(self, tmp_path: Any) -> None: """_handle_end_trace should remove trace from internal state after sending.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -456,7 +462,7 @@ def test_handle_end_trace_removes_trace_after_sending(self, tmp_path: Any) -> No def test_handle_end_trace_skips_if_missing_trace_id(self, tmp_path: Any) -> None: """_handle_end_trace should skip if trace_id is missing.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -487,7 +493,7 @@ def test_handle_end_trace_skips_if_missing_recording_id( self, tmp_path: Any ) -> None: """_handle_end_trace should skip if recording_id is missing.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -554,8 +560,8 @@ class TestDrainChannelMessages: def test_drain_channel_messages_passes_data_type_to_on_complete( self, tmp_path: Any ) -> None: - """_drain_channel_messages passes data_type from reader to handler.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + """_drain_channel_messages passes data_type to _on_complete_message.""" + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -599,7 +605,7 @@ def test_drain_channel_messages_handles_multi_chunk_message( self, tmp_path: Any ) -> None: """_drain_channel_messages should reassemble multi-chunk messages correctly.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -677,7 +683,7 @@ def test_on_complete_message_handles_all_data_types( self, tmp_path: Any, data_type: DataType ) -> None: """_on_complete_message should handle all DataType values.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -711,12 +717,16 @@ class TestExpiredChannelCleanup: """Tests for _cleanup_expired_channels() method.""" def test_cleanup_expired_channels_sends_final_chunk(self, tmp_path: Any) -> None: - """_cleanup_expired_channels sends final_chunk for expired channels.""" + """ + _cleanup_expired_channels sends final_chunk for expired channels. + + This covers the case where a channel has active traces. + """ from datetime import timedelta from neuracore.data_daemon.const import HEARTBEAT_TIMEOUT_SECS - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -764,7 +774,7 @@ def test_cleanup_expired_channels_no_trace_no_message(self, tmp_path: Any) -> No from neuracore.data_daemon.const import HEARTBEAT_TIMEOUT_SECS - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -797,8 +807,10 @@ def test_cleanup_expired_channels_no_trace_no_message(self, tmp_path: Any) -> No def test_cleanup_expired_channels_skips_active_channels( self, tmp_path: Any ) -> None: - """_cleanup_expired_channels skips channels with recent heartbeat.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + """ + _cleanup_expired_channels should not remove channels with recent heartbeat. + """ + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() mock_rdm = MockRDM() @@ -836,7 +848,7 @@ class TestRDMEnqueueErrorHandling: def test_on_complete_message_handles_enqueue_exception(self, tmp_path: Any) -> None: """_on_complete_message should catch and log exceptions from RDM.enqueue().""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() # Create a mock RDM that raises an exception @@ -867,7 +879,7 @@ def enqueue(self, message: CompleteMessage) -> None: def test_daemon_continues_after_enqueue_failure(self, tmp_path: Any) -> None: """Daemon should continue processing after RDM.enqueue() failure.""" - mock_config = MockConfigManager().path_to_store_record_from(tmp_path) + mock_config = MockConfigManager(path_to_store_record=str(tmp_path)) mock_comm = MockComm() # Create a mock RDM that fails on first call, succeeds on second diff --git a/tests/unit/data_daemon/upload_manager/test_upload_manager.py b/tests/unit/data_daemon/upload_manager/test_upload_manager.py index 96cc161e..bbaca64f 100644 --- a/tests/unit/data_daemon/upload_manager/test_upload_manager.py +++ b/tests/unit/data_daemon/upload_manager/test_upload_manager.py @@ -13,7 +13,6 @@ import pytest from neuracore_types import DataType, RecordingDataTraceStatus -from neuracore.data_daemon.config_manager.daemon_config import DaemonConfig from neuracore.data_daemon.event_emitter import Emitter, emitter from neuracore.data_daemon.models import TraceErrorCode, TraceStatus from neuracore.data_daemon.upload_management.upload_manager import UploadManager @@ -43,8 +42,7 @@ def mock_auth(): @pytest.fixture def upload_manager() -> UploadManager: """Create and cleanup UploadManager instance.""" - config = DaemonConfig(num_threads=2) - manager = UploadManager(config=config) + manager = UploadManager(num_threads=2) try: yield manager finally: @@ -54,8 +52,7 @@ def upload_manager() -> UploadManager: @pytest.fixture def upload_manager_with_more_threads() -> UploadManager: """Create UploadManager with more threads for concurrent tests.""" - config = DaemonConfig(num_threads=4) - manager = UploadManager(config=config) + manager = UploadManager(num_threads=4) try: yield manager finally: @@ -80,13 +77,11 @@ def setup_test_env(mock_auth): def test_initialize_with_config() -> None: """Initialization with configuration.""" - config = DaemonConfig(num_threads=8) - manager = UploadManager(config=config) + manager = UploadManager(num_threads=8) try: assert manager._num_threads == 8 assert manager._executor is not None - assert manager._config == config finally: manager.shutdown(wait=False) @@ -594,8 +589,7 @@ def mock_upload(): def test_upload_manager_shutdown_waits_for_in_flight_uploads(test_file: Path) -> None: """Test UploadManager shutdown waits for in-flight uploads.""" - config = DaemonConfig(num_threads=2) - upload_manager = UploadManager(config=config) + upload_manager = UploadManager(num_threads=2) upload_completed = [] @@ -640,8 +634,7 @@ def slow_upload(): def test_upload_manager_shutdown_unsubscribes_from_events(test_file: Path) -> None: """Test UploadManager shutdown unsubscribes from READY_FOR_UPLOAD events.""" - config = DaemonConfig(num_threads=2) - upload_manager = UploadManager(config=config) + upload_manager = UploadManager(num_threads=2) upload_manager.shutdown()