1111from firebolt .client .auth import Auth
1212from firebolt .client .client import AsyncClient , AsyncClientV1 , AsyncClientV2
1313from firebolt .common .base_connection import (
14+ ASYNC_QUERY_CANCEL ,
1415 ASYNC_QUERY_STATUS_REQUEST ,
1516 ASYNC_QUERY_STATUS_RUNNING ,
1617 ASYNC_QUERY_STATUS_SUCCESSFUL ,
18+ AsyncQueryInfo ,
1719 BaseConnection ,
1820)
1921from firebolt .common .cache import _firebolt_system_engine_cache
@@ -90,19 +92,33 @@ def cursor(self, **kwargs: Any) -> Cursor:
9092 return c
9193
9294 # Server-side async methods
93- async def _get_async_query_status (self , token : str ) -> str :
95+ async def _get_async_query_info (self , token : str ) -> AsyncQueryInfo :
9496 if self .cursor_type != CursorV2 :
9597 raise FireboltError (
9698 "This method is only supported for connection with service account."
9799 )
98100 cursor = self .cursor ()
99- await cursor .execute (ASYNC_QUERY_STATUS_REQUEST . format ( token = token ) )
101+ await cursor .execute (ASYNC_QUERY_STATUS_REQUEST , [ token ] )
100102 result = await cursor .fetchone ()
101103 if cursor .rowcount != 1 or not result :
102104 raise FireboltError ("Unexpected result from async query status request." )
103105 columns = cursor .description
104106 result_dict = dict (zip ([column .name for column in columns ], result ))
105- return str (result_dict .get ("status" ))
107+
108+ if not result_dict .get ("status" ) or not result_dict .get ("query_id" ):
109+ raise FireboltError (
110+ "Something went wrong - async query status request returned "
111+ "unexpected result with status and/or query id missing. "
112+ "Rerun the command and reach out to Firebolt support if "
113+ "the issue persists."
114+ )
115+
116+ # Only pass the expected keys to AsyncQueryInfo
117+ filtered_result_dict = {
118+ k : v for k , v in result_dict .items () if k in AsyncQueryInfo ._fields
119+ }
120+
121+ return AsyncQueryInfo (** filtered_result_dict )
106122
107123 async def is_async_query_running (self , token : str ) -> bool :
108124 """
@@ -114,8 +130,8 @@ async def is_async_query_running(self, token: str) -> bool:
114130 Returns:
115131 bool: True if async query is still running, False otherwise
116132 """
117- status = await self ._get_async_query_status (token )
118- return status == ASYNC_QUERY_STATUS_RUNNING
133+ async_query_details = await self ._get_async_query_info (token )
134+ return async_query_details . status == ASYNC_QUERY_STATUS_RUNNING
119135
120136 async def is_async_query_successful (self , token : str ) -> Optional [bool ]:
121137 """
@@ -128,10 +144,22 @@ async def is_async_query_successful(self, token: str) -> Optional[bool]:
128144 bool: None if the query is still running, True if successful,
129145 False otherwise
130146 """
131- status = await self ._get_async_query_status (token )
132- if status == ASYNC_QUERY_STATUS_RUNNING :
147+ async_query_details = await self ._get_async_query_info (token )
148+ if async_query_details . status == ASYNC_QUERY_STATUS_RUNNING :
133149 return None
134- return status == ASYNC_QUERY_STATUS_SUCCESSFUL
150+ return async_query_details .status == ASYNC_QUERY_STATUS_SUCCESSFUL
151+
152+ async def cancel_async_query (self , token : str ) -> None :
153+ """
154+ Cancel an async query.
155+
156+ Args:
157+ token: Async query token. Can be obtained from Cursor.async_query_token.
158+ """
159+ async_query_details = await self ._get_async_query_info (token )
160+ async_query_id = async_query_details .query_id
161+ cursor = self .cursor ()
162+ await cursor .execute (ASYNC_QUERY_CANCEL , [async_query_id ])
135163
136164 # Context manager support
137165 async def __aenter__ (self ) -> Connection :
0 commit comments