Skip to content

Commit 8e62afb

Browse files
authored
[ENH] Make better use of concurrency (#288)
We recently made the switch to use asyncio for our database connection and API clients. The change was made with as little changes as possible, which means that there were multiple call sites where async calls could be sent off and awaited together instead of in sequence.
1 parent 5c30ef7 commit 8e62afb

6 files changed

Lines changed: 91 additions & 63 deletions

File tree

src/routers/mldcat_ap/dataset.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Specific queries could be written to fetch e.g., a single feature or quality.
55
"""
66

7+
import asyncio
78
from typing import Annotated
89

910
from fastapi import APIRouter, Depends, HTTPException
@@ -46,13 +47,16 @@ async def get_mldcat_ap_distribution(
4647
) -> JsonLDGraph:
4748
assert user_db is not None # noqa: S101
4849
assert expdb is not None # noqa: S101
49-
oml_dataset = await get_dataset(
50-
dataset_id=distribution_id,
51-
user=user,
52-
user_db=user_db,
53-
expdb_db=expdb,
50+
oml_dataset, openml_features, oml_qualities = await asyncio.gather(
51+
get_dataset(
52+
dataset_id=distribution_id,
53+
user=user,
54+
user_db=user_db,
55+
expdb_db=expdb,
56+
),
57+
get_dataset_features(distribution_id, user, expdb),
58+
get_qualities(distribution_id, user, expdb),
5459
)
55-
openml_features = await get_dataset_features(distribution_id, user, expdb)
5660
features = [
5761
Feature(
5862
id_=f"{_server_url}/feature/{distribution_id}/{feature.index}",
@@ -61,7 +65,6 @@ async def get_mldcat_ap_distribution(
6165
)
6266
for feature in openml_features
6367
]
64-
oml_qualities = await get_qualities(distribution_id, user, expdb)
6568
qualities = [
6669
Quality(
6770
id_=f"{_server_url}/quality/{quality.name}/{distribution_id}",

src/routers/openml/datasets.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import re
23
from datetime import datetime
34
from enum import StrEnum
@@ -298,8 +299,10 @@ async def get_dataset_features(
298299
) -> list[Feature]:
299300
assert expdb is not None # noqa: S101
300301
await _get_dataset_raise_otherwise(dataset_id, user, expdb)
301-
features = await database.datasets.get_features(dataset_id, expdb)
302-
ontologies = await database.datasets.get_feature_ontologies(dataset_id, expdb)
302+
features, ontologies = await asyncio.gather(
303+
database.datasets.get_features(dataset_id, expdb),
304+
database.datasets.get_feature_ontologies(dataset_id, expdb),
305+
)
303306
for feature in features:
304307
feature.ontology = ontologies.get(feature.index)
305308

@@ -402,10 +405,12 @@ async def get_dataset(
402405
msg = f"No data file found for dataset {dataset_id}."
403406
raise DatasetNoDataFileError(msg)
404407

405-
tags = await database.datasets.get_tags_for(dataset_id, expdb_db)
406-
description = await database.datasets.get_description(dataset_id, expdb_db)
407-
processing_result = await _get_processing_information(dataset_id, expdb_db)
408-
status = await database.datasets.get_status(dataset_id, expdb_db)
408+
tags, description, processing_result, status = await asyncio.gather(
409+
database.datasets.get_tags_for(dataset_id, expdb_db),
410+
database.datasets.get_description(dataset_id, expdb_db),
411+
_get_processing_information(dataset_id, expdb_db),
412+
database.datasets.get_status(dataset_id, expdb_db),
413+
)
409414

410415
status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION
411416

src/routers/openml/flows.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Annotated, Literal
23

34
from fastapi import APIRouter, Depends
@@ -40,7 +41,11 @@ async def get_flow(
4041
msg = f"Flow with id {flow_id} not found."
4142
raise FlowNotFoundError(msg)
4243

43-
parameter_rows = await database.flows.get_parameters(flow_id, expdb)
44+
parameter_rows, tags, subflow_rows = await asyncio.gather(
45+
database.flows.get_parameters(flow_id, expdb),
46+
database.flows.get_tags(flow_id, expdb),
47+
database.flows.get_subflows(flow_id, expdb),
48+
)
4449
parameters = [
4550
Parameter(
4651
name=parameter.name,
@@ -53,9 +58,6 @@ async def get_flow(
5358
)
5459
for parameter in parameter_rows
5560
]
56-
57-
tags = await database.flows.get_tags(flow_id, expdb)
58-
subflow_rows = await database.flows.get_subflows(flow_id, expdb)
5961
subflows = []
6062
for subflow in subflow_rows:
6163
subflows.append( # noqa: PERF401

src/routers/openml/setups.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""All endpoints that relate to setups."""
22

3+
import asyncio
34
from typing import Annotated
45

56
from fastapi import APIRouter, Body, Depends
@@ -27,11 +28,13 @@ async def tag_setup(
2728
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
2829
) -> dict[str, dict[str, str | list[str]]]:
2930
"""Add tag `tag` to setup with id `setup_id`."""
30-
if not await database.setups.get(setup_id, expdb_db):
31+
setup, setup_tags = await asyncio.gather(
32+
database.setups.get(setup_id, expdb_db),
33+
database.setups.get_tags(setup_id, expdb_db),
34+
)
35+
if not setup:
3136
msg = f"Setup {setup_id} not found."
3237
raise SetupNotFoundError(msg)
33-
34-
setup_tags = await database.setups.get_tags(setup_id, expdb_db)
3538
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)
3639

3740
if matched_tag_row:
@@ -51,11 +54,13 @@ async def untag_setup(
5154
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
5255
) -> dict[str, dict[str, str | list[str]]]:
5356
"""Remove tag `tag` from setup with id `setup_id`."""
54-
if not await database.setups.get(setup_id, expdb_db):
57+
setup, setup_tags = await asyncio.gather(
58+
database.setups.get(setup_id, expdb_db),
59+
database.setups.get_tags(setup_id, expdb_db),
60+
)
61+
if not setup:
5562
msg = f"Setup {setup_id} not found."
5663
raise SetupNotFoundError(msg)
57-
58-
setup_tags = await database.setups.get_tags(setup_id, expdb_db)
5964
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)
6065

6166
if not matched_tag_row:

src/routers/openml/tasks.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import re
34
from typing import Annotated, cast
@@ -169,23 +170,30 @@ async def get_task(
169170
msg = f"Task {task_id} has task type {task.ttid}, but task type {task.ttid} is not found."
170171
raise InternalError(msg)
171172

173+
task_input_rows, ttios, tags = await asyncio.gather(
174+
database.tasks.get_input_for_task(task_id, expdb),
175+
database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb),
176+
database.tasks.get_tags(task_id, expdb),
177+
)
172178
task_inputs = {
173-
row.input: int(row.value) if row.value.isdigit() else row.value
174-
for row in await database.tasks.get_input_for_task(task_id, expdb)
179+
row.input: int(row.value) if row.value.isdigit() else row.value for row in task_input_rows
175180
}
176-
ttios = await database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb)
177181
templates = [(tt_io.name, tt_io.io, tt_io.requirement, tt_io.template_api) for tt_io in ttios]
182+
input_templates = [
183+
(name, template) for name, io, required, template in templates if io == "input"
184+
]
185+
filled_templates = await asyncio.gather(
186+
*[fill_template(template, task, task_inputs, expdb) for name, template in input_templates],
187+
)
178188
inputs = [
179-
await fill_template(template, task, task_inputs, expdb) | {"name": name}
180-
for name, io, required, template in templates
181-
if io == "input"
189+
filled | {"name": name}
190+
for (name, _), filled in zip(input_templates, filled_templates, strict=True)
182191
]
183192
outputs = [
184193
convert_template_xml_to_json(template) | {"name": name}
185194
for name, io, required, template in templates
186195
if io == "output"
187196
]
188-
tags = await database.tasks.get_tags(task_id, expdb)
189197
name = f"Task {task_id} ({task_type.name})"
190198
dataset_id = task_inputs.get("source_data")
191199
if isinstance(dataset_id, int) and (dataset := await database.datasets.get(dataset_id, expdb)):

tests/routers/openml/migration/setups_migration_test.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import contextlib
23
import re
34
from collections.abc import AsyncGenerator, Callable, Iterable
@@ -114,14 +115,15 @@ async def test_setup_untag_response_is_identical_setup_doesnt_exist(
114115
tag = "totally_new_tag_for_migration_testing"
115116
api_key = ApiKey.SOME_USER
116117

117-
original = await php_api.post(
118-
"/setup/untag",
119-
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
120-
)
121-
122-
new = await py_api.post(
123-
f"/setup/untag?api_key={api_key}",
124-
json={"setup_id": setup_id, "tag": tag},
118+
original, new = await asyncio.gather(
119+
php_api.post(
120+
"/setup/untag",
121+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
122+
),
123+
py_api.post(
124+
f"/setup/untag?api_key={api_key}",
125+
json={"setup_id": setup_id, "tag": tag},
126+
),
125127
)
126128

127129
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
@@ -142,14 +144,15 @@ async def test_setup_untag_response_is_identical_tag_doesnt_exist(
142144
tag = "totally_new_tag_for_migration_testing"
143145
api_key = ApiKey.SOME_USER
144146

145-
original = await php_api.post(
146-
"/setup/untag",
147-
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
148-
)
149-
150-
new = await py_api.post(
151-
f"/setup/untag?api_key={api_key}",
152-
json={"setup_id": setup_id, "tag": tag},
147+
original, new = await asyncio.gather(
148+
php_api.post(
149+
"/setup/untag",
150+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
151+
),
152+
py_api.post(
153+
f"/setup/untag?api_key={api_key}",
154+
json={"setup_id": setup_id, "tag": tag},
155+
),
153156
)
154157

155158
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
@@ -223,14 +226,15 @@ async def test_setup_tag_response_is_identical_setup_doesnt_exist(
223226
tag = "totally_new_tag_for_migration_testing"
224227
api_key = ApiKey.SOME_USER
225228

226-
original = await php_api.post(
227-
"/setup/tag",
228-
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
229-
)
230-
231-
new = await py_api.post(
232-
f"/setup/tag?api_key={api_key}",
233-
json={"setup_id": setup_id, "tag": tag},
229+
original, new = await asyncio.gather(
230+
php_api.post(
231+
"/setup/tag",
232+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
233+
),
234+
py_api.post(
235+
f"/setup/tag?api_key={api_key}",
236+
json={"setup_id": setup_id, "tag": tag},
237+
),
234238
)
235239

236240
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
@@ -253,15 +257,16 @@ async def test_setup_tag_response_is_identical_tag_already_exists(
253257
api_key = ApiKey.SOME_USER
254258

255259
async with temporary_tags(tags=[tag], setup_id=setup_id, persist=True):
256-
original = await php_api.post(
257-
"/setup/tag",
258-
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
259-
)
260-
261-
# In Python, since PHP committed it, it's also there for Python test context
262-
new = await py_api.post(
263-
f"/setup/tag?api_key={api_key}",
264-
json={"setup_id": setup_id, "tag": tag},
260+
# Both APIs can be tested in parallel since the tag is already persisted
261+
original, new = await asyncio.gather(
262+
php_api.post(
263+
"/setup/tag",
264+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
265+
),
266+
py_api.post(
267+
f"/setup/tag?api_key={api_key}",
268+
json={"setup_id": setup_id, "tag": tag},
269+
),
265270
)
266271

267272
assert original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR

0 commit comments

Comments
 (0)