Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 513ae3a

Browse files
authored
feat: improve typing definitions, introduce mypy in CI (#338)
* feat: do not supress callback exceptions raised from calling callbacks * fix: add explicit clauses for ConnectionClosedOK in send and heartbeat * fix: change import order * fix: apply ruff format * feat: introduce flake.nix leveraging pyproject.nix introduce flake.nix configuration leveraging the already existing information in pyproject.toml, by using pyproject.nix to parse it and generate the correct derivation. this way, any updates to pyproject.toml should automatically be reflected in the flake configuration, with low additional maintaining overhead required. the only time when this may come as a problem is due to dependabot constantly upgrading the packages to versions that do not exist yet in nixpkgs. it may be required to force dependabot to also update the nixpkgs version declared in the flake for it to not also break (or at least to upgrade less often) * chore: make `flake.nix` aware of dev dependencies, including ruff and pre-commit by adding `dependency-groups`, `pyproject.nix` can now parse the dependency groups and add these to the project, by passing `groups = ["dev"]` when calling the renderer. special care was needed in order to expose `ruff` and `pre-commit` to the toplevel of the environment. by default, it will treat them as python dependencies and hide them in the python wrapper. by using `groupBy`, we split the dependency list in "python" dependencies and top level ones, such that they can be exposed. sadly, these toplevel packages are not exposed in `python.pkgs` so we must manually override them in a case by case basis. * chore: refactor `dependencies-for` * feat: add nix develop CI job ensure that `nix develop` works on CI * fix: change job name to be more compatible with the other actions * feat: setup infra and run `python -m pytest` to run tests in nix setup action do not rely on poetry to run tests, we'd like to not rely on it from within the nix environment * fix: do not rely on `npx` as its not on nix develop use manual `supabase start` command instead * fix: add `--command` typo * chore: change name of nix develop command part * feat: add coverage information and upload it to coveralls * chore: run nix setup tests in both ubuntu and macos * chore: undo macos latest as it apparently does not work for some reason, `supabase start` does not seem to work on `macos-latest` host in CI * chore: switch from pylsp + mypy to pyright (basepyright) it seems that pyright is better and faster than mypy in most if not all cases, and the only reason I wasn't using it before is because it was not working with eglot's default config. * feat: add `basedpyright` to `pyproject.toml` instead of flake.nix only this makes it available to all users, preparing for including it into the standard CI later on * fix: change back to python-lsp-server with pylsp-mypy * chore: improve type definitions files to make mypy happy there's still a lot of issues that I intend to improve upon. this is just an initial set of changes to ensure that `mypy` doesn't complain. we should further strengthen the mypy checking rules (currently there are a lot of `Any`s in the code) as the code improves. * chore: add mypy to CI as a step * fix: run mypy through poetry * fix: check for phx_ref_prev before calling del * fix: use `python -m mypy` instead of `mypy` directly * fix: run type check after `make run_tests` so that `poetry install` is ran * fix: `StrEnum` does not exist in python 3.9 * fix: remove `phx_ref` from presence dict * fix: add deprecation warning for calling send with dicts * fix: fix config payload, improve more types * fix: make typing definitions compatible with python 3.9 also enforce that mypy uses python 3.9 only rules. * fix: import annotations for 3.9 to not complain about type error * fix: move mypy to Makefile `run_tests`, run it before actual tests * format: apply ruff reformating * fix: try explicitly annotating `Callback` as a `TypeAlias` * fix: change `dict` to `Dict` type due to 3.9 * fix: finally, import annotations to stop runtime from breaking * format: apply ruff format one last time * format: reorder import order in test_connection * format: trim whitespace
1 parent a8b063f commit 513ae3a

File tree

12 files changed

+339
-158
lines changed

12 files changed

+339
-158
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
github_access_token: ${{ secrets.GITHUB_TOKEN }}
5454
- name: Clone Repository
5555
uses: actions/checkout@v4
56+
- name: Type check
57+
run: nix develop --command mypy ./realtime
5658
- name: Start Supabase local development setup
5759
run: nix develop --command supabase start --workdir infra -x studio,mailpit,edge-runtime,logflare,vector,supavisor,imgproxy,storage-api
5860
- name: Run python tests through nix

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ install_poetry:
55
curl -sSL https://install.python-poetry.org | python -
66
poetry install
77

8-
tests: install tests_only tests_pre_commit
8+
tests: install run_mypy tests_only tests_pre_commit
99

1010
tests_pre_commit:
1111
poetry run pre-commit run --all-files
1212

13+
run_mypy:
14+
poetry run mypy ./realtime
15+
1316
run_infra:
1417
npx supabase start --workdir infra -x studio,mailpit,edge-runtime,logflare,vector,supavisor,imgproxy,storage-api
1518

poetry.lock

Lines changed: 162 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ requires-python = ">=3.9"
1313
dependencies = [
1414
"websockets >=11,<16",
1515
"typing-extensions >=4.14.0",
16+
"pydantic (>=2.11.7,<3.0.0)",
1617
]
1718

1819
[tool.poetry.group.dev.dependencies]
@@ -77,3 +78,13 @@ keep-runtime-typing = true
7778
[tool.pytest.ini_options]
7879
asyncio_mode = "strict"
7980
asyncio_default_fixture_loop_scope = "function"
81+
82+
[tool.mypy]
83+
python_version = "3.9"
84+
check_untyped_defs = true
85+
allow_redefinition = true
86+
87+
warn_return_any = true
88+
warn_unused_configs = true
89+
warn_redundant_casts = true
90+
warn_unused_ignores = true

realtime/_async/channel.py

Lines changed: 49 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
import json
55
import logging
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
77

88
from realtime.types import (
99
Binding,
@@ -16,6 +16,7 @@
1616
RealtimeSubscribeStates,
1717
)
1818

19+
from ..message import Message
1920
from ..transformers import http_endpoint_url
2021
from .presence import (
2122
AsyncRealtimePresence,
@@ -52,23 +53,27 @@ def __init__(
5253
:param params: Optional parameters for connection.
5354
"""
5455
self.socket = socket
55-
self.params = params or {}
56-
if self.params.get("config") is None:
57-
self.params["config"] = {
58-
"broadcast": {"ack": False, "self": False},
59-
"presence": {"key": ""},
60-
"private": False,
56+
self.params: RealtimeChannelOptions = (
57+
params
58+
if params
59+
else {
60+
"config": {
61+
"broadcast": {"ack": False, "self": False},
62+
"presence": {"key": ""},
63+
"private": False,
64+
}
6165
}
66+
)
6267

6368
self.topic = topic
6469
self._joined_once = False
65-
self.bindings: Dict[str, List[Binding]] = {}
70+
self.bindings: dict[str, list[Binding]] = {}
6671
self.presence = AsyncRealtimePresence(self)
6772
self.state = ChannelStates.CLOSED
68-
self._push_buffer: List[AsyncPush] = []
73+
self._push_buffer: list[AsyncPush] = []
6974
self.timeout = self.socket.timeout
7075

71-
self.join_push = AsyncPush(self, ChannelEvents.join, self.params)
76+
self.join_push: AsyncPush = AsyncPush(self, ChannelEvents.join, self.params)
7277
self.rejoin_timer = AsyncTimer(
7378
self._rejoin_until_connected, lambda tries: 2**tries
7479
)
@@ -111,8 +116,9 @@ def on_error(payload, *args):
111116
self._on("close", on_close)
112117
self._on("error", on_error)
113118

114-
def on_reply(payload, ref):
115-
self._trigger(self._reply_event_name(ref), payload)
119+
def on_reply(payload: Dict[str, Any], ref: Optional[str]):
120+
if ref:
121+
self._trigger(self._reply_event_name(ref), payload)
116122

117123
self._on(ChannelEvents.reply, on_reply)
118124

@@ -169,22 +175,24 @@ async def subscribe(
169175
presence = config.get("presence", {})
170176
private = config.get("private", False)
171177

172-
access_token_payload = {}
173-
config = {
174-
"broadcast": broadcast,
175-
"presence": presence,
176-
"private": private,
177-
"postgres_changes": list(
178-
map(lambda x: x.filter, self.bindings.get("postgres_changes", []))
179-
),
178+
config_payload: Dict[str, Any] = {
179+
"config": {
180+
"broadcast": broadcast,
181+
"presence": presence,
182+
"private": private,
183+
"postgres_changes": list(
184+
map(
185+
lambda x: x.filter,
186+
self.bindings.get("postgres_changes", []),
187+
)
188+
),
189+
}
180190
}
181191

182192
if self.socket.access_token:
183-
access_token_payload["access_token"] = self.socket.access_token
193+
config_payload["access_token"] = self.socket.access_token
184194

185-
self.join_push.update_payload(
186-
{**{"config": config}, **access_token_payload}
187-
)
195+
self.join_push.update_payload(config_payload)
188196
self._joined_once = True
189197

190198
def on_join_push_ok(payload: Dict[str, Any]):
@@ -253,7 +261,7 @@ def on_join_push_timeout(*args):
253261

254262
return self
255263

256-
async def unsubscribe(self):
264+
async def unsubscribe(self) -> None:
257265
"""
258266
Unsubscribe from the channel and leave the topic.
259267
Sets channel state to LEAVING and cleans up timers and pushes.
@@ -263,9 +271,9 @@ async def unsubscribe(self):
263271
self.rejoin_timer.reset()
264272
self.join_push.destroy()
265273

266-
def _close(*args):
274+
def _close(*args) -> None:
267275
logger.info(f"channel {self.topic} leave")
268-
self._trigger(ChannelEvents.close, "leave")
276+
self._trigger(ChannelEvents.close, {})
269277

270278
leave_push = AsyncPush(self, ChannelEvents.leave, {})
271279
leave_push.receive("ok", _close).receive("timeout", _close)
@@ -310,21 +318,24 @@ async def join(self) -> AsyncRealtimeChannel:
310318
:return: Channel
311319
"""
312320
try:
313-
await self.socket.send(
314-
{
315-
"topic": self.topic,
316-
"event": "phx_join",
317-
"payload": {"config": self.params},
318-
"ref": None,
319-
}
321+
message = Message(
322+
topic=self.topic,
323+
event=ChannelEvents.join,
324+
payload={"config": self.params},
325+
ref=None,
320326
)
327+
await self.socket.send(message)
328+
return self
321329
except Exception as e:
322330
print(e)
323331
return self
324332

325333
# Event handling methods
326334
def _on(
327-
self, type: str, callback: Callback, filter: Optional[Dict[str, Any]] = None
335+
self,
336+
type: str,
337+
callback: Callback[[Dict[str, Any], Optional[str]], None],
338+
filter: Optional[Dict[str, Any]] = None,
328339
) -> AsyncRealtimeChannel:
329340
"""
330341
Set up a listener for a specific event.
@@ -411,7 +422,7 @@ def on_postgres_changes(
411422
)
412423

413424
def on_system(
414-
self, callback: Callable[[Dict[str, Any], None]]
425+
self, callback: Callable[[Dict[str, Any]], None]
415426
) -> AsyncRealtimeChannel:
416427
"""
417428
Set up a listener for system events.
@@ -508,7 +519,7 @@ def _can_push(self):
508519
async def send_presence(self, event: str, data: Any) -> None:
509520
await self.push(ChannelEvents.presence, {"event": event, "payload": data})
510521

511-
def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None):
522+
def _trigger(self, type: str, payload: Dict[str, Any], ref: Optional[str] = None):
512523
type_lowercase = type.lower()
513524
events = [
514525
ChannelEvents.close,
@@ -562,7 +573,7 @@ def _trigger(self, type: str, payload: Optional[Any], ref: Optional[str] = None)
562573
elif binding.type == type_lowercase:
563574
binding.callback(payload, ref)
564575

565-
def _reply_event_name(self, ref: str):
576+
def _reply_event_name(self, ref: str) -> str:
566577
return f"chan_reply_{ref}"
567578

568579
async def _rejoin_until_connected(self):

realtime/_async/client.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import logging
44
import re
55
from functools import wraps
6-
from typing import Any, Callable, Dict, List, Optional
6+
from typing import Any, Callable, Dict, List, Optional, Union
77
from urllib.parse import urlencode, urlparse, urlunparse
88

99
import websockets
1010
from websockets import connect
11-
from websockets.client import ClientProtocol
11+
from websockets.asyncio.client import ClientConnection
1212

1313
from ..exceptions import NotConnectedError
1414
from ..message import Message
@@ -62,7 +62,7 @@ def __init__(
6262
:param timeout: Connection timeout in seconds. Defaults to DEFAULT_TIMEOUT.
6363
"""
6464
if not is_ws_url(url):
65-
ValueError("url must be a valid WebSocket URL or HTTP URL string")
65+
raise ValueError("url must be a valid WebSocket URL or HTTP URL string")
6666
self.url = f"{re.sub(r'https://', 'wss://', re.sub(r'http://', 'ws://', url, flags=re.IGNORECASE), flags=re.IGNORECASE)}/websocket"
6767
if token:
6868
self.url += f"?apikey={token}"
@@ -72,7 +72,7 @@ def __init__(
7272
self.access_token = token
7373
self.send_buffer: List[Callable] = []
7474
self.hb_interval = hb_interval
75-
self._ws_connection: Optional[ClientProtocol] = None
75+
self._ws_connection: Optional[ClientConnection] = None
7676
self.ref = 0
7777
self.auto_reconnect = auto_reconnect
7878
self.channels: Dict[str, AsyncRealtimeChannel] = {}
@@ -97,13 +97,15 @@ async def _listen(self) -> None:
9797

9898
try:
9999
async for msg in self._ws_connection:
100-
logger.info(f"receive: {msg}")
100+
logger.info(f"receive: {msg!r}")
101101

102-
msg = Message(**json.loads(msg))
103-
channel = self.channels.get(msg.topic)
102+
message = Message.model_validate_json(msg)
103+
channel = self.channels.get(message.topic)
104104

105105
if channel:
106-
channel._trigger(msg.event, msg.payload, msg.ref)
106+
channel._trigger(
107+
message.event, dict(**message.payload), message.ref
108+
)
107109
except websockets.exceptions.ConnectionClosedError as e:
108110
await self._on_connect_error(e)
109111

@@ -236,7 +238,7 @@ async def _heartbeat(self) -> None:
236238

237239
while self.is_connected:
238240
try:
239-
data = dict(
241+
data = Message(
240242
topic=PHOENIX_CHANNEL,
241243
event=ChannelEvents.heartbeat,
242244
payload={},
@@ -294,14 +296,6 @@ async def remove_all_channels(self) -> None:
294296

295297
await self.close()
296298

297-
def summary(self) -> None:
298-
"""
299-
Prints a list of topics and event the socket is listening to
300-
:return: None
301-
"""
302-
for topic, channel in self.channels.items():
303-
print(f"Topic: {topic} | Events: {[e for e, _ in channel.listeners]}]")
304-
305299
async def set_auth(self, token: Optional[str]) -> None:
306300
"""
307301
Set the authentication token for the connection and update all joined channels.
@@ -325,7 +319,7 @@ def _make_ref(self) -> str:
325319
self.ref += 1
326320
return f"{self.ref}"
327321

328-
async def send(self, message: Dict[str, Any]) -> None:
322+
async def send(self, message: Union[Message, Dict[str, Any]]) -> None:
329323
"""
330324
Send a message through the WebSocket connection.
331325
@@ -340,16 +334,22 @@ async def send(self, message: Dict[str, Any]) -> None:
340334
Returns:
341335
None
342336
"""
343-
344-
message = json.dumps(message)
345-
logger.info(f"send: {message}")
337+
if isinstance(message, Message):
338+
msg = message
339+
else:
340+
logger.warning(
341+
"Warning: calling AsyncRealtimeClient.send with a dictionary is deprecated. Please call it with a Message object instead. This will be a hard error in the future."
342+
)
343+
msg = Message(**message)
344+
message_str = msg.model_dump_json()
345+
logger.info(f"send: {message_str}")
346346

347347
async def send_message():
348348
if not self._ws_connection:
349349
raise NotConnectedError("_send")
350350

351351
try:
352-
await self._ws_connection.send(message)
352+
await self._ws_connection.send(message_str)
353353
except websockets.exceptions.ConnectionClosedError as e:
354354
await self._on_connect_error(e)
355355
except websockets.exceptions.ConnectionClosedOK:

0 commit comments

Comments
 (0)