Skip to content

Commit 5357b01

Browse files
igennovaigennova
andauthored
[ENH] Add POST /setup/tag endpoint (#271)
Fixes: #64 - Implemented the POST /setup/tag FastAPI endpoint and underlying database logic. - Modified the response to return the full array of remaining tags to match the old PHP API behavior. - Added comprehensive unit tests and `py_api` vs `php_api migration` parity tests. - Updated `docs/migration.md` to document the new tag array response behavior. --------- Co-authored-by: igennova <luckynegi025@gmail.com>
1 parent c9ada15 commit 5357b01

6 files changed

Lines changed: 268 additions & 27 deletions

File tree

docs/migration.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ For example, after tagging dataset 21 with the tag `"foo"`:
107107
}
108108
```
109109

110+
## Setups
111+
112+
### `POST /setup/tag` and `POST /setup/untag`
113+
When successful, the "tag" property in the returned response is now always a list, even if only one tag exists for the entity. When removing the last tag, the "tag" property will be an empty list `[]` instead of being omitted from the response.
114+
110115
## Studies
111116

112117
### `GET /{id_or_alias}`

src/database/setups.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,16 @@ async def untag(setup_id: int, tag: str, connection: AsyncConnection) -> None:
4646
),
4747
parameters={"setup_id": setup_id, "tag": tag},
4848
)
49+
50+
51+
async def tag(setup_id: int, tag: str, user_id: int, connection: AsyncConnection) -> None:
52+
"""Add tag `tag` to setup with id `setup_id`."""
53+
await connection.execute(
54+
text(
55+
"""
56+
INSERT INTO setup_tag (id, tag, uploader)
57+
VALUES (:setup_id, :tag, :user_id)
58+
""",
59+
),
60+
parameters={"setup_id": setup_id, "tag": tag, "user_id": user_id},
61+
)

src/routers/openml/setups.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,43 @@
66
from sqlalchemy.ext.asyncio import AsyncConnection
77

88
import database.setups
9-
from core.errors import SetupNotFoundError, TagNotFoundError, TagNotOwnedError
9+
from core.errors import (
10+
SetupNotFoundError,
11+
TagAlreadyExistsError,
12+
TagNotFoundError,
13+
TagNotOwnedError,
14+
)
1015
from database.users import User, UserGroup
1116
from routers.dependencies import expdb_connection, fetch_user_or_raise
1217
from routers.types import SystemString64
1318

1419
router = APIRouter(prefix="/setup", tags=["setup"])
1520

1621

22+
@router.post(path="/tag")
23+
async def tag_setup(
24+
setup_id: Annotated[int, Body()],
25+
tag: Annotated[str, SystemString64],
26+
user: Annotated[User, Depends(fetch_user_or_raise)],
27+
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
28+
) -> dict[str, dict[str, str | list[str]]]:
29+
"""Add tag `tag` to setup with id `setup_id`."""
30+
if not await database.setups.get(setup_id, expdb_db):
31+
msg = f"Setup {setup_id} not found."
32+
raise SetupNotFoundError(msg)
33+
34+
setup_tags = await database.setups.get_tags(setup_id, expdb_db)
35+
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)
36+
37+
if matched_tag_row:
38+
msg = f"Setup {setup_id} already has tag {tag!r}."
39+
raise TagAlreadyExistsError(msg)
40+
41+
await database.setups.tag(setup_id, tag, user.user_id, expdb_db)
42+
all_tags = [t.tag for t in setup_tags] + [tag]
43+
return {"setup_tag": {"id": str(setup_id), "tag": all_tags}}
44+
45+
1746
@router.post(path="/untag")
1847
async def untag_setup(
1948
setup_id: Annotated[int, Body()],
@@ -40,5 +69,7 @@ async def untag_setup(
4069
raise TagNotOwnedError(msg)
4170

4271
await database.setups.untag(setup_id, matched_tag_row.tag, expdb_db)
43-
remaining_tags = [t.tag.casefold() for t in setup_tags if t != matched_tag_row]
72+
remaining_tags = [
73+
t.tag for t in setup_tags if t.tag.casefold() != matched_tag_row.tag.casefold()
74+
]
4475
return {"setup_untag": {"id": str(setup_id), "tag": remaining_tags}}

tests/conftest.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import json
3-
from collections.abc import AsyncIterator, Iterator
3+
from collections.abc import AsyncIterator, Iterable, Iterator
44
from pathlib import Path
55
from typing import Any, NamedTuple
66

@@ -28,6 +28,29 @@ async def automatic_rollback(engine: AsyncEngine) -> AsyncIterator[AsyncConnecti
2828
await transaction.rollback()
2929

3030

31+
@contextlib.asynccontextmanager
32+
async def temporary_records(
33+
connection: AsyncConnection,
34+
insert_queries: Iterable[tuple[str, dict[str, Any] | None]],
35+
delete_queries: Iterable[tuple[str, dict[str, Any] | None]],
36+
*,
37+
persist: bool = False,
38+
) -> AsyncIterator[None]:
39+
"""Execute insert queries on enter and their corresponding delete queries on exit."""
40+
for query, parameters in insert_queries:
41+
await connection.execute(text(query), parameters=parameters)
42+
if persist:
43+
await connection.commit()
44+
45+
try:
46+
yield
47+
finally:
48+
for query, parameters in delete_queries:
49+
await connection.execute(text(query), parameters=parameters)
50+
if persist:
51+
await connection.commit()
52+
53+
3154
@pytest.fixture
3255
async def expdb_test() -> AsyncIterator[AsyncConnection]:
3356
async with automatic_rollback(expdb_database()) as connection:

tests/routers/openml/migration/setups_migration_test.py

Lines changed: 146 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,51 @@
11
import contextlib
22
import re
3-
from collections.abc import AsyncGenerator, Iterable
3+
from collections.abc import AsyncGenerator, Callable, Iterable
4+
from contextlib import AbstractAsyncContextManager
45
from http import HTTPStatus
56

67
import httpx
78
import pytest
89
from sqlalchemy import text
910
from sqlalchemy.ext.asyncio import AsyncConnection
1011

12+
from tests.conftest import temporary_records
1113
from tests.users import OWNER_USER, ApiKey
1214

1315

16+
@pytest.fixture
17+
def temporary_tags(
18+
expdb_test: AsyncConnection,
19+
) -> Callable[..., AbstractAsyncContextManager[None]]:
20+
@contextlib.asynccontextmanager
21+
async def _temporary_tags(
22+
tags: Iterable[str], setup_id: int, *, persist: bool = False
23+
) -> AsyncGenerator[None]:
24+
insert_queries = [
25+
(
26+
"INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);",
27+
{"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id},
28+
)
29+
for tag in tags
30+
]
31+
delete_queries = [
32+
(
33+
"DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag",
34+
{"setup_id": setup_id, "tag": tag},
35+
)
36+
for tag in tags
37+
]
38+
async with temporary_records(
39+
connection=expdb_test,
40+
insert_queries=insert_queries,
41+
delete_queries=delete_queries,
42+
persist=persist,
43+
):
44+
yield
45+
46+
return _temporary_tags
47+
48+
1449
@pytest.mark.parametrize(
1550
"api_key",
1651
[ApiKey.ADMIN, ApiKey.SOME_USER, ApiKey.OWNER_USER],
@@ -26,33 +61,11 @@ async def test_setup_untag_response_is_identical_when_tag_exists(
2661
other_tags: list[str],
2762
py_api: httpx.AsyncClient,
2863
php_api: httpx.AsyncClient,
29-
expdb_test: AsyncConnection,
64+
temporary_tags: Callable[..., AbstractAsyncContextManager[None]],
3065
) -> None:
3166
setup_id = 1
3267
tag = "totally_new_tag_for_migration_testing"
3368

34-
@contextlib.asynccontextmanager
35-
async def temporary_tags(
36-
tags: Iterable[str], setup_id: int, *, persist: bool = False
37-
) -> AsyncGenerator[None]:
38-
for tag in tags:
39-
await expdb_test.execute(
40-
text(
41-
"INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);" # noqa: E501
42-
),
43-
parameters={"setup_id": setup_id, "tag": tag, "user_id": OWNER_USER.user_id},
44-
)
45-
if persist:
46-
await expdb_test.commit()
47-
yield
48-
for tag in tags:
49-
await expdb_test.execute(
50-
text("DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag"),
51-
parameters={"setup_id": setup_id, "tag": tag},
52-
)
53-
if persist:
54-
await expdb_test.commit()
55-
5669
all_tags = [tag, *other_tags]
5770
async with temporary_tags(tags=all_tags, setup_id=setup_id, persist=True):
5871
original = await php_api.post(
@@ -147,3 +160,112 @@ async def test_setup_untag_response_is_identical_tag_doesnt_exist(
147160
r"Setup \d+ does not have tag '\S+'.",
148161
new.json()["detail"],
149162
)
163+
164+
165+
@pytest.mark.parametrize(
166+
"api_key",
167+
[ApiKey.ADMIN, ApiKey.SOME_USER],
168+
ids=["Administrator", "non-owner"],
169+
)
170+
@pytest.mark.parametrize(
171+
"other_tags",
172+
[[], ["some_other_tag"], ["foo_some_other_tag", "bar_some_other_tag"]],
173+
ids=["none", "one tag", "two tags"],
174+
)
175+
async def test_setup_tag_response_is_identical_when_tag_doesnt_exist( # noqa: PLR0913
176+
api_key: str,
177+
other_tags: list[str],
178+
py_api: httpx.AsyncClient,
179+
php_api: httpx.AsyncClient,
180+
expdb_test: AsyncConnection,
181+
temporary_tags: Callable[..., AbstractAsyncContextManager[None]],
182+
) -> None:
183+
setup_id = 1
184+
tag = "totally_new_tag_for_migration_testing"
185+
186+
async with temporary_tags(tags=other_tags, setup_id=setup_id, persist=True):
187+
original = await php_api.post(
188+
"/setup/tag",
189+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
190+
)
191+
192+
await expdb_test.execute(
193+
text("DELETE FROM setup_tag WHERE `id`=:setup_id AND `tag`=:tag"),
194+
parameters={"setup_id": setup_id, "tag": tag},
195+
)
196+
await expdb_test.commit()
197+
198+
async with temporary_tags(tags=other_tags, setup_id=setup_id):
199+
new = await py_api.post(
200+
f"/setup/tag?api_key={api_key}",
201+
json={"setup_id": setup_id, "tag": tag},
202+
)
203+
204+
assert new.status_code == HTTPStatus.OK
205+
assert original.status_code == new.status_code
206+
original_tag = original.json()["setup_tag"]
207+
new_tag = new.json()["setup_tag"]
208+
assert original_tag["id"] == new_tag["id"]
209+
if tags := original_tag.get("tag"):
210+
if isinstance(tags, str):
211+
assert tags == new_tag["tag"][0]
212+
else:
213+
assert set(tags) == set(new_tag["tag"])
214+
else:
215+
assert new_tag["tag"] == []
216+
217+
218+
async def test_setup_tag_response_is_identical_setup_doesnt_exist(
219+
py_api: httpx.AsyncClient,
220+
php_api: httpx.AsyncClient,
221+
) -> None:
222+
setup_id = 999999
223+
tag = "totally_new_tag_for_migration_testing"
224+
api_key = ApiKey.SOME_USER
225+
226+
original = await php_api.post(
227+
"/setup/tag",
228+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
229+
)
230+
231+
new = await py_api.post(
232+
f"/setup/tag?api_key={api_key}",
233+
json={"setup_id": setup_id, "tag": tag},
234+
)
235+
236+
assert original.status_code == HTTPStatus.PRECONDITION_FAILED
237+
assert new.status_code == HTTPStatus.NOT_FOUND
238+
assert original.json()["error"]["message"] == "Entity not found."
239+
assert original.json()["error"]["code"] == new.json()["code"]
240+
assert re.match(
241+
r"Setup \d+ not found.",
242+
new.json()["detail"],
243+
)
244+
245+
246+
async def test_setup_tag_response_is_identical_tag_already_exists(
247+
py_api: httpx.AsyncClient,
248+
php_api: httpx.AsyncClient,
249+
temporary_tags: Callable[..., AbstractAsyncContextManager[None]],
250+
) -> None:
251+
setup_id = 1
252+
tag = "totally_new_tag_for_migration_testing"
253+
api_key = ApiKey.SOME_USER
254+
255+
async with temporary_tags(tags=[tag], setup_id=setup_id, persist=True):
256+
original = await php_api.post(
257+
"/setup/tag",
258+
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
259+
)
260+
261+
# In Python, since PHP committed it, it's also there for Python test context
262+
new = await py_api.post(
263+
f"/setup/tag?api_key={api_key}",
264+
json={"setup_id": setup_id, "tag": tag},
265+
)
266+
267+
assert original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
268+
assert new.status_code == HTTPStatus.CONFLICT
269+
assert original.json()["error"]["code"] == new.json()["code"]
270+
assert original.json()["error"]["message"] == "Entity already tagged by this tag."
271+
assert new.json()["detail"] == f"Setup {setup_id} already has tag {tag!r}."

tests/routers/openml/setups_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,50 @@ async def test_setup_untag_success(
8383
text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'test_success_tag'")
8484
)
8585
assert len(rows.all()) == 0
86+
87+
88+
async def test_setup_tag_missing_auth(py_api: httpx.AsyncClient) -> None:
89+
response = await py_api.post("/setup/tag", json={"setup_id": 1, "tag": "test_tag"})
90+
assert response.status_code == HTTPStatus.UNAUTHORIZED
91+
assert response.json()["code"] == "103"
92+
assert response.json()["detail"] == "Authentication failed"
93+
94+
95+
async def test_setup_tag_unknown_setup(py_api: httpx.AsyncClient) -> None:
96+
response = await py_api.post(
97+
f"/setup/tag?api_key={ApiKey.SOME_USER}",
98+
json={"setup_id": 999999, "tag": "test_tag"},
99+
)
100+
assert response.status_code == HTTPStatus.NOT_FOUND
101+
assert re.match(r"Setup \d+ not found.", response.json()["detail"])
102+
103+
104+
@pytest.mark.mut
105+
async def test_setup_tag_already_exists(
106+
py_api: httpx.AsyncClient, expdb_test: AsyncConnection
107+
) -> None:
108+
await expdb_test.execute(
109+
text("INSERT INTO setup_tag (id, tag, uploader) VALUES (1, 'existing_tag_123', 2);")
110+
)
111+
response = await py_api.post(
112+
f"/setup/tag?api_key={ApiKey.SOME_USER}",
113+
json={"setup_id": 1, "tag": "existing_tag_123"},
114+
)
115+
assert response.status_code == HTTPStatus.CONFLICT
116+
assert response.json()["detail"] == "Setup 1 already has tag 'existing_tag_123'."
117+
118+
119+
@pytest.mark.mut
120+
async def test_setup_tag_success(py_api: httpx.AsyncClient, expdb_test: AsyncConnection) -> None:
121+
response = await py_api.post(
122+
f"/setup/tag?api_key={ApiKey.SOME_USER}",
123+
json={"setup_id": 1, "tag": "my_new_success_tag"},
124+
)
125+
126+
assert response.status_code == HTTPStatus.OK
127+
assert "my_new_success_tag" in response.json()["setup_tag"]["tag"]
128+
129+
rows = await expdb_test.execute(
130+
text("SELECT * FROM setup_tag WHERE id = 1 AND tag = 'my_new_success_tag'")
131+
)
132+
assert len(rows.all()) == 1

0 commit comments

Comments
 (0)