Skip to content

Commit 419c1ff

Browse files
Support Trio with httpx (#3089)
* Support trio with httpx * Fix lint * Stop trying to run clients tests twice in a row * Remove anyio from runtime dependencies * Use custom _sleep function * Stop pretending that we want/can run YAML test with both trio and asyncio * Document Trio support * Remove debug print * Remove useless anyio fixture * Handle more markers * Fix references to asyncio * Use anyio for bulk flush timeout * linter fixes * revert pyproject.toml to official transport releases --------- Co-authored-by: Miguel Grinberg <miguel.grinberg@gmail.com>
1 parent c275ded commit 419c1ff

33 files changed

+263
-220
lines changed

docs/reference/async.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,27 @@ All APIs that are available under the sync client are also available under the a
5050

5151
See also the [Using OpenTelemetry](/reference/opentelemetry.md) page.
5252

53+
## Trio support
54+
55+
If you prefer using Trio instead of asyncio to take advantage of its better structured concurrency support, you can use the HTTPX async node which supports Trio out of the box.
56+
57+
```python
58+
import trio
59+
from elasticsearch import AsyncElasticsearch
60+
61+
client = AsyncElasticsearch(
62+
"https://...",
63+
api_key="...",
64+
node_class="httpxasync")
65+
66+
async def main():
67+
resp = await client.info()
68+
print(resp.body)
69+
70+
trio.run(main)
71+
```
72+
73+
The one limitation of Trio support is that it does not currently support node sniffing, which was not implemented with structured concurrency in mind.
5374

5475
## Frequently Asked Questions [_frequently_asked_questions]
5576

docs/reference/dsl_how_to_guides.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,12 @@ The DSL module supports async/await with [asyncio](https://docs.python.org/3/lib
15551555
$ python -m pip install "elasticsearch[async]"
15561556
```
15571557

1558+
The DSL module also supports [Trio](https://trio.readthedocs.io/en/stable/) when using the Async HTTPX client. You do need to install Trio and HTTPX separately:
1559+
1560+
```bash
1561+
$ python -m pip install "elasticsearch trio httpx"
1562+
```
1563+
15581564
### Connections [_connections]
15591565

15601566
Use the `async_connections` module to manage your asynchronous connections.
@@ -1565,6 +1571,14 @@ from elasticsearch.dsl import async_connections
15651571
async_connections.create_connection(hosts=['localhost'], timeout=20)
15661572
```
15671573

1574+
If you're using Trio, you need to explicitly request the Async HTTP client:
1575+
1576+
```python
1577+
from elasticsearch.dsl import async_connections
1578+
1579+
async_connections.create_connection(hosts=['localhost'], node_class="httpxasync")
1580+
```
1581+
15681582
All the options available in the `connections` module can be used with `async_connections`.
15691583

15701584
#### How to avoid *Unclosed client session / connector* warnings on exit [_how_to_avoid_unclosed_client_session_connector_warnings_on_exit]
@@ -1576,8 +1590,6 @@ es = async_connections.get_connection()
15761590
await es.close()
15771591
```
15781592

1579-
1580-
15811593
### Search DSL [_search_dsl]
15821594

15831595
Use the `AsyncSearch` class to perform asynchronous searches.

elasticsearch/_async/helpers.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
Union,
3434
)
3535

36-
from ..compat import safe_task
36+
import sniffio
37+
from anyio import create_memory_object_stream, create_task_group, move_on_after
38+
3739
from ..exceptions import ApiError, NotFoundError, TransportError
3840
from ..helpers.actions import (
3941
_TYPE_BULK_ACTION,
@@ -57,6 +59,15 @@
5759
T = TypeVar("T")
5860

5961

62+
async def _sleep(seconds: float) -> None:
63+
if sniffio.current_async_library() == "trio":
64+
import trio
65+
66+
await trio.sleep(seconds)
67+
else:
68+
await asyncio.sleep(seconds)
69+
70+
6071
async def _chunk_actions(
6172
actions: AsyncIterable[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY],
6273
chunk_size: int,
@@ -82,32 +93,36 @@ async def _chunk_actions(
8293
chunk_size=chunk_size, max_chunk_bytes=max_chunk_bytes, serializer=serializer
8394
)
8495

96+
action: _TYPE_BULK_ACTION_WITH_META
97+
data: _TYPE_BULK_ACTION_BODY
8598
if not flush_after_seconds:
8699
async for action, data in actions:
87100
ret = chunker.feed(action, data)
88101
if ret:
89102
yield ret
90103
else:
91-
item_queue: asyncio.Queue[_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY] = (
92-
asyncio.Queue()
93-
)
104+
sender, receiver = create_memory_object_stream[
105+
_TYPE_BULK_ACTION_HEADER_WITH_META_AND_BODY
106+
]()
94107

95108
async def get_items() -> None:
96109
try:
97110
async for item in actions:
98-
await item_queue.put(item)
111+
await sender.send(item)
99112
finally:
100-
await item_queue.put((BulkMeta.done, None))
113+
await sender.send((BulkMeta.done, None))
114+
115+
async with create_task_group() as tg:
116+
tg.start_soon(get_items)
101117

102-
async with safe_task(get_items()):
103118
timeout: Optional[float] = flush_after_seconds
104119
while True:
105-
try:
106-
action, data = await asyncio.wait_for(
107-
item_queue.get(), timeout=timeout
108-
)
120+
action = {}
121+
data = None
122+
with move_on_after(timeout) as scope:
123+
action, data = await receiver.receive()
109124
timeout = flush_after_seconds
110-
except asyncio.TimeoutError:
125+
if scope.cancelled_caught:
111126
action, data = BulkMeta.flush, None
112127
timeout = None
113128

@@ -294,9 +309,7 @@ async def map_actions() -> (
294309
]
295310
] = []
296311
if attempt:
297-
await asyncio.sleep(
298-
min(max_backoff, initial_backoff * 2 ** (attempt - 1))
299-
)
312+
await _sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1)))
300313

301314
try:
302315
data: Union[

elasticsearch/compat.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import asyncio
1918
import inspect
2019
import os
2120
import sys
22-
from contextlib import asynccontextmanager, contextmanager
21+
from contextlib import contextmanager
2322
from pathlib import Path
2423
from threading import Thread
25-
from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, Tuple, Type, Union
24+
from typing import Any, Callable, Iterator, Tuple, Type, Union
2625

2726
string_types: Tuple[Type[str], Type[bytes]] = (str, bytes)
2827

@@ -105,22 +104,10 @@ def run() -> None:
105104
raise captured_exception
106105

107106

108-
@asynccontextmanager
109-
async def safe_task(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[asyncio.Task[Any]]:
110-
"""Run a background task within a context manager block.
111-
112-
The task is awaited when the block ends.
113-
"""
114-
task = asyncio.create_task(coro)
115-
yield task
116-
await task
117-
118-
119107
__all__ = [
120108
"string_types",
121109
"to_str",
122110
"to_bytes",
123111
"warn_stacklevel",
124112
"safe_thread",
125-
"safe_task",
126113
]

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ keywords = [
4040
]
4141
dynamic = ["version"]
4242
dependencies = [
43-
"elastic-transport>=9.1.0,<10",
43+
"elastic-transport>=9.2.0,<10",
4444
"python-dateutil",
4545
"typing-extensions",
46+
"sniffio",
47+
"anyio",
4648
]
4749

4850
[project.optional-dependencies]
@@ -55,6 +57,7 @@ vectorstore_mmr = ["numpy>=1", "simsimd>=3"]
5557
dev = [
5658
"requests>=2, <3",
5759
"aiohttp",
60+
"httpx",
5861
"pytest",
5962
"pytest-cov",
6063
"pytest-mock",
@@ -77,6 +80,7 @@ dev = [
7780
"mapbox-vector-tile",
7881
"jinja2",
7982
"tqdm",
83+
"trio",
8084
"mypy",
8185
"pyright",
8286
"types-python-dateutil",

test_elasticsearch/test_async/test_server/conftest.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,26 @@
1616
# under the License.
1717

1818
import pytest
19-
import pytest_asyncio
19+
import sniffio
2020

2121
import elasticsearch
2222

2323
from ...utils import CA_CERTS, wipe_cluster
2424

25-
pytestmark = pytest.mark.asyncio
2625

27-
28-
@pytest_asyncio.fixture(scope="function")
26+
@pytest.fixture(scope="function")
2927
async def async_client_factory(elasticsearch_url):
30-
31-
if not hasattr(elasticsearch, "AsyncElasticsearch"):
32-
pytest.skip("test requires 'AsyncElasticsearch' and aiohttp to be installed")
33-
28+
kwargs = {}
29+
if sniffio.current_async_library() == "trio":
30+
kwargs["node_class"] = "httpxasync"
3431
# Unfortunately the asyncio client needs to be rebuilt every
3532
# test execution due to how pytest-asyncio manages
3633
# event loops (one per test!)
3734
client = None
3835
try:
39-
client = elasticsearch.AsyncElasticsearch(elasticsearch_url, ca_certs=CA_CERTS)
36+
client = elasticsearch.AsyncElasticsearch(
37+
elasticsearch_url, ca_certs=CA_CERTS, **kwargs
38+
)
4039
yield client
4140
finally:
4241
if client:

test_elasticsearch/test_async/test_server/test_clients.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import pytest
2020

21-
pytestmark = pytest.mark.asyncio
21+
pytestmark = pytest.mark.anyio
2222

2323

2424
@pytest.mark.parametrize("kwargs", [{"body": {"text": "привет"}}, {"text": "привет"}])

test_elasticsearch/test_async/test_server/test_helpers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,20 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
import asyncio
1918
import logging
2019
import time
2120
from datetime import datetime, timedelta, timezone
2221
from unittest.mock import MagicMock, call, patch
2322

23+
import anyio
2424
import pytest
25-
import pytest_asyncio
2625
from elastic_transport import ApiResponseMeta, ObjectApiResponse
2726

2827
from elasticsearch import helpers
2928
from elasticsearch.exceptions import ApiError
3029
from elasticsearch.helpers import ScanError
3130

32-
pytestmark = [pytest.mark.asyncio]
31+
pytestmark = pytest.mark.anyio
3332

3433

3534
class AsyncMock(MagicMock):
@@ -93,7 +92,7 @@ async def test_all_documents_get_inserted(self, async_client):
9392
async def test_documents_data_types(self, async_client):
9493
async def async_gen():
9594
for x in range(100):
96-
await asyncio.sleep(0)
95+
await anyio.sleep(0)
9796
yield {"answer": x, "_id": x}
9897

9998
def sync_gen():
@@ -129,7 +128,7 @@ async def async_gen():
129128
yield {"answer": 2, "_id": 0}
130129
yield {"answer": 1, "_id": 1}
131130
yield helpers.BULK_FLUSH
132-
await asyncio.sleep(0.5)
131+
await anyio.sleep(0.5)
133132
yield {"answer": 2, "_id": 2}
134133

135134
timestamps = []
@@ -146,7 +145,7 @@ async def test_timeout_flushes(self, async_client):
146145
async def async_gen():
147146
yield {"answer": 2, "_id": 0}
148147
yield {"answer": 1, "_id": 1}
149-
await asyncio.sleep(0.5)
148+
await anyio.sleep(0.5)
150149
yield {"answer": 2, "_id": 2}
151150

152151
timestamps = []
@@ -531,7 +530,7 @@ def __await__(self):
531530
return self().__await__()
532531

533532

534-
@pytest_asyncio.fixture(scope="function")
533+
@pytest.fixture(scope="function")
535534
async def scan_teardown(async_client):
536535
yield
537536
await async_client.clear_scroll(scroll_id="_all")
@@ -955,7 +954,7 @@ async def test_scan_from_keyword_is_aliased(async_client, scan_kwargs):
955954
assert "from" not in search_mock.call_args[1]
956955

957956

958-
@pytest_asyncio.fixture(scope="function")
957+
@pytest.fixture(scope="function")
959958
async def reindex_setup(async_client):
960959
bulk = []
961960
for x in range(100):
@@ -1033,7 +1032,7 @@ async def test_all_documents_get_moved(self, async_client, reindex_setup):
10331032
)["_source"]
10341033

10351034

1036-
@pytest_asyncio.fixture(scope="function")
1035+
@pytest.fixture(scope="function")
10371036
async def parent_reindex_setup(async_client):
10381037
body = {
10391038
"settings": {"number_of_shards": 1, "number_of_replicas": 0},
@@ -1094,7 +1093,7 @@ async def test_children_are_reindexed_correctly(
10941093
} == q
10951094

10961095

1097-
@pytest_asyncio.fixture(scope="function")
1096+
@pytest.fixture(scope="function")
10981097
async def reindex_data_stream_setup(async_client):
10991098
dt = datetime.now(tz=timezone.utc)
11001099
bulk = []

test_elasticsearch/test_async/test_server/test_mapbox_vector_tile.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,13 @@
1616
# under the License.
1717

1818
import pytest
19-
import pytest_asyncio
2019

2120
from elasticsearch import RequestError
2221

23-
pytestmark = pytest.mark.asyncio
22+
pytestmark = pytest.mark.anyio
2423

2524

26-
@pytest_asyncio.fixture(scope="function")
25+
@pytest.fixture(scope="function")
2726
async def mvt_setup(async_client):
2827
await async_client.indices.create(
2928
index="museums",

test_elasticsearch/test_async/test_server/test_rest_api_spec.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import warnings
2626

2727
import pytest
28-
import pytest_asyncio
2928

3029
from elasticsearch import ElasticsearchWarning, RequestError
3130

@@ -39,6 +38,8 @@
3938
)
4039
from ...utils import parse_version
4140

41+
# We're not using `pytest.mark.anyio` here because it would run the test suite twice,
42+
# which does not work as it does not fully clean up after itself.
4243
pytestmark = pytest.mark.asyncio
4344

4445
XPACK_FEATURES = None
@@ -240,7 +241,7 @@ async def _feature_enabled(self, name):
240241
return name in XPACK_FEATURES
241242

242243

243-
@pytest_asyncio.fixture(scope="function")
244+
@pytest.fixture(scope="function")
244245
def async_runner(async_client_factory):
245246
return AsyncYamlRunner(async_client_factory)
246247

0 commit comments

Comments
 (0)