Skip to content

Commit 3d92df9

Browse files
authored
feat(FIR-43324): async cancellation method (#420)
1 parent 4d7008e commit 3d92df9

File tree

9 files changed

+342
-16
lines changed

9 files changed

+342
-16
lines changed

docsrc/Connecting_and_queries.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,23 @@ has finished successfully, None if query is still running and False if the query
654654
else:
655655
print("Query failed")
656656

657+
Cancelling a running query
658+
--------------------------
659+
660+
To cancel a running query, use the :py:meth:`firebolt.db.connection.Connection.cancel_async_query` method. This method
661+
will send a cancel request to the server and the query will be stopped.
662+
663+
::
664+
665+
token = cursor.async_query_token
666+
connection.cancel_async_query(token)
667+
668+
# Verify that the query was cancelled
669+
running = connection.is_async_query_running(token)
670+
print(running) # False
671+
successful = connection.is_async_query_successful(token)
672+
print(successful) # False
673+
657674

658675
Thread safety
659676
==============================

src/firebolt/async_db/connection.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
from firebolt.client.auth import Auth
1212
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
1313
from firebolt.common.base_connection import (
14+
ASYNC_QUERY_CANCEL,
1415
ASYNC_QUERY_STATUS_REQUEST,
1516
ASYNC_QUERY_STATUS_RUNNING,
1617
ASYNC_QUERY_STATUS_SUCCESSFUL,
18+
AsyncQueryInfo,
1719
BaseConnection,
1820
)
1921
from firebolt.common.cache import _firebolt_system_engine_cache
@@ -90,19 +92,33 @@ def cursor(self, **kwargs: Any) -> Cursor:
9092
return c
9193

9294
# Server-side async methods
93-
async def _get_async_query_status(self, token: str) -> str:
95+
async def _get_async_query_info(self, token: str) -> AsyncQueryInfo:
9496
if self.cursor_type != CursorV2:
9597
raise FireboltError(
9698
"This method is only supported for connection with service account."
9799
)
98100
cursor = self.cursor()
99-
await cursor.execute(ASYNC_QUERY_STATUS_REQUEST.format(token=token))
101+
await cursor.execute(ASYNC_QUERY_STATUS_REQUEST, [token])
100102
result = await cursor.fetchone()
101103
if cursor.rowcount != 1 or not result:
102104
raise FireboltError("Unexpected result from async query status request.")
103105
columns = cursor.description
104106
result_dict = dict(zip([column.name for column in columns], result))
105-
return str(result_dict.get("status"))
107+
108+
if not result_dict.get("status") or not result_dict.get("query_id"):
109+
raise FireboltError(
110+
"Something went wrong - async query status request returned "
111+
"unexpected result with status and/or query id missing. "
112+
"Rerun the command and reach out to Firebolt support if "
113+
"the issue persists."
114+
)
115+
116+
# Only pass the expected keys to AsyncQueryInfo
117+
filtered_result_dict = {
118+
k: v for k, v in result_dict.items() if k in AsyncQueryInfo._fields
119+
}
120+
121+
return AsyncQueryInfo(**filtered_result_dict)
106122

107123
async def is_async_query_running(self, token: str) -> bool:
108124
"""
@@ -114,8 +130,8 @@ async def is_async_query_running(self, token: str) -> bool:
114130
Returns:
115131
bool: True if async query is still running, False otherwise
116132
"""
117-
status = await self._get_async_query_status(token)
118-
return status == ASYNC_QUERY_STATUS_RUNNING
133+
async_query_details = await self._get_async_query_info(token)
134+
return async_query_details.status == ASYNC_QUERY_STATUS_RUNNING
119135

120136
async def is_async_query_successful(self, token: str) -> Optional[bool]:
121137
"""
@@ -128,10 +144,22 @@ async def is_async_query_successful(self, token: str) -> Optional[bool]:
128144
bool: None if the query is still running, True if successful,
129145
False otherwise
130146
"""
131-
status = await self._get_async_query_status(token)
132-
if status == ASYNC_QUERY_STATUS_RUNNING:
147+
async_query_details = await self._get_async_query_info(token)
148+
if async_query_details.status == ASYNC_QUERY_STATUS_RUNNING:
133149
return None
134-
return status == ASYNC_QUERY_STATUS_SUCCESSFUL
150+
return async_query_details.status == ASYNC_QUERY_STATUS_SUCCESSFUL
151+
152+
async def cancel_async_query(self, token: str) -> None:
153+
"""
154+
Cancel an async query.
155+
156+
Args:
157+
token: Async query token. Can be obtained from Cursor.async_query_token.
158+
"""
159+
async_query_details = await self._get_async_query_info(token)
160+
async_query_id = async_query_details.query_id
161+
cursor = self.cursor()
162+
await cursor.execute(ASYNC_QUERY_CANCEL, [async_query_id])
135163

136164
# Context manager support
137165
async def __aenter__(self) -> Connection:

src/firebolt/common/base_connection.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
1+
from collections import namedtuple
12
from typing import Any, List, Type
23

34
from firebolt.utils.exception import ConnectionClosedError
45

56
ASYNC_QUERY_STATUS_RUNNING = "RUNNING"
67
ASYNC_QUERY_STATUS_SUCCESSFUL = "ENDED_SUCCESSFULLY"
7-
ASYNC_QUERY_STATUS_REQUEST = "CALL fb_GetAsyncStatus('{token}')"
8+
ASYNC_QUERY_STATUS_REQUEST = "CALL fb_GetAsyncStatus(?)"
9+
ASYNC_QUERY_CANCEL = "CANCEL QUERY WHERE query_id=?"
10+
11+
AsyncQueryInfo = namedtuple(
12+
"AsyncQueryInfo",
13+
[
14+
"account_name",
15+
"user_name",
16+
"submitted_time",
17+
"start_time",
18+
"end_time",
19+
"status",
20+
"request_id",
21+
"query_id",
22+
"error_message",
23+
"scanned_bytes",
24+
"scanned_rows",
25+
"retries",
26+
],
27+
)
828

929

1030
class BaseConnection:

src/firebolt/db/connection.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from firebolt.client import DEFAULT_API_URL, Client, ClientV1, ClientV2
1111
from firebolt.client.auth import Auth
1212
from firebolt.common.base_connection import (
13+
ASYNC_QUERY_CANCEL,
1314
ASYNC_QUERY_STATUS_REQUEST,
1415
ASYNC_QUERY_STATUS_RUNNING,
1516
ASYNC_QUERY_STATUS_SUCCESSFUL,
17+
AsyncQueryInfo,
1618
BaseConnection,
1719
)
1820
from firebolt.common.cache import _firebolt_system_engine_cache
@@ -227,19 +229,34 @@ def close(self) -> None:
227229
self._is_closed = True
228230

229231
# Server-side async methods
230-
def _get_async_query_status(self, token: str) -> str:
232+
233+
def _get_async_query_info(self, token: str) -> AsyncQueryInfo:
231234
if self.cursor_type != CursorV2:
232235
raise FireboltError(
233236
"This method is only supported for connection with service account."
234237
)
235238
cursor = self.cursor()
236-
cursor.execute(ASYNC_QUERY_STATUS_REQUEST.format(token=token))
239+
cursor.execute(ASYNC_QUERY_STATUS_REQUEST, [token])
237240
result = cursor.fetchone()
238241
if cursor.rowcount != 1 or not result:
239242
raise FireboltError("Unexpected result from async query status request.")
240243
columns = cursor.description
241244
result_dict = dict(zip([column.name for column in columns], result))
242-
return result_dict["status"]
245+
246+
if not result_dict.get("status") or not result_dict.get("query_id"):
247+
raise FireboltError(
248+
"Something went wrong - async query status request returned "
249+
"unexpected result with status and/or query id missing. "
250+
"Rerun the command and reach out to Firebolt support if "
251+
"the issue persists."
252+
)
253+
254+
# Only pass the expected keys to AsyncQueryInfo
255+
filtered_result_dict = {
256+
k: v for k, v in result_dict.items() if k in AsyncQueryInfo._fields
257+
}
258+
259+
return AsyncQueryInfo(**filtered_result_dict)
243260

244261
def is_async_query_running(self, token: str) -> bool:
245262
"""
@@ -251,7 +268,7 @@ def is_async_query_running(self, token: str) -> bool:
251268
Returns:
252269
bool: True if async query is still running, False otherwise
253270
"""
254-
return self._get_async_query_status(token) == ASYNC_QUERY_STATUS_RUNNING
271+
return self._get_async_query_info(token).status == ASYNC_QUERY_STATUS_RUNNING
255272

256273
def is_async_query_successful(self, token: str) -> Optional[bool]:
257274
"""
@@ -264,10 +281,21 @@ def is_async_query_successful(self, token: str) -> Optional[bool]:
264281
bool: None if the query is still running, True if successful,
265282
False otherwise
266283
"""
267-
status = self._get_async_query_status(token)
268-
if status == ASYNC_QUERY_STATUS_RUNNING:
284+
async_query_info = self._get_async_query_info(token)
285+
if async_query_info.status == ASYNC_QUERY_STATUS_RUNNING:
269286
return None
270-
return status == ASYNC_QUERY_STATUS_SUCCESSFUL
287+
return async_query_info.status == ASYNC_QUERY_STATUS_SUCCESSFUL
288+
289+
def cancel_async_query(self, token: str) -> None:
290+
"""
291+
Cancel an async query.
292+
293+
Args:
294+
token: Async query token. Can be obtained from Cursor.async_query_token.
295+
"""
296+
async_query_id = self._get_async_query_info(token).query_id
297+
cursor = self.cursor()
298+
cursor.execute(ASYNC_QUERY_CANCEL, [async_query_id])
271299

272300
# Context manager support
273301
def __enter__(self) -> Connection:

tests/integration/dbapi/async/V2/test_server_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,23 @@ async def test_check_async_execution_fails(connection: Connection) -> None:
102102
await cursor.execute_async(f"MALFORMED QUERY")
103103
with raises(FireboltError):
104104
cursor.async_query_token
105+
106+
107+
async def test_cancel_async_query(connection: Connection) -> None:
108+
cursor = connection.cursor()
109+
rnd_suffix = str(randint(0, 1000))
110+
table_name = f"test_insert_async_{rnd_suffix}"
111+
try:
112+
await cursor.execute(f"CREATE TABLE {table_name} (id LONG)")
113+
await cursor.execute_async(f"INSERT INTO {table_name} {LONG_SELECT}")
114+
token = cursor.async_query_token
115+
assert token is not None, "Async token was not returned"
116+
assert await connection.is_async_query_running(token) == True
117+
await connection.cancel_async_query(token)
118+
assert await connection.is_async_query_running(token) == False
119+
assert await connection.is_async_query_successful(token) == False
120+
await cursor.execute(f"SELECT * FROM {table_name}")
121+
result = await cursor.fetchall()
122+
assert result == []
123+
finally:
124+
await cursor.execute(f"DROP TABLE {table_name}")

tests/integration/dbapi/sync/V2/test_server_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,23 @@ def test_check_async_execution_fails(connection: Connection) -> None:
9898
cursor.execute_async(f"MALFORMED QUERY")
9999
with raises(FireboltError):
100100
cursor.async_query_token
101+
102+
103+
def test_cancel_async_query(connection: Connection) -> None:
104+
cursor = connection.cursor()
105+
rnd_suffix = str(randint(0, 1000))
106+
table_name = f"test_insert_async_{rnd_suffix}"
107+
try:
108+
cursor.execute(f"CREATE TABLE {table_name} (id LONG)")
109+
cursor.execute_async(f"INSERT INTO {table_name} {LONG_SELECT}")
110+
token = cursor.async_query_token
111+
assert token is not None, "Async token was not returned"
112+
assert connection.is_async_query_running(token) == True
113+
connection.cancel_async_query(token)
114+
assert connection.is_async_query_running(token) == False
115+
assert connection.is_async_query_successful(token) == False
116+
cursor.execute(f"SELECT * FROM {table_name}")
117+
result = cursor.fetchall()
118+
assert result == []
119+
finally:
120+
cursor.execute(f"DROP TABLE {table_name}")

tests/unit/async_db/test_connection.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,89 @@ async def test_async_query_status_unexpected_result(
492492
await connection.is_async_query_running("token")
493493
with raises(FireboltError):
494494
await connection.is_async_query_successful("token")
495+
496+
497+
async def test_async_query_status_no_id_or_status(
498+
db_name: str,
499+
account_name: str,
500+
engine_name: str,
501+
auth: Auth,
502+
api_endpoint: str,
503+
httpx_mock: HTTPXMock,
504+
query_url: str,
505+
async_query_callback_factory: Callable,
506+
async_query_meta: List[Tuple[str, str]],
507+
async_query_data: List[List[ColType]],
508+
mock_connection_flow: Callable,
509+
):
510+
mock_connection_flow()
511+
data_no_query_id = async_query_data[0].copy()
512+
data_no_query_id[7] = ""
513+
data_no_query_status = async_query_data[0].copy()
514+
data_no_query_status[5] = ""
515+
for data_case in [data_no_query_id, data_no_query_status]:
516+
async_query_status_running_callback = async_query_callback_factory(
517+
[data_case], async_query_meta
518+
)
519+
httpx_mock.add_callback(
520+
async_query_status_running_callback,
521+
url=query_url,
522+
match_content="CALL fb_GetAsyncStatus('token')".encode("utf-8"),
523+
)
524+
async with await connect(
525+
database=db_name,
526+
auth=auth,
527+
engine_name=engine_name,
528+
account_name=account_name,
529+
api_endpoint=api_endpoint,
530+
) as connection:
531+
with raises(FireboltError):
532+
await connection.is_async_query_running("token")
533+
with raises(FireboltError):
534+
await connection.is_async_query_successful("token")
535+
536+
537+
async def test_async_query_cancellation(
538+
db_name: str,
539+
account_name: str,
540+
engine_name: str,
541+
auth: Auth,
542+
api_endpoint: str,
543+
httpx_mock: HTTPXMock,
544+
query_url: str,
545+
query_callback: Callable,
546+
async_query_callback_factory: Callable,
547+
async_query_data: List[List[ColType]],
548+
async_query_meta: List[Tuple[str, str]],
549+
mock_connection_flow: Callable,
550+
):
551+
"""Test async query cancellation"""
552+
mock_connection_flow()
553+
async_query_data[0][5] = "RUNNING"
554+
async_query_status_running_callback = async_query_callback_factory(
555+
async_query_data, async_query_meta
556+
)
557+
558+
query_dict = dict(zip([m[0] for m in async_query_meta], async_query_data[0]))
559+
query_id = query_dict["query_id"]
560+
561+
httpx_mock.add_callback(
562+
async_query_status_running_callback,
563+
url=query_url,
564+
match_content="CALL fb_GetAsyncStatus('token')".encode("utf-8"),
565+
)
566+
567+
httpx_mock.add_callback(
568+
query_callback,
569+
url=query_url,
570+
match_content=f"CANCEL QUERY WHERE query_id='{query_id}'".encode("utf-8"),
571+
)
572+
573+
async with await connect(
574+
database=db_name,
575+
auth=auth,
576+
engine_name=engine_name,
577+
account_name=account_name,
578+
api_endpoint=api_endpoint,
579+
) as connection:
580+
await connection.cancel_async_query("token")

0 commit comments

Comments
 (0)