diff --git a/streaming/base/constant.py b/streaming/base/constant.py index f26f638fb..12ac5eeb3 100644 --- a/streaming/base/constant.py +++ b/streaming/base/constant.py @@ -34,3 +34,6 @@ # Default download timeout DEFAULT_TIMEOUT = 60.0 + +# Maximum prefix integers +MAX_PREFIX_INT = 1000 diff --git a/streaming/base/shared/prefix.py b/streaming/base/shared/prefix.py index b25743df2..b6b03d865 100644 --- a/streaming/base/shared/prefix.py +++ b/streaming/base/shared/prefix.py @@ -16,7 +16,8 @@ import numpy as np from torch import distributed as dist -from streaming.base.constant import BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, SHM_TO_CLEAN, TICK +from streaming.base.constant import (BARRIER_FILELOCK, CACHE_FILELOCK, LOCALS, MAX_PREFIX_INT, + SHM_TO_CLEAN, TICK) from streaming.base.shared import SharedMemory from streaming.base.world import World @@ -113,6 +114,11 @@ def _check_and_find(streams_local: list[str], streams_remote: list[Union[str, No for prefix_int in _each_prefix_int(): + if prefix_int >= MAX_PREFIX_INT: + raise ValueError(f'prefix_int exceeds {MAX_PREFIX_INT}. This may happen ' + + f'when you mock os.path.exists or os.stat functions so the filelock ' + + f'checks always returns `True` ' + f'you need to clean up TMPDIR.') + name = _get_path(prefix_int, shm_name) # Check if any shared memory filelocks exist for the current prefix diff --git a/tests/test_shared.py b/tests/test_shared.py index d1914617d..d73c6c811 100644 --- a/tests/test_shared.py +++ b/tests/test_shared.py @@ -190,3 +190,13 @@ def test_shared_memory_permission_error(mock_shared_memory_class: MagicMock): with patch('os.path.exists', return_value=False): next_prefix = _check_and_find(['local'], [None], LOCALS) assert next_prefix == 1 + + +@pytest.mark.usefixtures('local_remote_dir') +def test_shared_memory_infinity_exception(local_remote_dir: tuple[str, str]): + local, remote = local_remote_dir + with patch('os.path.exists', return_value=True): + with pytest.raises(ValueError, match='prefix_int exceeds .*clean up TMPDIR.'): + _, _ = get_shm_prefix(streams_local=[local], + streams_remote=[remote], + world=World.detect())