|
3 | 3 | from posthog.client import Client |
4 | 4 |
|
5 | 5 | try: |
6 | | - from asgiref.sync import iscoroutinefunction |
| 6 | + from asgiref.sync import iscoroutinefunction, markcoroutinefunction |
7 | 7 | except ImportError: |
8 | | - # Fallback for older Django versions |
| 8 | + # Fallback for older Django versions without asgiref |
9 | 9 | import asyncio |
10 | 10 |
|
11 | 11 | iscoroutinefunction = asyncio.iscoroutinefunction |
12 | 12 |
|
| 13 | + # No-op fallback for markcoroutinefunction |
| 14 | + # Older Django versions without asgiref typically don't support async middleware anyway |
| 15 | + def markcoroutinefunction(func): |
| 16 | + return func |
| 17 | + |
| 18 | + |
13 | 19 | if TYPE_CHECKING: |
14 | 20 | from django.http import HttpRequest, HttpResponse # noqa: F401 |
15 | 21 | from typing import Callable, Dict, Any, Optional, Union, Awaitable # noqa: F401 |
@@ -39,26 +45,24 @@ class PosthogContextMiddleware: |
39 | 45 | See the context documentation for more information. The extracted distinct ID and session ID, if found, are used to |
40 | 46 | associate all events captured in the middleware context with the same distinct ID and session as currently active on the |
41 | 47 | frontend. See the documentation for `set_context_session` and `identify_context` for more details. |
| 48 | +
|
| 49 | + This middleware is hybrid-capable: it supports both WSGI (sync) and ASGI (async) Django applications. The middleware |
| 50 | + detects at initialization whether the next middleware in the chain is async or sync, and adapts its behavior accordingly. |
| 51 | + This ensures compatibility with both pure sync and pure async middleware chains, as well as mixed chains in ASGI mode. |
42 | 52 | """ |
43 | 53 |
|
44 | | - # Django middleware capability flags |
45 | 54 | sync_capable = True |
46 | 55 | async_capable = True |
47 | 56 |
|
48 | 57 | def __init__(self, get_response): |
49 | 58 | # type: (Union[Callable[[HttpRequest], HttpResponse], Callable[[HttpRequest], Awaitable[HttpResponse]]]) -> None |
| 59 | + self.get_response = get_response |
50 | 60 | self._is_coroutine = iscoroutinefunction(get_response) |
51 | | - self._async_get_response = None # type: Optional[Callable[[HttpRequest], Awaitable[HttpResponse]]] |
52 | | - self._sync_get_response = None # type: Optional[Callable[[HttpRequest], HttpResponse]] |
53 | 61 |
|
| 62 | + # Mark this instance as a coroutine function if get_response is async |
| 63 | + # This is required for Django to correctly detect async middleware |
54 | 64 | if self._is_coroutine: |
55 | | - self._async_get_response = cast( |
56 | | - "Callable[[HttpRequest], Awaitable[HttpResponse]]", get_response |
57 | | - ) |
58 | | - else: |
59 | | - self._sync_get_response = cast( |
60 | | - "Callable[[HttpRequest], HttpResponse]", get_response |
61 | | - ) |
| 65 | + markcoroutinefunction(self) |
62 | 66 |
|
63 | 67 | from django.conf import settings |
64 | 68 |
|
@@ -181,40 +185,38 @@ def extract_request_user(self, request): |
181 | 185 | return user_id, email |
182 | 186 |
|
183 | 187 | def __call__(self, request): |
184 | | - # type: (HttpRequest) -> HttpResponse |
185 | | - # Purely defensive around django's internal sync/async handling - this should be unreachable, but if it's reached, we may |
186 | | - # as well return something semi-meaningful |
187 | | - if self._is_coroutine: |
188 | | - raise RuntimeError( |
189 | | - "PosthogContextMiddleware received sync call but get_response is async" |
190 | | - ) |
| 188 | + # type: (HttpRequest) -> Union[HttpResponse, Awaitable[HttpResponse]] |
| 189 | + """ |
| 190 | + Unified entry point for both sync and async request handling. |
191 | 191 |
|
192 | | - if self.request_filter and not self.request_filter(request): |
193 | | - assert self._sync_get_response is not None |
194 | | - return self._sync_get_response(request) |
| 192 | + When sync_capable and async_capable are both True, Django passes requests |
| 193 | + without conversion. This method detects the mode and routes accordingly. |
| 194 | + """ |
| 195 | + if self._is_coroutine: |
| 196 | + return self.__acall__(request) |
| 197 | + else: |
| 198 | + # Synchronous path |
| 199 | + if self.request_filter and not self.request_filter(request): |
| 200 | + return self.get_response(request) |
195 | 201 |
|
196 | | - with contexts.new_context(self.capture_exceptions, client=self.client): |
197 | | - for k, v in self.extract_tags(request).items(): |
198 | | - contexts.tag(k, v) |
| 202 | + with contexts.new_context(self.capture_exceptions, client=self.client): |
| 203 | + for k, v in self.extract_tags(request).items(): |
| 204 | + contexts.tag(k, v) |
199 | 205 |
|
200 | | - assert self._sync_get_response is not None |
201 | | - return self._sync_get_response(request) |
| 206 | + return self.get_response(request) |
202 | 207 |
|
203 | 208 | async def __acall__(self, request): |
204 | | - # type: (HttpRequest) -> HttpResponse |
| 209 | + # type: (HttpRequest) -> Awaitable[HttpResponse] |
| 210 | + """ |
| 211 | + Asynchronous entry point for async request handling. |
| 212 | +
|
| 213 | + This method is called when the middleware chain is async. |
| 214 | + """ |
205 | 215 | if self.request_filter and not self.request_filter(request): |
206 | | - if self._async_get_response is not None: |
207 | | - return await self._async_get_response(request) |
208 | | - else: |
209 | | - assert self._sync_get_response is not None |
210 | | - return self._sync_get_response(request) |
| 216 | + return await self.get_response(request) |
211 | 217 |
|
212 | 218 | with contexts.new_context(self.capture_exceptions, client=self.client): |
213 | 219 | for k, v in self.extract_tags(request).items(): |
214 | 220 | contexts.tag(k, v) |
215 | 221 |
|
216 | | - if self._async_get_response is not None: |
217 | | - return await self._async_get_response(request) |
218 | | - else: |
219 | | - assert self._sync_get_response is not None |
220 | | - return self._sync_get_response(request) |
| 222 | + return await self.get_response(request) |
0 commit comments