From e436f54d52fa91603630e0e5bbe5b0f127959a21 Mon Sep 17 00:00:00 2001 From: "ivan-nedd@mail.ru" Date: Mon, 9 Sep 2024 21:25:38 +0300 Subject: [PATCH 1/3] add flag --- asyncio_redis_rate_limit/__init__.py | 6 ++++++ asyncio_redis_rate_limit/compat.py | 5 +++++ tests/test_examples.py | 27 +++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) diff --git a/asyncio_redis_rate_limit/__init__.py b/asyncio_redis_rate_limit/__init__.py index fd0157d..91b36e8 100644 --- a/asyncio_redis_rate_limit/__init__.py +++ b/asyncio_redis_rate_limit/__init__.py @@ -49,6 +49,7 @@ class RateLimiter: '_backend', '_cache_prefix', '_lock', + '_use_nx_on_expire' ) def __init__( @@ -58,6 +59,7 @@ def __init__( backend: AnyRedis, *, cache_prefix: str, + use_nx_on_expire: bool = True, ) -> None: """In the future other backends might be supported as well.""" self._unique_key = unique_key @@ -65,6 +67,7 @@ def __init__( self._backend = backend self._cache_prefix = cache_prefix self._lock = asyncio.Lock() + self._use_nx_on_expire = use_nx_on_expire async def __aenter__(self: _RateLimiterT) -> _RateLimiterT: """ @@ -110,6 +113,7 @@ async def _run_pipeline( pipeline.incr(cache_key), cache_key, self._rate_spec.seconds, + use_nx=self._use_nx_on_expire, ).execute() return current_rate # type: ignore[no-any-return] @@ -130,6 +134,7 @@ def rate_limit( # noqa: WPS320 backend: AnyRedis, *, cache_prefix: str = 'aio-rate-limit', + use_nx_on_expire: bool = True, ) -> Callable[ [_CoroutineFunction[_ParamsT, _ResultT]], _CoroutineFunction[_ParamsT, _ResultT], @@ -167,6 +172,7 @@ async def factory( backend=backend, rate_spec=rate_spec, cache_prefix=cache_prefix, + use_nx_on_expire=use_nx_on_expire, ): return await function(*args, **kwargs) return factory diff --git a/asyncio_redis_rate_limit/compat.py b/asyncio_redis_rate_limit/compat.py index fdedb2d..69f22f9 100644 --- a/asyncio_redis_rate_limit/compat.py +++ b/asyncio_redis_rate_limit/compat.py @@ -44,8 +44,13 @@ def pipeline_expire( pipeline: Any, cache_key: str, seconds: int, + *, + use_nx: bool = True, ) -> AnyPipeline: """Compatibility mode for `.expire(..., nx=True)` command.""" + if not use_nx: + return pipeline.expire(cache_key, seconds) # type: ignore + if isinstance(pipeline, _AsyncPipeline): return pipeline.expire(cache_key, seconds, nx=True) # type: ignore # `aioredis` somehow does not have this boolean argument in `.expire`, diff --git a/tests/test_examples.py b/tests/test_examples.py index ad6959f..dc5d0a1 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -25,6 +25,8 @@ def __call__( self, requests: int = ..., seconds: int = ..., + *, + use_nx_on_expire: bool = ..., ) -> _LimitedSig: """We use this callback to construct `limited` test function.""" @@ -246,6 +248,31 @@ async def test_ten_reqs_in_two_secs2( await asyncio.sleep(1 + 0.5) await function() +@pytest.mark.repeat(5) +async def test_ten_reqs_in_two_secs_without_nx( + limited: _LimitedCallback, +) -> None: + """Ensure that several gathered coroutines do respect the rate limit.""" + function = limited(requests=10, seconds=2, use_nx_on_expire=False) + + # Or just consume all: + for attempt in range(10): + await function(attempt) + + # This one will fail: + with pytest.raises(RateLimitError): + await function() + + # Now, let's move time to the next second: + await asyncio.sleep(1) + + # This one will also fail: + with pytest.raises(RateLimitError): + await function() + + # Next attempts will pass: + await asyncio.sleep(1 + 0.5) + await function() class _Counter: def __init__(self) -> None: From 718e6e200646354ad62932c05078f1a8cb0916d1 Mon Sep 17 00:00:00 2001 From: "ivan-nedd@mail.ru" Date: Sun, 15 Sep 2024 16:43:31 +0300 Subject: [PATCH 2/3] remove use_nx and compat expire Add expire on first set --- asyncio_redis_rate_limit/__init__.py | 20 ++++---------------- asyncio_redis_rate_limit/compat.py | 23 ----------------------- tests/test_examples.py | 27 --------------------------- 3 files changed, 4 insertions(+), 66 deletions(-) diff --git a/asyncio_redis_rate_limit/__init__.py b/asyncio_redis_rate_limit/__init__.py index 91b36e8..4110a2f 100644 --- a/asyncio_redis_rate_limit/__init__.py +++ b/asyncio_redis_rate_limit/__init__.py @@ -6,11 +6,7 @@ from typing_extensions import ParamSpec, TypeAlias, final -from asyncio_redis_rate_limit.compat import ( - AnyPipeline, - AnyRedis, - pipeline_expire, -) +from asyncio_redis_rate_limit.compat import AnyPipeline, AnyRedis #: These aliases makes our code more readable. _Seconds: TypeAlias = int @@ -49,7 +45,6 @@ class RateLimiter: '_backend', '_cache_prefix', '_lock', - '_use_nx_on_expire' ) def __init__( @@ -59,7 +54,6 @@ def __init__( backend: AnyRedis, *, cache_prefix: str, - use_nx_on_expire: bool = True, ) -> None: """In the future other backends might be supported as well.""" self._unique_key = unique_key @@ -67,7 +61,6 @@ def __init__( self._backend = backend self._cache_prefix = cache_prefix self._lock = asyncio.Lock() - self._use_nx_on_expire = use_nx_on_expire async def __aenter__(self: _RateLimiterT) -> _RateLimiterT: """ @@ -109,12 +102,9 @@ async def _run_pipeline( pipeline: AnyPipeline, ) -> int: # https://redis.io/commands/incr/#pattern-rate-limiter-1 - current_rate, _ = await pipeline_expire( - pipeline.incr(cache_key), - cache_key, - self._rate_spec.seconds, - use_nx=self._use_nx_on_expire, - ).execute() + _, current_rate = await pipeline.set( # type: ignore[union-attr] + cache_key, 0, nx=True, ex=self._rate_spec.seconds, + ).incr(cache_key).execute() return current_rate # type: ignore[no-any-return] def _make_cache_key( @@ -134,7 +124,6 @@ def rate_limit( # noqa: WPS320 backend: AnyRedis, *, cache_prefix: str = 'aio-rate-limit', - use_nx_on_expire: bool = True, ) -> Callable[ [_CoroutineFunction[_ParamsT, _ResultT]], _CoroutineFunction[_ParamsT, _ResultT], @@ -172,7 +161,6 @@ async def factory( backend=backend, rate_spec=rate_spec, cache_prefix=cache_prefix, - use_nx_on_expire=use_nx_on_expire, ): return await function(*args, **kwargs) return factory diff --git a/asyncio_redis_rate_limit/compat.py b/asyncio_redis_rate_limit/compat.py index 69f22f9..eacfc31 100644 --- a/asyncio_redis_rate_limit/compat.py +++ b/asyncio_redis_rate_limit/compat.py @@ -38,26 +38,3 @@ class _AIORedis: # type: ignore # noqa: WPS306, WPS440 AnyPipeline: TypeAlias = Union['_AsyncPipeline[Any]', _AIOPipeline] AnyRedis: TypeAlias = Union['_AsyncRedis[Any]', _AIORedis] - - -def pipeline_expire( - pipeline: Any, - cache_key: str, - seconds: int, - *, - use_nx: bool = True, -) -> AnyPipeline: - """Compatibility mode for `.expire(..., nx=True)` command.""" - if not use_nx: - return pipeline.expire(cache_key, seconds) # type: ignore - - if isinstance(pipeline, _AsyncPipeline): - return pipeline.expire(cache_key, seconds, nx=True) # type: ignore - # `aioredis` somehow does not have this boolean argument in `.expire`, - # so, we use `EXPIRE` directly with `NX` flag. - return pipeline.execute_command( # type: ignore - 'EXPIRE', - cache_key, - seconds, - 'NX', - ) diff --git a/tests/test_examples.py b/tests/test_examples.py index dc5d0a1..ad6959f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -25,8 +25,6 @@ def __call__( self, requests: int = ..., seconds: int = ..., - *, - use_nx_on_expire: bool = ..., ) -> _LimitedSig: """We use this callback to construct `limited` test function.""" @@ -248,31 +246,6 @@ async def test_ten_reqs_in_two_secs2( await asyncio.sleep(1 + 0.5) await function() -@pytest.mark.repeat(5) -async def test_ten_reqs_in_two_secs_without_nx( - limited: _LimitedCallback, -) -> None: - """Ensure that several gathered coroutines do respect the rate limit.""" - function = limited(requests=10, seconds=2, use_nx_on_expire=False) - - # Or just consume all: - for attempt in range(10): - await function(attempt) - - # This one will fail: - with pytest.raises(RateLimitError): - await function() - - # Now, let's move time to the next second: - await asyncio.sleep(1) - - # This one will also fail: - with pytest.raises(RateLimitError): - await function() - - # Next attempts will pass: - await asyncio.sleep(1 + 0.5) - await function() class _Counter: def __init__(self) -> None: From 1af59594412f9e546cb6bc317f87ebcc17804df0 Mon Sep 17 00:00:00 2001 From: "ivan-nedd@mail.ru" Date: Sun, 15 Sep 2024 16:43:51 +0300 Subject: [PATCH 3/3] add keydb to ci --- .github/workflows/test.yml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0b61fc3..19b1008 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: python-version: ['3.9', '3.10', '3.11', '3.12'] - redis-image: ['redis:7.0-alpine'] + redis-image: ['redis:7.0-alpine', 'eqalpha/keydb:alpine'] env-type: ['redis'] include: @@ -31,6 +31,13 @@ jobs: - python-version: '3.9' env-type: 'dev' redis-image: 'redis:7.0-alpine' + - python-version: '3.10' + env-type: 'aioredis' + redis-image: 'eqalpha/keydb:alpine' + - python-version: '3.9' + env-type: 'dev' + redis-image: 'eqalpha/keydb:alpine' + steps: - uses: actions/checkout@v4