Skip to content

Commit b97aa71

Browse files
finish initial valkey
1 parent c8e78dd commit b97aa71

File tree

4 files changed

+80
-81
lines changed

4 files changed

+80
-81
lines changed

reflex/state.py

Lines changed: 66 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
)
4141

4242
from glide import (
43+
OK,
4344
ConditionalChange,
4445
ExpirySet,
4546
ExpiryType,
@@ -77,7 +78,6 @@
7778
BaseModelV1 = BaseModelV2
7879

7980
import wrapt
80-
from redis.exceptions import ResponseError
8181

8282
import reflex.istate.dynamic
8383
from reflex import constants
@@ -100,6 +100,7 @@
100100
ImmutableStateError,
101101
InvalidStateManagerMode,
102102
LockExpiredError,
103+
RedisConfigError,
103104
ReflexRuntimeError,
104105
SetUndefinedStateVarError,
105106
StateSchemaMismatchError,
@@ -3217,29 +3218,33 @@ class StateManagerRedis(StateManager):
32173218
}
32183219

32193220
# This lock is used to ensure we only subscribe to keyspace events once per token and worker
3220-
# _pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
3221-
3222-
_pubsub_clients: Dict[bytes, GlideClient] = pydantic.PrivateAttr({})
3221+
_pubsub_locks: Dict[bytes, asyncio.Lock] = pydantic.PrivateAttr({})
32233222

32243223
async def get_redis(self) -> GlideClient:
32253224
"""Get the redis client.
32263225
32273226
Returns:
32283227
The redis client.
3228+
3229+
Raises:
3230+
RedisConfigError: If the redis client could not be configured.
32293231
"""
32303232
if self.redis is not None:
32313233
return self.redis
32323234
redis = await prerequisites.get_redis()
32333235
assert redis is not None
3234-
try:
3235-
_ = await redis.config_set(
3236-
{"notify-keyspace-events": self._redis_notify_keyspace_events},
3236+
config_result = await redis.config_set(
3237+
{"notify-keyspace-events": self._redis_notify_keyspace_events},
3238+
)
3239+
# Some redis servers only allow out-of-band configuration, so ignore errors here.
3240+
if (
3241+
config_result != OK
3242+
and not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get()
3243+
):
3244+
raise RedisConfigError(
3245+
f"Failed to set notify-keyspace-events: {config_result}"
32373246
)
3238-
# TODO: adjust exception for glide
3239-
except ResponseError:
3240-
# Some redis servers only allow out-of-band configuration, so ignore errors here.
3241-
if not environment.REFLEX_IGNORE_REDIS_CONFIG_ERROR.get():
3242-
raise
3247+
32433248
self.redis = redis
32443249
return redis
32453250

@@ -3407,6 +3412,7 @@ async def set_state(
34073412
"""
34083413
# Check that we're holding the lock.
34093414
redis = await self.get_redis()
3415+
34103416
if lock_id is not None and await redis.get(self._lock_key(token)) != lock_id:
34113417
raise LockExpiredError(
34123418
f"Lock expired for token {token} while processing. Consider increasing "
@@ -3440,9 +3446,15 @@ async def set_state(
34403446
_ = await redis.set(
34413447
_substate_key(client_token, state),
34423448
pickle_state,
3443-
expiry=self.expiry,
3444-
# ex=self.token_expiration,
3449+
expiry=ExpirySet(
3450+
expiry_type=ExpiryType.MILLSEC,
3451+
value=self.token_expiration,
3452+
),
34453453
)
3454+
# if str(res) != OK:
3455+
# raise RuntimeError(
3456+
# f"Failed to set state for token {token}. {res} {OK}"
3457+
# )
34463458

34473459
# Wait for substates to be persisted.
34483460
for t in tasks:
@@ -3478,18 +3490,6 @@ def _lock_key(token: str) -> bytes:
34783490
client_token = _split_substate_key(token)[0]
34793491
return f"{client_token}_lock".encode()
34803492

3481-
@property
3482-
def expiry(self) -> ExpirySet:
3483-
"""Get the expiry set for the token.
3484-
3485-
Returns:
3486-
The expiry set for the token.
3487-
"""
3488-
return ExpirySet(
3489-
expiry_type=ExpiryType.SEC,
3490-
value=self.token_expiration,
3491-
)
3492-
34933493
async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
34943494
"""Try to get a redis lock for a token.
34953495
@@ -3504,10 +3504,13 @@ async def _try_get_lock(self, lock_key: bytes, lock_id: bytes) -> bool | None:
35043504
response = await redis.set(
35053505
lock_key,
35063506
lock_id,
3507-
expiry=self.expiry,
3507+
expiry=ExpirySet(
3508+
expiry_type=ExpiryType.MILLSEC,
3509+
value=self.lock_expiration,
3510+
),
35083511
conditional_set=ConditionalChange.ONLY_IF_DOES_NOT_EXIST,
35093512
)
3510-
return bool(response)
3513+
return str(response) == OK
35113514

35123515
async def get_pubsub(self, lock_key: bytes) -> GlideClient:
35133516
"""Get the pubsub client for a lock key channel.
@@ -3519,12 +3522,10 @@ async def get_pubsub(self, lock_key: bytes) -> GlideClient:
35193522
The pubsub client.
35203523
"""
35213524
lock_key_channel = f"__keyspace@0__:{lock_key.decode()}"
3522-
if lock_key_channel in self._pubsub_clients:
3523-
return self._pubsub_clients[lock_key_channel]
35243525
pubsub_config = GlideClientConfiguration.PubSubSubscriptions(
35253526
channels_and_patterns={
35263527
GlideClientConfiguration.PubSubChannelModes.Pattern: {lock_key_channel},
3527-
GlideClientConfiguration.PubSubChannelModes.Exact: {lock_key_channel},
3528+
# GlideClientConfiguration.PubSubChannelModes.Exact: {lock_key_channel},
35283529
},
35293530
callback=None,
35303531
context=None,
@@ -3534,7 +3535,6 @@ async def get_pubsub(self, lock_key: bytes) -> GlideClient:
35343535
)
35353536
assert config is not None
35363537
pubsub = await GlideClient.create(config)
3537-
self._pubsub_clients[lock_key] = pubsub
35383538
return pubsub
35393539

35403540
async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
@@ -3545,58 +3545,48 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None:
35453545
Args:
35463546
lock_key: The redis key for the lock.
35473547
lock_id: The ID of the lock.
3548-
3549-
Raises:
3550-
ResponseError: when the keyspace config cannot be set.
35513548
"""
35523549
state_is_locked = False
35533550
# Enable keyspace notifications for the lock key, so we know when it is available.
35543551
redis = await self.get_redis()
3555-
pubsub = await self.get_pubsub(lock_key)
3556-
# async with self.redis.pubsub() as pubsub:
3557-
# await pubsub.psubscribe(lock_key_channel)
3558-
# await pubsub.get_pubsub_message()
3559-
count = 0
3560-
while not state_is_locked:
3561-
count += 1
3562-
if count > 10000:
3563-
raise Exception("Could not obtain lock")
3564-
# wait for the lock to be released
3565-
print("waiting for lock to be released")
3566-
while True:
3567-
if not await redis.exists([lock_key]):
3568-
# if not pubsub.try_get_pubsub_message():
3569-
break # key was removed, try to get the lock again
3570-
message = await pubsub.get_pubsub_message(
3571-
# ignore_subscribe_messages=True,
3572-
# timeout=self.lock_expiration / 1000.0,
3573-
)
3574-
# if message.pattern is None:
3575-
# # raise Exception("Pattern is None")
3576-
# continue
3577-
# raise Exception(message)
3578-
if message.message in self._redis_keyspace_lock_release_events:
3579-
break
3580-
state_is_locked = await self._try_get_lock(lock_key, lock_id)
3552+
if lock_key not in self._pubsub_locks:
3553+
self._pubsub_locks[lock_key] = asyncio.Lock()
3554+
async with self._pubsub_locks[lock_key]:
3555+
pubsub = await self.get_pubsub(lock_key)
3556+
while not state_is_locked:
3557+
# wait for the lock to be released
3558+
while True:
3559+
# check if we missed lock release events
3560+
if await redis.exists([lock_key]) == 0:
3561+
break # key was removed, try to get the lock again
3562+
3563+
try:
3564+
# TODO: alternative to ignore_subscribe_messages?
3565+
message = await asyncio.wait_for(
3566+
pubsub.get_pubsub_message(),
3567+
timeout=self.lock_expiration / 1000.0,
3568+
)
3569+
except asyncio.TimeoutError:
3570+
continue
3571+
if message.message in self._redis_keyspace_lock_release_events:
3572+
break
3573+
state_is_locked = await self._try_get_lock(lock_key, lock_id)
35813574

35823575
@override
3583-
async def disconnect(self, token: str):
3576+
async def disconnect(self, token: str) -> None:
35843577
"""Disconnect the token from the redis client.
35853578
35863579
Args:
35873580
token: The token to disconnect.
35883581
"""
35893582
lock_key = self._lock_key(token)
3590-
# if lock := self._pubsub_locks.get(lock_key):
3591-
# if lock.locked():
3592-
# lock.release()
3593-
# del self._pubsub_locks[lock_key]
3594-
if client := self._pubsub_clients.get(self._lock_key(token)):
3595-
await client.close()
3596-
del self._pubsub_clients[lock_key]
3583+
if lock := self._pubsub_locks.get(lock_key):
3584+
if lock.locked():
3585+
lock.release()
3586+
del self._pubsub_locks[lock_key]
35973587

35983588
@contextlib.asynccontextmanager
3599-
async def _lock(self, token: str):
3589+
async def _lock(self, token: str) -> AsyncIterator[bytes]:
36003590
"""Obtain a redis lock for a token.
36013591
36023592
Args:
@@ -3626,8 +3616,10 @@ async def _lock(self, token: str):
36263616
# only delete our lock
36273617
redis = await self.get_redis()
36283618
_ = await redis.delete([lock_key])
3619+
# if not res:
3620+
# raise RuntimeError(f"Failed to release lock for token {token}")
36293621

3630-
async def close(self):
3622+
async def close(self) -> None:
36313623
"""Explicitly close the redis connection and connection_pool.
36323624
36333625
It is necessary in testing scenarios to close between asyncio test cases
@@ -3636,14 +3628,9 @@ async def close(self):
36363628
36373629
Note: Connections will be automatically reopened when needed.
36383630
"""
3639-
# await self.redis.aclose(close_connection_pool=True)
3640-
# TODO: is this needed with glide?
3641-
redis = await self.get_redis()
3642-
await redis.close()
3643-
3644-
for pubsub in self._pubsub_clients.values():
3645-
await pubsub.close()
3646-
self._pubsub_clients = {}
3631+
if self.redis is not None:
3632+
await self.redis.close()
3633+
self.redis = None
36473634

36483635

36493636
def get_state_manager() -> StateManager:

reflex/utils/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class LockExpiredError(ReflexError):
8787
"""Raised when the state lock expires while an event is being processed."""
8888

8989

90+
class RedisConfigError(ReflexError):
91+
"""Raised when the Redis configuration is not applied correctly."""
92+
93+
9094
class MatchTypeError(ReflexError, TypeError):
9195
"""Raised when the return types of match cases are different."""
9296

tests/integration/test_background_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def handle_event_yield_only(self):
4444

4545
@rx.event(background=True)
4646
async def fast_yielding(self):
47-
for _ in range(100000):
47+
for _ in range(1000):
4848
yield State.increment()
4949

5050
@rx.event
@@ -409,4 +409,4 @@ def test_fast_yielding(
409409
assert background_task._poll_for(lambda: counter.text == "0", timeout=5)
410410

411411
fast_yielding_button.click()
412-
assert background_task._poll_for(lambda: counter.text == "100000", timeout=1200)
412+
assert background_task._poll_for(lambda: counter.text == "1000", timeout=20)

tests/units/test_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,14 @@ async def test_state_proxy(grandchild_state: GrandchildState, mock_app: rx.App):
19281928
# Cannot access substates
19291929
sp.substates[""]
19301930

1931+
assert (
1932+
sp.router.session.client_token == grandchild_state.router.session.client_token
1933+
)
1934+
assert (
1935+
sp.__wrapped__.router.session.client_token
1936+
== grandchild_state.router.session.client_token
1937+
)
1938+
assert sp.router.session.client_token is not None
19311939
async with sp:
19321940
assert sp._self_actx is not None
19331941
assert sp._self_mutable # proxy is mutable inside context

0 commit comments

Comments
 (0)