From 578ca88f0ad9612d64005439fd32fc7dc978bd26 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 06:50:40 +0100 Subject: [PATCH 01/11] sniffio for local --- asgiref/local.py | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/asgiref/local.py b/asgiref/local.py index 7d228aeb..b9158751 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -4,6 +4,29 @@ import threading from typing import Any, Union +def _is_asyncio_running(): + try: + asyncio.get_running_loop() + except RuntimeError: + return False + else: + return True + +try: + import sniffio +except ModuleNotFoundError: + _is_async = _is_asyncio_running +else: + def _is_async(): + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + return True + + return _is_asyncio_running() + class _CVar: """Storage utility for Local.""" @@ -83,18 +106,9 @@ def __init__(self, thread_critical: bool = False) -> None: def _lock_storage(self): # Thread safe access to storage if self._thread_critical: - try: - # this is a test for are we in a async or sync - # thread - will raise RuntimeError if there is - # no current loop - asyncio.get_running_loop() - except RuntimeError: - # We are in a sync thread, the storage is - # just the plain thread local (i.e, "global within - # this thread" - it doesn't matter where you are - # in a call stack you see the same storage) - yield self._storage - else: + # this is a test for are we in a async or sync + # thread + if _is_async(): # We are in an async thread - storage is still # local to this thread, but additionally should # behave like a context var (is only visible with @@ -108,6 +122,12 @@ def _lock_storage(self): # can't be accessed in another thread (we don't # need any locks) yield self._storage.cvar + else: + # We are in a sync thread, the storage is + # just the plain thread local (i.e, "global within + # this thread" - it doesn't matter where you are + # in a call stack you see the same storage) + yield self._storage else: # Lock for thread_critical=False as other threads # can access the exact same storage object From 8ebe662f10ad76982b1a556a2ee3afe7bff40f4f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 08:04:43 +0100 Subject: [PATCH 02/11] asgiref.sync trio support --- asgiref/local.py | 3 + asgiref/sync.py | 221 ++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 193 insertions(+), 31 deletions(-) diff --git a/asgiref/local.py b/asgiref/local.py index b9158751..11721c9f 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -4,6 +4,7 @@ import threading from typing import Any, Union + def _is_asyncio_running(): try: asyncio.get_running_loop() @@ -12,11 +13,13 @@ def _is_asyncio_running(): else: return True + try: import sniffio except ModuleNotFoundError: _is_async = _is_asyncio_running else: + def _is_async(): try: sniffio.current_async_library() diff --git a/asgiref/sync.py b/asgiref/sync.py index 0c6ea98e..c96ee2e9 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -1,6 +1,6 @@ -import asyncio import asyncio.coroutines import contextvars +import enum import functools import inspect import os @@ -69,6 +69,141 @@ def markcoroutinefunction(func: _F) -> _F: return func +def _asyncio_create_task_threadsafe(loop, awaitable): + loop.call_soon_threadsafe(loop.create_task, awaitable) + + +def _asyncio_run_in_executor(loop, executor, in_thread, callback): + fut = loop.run_in_executor(executor, in_thread) + fut.add_done_callback(callback) + return fut + + +class AsyncioTaskContext: + def __init__(self, task): + self._task = task + + def cancel(self): + return self._task.cancel() + + async def wait(self): + return await self._task + + +class TrioTaskContext: + def __init__(self, cs, event): + self._cs = cs + self._event = event + + def cancel(self): + return self._cs.cancel() + + async def wait(self): + return await self._event.wait() + + +async def _asyncio_wrap_task_context(task_context, awaitable): + if task_context is None: + return await awaitable + + current_task = asyncio.current_task() + if current_task is None: + return await awaitable + + task_context_wrapped = AsyncioTaskContext(current_task) + task_context.append(task_context_wrapped) + try: + return await awaitable + finally: + if current_task is not None: + task_context.remove(task_context_wrapped) + + +try: + import sniffio + import trio.lowlevel +except ModuleNotFoundError: + from asyncio import get_running_loop + + create_task_threadsafe = _asyncio_create_task_threadsafe + run_in_executor = _asyncio_run_in_executor + wrap_task_context = _asyncio_wrap_task_context + + def event(loop): + return asyncio.Event() + + def get_cancelled_exc(loop): + return asyncio.CancelledError + +else: + + def get_running_loop(): + try: + asynclib = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + return asyncio.get_running_loop() + + if asynclib == "asyncio": + return asyncio.get_running_loop() + if asynclib == "trio": + return trio.lowlevel.current_token() + raise RuntimeError(f"unsupported library {asynclib}") + + @trio.lowlevel.disable_ki_protection + async def wrap_awaitable(awaitable): + return await awaitable + + def create_task_threadsafe(loop, awaitable): + if isinstance(loop, trio.lowlevel.TrioToken): + try: + loop.run_sync_soon( + trio.lowlevel.spawn_system_task, + wrap_awaitable, + awaitable, + ) + except trio.RunFinishedError: + raise RuntimeError("trio loop no-longer running") + + return _asyncio_create_task_threadsafe(loop, awaitable) + + def run_in_executor(loop, executor, in_thread, callback): + if isinstance(loop, trio.lowlevel.TrioToken): + + def sync_callback(fut): + loop.run_sync_soon(callback, fut) + + fut = executor.submit(in_thread) + fut.add_done_callback(sync_callback) + return fut + + return _asyncio_run_in_executor(executor, in_thread, callback) + + def event(loop): + if isinstance(loop, trio.lowlevel.TrioToken): + return trio.Event() + return asyncio.Event() + + def get_cancelled_exc(loop): + if isinstance(loop, trio.lowlevel.TrioToken): + return trio.Cancelled + return asyncio.CancelledError + + async def wrap_task_context(loop, task_context, awaitable): + if task_context is None: + return await awaitable + + if isinstance(loop, trio.lowlevel.TrioToken): + with trio.CancelScope as scope: + ctx = TrioTaskContext(scope) + task_context.append(ctx) + try: + return await awaitable + finally: + task_context.remove(ctx) + + return await _asyncio_wrap_task_context(task_context, awaitable) + + class ThreadSensitiveContext: """Async context manager to manage context for thread sensitive mode @@ -110,6 +245,19 @@ async def __aexit__(self, exc, value, tb): SyncToAsync.thread_sensitive_context.reset(self.token) +class LoopType(enum.Enum): + ASYNCIO = enum.auto() + TRIO = enum.auto() + + +def run(async_backend, callable, /, *args): + if async_backend is LoopType.TRIO: + import trio + + return trio.run(callable, *args) + return asyncio.run(callable(*args)) + + class AsyncToSync(Generic[_P, _R]): """ Utility class which turns an awaitable that only works on the thread with @@ -129,7 +277,7 @@ class AsyncToSync(Generic[_P, _R]): # When we can't find a CurrentThreadExecutor from the context, such as # inside create_task, we'll look it up here from the running event loop. - loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {} + loop_thread_executors: "Dict[object, CurrentThreadExecutor]" = {} def __init__( self, @@ -137,8 +285,11 @@ def __init__( Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]], ], - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ): + if force_new_loop and not isinstance(LoopType): + force_new_loop = LoopType.ASYNCIO + if not callable(awaitable) or ( not iscoroutinefunction(awaitable) and not iscoroutinefunction(getattr(awaitable, "__call__", awaitable)) @@ -156,7 +307,7 @@ def __init__( self.force_new_loop = force_new_loop self.main_event_loop = None try: - self.main_event_loop = asyncio.get_running_loop() + self.main_event_loop = get_running_loop() except RuntimeError: # There's no event loop in this thread. pass @@ -179,7 +330,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # You can't call AsyncToSync from a thread with a running event loop try: - asyncio.get_running_loop() + get_running_loop() except RuntimeError: pass else: @@ -224,7 +375,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ) async def new_loop_wrap() -> None: - loop = asyncio.get_running_loop() + loop = get_running_loop() self.loop_thread_executors[loop] = current_executor try: await awaitable @@ -233,8 +384,9 @@ async def new_loop_wrap() -> None: if self.main_event_loop is not None: try: - self.main_event_loop.call_soon_threadsafe( - self.main_event_loop.create_task, awaitable + create_task_threadsafe( + self.main_event_loop, + awaitable, ) except RuntimeError: running_in_main_event_loop = False @@ -248,7 +400,9 @@ async def new_loop_wrap() -> None: if not running_in_main_event_loop: # Make our own event loop - in a new thread - and run inside that. loop_executor = ThreadPoolExecutor(max_workers=1) - loop_future = loop_executor.submit(asyncio.run, new_loop_wrap()) + loop_future = loop_executor.submit( + run, self.force_new_loop, new_loop_wrap + ) # Run the CurrentThreadExecutor until the future is done. current_executor.run_until_future(loop_future) # Wait for future and/or allow for exception propagation @@ -286,10 +440,6 @@ async def main_wrap( if context is not None: _restore_context(context[0]) - current_task = asyncio.current_task() - if current_task is not None and task_context is not None: - task_context.append(current_task) - try: # If we have an exception, run the function inside the except block # after raising it so exc_info is correctly populated. @@ -297,16 +447,14 @@ async def main_wrap( try: raise exc_info[1] except BaseException: - result = await awaitable + result = await wrap_task_context(task_context, awaitable) else: - result = await awaitable + result = await wrap_task_context(task_context, awaitable) except BaseException as e: call_result.set_exception(e) else: call_result.set_result(result) finally: - if current_task is not None and task_context is not None: - task_context.remove(current_task) context[0] = contextvars.copy_context() @@ -382,7 +530,7 @@ def __init__( async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: __traceback_hide__ = True # noqa: F841 - loop = asyncio.get_running_loop() + loop = get_running_loop() # Work out what thread to run the code in if self._thread_sensitive: @@ -422,10 +570,19 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: func = context.run task_context: List[asyncio.Task[Any]] = [] + executor_done = event(loop) + executor_result = None + + def callback(fut): + nonlocal executor_result + executor_done.set() + executor_result = fut + # Run the code in the right thread - exec_coro = loop.run_in_executor( - executor, - functools.partial( + exec_fut = run_in_executor( + loop=loop, + executor=executor, + in_thread=functools.partial( self.thread_handler, loop, sys.exc_info(), @@ -433,32 +590,34 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: func, child, ), + callback=callback, ) - ret: _R + # hmm this is a bit messy - needs a while loop and shield or it can + # loose cancellations on multi-cancel. try: - ret = await asyncio.shield(exec_coro) - except asyncio.CancelledError: + await executor_done.wait() + except get_cancelled_exc(loop): cancel_parent = True try: task = task_context[0] task.cancel() try: - await task + await task.wait() cancel_parent = False - except asyncio.CancelledError: + except get_cancelled_exc(loop): pass except IndexError: pass - if exec_coro.done(): + if executor_done.is_set(): raise if cancel_parent: - exec_coro.cancel() - ret = await exec_coro + exec_fut.cancel() + await executor_done.wait() + return executor_result.result() finally: _restore_context(context) self.deadlock_context.set(False) - - return ret + return executor_result.result() def __get__( self, parent: Any, objtype: Any From b9e74122edfc8f139ef5539445953fad2828e206 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 09:31:27 +0100 Subject: [PATCH 03/11] fix LoopType annotation --- asgiref/sync.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index c96ee2e9..0689e3b8 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -655,7 +655,7 @@ def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs): @overload def async_to_sync( *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Callable[ [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], Callable[_P, _R], @@ -670,7 +670,7 @@ def async_to_sync( Callable[_P, Awaitable[_R]], ], *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Callable[_P, _R]: ... @@ -683,7 +683,7 @@ def async_to_sync( ] ] = None, *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Union[ Callable[ [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], From 92a2b5f47b39dcd55433823b56c39cc4e8b92dc1 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 09:48:12 +0100 Subject: [PATCH 04/11] handle None executors --- asgiref/sync.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index 0689e3b8..59588b05 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -120,8 +120,10 @@ async def _asyncio_wrap_task_context(task_context, awaitable): try: + import outcome import sniffio import trio.lowlevel + import trio.to_thread except ModuleNotFoundError: from asyncio import get_running_loop @@ -166,14 +168,45 @@ def create_task_threadsafe(loop, awaitable): return _asyncio_create_task_threadsafe(loop, awaitable) + class TrioToThreadFut: + def __init__(self, cs): + self._cs = cs + self._outcome = None + + def cancel(self): + self._cs.cancel() + + def result(self): + return self._outcome.unwrap() + + def set_result(self, outcome): + self._outcome = outcome + def run_in_executor(loop, executor, in_thread, callback): if isinstance(loop, trio.lowlevel.TrioToken): + if executor is not None: + + def sync_callback(fut): + loop.run_sync_soon(callback, fut) + + fut = executor.submit(in_thread) + fut.add_done_callback(sync_callback) + return fut + + # executor is None - we need to run on the trio + # thread pool, which is a bit more complicated. + cs = trio.CancelScope() + fut = TrioToThreadFut(cs) + + async def run_in_scope(): + with cs: + await trio.to_thread.run_sync(in_thread) - def sync_callback(fut): - loop.run_sync_soon(callback, fut) + async def run_in_thread(): + fut.set_result(await outcome.acapture(run_in_scope)) + callback(fut) - fut = executor.submit(in_thread) - fut.add_done_callback(sync_callback) + trio.lowlevel.spawn_system_task(run_in_thread) return fut return _asyncio_run_in_executor(executor, in_thread, callback) From dcc36716009849e97b522f871d590fc98db1e4df Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 11:02:47 +0100 Subject: [PATCH 05/11] simplify trio handling of task-thread-task --- asgiref/sync.py | 274 ++++++++++++++++++++++-------------------------- 1 file changed, 126 insertions(+), 148 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index 59588b05..42b357bd 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -73,54 +73,69 @@ def _asyncio_create_task_threadsafe(loop, awaitable): loop.call_soon_threadsafe(loop.create_task, awaitable) -def _asyncio_run_in_executor(loop, executor, in_thread, callback): - fut = loop.run_in_executor(executor, in_thread) - fut.add_done_callback(callback) - return fut - - -class AsyncioTaskContext: - def __init__(self, task): - self._task = task - - def cancel(self): - return self._task.cancel() - - async def wait(self): - return await self._task - - -class TrioTaskContext: - def __init__(self, cs, event): - self._cs = cs - self._event = event - - def cancel(self): - return self._cs.cancel() - - async def wait(self): - return await self._event.wait() - - -async def _asyncio_wrap_task_context(task_context, awaitable): +async def _asyncio_wrap_task_context(loop, task_context, awaitable): if task_context is None: return await awaitable - current_task = asyncio.current_task() + current_task = asyncio.current_task(loop) if current_task is None: return await awaitable - task_context_wrapped = AsyncioTaskContext(current_task) - task_context.append(task_context_wrapped) + task_context.append(current_task) try: return await awaitable finally: - if current_task is not None: - task_context.remove(task_context_wrapped) + task_context.remove(current_task) + + +async def _asyncio_run_in_executor(*, loop, executor, thread_handler, child): + context = contextvars.copy_context() + func = context.run + task_context: List[asyncio.Task[Any]] = [] + + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ), + ) + ret: _R + try: + ret = await asyncio.shield(exec_coro) + except asyncio.CancelledError: + cancel_parent = True + try: + task = task_context[0] + task.cancel() + try: + await task + cancel_parent = False + except asyncio.CancelledError: + pass + except IndexError: + pass + if exec_coro.done(): + raise + if cancel_parent: + exec_coro.cancel() + ret = await exec_coro + finally: + _restore_context(context) + + return ret + + +class TrioThreadCancelled(BaseException): + pass try: - import outcome import sniffio import trio.lowlevel import trio.to_thread @@ -131,12 +146,6 @@ async def _asyncio_wrap_task_context(task_context, awaitable): run_in_executor = _asyncio_run_in_executor wrap_task_context = _asyncio_wrap_task_context - def event(loop): - return asyncio.Event() - - def get_cancelled_exc(loop): - return asyncio.CancelledError - else: def get_running_loop(): @@ -168,58 +177,69 @@ def create_task_threadsafe(loop, awaitable): return _asyncio_create_task_threadsafe(loop, awaitable) - class TrioToThreadFut: - def __init__(self, cs): - self._cs = cs - self._outcome = None - - def cancel(self): - self._cs.cancel() - - def result(self): - return self._outcome.unwrap() - - def set_result(self, outcome): - self._outcome = outcome - - def run_in_executor(loop, executor, in_thread, callback): - if isinstance(loop, trio.lowlevel.TrioToken): - if executor is not None: - - def sync_callback(fut): - loop.run_sync_soon(callback, fut) - - fut = executor.submit(in_thread) - fut.add_done_callback(sync_callback) - return fut - - # executor is None - we need to run on the trio - # thread pool, which is a bit more complicated. - cs = trio.CancelScope() - fut = TrioToThreadFut(cs) - - async def run_in_scope(): - with cs: - await trio.to_thread.run_sync(in_thread) - - async def run_in_thread(): - fut.set_result(await outcome.acapture(run_in_scope)) - callback(fut) - - trio.lowlevel.spawn_system_task(run_in_thread) - return fut - - return _asyncio_run_in_executor(executor, in_thread, callback) - - def event(loop): + async def run_in_executor(*, loop, executor, thread_handler, child): if isinstance(loop, trio.lowlevel.TrioToken): - return trio.Event() - return asyncio.Event() + context = contextvars.copy_context() + func = context.run + task_context: List[asyncio.Task[Any]] = [] - def get_cancelled_exc(loop): - if isinstance(loop, trio.lowlevel.TrioToken): - return trio.Cancelled - return asyncio.CancelledError + # Run the code in the right thread + full_func = functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ) + try: + if executor is None: + async with trio.open_nursery() as nursery: + + async def handle_abort(): + try: + await trio.sleep_forever() + except trio.Cancelled: + if task_context: + task_context[0].cancel() + raise + + try: + return await trio.to_thread.run_sync( + thread_handler, func, abandon_on_cancel=False + ) + finally: + nursery.cancel_scope.cancel() + else: + event = trio.Event() + + def callback(fut): + loop.run_sync_soon(event.set) + + fut = executor.submit(full_func) + fut.add_done_callback(callback) + + async with trio.open_nursery() as nursery: + + async def handle_abort(): + try: + await trio.sleep_forever() + except trio.Cancelled: + fut.cancel() + if task_context: + task_context[0].cancel() + raise + + with trio.CancelScope(shield=True): + await event.wait() + nursery.cancel_scope.cancel() + return fut.result() + finally: + _restore_context(context) + + return await _asyncio_run_in_executor( + loop=loop, executor=executor, thread_handler=thread_handler, func=func + ) async def wrap_task_context(loop, task_context, awaitable): if task_context is None: @@ -227,14 +247,15 @@ async def wrap_task_context(loop, task_context, awaitable): if isinstance(loop, trio.lowlevel.TrioToken): with trio.CancelScope as scope: - ctx = TrioTaskContext(scope) - task_context.append(ctx) + task_context.append(scope) try: return await awaitable finally: - task_context.remove(ctx) + task_context.remove(scope) + if scope.cancelled_caught: + raise TrioThreadCancelled - return await _asyncio_wrap_task_context(task_context, awaitable) + return await _asyncio_wrap_task_context(loop, task_context, awaitable) class ThreadSensitiveContext: @@ -470,6 +491,7 @@ async def main_wrap( __traceback_hide__ = True # noqa: F841 + loop = get_running_loop() if context is not None: _restore_context(context[0]) @@ -480,9 +502,9 @@ async def main_wrap( try: raise exc_info[1] except BaseException: - result = await wrap_task_context(task_context, awaitable) + result = await wrap_task_context(loop, task_context, awaitable) else: - result = await wrap_task_context(task_context, awaitable) + result = await wrap_task_context(loop, task_context, awaitable) except BaseException as e: call_result.set_exception(e) else: @@ -598,59 +620,15 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # Use the passed in executor, or the loop's default if it is None executor = self._executor - context = contextvars.copy_context() - child = functools.partial(self.func, *args, **kwargs) - func = context.run - task_context: List[asyncio.Task[Any]] = [] - - executor_done = event(loop) - executor_result = None - - def callback(fut): - nonlocal executor_result - executor_done.set() - executor_result = fut - - # Run the code in the right thread - exec_fut = run_in_executor( - loop=loop, - executor=executor, - in_thread=functools.partial( - self.thread_handler, - loop, - sys.exc_info(), - task_context, - func, - child, - ), - callback=callback, - ) - # hmm this is a bit messy - needs a while loop and shield or it can - # loose cancellations on multi-cancel. try: - await executor_done.wait() - except get_cancelled_exc(loop): - cancel_parent = True - try: - task = task_context[0] - task.cancel() - try: - await task.wait() - cancel_parent = False - except get_cancelled_exc(loop): - pass - except IndexError: - pass - if executor_done.is_set(): - raise - if cancel_parent: - exec_fut.cancel() - await executor_done.wait() - return executor_result.result() + return await run_in_executor( + loop=loop, + executor=executor, + thread_handler=self.thread_handler, + child=functools.partial(self.func, *args, **kwargs), + ) finally: - _restore_context(context) self.deadlock_context.set(False) - return executor_result.result() def __get__( self, parent: Any, objtype: Any From 71906e7c8a5ffa1dc3e0b2982f8ec3b3a4bcf7b1 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 12 Apr 2025 11:32:17 +0100 Subject: [PATCH 06/11] actually run the cancel handlers --- asgiref/sync.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index 42b357bd..cdde66b0 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -194,16 +194,17 @@ async def run_in_executor(*, loop, executor, thread_handler, child): ) try: if executor is None: - async with trio.open_nursery() as nursery: - async def handle_abort(): - try: - await trio.sleep_forever() - except trio.Cancelled: - if task_context: - task_context[0].cancel() - raise + async def handle_cancel(): + try: + await trio.sleep_forever() + except trio.Cancelled: + if task_context: + task_context[0].cancel() + raise + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel) try: return await trio.to_thread.run_sync( thread_handler, func, abandon_on_cancel=False @@ -219,17 +220,17 @@ def callback(fut): fut = executor.submit(full_func) fut.add_done_callback(callback) - async with trio.open_nursery() as nursery: - - async def handle_abort(): - try: - await trio.sleep_forever() - except trio.Cancelled: - fut.cancel() - if task_context: - task_context[0].cancel() - raise + async def handle_cancel_fut(): + try: + await trio.sleep_forever() + except trio.Cancelled: + fut.cancel() + if task_context: + task_context[0].cancel() + raise + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel_fut) with trio.CancelScope(shield=True): await event.wait() nursery.cancel_scope.cancel() From 23b57e1c9850ae64d79ffdaf71cd27a791cd6ba0 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 13 Apr 2025 08:17:47 +0100 Subject: [PATCH 07/11] catch TrioThreadCancelled --- asgiref/sync.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/asgiref/sync.py b/asgiref/sync.py index cdde66b0..12f5d06d 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -209,6 +209,8 @@ async def handle_cancel(): return await trio.to_thread.run_sync( thread_handler, func, abandon_on_cancel=False ) + except TrioThreadCancelled: + pass finally: nursery.cancel_scope.cancel() else: @@ -234,7 +236,10 @@ async def handle_cancel_fut(): with trio.CancelScope(shield=True): await event.wait() nursery.cancel_scope.cancel() - return fut.result() + try: + return fut.result() + except TrioThreadCancelled: + pass finally: _restore_context(context) @@ -247,7 +252,7 @@ async def wrap_task_context(loop, task_context, awaitable): return await awaitable if isinstance(loop, trio.lowlevel.TrioToken): - with trio.CancelScope as scope: + with trio.CancelScope() as scope: task_context.append(scope) try: return await awaitable From 89d20644f325256419bba22d1d1fed0af01cc499 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 13 Apr 2025 09:23:27 +0100 Subject: [PATCH 08/11] split out trio/asyncio compat --- asgiref/_asyncio.py | 82 ++++++++++++++++ asgiref/_context.py | 13 +++ asgiref/_trio.py | 136 ++++++++++++++++++++++++++ asgiref/sync.py | 232 +++++--------------------------------------- 4 files changed, 256 insertions(+), 207 deletions(-) create mode 100644 asgiref/_asyncio.py create mode 100644 asgiref/_context.py create mode 100644 asgiref/_trio.py diff --git a/asgiref/_asyncio.py b/asgiref/_asyncio.py new file mode 100644 index 00000000..b1bb1d4f --- /dev/null +++ b/asgiref/_asyncio.py @@ -0,0 +1,82 @@ +__all__ = [ + "get_running_loop", + "create_task_threadsafe", + "wrap_task_context", + "run_in_executor", +] + +import asyncio +import contextvars +import functools +import sys +from asyncio import get_running_loop +from collections.abc import Callable +from typing import Any, TypeVar + +from ._context import restore_context as _restore_context + +_R = TypeVar("_R") + + +def create_task_threadsafe(loop, awaitable) -> None: + loop.call_soon_threadsafe(loop.create_task, awaitable) + + +async def wrap_task_context(loop, task_context, awaitable): + if task_context is None: + return await awaitable + + current_task = asyncio.current_task(loop) + if current_task is None: + return await awaitable + + task_context.append(current_task) + try: + return await awaitable + finally: + task_context.remove(current_task) + + +async def run_in_executor( + *, loop, executor, thread_handler, child: Callable[[], _R] +) -> _R: + context = contextvars.copy_context() + func = context.run + task_context: list[asyncio.Task[Any]] = [] + + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ), + ) + ret: _R + try: + ret = await asyncio.shield(exec_coro) + except asyncio.CancelledError: + cancel_parent = True + try: + task = task_context[0] + task.cancel() + try: + await task + cancel_parent = False + except asyncio.CancelledError: + pass + except IndexError: + pass + if exec_coro.done(): + raise + if cancel_parent: + exec_coro.cancel() + ret = await exec_coro + finally: + _restore_context(context) + + return ret diff --git a/asgiref/_context.py b/asgiref/_context.py new file mode 100644 index 00000000..08af5153 --- /dev/null +++ b/asgiref/_context.py @@ -0,0 +1,13 @@ +import contextvars + + +def restore_context(context: contextvars.Context) -> None: + # Check for changes in contextvars, and set them to the current + # context for downstream consumers + for cvar in context: + cvalue = context.get(cvar) + try: + if cvar.get() != cvalue: + cvar.set(cvalue) + except LookupError: + cvar.set(cvalue) diff --git a/asgiref/_trio.py b/asgiref/_trio.py new file mode 100644 index 00000000..f002b63d --- /dev/null +++ b/asgiref/_trio.py @@ -0,0 +1,136 @@ +import asyncio +import contextvars +import functools +import sys +from typing import Any + +import sniffio +import trio.lowlevel +import trio.to_thread + +from . import _asyncio +from ._context import restore_context as _restore_context + + +class TrioThreadCancelled(BaseException): + pass + + +def get_running_loop(): + try: + asynclib = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + return asyncio.get_running_loop() + + if asynclib == "asyncio": + return asyncio.get_running_loop() + if asynclib == "trio": + return trio.lowlevel.current_token() + raise RuntimeError(f"unsupported library {asynclib}") + + +@trio.lowlevel.disable_ki_protection +async def wrap_awaitable(awaitable): + return await awaitable + + +def create_task_threadsafe(loop, awaitable): + if isinstance(loop, trio.lowlevel.TrioToken): + try: + loop.run_sync_soon( + trio.lowlevel.spawn_system_task, + wrap_awaitable, + awaitable, + ) + except trio.RunFinishedError: + raise RuntimeError("trio loop no-longer running") + + return _asyncio.create_task_threadsafe(loop, awaitable) + + +async def run_in_executor(*, loop, executor, thread_handler, child): + if isinstance(loop, trio.lowlevel.TrioToken): + context = contextvars.copy_context() + func = context.run + task_context: list[asyncio.Task[Any]] = [] + + # Run the code in the right thread + full_func = functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ) + try: + if executor is None: + + async def handle_cancel(): + try: + await trio.sleep_forever() + except trio.Cancelled: + if task_context: + task_context[0].cancel() + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel) + try: + return await trio.to_thread.run_sync( + thread_handler, func, abandon_on_cancel=False + ) + except TrioThreadCancelled: + pass + finally: + nursery.cancel_scope.cancel() + else: + event = trio.Event() + + def callback(fut): + loop.run_sync_soon(event.set) + + fut = executor.submit(full_func) + fut.add_done_callback(callback) + + async def handle_cancel_fut(): + try: + await trio.sleep_forever() + except trio.Cancelled: + fut.cancel() + if task_context: + task_context[0].cancel() + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel_fut) + with trio.CancelScope(shield=True): + await event.wait() + nursery.cancel_scope.cancel() + try: + return fut.result() + except TrioThreadCancelled: + pass + finally: + _restore_context(context) + + return await _asyncio.run_in_executor( + loop=loop, executor=executor, thread_handler=thread_handler, func=func + ) + + +async def wrap_task_context(loop, task_context, awaitable): + if task_context is None: + return await awaitable + + if isinstance(loop, trio.lowlevel.TrioToken): + with trio.CancelScope() as scope: + task_context.append(scope) + try: + return await awaitable + finally: + task_context.remove(scope) + if scope.cancelled_caught: + raise TrioThreadCancelled + + return await _asyncio.wrap_task_context(loop, task_context, awaitable) diff --git a/asgiref/sync.py b/asgiref/sync.py index 12f5d06d..99e85bee 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -24,6 +24,7 @@ overload, ) +from ._context import restore_context as _restore_context from .current_thread_executor import CurrentThreadExecutor from .local import Local @@ -36,23 +37,35 @@ # This is not available to import at runtime from _typeshed import OptExcInfo + from ._trio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) +else: + try: + __import__("trio") + except ModuleNotFoundError: + from ._asyncio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) + else: + from ._trio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) + _F = TypeVar("_F", bound=Callable[..., Any]) _P = ParamSpec("_P") _R = TypeVar("_R") -def _restore_context(context: contextvars.Context) -> None: - # Check for changes in contextvars, and set them to the current - # context for downstream consumers - for cvar in context: - cvalue = context.get(cvar) - try: - if cvar.get() != cvalue: - cvar.set(cvalue) - except LookupError: - cvar.set(cvalue) - - # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for # inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker. # The latter is replaced with the inspect.markcoroutinefunction decorator. @@ -69,201 +82,6 @@ def markcoroutinefunction(func: _F) -> _F: return func -def _asyncio_create_task_threadsafe(loop, awaitable): - loop.call_soon_threadsafe(loop.create_task, awaitable) - - -async def _asyncio_wrap_task_context(loop, task_context, awaitable): - if task_context is None: - return await awaitable - - current_task = asyncio.current_task(loop) - if current_task is None: - return await awaitable - - task_context.append(current_task) - try: - return await awaitable - finally: - task_context.remove(current_task) - - -async def _asyncio_run_in_executor(*, loop, executor, thread_handler, child): - context = contextvars.copy_context() - func = context.run - task_context: List[asyncio.Task[Any]] = [] - - # Run the code in the right thread - exec_coro = loop.run_in_executor( - executor, - functools.partial( - thread_handler, - loop, - sys.exc_info(), - task_context, - func, - child, - ), - ) - ret: _R - try: - ret = await asyncio.shield(exec_coro) - except asyncio.CancelledError: - cancel_parent = True - try: - task = task_context[0] - task.cancel() - try: - await task - cancel_parent = False - except asyncio.CancelledError: - pass - except IndexError: - pass - if exec_coro.done(): - raise - if cancel_parent: - exec_coro.cancel() - ret = await exec_coro - finally: - _restore_context(context) - - return ret - - -class TrioThreadCancelled(BaseException): - pass - - -try: - import sniffio - import trio.lowlevel - import trio.to_thread -except ModuleNotFoundError: - from asyncio import get_running_loop - - create_task_threadsafe = _asyncio_create_task_threadsafe - run_in_executor = _asyncio_run_in_executor - wrap_task_context = _asyncio_wrap_task_context - -else: - - def get_running_loop(): - try: - asynclib = sniffio.current_async_library() - except sniffio.AsyncLibraryNotFoundError: - return asyncio.get_running_loop() - - if asynclib == "asyncio": - return asyncio.get_running_loop() - if asynclib == "trio": - return trio.lowlevel.current_token() - raise RuntimeError(f"unsupported library {asynclib}") - - @trio.lowlevel.disable_ki_protection - async def wrap_awaitable(awaitable): - return await awaitable - - def create_task_threadsafe(loop, awaitable): - if isinstance(loop, trio.lowlevel.TrioToken): - try: - loop.run_sync_soon( - trio.lowlevel.spawn_system_task, - wrap_awaitable, - awaitable, - ) - except trio.RunFinishedError: - raise RuntimeError("trio loop no-longer running") - - return _asyncio_create_task_threadsafe(loop, awaitable) - - async def run_in_executor(*, loop, executor, thread_handler, child): - if isinstance(loop, trio.lowlevel.TrioToken): - context = contextvars.copy_context() - func = context.run - task_context: List[asyncio.Task[Any]] = [] - - # Run the code in the right thread - full_func = functools.partial( - thread_handler, - loop, - sys.exc_info(), - task_context, - func, - child, - ) - try: - if executor is None: - - async def handle_cancel(): - try: - await trio.sleep_forever() - except trio.Cancelled: - if task_context: - task_context[0].cancel() - raise - - async with trio.open_nursery() as nursery: - nursery.start_soon(handle_cancel) - try: - return await trio.to_thread.run_sync( - thread_handler, func, abandon_on_cancel=False - ) - except TrioThreadCancelled: - pass - finally: - nursery.cancel_scope.cancel() - else: - event = trio.Event() - - def callback(fut): - loop.run_sync_soon(event.set) - - fut = executor.submit(full_func) - fut.add_done_callback(callback) - - async def handle_cancel_fut(): - try: - await trio.sleep_forever() - except trio.Cancelled: - fut.cancel() - if task_context: - task_context[0].cancel() - raise - - async with trio.open_nursery() as nursery: - nursery.start_soon(handle_cancel_fut) - with trio.CancelScope(shield=True): - await event.wait() - nursery.cancel_scope.cancel() - try: - return fut.result() - except TrioThreadCancelled: - pass - finally: - _restore_context(context) - - return await _asyncio_run_in_executor( - loop=loop, executor=executor, thread_handler=thread_handler, func=func - ) - - async def wrap_task_context(loop, task_context, awaitable): - if task_context is None: - return await awaitable - - if isinstance(loop, trio.lowlevel.TrioToken): - with trio.CancelScope() as scope: - task_context.append(scope) - try: - return await awaitable - finally: - task_context.remove(scope) - if scope.cancelled_caught: - raise TrioThreadCancelled - - return await _asyncio_wrap_task_context(loop, task_context, awaitable) - - class ThreadSensitiveContext: """Async context manager to manage context for thread sensitive mode From 152d990b9069af13db6ee9f53c95125e5bdd4fc8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sun, 13 Apr 2025 10:10:30 +0100 Subject: [PATCH 09/11] mypy --- asgiref/_asyncio.py | 42 ++++++++++++++++++++++--- asgiref/_trio.py | 74 ++++++++++++++++++++++++++++++++++----------- asgiref/sync.py | 3 +- tox.ini | 1 + 4 files changed, 97 insertions(+), 23 deletions(-) diff --git a/asgiref/_asyncio.py b/asgiref/_asyncio.py index b1bb1d4f..2b450bee 100644 --- a/asgiref/_asyncio.py +++ b/asgiref/_asyncio.py @@ -6,23 +6,33 @@ ] import asyncio +import concurrent.futures import contextvars import functools import sys +import types from asyncio import get_running_loop -from collections.abc import Callable -from typing import Any, TypeVar +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, Generic, Protocol, TypeVar, Union from ._context import restore_context as _restore_context _R = TypeVar("_R") +Coro = Coroutine[Any, Any, _R] -def create_task_threadsafe(loop, awaitable) -> None: + +def create_task_threadsafe( + loop: asyncio.AbstractEventLoop, awaitable: Coro[object] +) -> None: loop.call_soon_threadsafe(loop.create_task, awaitable) -async def wrap_task_context(loop, task_context, awaitable): +async def wrap_task_context( + loop: asyncio.AbstractEventLoop, + task_context: list[asyncio.Task[Any]], + awaitable: Awaitable[_R], +) -> _R: if task_context is None: return await awaitable @@ -37,8 +47,30 @@ async def wrap_task_context(loop, task_context, awaitable): task_context.remove(current_task) +ExcInfo = Union[ + tuple[type[BaseException], BaseException, types.TracebackType], + tuple[None, None, None], +] + + +class ThreadHandlerType(Protocol, Generic[_R]): + def __call__( + self, + loop: asyncio.AbstractEventLoop, + exc_info: ExcInfo, + task_context: list[asyncio.Task[Any]], + func: Callable[[Callable[[], _R]], _R], + child: Callable[[], _R], + ) -> _R: + ... + + async def run_in_executor( - *, loop, executor, thread_handler, child: Callable[[], _R] + *, + loop: asyncio.AbstractEventLoop, + executor: concurrent.futures.ThreadPoolExecutor, + thread_handler: ThreadHandlerType[_R], + child: Callable[[], _R], ) -> _R: context = contextvars.copy_context() func = context.run diff --git a/asgiref/_trio.py b/asgiref/_trio.py index f002b63d..af56892b 100644 --- a/asgiref/_trio.py +++ b/asgiref/_trio.py @@ -1,8 +1,11 @@ import asyncio +import concurrent.futures import contextvars import functools import sys -from typing import Any +import types +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, Generic, Protocol, TypeVar, Union import sniffio import trio.lowlevel @@ -11,12 +14,20 @@ from . import _asyncio from ._context import restore_context as _restore_context +_R = TypeVar("_R") + +Coro = Coroutine[Any, Any, _R] + +Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken] +TaskContext = list[Any] + class TrioThreadCancelled(BaseException): pass -def get_running_loop(): +def get_running_loop() -> Loop: + try: asynclib = sniffio.current_async_library() except sniffio.AsyncLibraryNotFoundError: @@ -25,16 +36,16 @@ def get_running_loop(): if asynclib == "asyncio": return asyncio.get_running_loop() if asynclib == "trio": - return trio.lowlevel.current_token() + return trio.lowlevel.current_trio_token() raise RuntimeError(f"unsupported library {asynclib}") @trio.lowlevel.disable_ki_protection -async def wrap_awaitable(awaitable): +async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R: return await awaitable -def create_task_threadsafe(loop, awaitable): +def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None: if isinstance(loop, trio.lowlevel.TrioToken): try: loop.run_sync_soon( @@ -44,15 +55,40 @@ def create_task_threadsafe(loop, awaitable): ) except trio.RunFinishedError: raise RuntimeError("trio loop no-longer running") + return + + _asyncio.create_task_threadsafe(loop, awaitable) - return _asyncio.create_task_threadsafe(loop, awaitable) +ExcInfo = Union[ + tuple[type[BaseException], BaseException, types.TracebackType], + tuple[None, None, None], +] -async def run_in_executor(*, loop, executor, thread_handler, child): + +class ThreadHandlerType(Protocol, Generic[_R]): + def __call__( + self, + loop: Loop, + exc_info: ExcInfo, + task_context: TaskContext, + func: Callable[[Callable[[], _R]], _R], + child: Callable[[], _R], + ) -> _R: + ... + + +async def run_in_executor( + *, + loop: Loop, + executor: concurrent.futures.ThreadPoolExecutor, + thread_handler: ThreadHandlerType[_R], + child: Callable[[], _R], +) -> _R: if isinstance(loop, trio.lowlevel.TrioToken): context = contextvars.copy_context() func = context.run - task_context: list[asyncio.Task[Any]] = [] + task_context: TaskContext = [] # Run the code in the right thread full_func = functools.partial( @@ -66,7 +102,7 @@ async def run_in_executor(*, loop, executor, thread_handler, child): try: if executor is None: - async def handle_cancel(): + async def handle_cancel() -> None: try: await trio.sleep_forever() except trio.Cancelled: @@ -84,16 +120,17 @@ async def handle_cancel(): pass finally: nursery.cancel_scope.cancel() + assert False else: event = trio.Event() - def callback(fut): + def callback(fut: object) -> None: loop.run_sync_soon(event.set) fut = executor.submit(full_func) fut.add_done_callback(callback) - async def handle_cancel_fut(): + async def handle_cancel_fut() -> None: try: await trio.sleep_forever() except trio.Cancelled: @@ -111,15 +148,19 @@ async def handle_cancel_fut(): return fut.result() except TrioThreadCancelled: pass + assert False finally: _restore_context(context) - return await _asyncio.run_in_executor( - loop=loop, executor=executor, thread_handler=thread_handler, func=func - ) + else: + return await _asyncio.run_in_executor( + loop=loop, executor=executor, thread_handler=thread_handler, child=child + ) -async def wrap_task_context(loop, task_context, awaitable): +async def wrap_task_context( + loop: Loop, task_context: Union[TaskContext, None], awaitable: Awaitable[_R] +) -> _R: if task_context is None: return await awaitable @@ -130,7 +171,6 @@ async def wrap_task_context(loop, task_context, awaitable): return await awaitable finally: task_context.remove(scope) - if scope.cancelled_caught: - raise TrioThreadCancelled + raise TrioThreadCancelled return await _asyncio.wrap_task_context(loop, task_context, awaitable) diff --git a/asgiref/sync.py b/asgiref/sync.py index 99e85bee..5de00983 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -165,7 +165,7 @@ def __init__( ], force_new_loop: Union[LoopType, bool] = False, ): - if force_new_loop and not isinstance(LoopType): + if force_new_loop and not isinstance(force_new_loop, LoopType): force_new_loop = LoopType.ASYNCIO if not callable(awaitable) or ( @@ -319,6 +319,7 @@ async def main_wrap( if context is not None: _restore_context(context[0]) + result: _R try: # If we have an exception, run the function inside the except block # after raising it so exc_info is correctly populated. diff --git a/tox.ini b/tox.ini index 49c49bce..50a1a463 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ commands = mypy: mypy . {posargs} deps = setuptools + mypy: trio [testenv:qa] skip_install = true From 9425b75e255e6b4af46796a89e9f364abaf5179c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 14 Apr 2025 07:53:05 +0100 Subject: [PATCH 10/11] run correct function --- asgiref/_trio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asgiref/_trio.py b/asgiref/_trio.py index af56892b..2c6803e0 100644 --- a/asgiref/_trio.py +++ b/asgiref/_trio.py @@ -114,7 +114,7 @@ async def handle_cancel() -> None: nursery.start_soon(handle_cancel) try: return await trio.to_thread.run_sync( - thread_handler, func, abandon_on_cancel=False + full_func, abandon_on_cancel=False ) except TrioThreadCancelled: pass From c0586d38884543741e0a7089e8d3d5303370deb2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 14 Apr 2025 07:53:17 +0100 Subject: [PATCH 11/11] test on trio using anyio --- setup.cfg | 1 + tests/test_sync.py | 103 ++++++++++++++++++++++++--------------------- tox.ini | 3 +- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/setup.cfg b/setup.cfg index ef0a4314..8d10df10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ zip_safe = false tests = pytest pytest-asyncio + anyio[trio] mypy>=1.14.0 [tool:pytest] diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..974e1a80 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,7 +1,6 @@ import asyncio import functools import multiprocessing -import sys import threading import time import warnings @@ -10,7 +9,9 @@ from typing import Any from unittest import TestCase +import anyio import pytest +import trio.to_thread from asgiref.sync import ( ThreadSensitiveContext, @@ -21,7 +22,7 @@ from asgiref.timeout import timeout -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async(): """ Tests we can call sync functions from an async thread @@ -41,6 +42,16 @@ def sync_function(): end = time.monotonic() assert result == 42 assert end - start >= 1 + + +@pytest.mark.asyncio +async def test_sync_to_async_one_worker(): + # Define sync function + @sync_to_async + def async_function(): + time.sleep(1) + return 42 + # Set workers to 1, call it twice and make sure that works right loop = asyncio.get_running_loop() old_executor = loop._default_executor or ThreadPoolExecutor() @@ -72,7 +83,7 @@ def test_sync_to_async_fail_non_function(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_fail_async(): """ sync_to_async raises a TypeError when applied to a sync function. @@ -88,7 +99,7 @@ async def test_function(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_fail_partial(): """ sync_to_async raises a TypeError when applied to a sync partial. @@ -106,7 +117,7 @@ async def test_function(*args): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_raises_typeerror_for_async_callable_instance(): class CallableClass: async def __call__(self): @@ -118,7 +129,7 @@ async def __call__(self): sync_to_async(CallableClass()) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_decorator(): """ Tests sync_to_async as a decorator @@ -134,7 +145,7 @@ def test_function(): assert result == 43 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_nested_sync_to_async_retains_wrapped_function_attributes(): """ Tests that attributes of functions wrapped by sync_to_async are retained @@ -157,7 +168,7 @@ def test_function(): assert test_function.__name__ == "test_function" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_method_decorator(): """ Tests sync_to_async as a method decorator @@ -175,7 +186,7 @@ def test_method(self): assert result == 44 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_method_self_attribute(): """ Tests sync_to_async on a method copies __self__ @@ -197,7 +208,7 @@ def test_method(self): assert method.__self__ == instance -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_to_async(): """ Tests we can call async functions from a sync thread created by async_to_sync @@ -225,7 +236,7 @@ def sync_function(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_to_async_decorator(): """ Test async_to_sync as a function decorator uses the outer thread @@ -253,9 +264,8 @@ def sync_function(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9") -async def test_async_to_sync_to_thread_decorator(): +@pytest.mark.anyio +async def test_async_to_sync_to_thread_decorator(anyio_backend_name): """ Test async_to_sync as a function decorator uses the outer thread when used inside another sync thread. @@ -270,7 +280,10 @@ async def inner_async_function(): return 42 # Check it works right - number = await asyncio.to_thread(inner_async_function) + if anyio_backend_name == "trio": + number = await trio.to_thread.run_sync(inner_async_function) + else: + number = await asyncio.to_thread(inner_async_function) assert number == 42 assert result["worked"] # Make sure that it didn't needlessly make a new async loop @@ -363,7 +376,7 @@ async def test_function(self): assert result["worked"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_in_async(): """ Makes sure async_to_sync bails if you try to call it from an async loop @@ -509,7 +522,7 @@ def inner_task(): assert result["thread2"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_outside_async(): """ Tests that thread_sensitive SyncToAsync where the outside is async code runs @@ -535,16 +548,16 @@ def inner(result): result["thread"] = threading.current_thread() # Run it (in supposed parallel!) - await asyncio.wait( - [asyncio.create_task(outer(result_1)), asyncio.create_task(inner(result_2))] - ) + async with anyio.create_task_group() as tg: + tg.start_soon(outer, result_1) + await inner(result_2) # They should not have run in the main thread, but in the same thread assert result_1["thread"] != threading.current_thread() assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_with_context_matches(): result_1 = {} result_2 = {} @@ -557,12 +570,9 @@ def store_thread(result): async def fn(): async with ThreadSensitiveContext(): # Run it (in supposed parallel!) - await asyncio.wait( - [ - asyncio.create_task(store_thread_async(result_1)), - asyncio.create_task(store_thread_async(result_2)), - ] - ) + async with anyio.create_task_group() as tg: + tg.start_soon(store_thread_async, result_1) + await store_thread_async(result_2) await fn() @@ -571,7 +581,7 @@ async def fn(): assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_nested_context(): result_1 = {} result_2 = {} @@ -590,7 +600,7 @@ def store_thread(result): assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_context_without_sync_work(): async with ThreadSensitiveContext(): pass @@ -629,7 +639,7 @@ def level4(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_double_nested_async(): """ Tests that thread_sensitive SyncToAsync nests inside itself where the @@ -729,7 +739,7 @@ def fork_first(): return queue.get(True, 1) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_multiprocessing(): """ Tests that a forked process can use async_to_sync without it looking for @@ -738,7 +748,7 @@ async def test_multiprocessing(): assert await sync_to_async(fork_first)() == 42 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_uses_executor(): """ Tests that SyncToAsync uses the passed in executor correctly. @@ -834,7 +844,7 @@ async def async_process_that_triggers_event(): await trigger_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_with_blocker_non_thread_sensitive(): """ Tests sync_to_async running on a long-time blocker in a non_thread_sensitive context. @@ -850,23 +860,20 @@ async def async_process_waiting_on_event(): async def async_process_that_triggers_event(): """Sleep, then set the event.""" - await asyncio.sleep(1) + await anyio.sleep(1) await sync_to_async(event.set)() - # Run the event setter as a task. - trigger_task = asyncio.ensure_future(async_process_that_triggers_event()) + async with anyio.create_task_group() as tg: + # Run the event setter as a task. + tg.start_soon(async_process_that_triggers_event) - try: - # wait on the event waiter, which is now blocking the event setter. - async with timeout(delay + 1): - assert await async_process_waiting_on_event() == 42 - except asyncio.TimeoutError: - # In case of timeout, set the event to unblock things, else - # downstream tests will get fouled up. - event.set() - raise - finally: - await trigger_task + try: + with anyio.fail_after(delay + 1): + assert await async_process_waiting_on_event() == 42 + except TimeoutError: + # In case of timeout, set the event to unblock things, else + # downstream tests will get fouled up. + event.set() @pytest.mark.asyncio @@ -1194,7 +1201,7 @@ async def test_function(**kwargs: Any) -> None: test_function(context=1) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_overlapping_kwargs() -> None: """ Tests that SyncToAsync correctly passes through kwargs to the wrapped function, diff --git a/tox.ini b/tox.ini index 50a1a463..1eb1cc7c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py{38,39,310,311,312,313}-{test,mypy} + py{38,39,310,311,312,313}-{test,mypy,trio} qa [testenv] @@ -11,7 +11,6 @@ commands = mypy: mypy . {posargs} deps = setuptools - mypy: trio [testenv:qa] skip_install = true