| 
3 | 3 | from __future__ import annotations  | 
4 | 4 | 
 
  | 
5 | 5 | import os  | 
 | 6 | +import threading  | 
6 | 7 | from concurrent import futures  | 
7 | 8 | from http.server import ThreadingHTTPServer  | 
8 |  | -from typing import Any, Callable, Coroutine, Optional, TypeVar, overload  | 
 | 9 | +from typing import Any, Callable, Coroutine, List, Optional, TypeVar, overload  | 
9 | 10 | from urllib.parse import urlsplit  | 
10 | 11 | 
 
  | 
11 | 12 | from typing_extensions import ParamSpec, TypeAlias  | 
 | 
31 | 32 |     "Status",  | 
32 | 33 |     "all",  | 
33 | 34 |     "any",  | 
 | 35 | +    "batch",  | 
34 | 36 |     "call",  | 
35 | 37 |     "function",  | 
36 | 38 |     "gather",  | 
 | 
44 | 46 | T = TypeVar("T")  | 
45 | 47 | 
 
  | 
46 | 48 | _registry: Optional[Registry] = None  | 
47 |  | - | 
 | 49 | +_workers: List[Callable[None, None]] = []  | 
 | 50 | +_threads: List[threading.Thread] = []  | 
48 | 51 | 
 
  | 
49 | 52 | def default_registry():  | 
50 | 53 |     global _registry  | 
@@ -89,10 +92,35 @@ def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwa  | 
89 | 92 |     parsed_url = urlsplit("//" + address)  | 
90 | 93 |     server_address = (parsed_url.hostname or "", parsed_url.port or 0)  | 
91 | 94 |     server = ThreadingHTTPServer(server_address, Dispatch(default_registry()))  | 
 | 95 | + | 
 | 96 | +    for worker in _workers:  | 
 | 97 | +        def entrypoint():  | 
 | 98 | +            try:  | 
 | 99 | +                worker()  | 
 | 100 | +            finally:  | 
 | 101 | +                server.shutdown()  | 
 | 102 | +        _threads.append(threading.Thread(target=entrypoint))  | 
 | 103 | + | 
 | 104 | +    for thread in _threads:  | 
 | 105 | +        thread.start()  | 
 | 106 | + | 
92 | 107 |     try:  | 
93 | 108 |         if init is not None:  | 
94 | 109 |             init(*args, **kwargs)  | 
95 | 110 |         server.serve_forever()  | 
96 | 111 |     finally:  | 
97 | 112 |         server.shutdown()  | 
98 | 113 |         server.server_close()  | 
 | 114 | + | 
 | 115 | +        for thread in _threads:  | 
 | 116 | +            thread.join()  | 
 | 117 | + | 
 | 118 | +def batch() -> Batch:  | 
 | 119 | +    """Create a new batch object."""  | 
 | 120 | +    return default_registry().batch()  | 
 | 121 | + | 
 | 122 | + | 
 | 123 | +def worker(fn: Callable[None, None]) -> Callable[None, None]:  | 
 | 124 | +    """Decorator declaring workers that will be started when dipatch.run is called."""  | 
 | 125 | +    _workers.append(fn)  | 
 | 126 | +    return fn  | 
0 commit comments