diff --git a/CHANGELOG.md b/CHANGELOG.md index e8ae7a8..cc9039a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.3.4] - 2025-11-07 + +### Fixed +- Fixed call tracking bug where `try_acquire()` and `async_try_acquire()` did not increment `call_count` when `track_calls=True` +- Fixed call tracking bug where `acquire()` did not increment `call_count` when called directly (outside context manager) with `track_calls=True` +- Fixed inefficient double-locking in `_record_call()` implementation +- All acquisition methods now consistently track calls when tracking is enabled + +### Changed +- Refactored call tracking to be implemented at the internal method level (`_try_consume_one_token_sync`, `_try_acquire_sync`) for cleaner architecture +- Context managers (`__enter__`, `__aenter__`) now delegate tracking to underlying acquisition methods + ## [0.3.3] - 2025-11-07 ### Added diff --git a/src/easylimit/__init__.py b/src/easylimit/__init__.py index 9e925c0..7ae1487 100644 --- a/src/easylimit/__init__.py +++ b/src/easylimit/__init__.py @@ -7,5 +7,5 @@ from .rate_limiter import CallStats, RateLimiter -__version__ = "0.3.3" +__version__ = "0.3.4" __all__ = ["RateLimiter", "CallStats"] diff --git a/src/easylimit/rate_limiter.py b/src/easylimit/rate_limiter.py index 98f0b3a..3736ff1 100644 --- a/src/easylimit/rate_limiter.py +++ b/src/easylimit/rate_limiter.py @@ -252,6 +252,8 @@ def try_acquire(self) -> bool: self._refill_tokens() if self.tokens >= 1: self.tokens -= 1 + if self._track_calls: + self._record_call(0.0) return True return False @@ -398,6 +400,8 @@ def _try_consume_one_token_sync(self, start_time: float, timeout: Optional[float self._refill_tokens() if self.tokens >= 1: self.tokens -= 1 + if self._track_calls: + self._record_call(time.time() - start_time) return True, 0.0, False if timeout is not None and (time.time() - start_time) >= timeout: return False, 0.0, True @@ -410,19 +414,25 @@ def _try_acquire_sync(self) -> bool: self._refill_tokens() if self.tokens >= 1: self.tokens -= 1 + if self._track_calls: + self._record_call(0.0) return True return False def _record_call(self, delay: float) -> None: - """Record tracking info under sync lock.""" - with self.lock: - self._call_count += 1 - now_ts = time.time() - self._timestamps.append(now_ts) - self._delays.append(delay) - self._last_call_time = datetime.now() - cutoff_time = now_ts - self._history_window - self._timestamps = [ts for ts in self._timestamps if ts >= cutoff_time] + """ + Record tracking info (caller must hold self.lock). + + Args: + delay: Time spent waiting for token acquisition + """ + self._call_count += 1 + now_ts = time.time() + self._timestamps.append(now_ts) + self._delays.append(delay) + self._last_call_time = datetime.now() + cutoff_time = now_ts - self._history_window + self._timestamps = [ts for ts in self._timestamps if ts >= cutoff_time] async def async_acquire(self, timeout: Optional[float] = None) -> bool: """ @@ -438,9 +448,6 @@ async def async_acquire(self, timeout: Optional[float] = None) -> bool: while True: acquired, sleep_time, timed_out = await _to_thread(self._try_consume_one_token_sync, start_time, timeout) if acquired: - if self._track_calls: - delay = time.time() - start_time - await _to_thread(self._record_call, delay) return True if timed_out: return False diff --git a/tests/test_async_rate_limiter.py b/tests/test_async_rate_limiter.py index 13095d2..ad8b3dc 100644 --- a/tests/test_async_rate_limiter.py +++ b/tests/test_async_rate_limiter.py @@ -77,6 +77,25 @@ async def test_async_call_tracking(self) -> None: assert stats.total_calls == 3 assert stats.average_delay_seconds >= 0.0 + async def test_async_acquire_records_tracking(self) -> None: + """Direct async_acquire() should increment the tracked call count.""" + limiter = RateLimiter(limit=2, track_calls=True) + + assert limiter.call_count == 0 + assert await limiter.async_acquire() is True + assert limiter.call_count == 1 + + async def test_async_try_acquire_records_tracking(self) -> None: + """async_try_acquire() should only count successful acquisitions.""" + limiter = RateLimiter(limit=1, track_calls=True) + + assert await limiter.async_try_acquire() is True + assert limiter.call_count == 1 + + # Subsequent call has no tokens available yet + assert await limiter.async_try_acquire() is False + assert limiter.call_count == 1 + class TestMixedSyncAsync: """Test mixed sync and async usage to ensure unified locking works.""" diff --git a/tests/test_call_tracking.py b/tests/test_call_tracking.py index 0535f3b..561e82b 100644 --- a/tests/test_call_tracking.py +++ b/tests/test_call_tracking.py @@ -85,6 +85,28 @@ def worker() -> None: assert limiter.call_count == 15 + def test_call_count_increments_for_acquire(self) -> None: + """Call tracking should include direct acquire() usage.""" + limiter = RateLimiter(limit=5, track_calls=True) + + assert limiter.call_count == 0 + assert limiter.acquire() is True + assert limiter.call_count == 1 + + def test_call_count_increments_for_try_acquire(self) -> None: + """Call tracking should include try_acquire() successes only.""" + limiter = RateLimiter(limit=2, track_calls=True) + + assert limiter.try_acquire() is True + assert limiter.call_count == 1 + + assert limiter.try_acquire() is True + assert limiter.call_count == 2 + + # Bucket is empty now; failure should not increment the counter + assert limiter.try_acquire() is False + assert limiter.call_count == 2 + def test_reset_call_count(self) -> None: """Test resetting call count.""" limiter = RateLimiter(limit=5, track_calls=True)