diff --git a/streaming/base/dataset.py b/streaming/base/dataset.py index d463d6f0e..914f0215e 100644 --- a/streaming/base/dataset.py +++ b/streaming/base/dataset.py @@ -1120,7 +1120,7 @@ def evict_shard(self, shard_id: int) -> None: """ # Lock the cache. FileLocks contain threading Locks, which are not pickleable, which is # incompatible with spawn, so must be created lazily. - if not hasattr(self, CACHE_FILELOCK): + if not hasattr(self, '_cache_filelock'): self._cache_filelock = FileLock(self._cache_filelock_path) with self._cache_filelock: @@ -1133,7 +1133,7 @@ def evict_coldest_shard(self) -> None: """ # Lock the cache. FileLocks contain threading Locks, which are not pickleable, which is # incompatible with spawn, so must be created lazily. - if not hasattr(self, CACHE_FILELOCK): + if not hasattr(self, '_cache_filelock'): self._cache_filelock = FileLock(self._cache_filelock_path) with self._cache_filelock: @@ -1154,7 +1154,7 @@ def prepare_shard(self, shard_id: int, blocking: bool = True) -> None: """ # Lock the cache. FileLocks contain threading Locks, which are not pickleable, which is # incompatible with spawn, so must be created lazily. - if not hasattr(self, CACHE_FILELOCK): + if not hasattr(self, '_cache_filelock'): self._cache_filelock = FileLock(self._cache_filelock_path) lock = self._cache_filelock lock.acquire() diff --git a/tests/test_eviction.py b/tests/test_eviction.py index e30e84583..229b36b69 100644 --- a/tests/test_eviction.py +++ b/tests/test_eviction.py @@ -231,3 +231,61 @@ def test_cache_limit_lower_than_few_shards(local_remote_dir: Any, cache_limit: s shuffle=False, batch_size=4, cache_limit=cache_limit) + + +@pytest.mark.usefixtures('local_remote_dir') +def test_cache_filelock_reuse(local_remote_dir: tuple[str, str]): + """Test that _cache_filelock is reused across multiple calls instead of being recreated. + + This test verifies the fix for issue #963 where CACHE_FILELOCK constant had a mismatch + with the actual attribute name, causing a new FileLock to be created on every call. + """ + num_samples = 1000 + local, remote = local_remote_dir + columns = {'data': 'bytes'} + compression = None + hashes = None + size_limit = 500 + + # Create a small dataset + with MDSWriter(out=remote, + columns=columns, + compression=compression, + hashes=hashes, + size_limit=size_limit) as out: + for _ in range(num_samples): + sample = {'data': b'\0'} + out.write(sample) + + dataset = StreamingDataset(remote=remote, local=local, batch_size=1) + + # First call to prepare_shard should create the filelock + dataset.prepare_shard(0) + assert hasattr(dataset, '_cache_filelock'), 'Expected _cache_filelock to be created' + + # Store reference to the filelock object + first_filelock = dataset._cache_filelock + + # Second call to prepare_shard should reuse the same filelock + if dataset.num_shards > 1: + dataset.prepare_shard(1) + second_filelock = dataset._cache_filelock + assert first_filelock is second_filelock, \ + 'Expected _cache_filelock to be reused, not recreated' + + # Call evict_coldest_shard and verify filelock is still the same + if dataset.num_shards > 0: + dataset.evict_coldest_shard() + third_filelock = dataset._cache_filelock + assert first_filelock is third_filelock, \ + 'Expected _cache_filelock to be reused in evict_coldest_shard' + + # Call evict_shard and verify filelock is still the same + if dataset.num_shards > 1: + dataset.prepare_shard(1) # Prepare shard 1 again + dataset.evict_shard(1) + fourth_filelock = dataset._cache_filelock + assert first_filelock is fourth_filelock, \ + 'Expected _cache_filelock to be reused in evict_shard' + + rmtree(local, ignore_errors=False)