Skip to content

Commit 811c874

Browse files
feat: FIR-42242 allow setting query timeout in python sdk (#410)
1 parent 5dd7ba8 commit 811c874

File tree

19 files changed

+353
-246
lines changed

19 files changed

+353
-246
lines changed

src/firebolt/async_db/cursor.py

Lines changed: 67 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,23 @@
1212
List,
1313
Optional,
1414
Sequence,
15-
Tuple,
1615
Union,
1716
)
1817
from urllib.parse import urljoin
1918

20-
from httpx import URL, Headers, Response, codes
19+
from httpx import (
20+
URL,
21+
USE_CLIENT_DEFAULT,
22+
Headers,
23+
Response,
24+
TimeoutException,
25+
codes,
26+
)
2127

2228
from firebolt.client.client import AsyncClient, AsyncClientV1, AsyncClientV2
2329
from firebolt.common._types import (
2430
ColType,
25-
Column,
2631
ParameterType,
27-
RawColType,
2832
SetParameter,
2933
split_format_sql,
3034
)
@@ -35,7 +39,7 @@
3539
UPDATE_PARAMETERS_HEADER,
3640
BaseCursor,
3741
CursorState,
38-
Statistics,
42+
RowSet,
3943
_parse_update_endpoint,
4044
_parse_update_parameters,
4145
_raise_if_internal_set_parameter,
@@ -47,7 +51,9 @@
4751
FireboltDatabaseError,
4852
OperationalError,
4953
ProgrammingError,
54+
QueryTimeoutError,
5055
)
56+
from firebolt.utils.timeout_controller import TimeoutController
5157
from firebolt.utils.urls import DATABASES_URL, ENGINES_URL
5258

5359
if TYPE_CHECKING:
@@ -86,15 +92,45 @@ def __init__(
8692
if connection.init_parameters:
8793
self._update_set_parameters(connection.init_parameters)
8894

89-
@abstractmethod
9095
async def _api_request(
9196
self,
9297
query: str = "",
9398
parameters: Optional[dict[str, Any]] = None,
9499
path: str = "",
95100
use_set_parameters: bool = True,
101+
timeout: Optional[float] = None,
96102
) -> Response:
97-
...
103+
"""
104+
Query API, return Response object.
105+
106+
Args:
107+
query (str): SQL query
108+
parameters (Optional[Sequence[ParameterType]]): A sequence of substitution
109+
parameters. Used to replace '?' placeholders inside a query with
110+
actual values. Note: In order to "output_format" dict value, it
111+
must be an empty string. If no value not specified,
112+
JSON_OUTPUT_FORMAT will be used.
113+
path (str): endpoint suffix, for example "cancel" or "status"
114+
use_set_parameters: Optional[bool]: Some queries will fail if additional
115+
set parameters are sent. Setting this to False will allow
116+
self._set_parameters to be ignored.
117+
timeout (Optional[float]): Request execution timeout in seconds
118+
"""
119+
parameters = parameters or {}
120+
if use_set_parameters:
121+
parameters = {**(self._set_parameters or {}), **parameters}
122+
if self.parameters:
123+
parameters = {**self.parameters, **parameters}
124+
try:
125+
return await self._client.request(
126+
url=urljoin(self.engine_url.rstrip("/") + "/", path or ""),
127+
method="POST",
128+
params=parameters,
129+
content=query,
130+
timeout=timeout if timeout is not None else USE_CLIENT_DEFAULT,
131+
)
132+
except TimeoutException:
133+
raise QueryTimeoutError()
98134

99135
async def _raise_if_error(self, resp: Response) -> None:
100136
"""Raise a proper error if any"""
@@ -119,10 +155,14 @@ async def _raise_if_error(self, resp: Response) -> None:
119155
_print_error_body(resp)
120156
resp.raise_for_status()
121157

122-
async def _validate_set_parameter(self, parameter: SetParameter) -> None:
158+
async def _validate_set_parameter(
159+
self, parameter: SetParameter, timeout: Optional[float]
160+
) -> None:
123161
"""Validate parameter by executing simple query with it."""
124162
_raise_if_internal_set_parameter(parameter)
125-
resp = await self._api_request("select 1", {parameter.name: parameter.value})
163+
resp = await self._api_request(
164+
"select 1", {parameter.name: parameter.value}, timeout=timeout
165+
)
126166
# Handle invalid set parameter
127167
if resp.status_code == codes.BAD_REQUEST:
128168
raise OperationalError(resp.text)
@@ -151,29 +191,30 @@ async def _do_execute(
151191
raw_query: str,
152192
parameters: Sequence[Sequence[ParameterType]],
153193
skip_parsing: bool = False,
194+
timeout: Optional[float] = None,
154195
) -> None:
155196
self._reset()
156197
# Allow users to manually skip parsing for performance improvement.
157198
queries: List[Union[SetParameter, str]] = (
158199
[raw_query] if skip_parsing else split_format_sql(raw_query, parameters)
159200
)
201+
timeout_controller = TimeoutController(timeout)
160202
try:
161203
for query in queries:
162204
start_time = time.time()
163205
Cursor._log_query(query)
206+
timeout_controller.raise_if_timeout()
164207

165-
# Define type for mypy
166-
row_set: Tuple[
167-
int,
168-
Optional[List[Column]],
169-
Optional[Statistics],
170-
Optional[List[List[RawColType]]],
171-
] = (-1, None, None, None)
172208
if isinstance(query, SetParameter):
173-
await self._validate_set_parameter(query)
209+
row_set: RowSet = (-1, None, None, None)
210+
await self._validate_set_parameter(
211+
query, timeout_controller.remaining()
212+
)
174213
else:
175214
resp = await self._api_request(
176-
query, {"output_format": JSON_OUTPUT_FORMAT}
215+
query,
216+
{"output_format": JSON_OUTPUT_FORMAT},
217+
timeout=timeout_controller.remaining(),
177218
)
178219
await self._raise_if_error(resp)
179220
await self._parse_response_headers(resp.headers)
@@ -198,6 +239,7 @@ async def execute(
198239
query: str,
199240
parameters: Optional[Sequence[ParameterType]] = None,
200241
skip_parsing: bool = False,
242+
timeout_seconds: Optional[float] = None,
201243
) -> Union[int, str]:
202244
"""Prepare and execute a database query.
203245
@@ -221,19 +263,23 @@ async def execute(
221263
skip_parsing (bool): Flag to disable query parsing. This will
222264
disable parameterized, multi-statement and SET queries,
223265
while improving performance
266+
timeout_seconds (Optional[float]): Query execution timeout in seconds
224267
225268
Returns:
226269
int: Query row count.
227270
"""
228271
params_list = [parameters] if parameters else []
229-
await self._do_execute(query, params_list, skip_parsing)
272+
await self._do_execute(
273+
query, params_list, skip_parsing, timeout=timeout_seconds
274+
)
230275
return self.rowcount
231276

232277
@check_not_closed
233278
async def executemany(
234279
self,
235280
query: str,
236281
parameters_seq: Sequence[Sequence[ParameterType]],
282+
timeout_seconds: Optional[float] = None,
237283
) -> Union[int, str]:
238284
"""Prepare and execute a database query.
239285
@@ -258,11 +304,12 @@ async def executemany(
258304
substitution parameter sets. Used to replace '?' placeholders inside a
259305
query with actual values from each set in a sequence. Resulting queries
260306
for each subset are executed sequentially.
307+
timeout_seconds (Optional[float]): Query execution timeout in seconds.
261308
262309
Returns:
263310
int: Query row count.
264311
"""
265-
await self._do_execute(query, parameters_seq)
312+
await self._do_execute(query, parameters_seq, timeout=timeout_seconds)
266313
return self.rowcount
267314

268315
@abstractmethod
@@ -345,40 +392,6 @@ def __init__(
345392
assert isinstance(client, AsyncClientV2)
346393
super().__init__(*args, client=client, connection=connection, **kwargs)
347394

348-
async def _api_request(
349-
self,
350-
query: str = "",
351-
parameters: Optional[dict[str, Any]] = None,
352-
path: str = "",
353-
use_set_parameters: bool = True,
354-
) -> Response:
355-
"""
356-
Query API, return Response object.
357-
358-
Args:
359-
query (str): SQL query
360-
parameters (Optional[Sequence[ParameterType]]): A sequence of substitution
361-
parameters. Used to replace '?' placeholders inside a query with
362-
actual values. Note: In order to "output_format" dict value, it
363-
must be an empty string. If no value not specified,
364-
JSON_OUTPUT_FORMAT will be used.
365-
path (str): endpoint suffix, for example "cancel" or "status"
366-
use_set_parameters: Optional[bool]: Some queries will fail if additional
367-
set parameters are sent. Setting this to False will allow
368-
self._set_parameters to be ignored.
369-
"""
370-
parameters = parameters or {}
371-
if use_set_parameters:
372-
parameters = {**(self._set_parameters or {}), **parameters}
373-
if self.parameters:
374-
parameters = {**self.parameters, **parameters}
375-
return await self._client.request(
376-
url=urljoin(self.engine_url.rstrip("/") + "/", path or ""),
377-
method="POST",
378-
params=parameters,
379-
content=query,
380-
)
381-
382395
async def is_db_available(self, database_name: str) -> bool:
383396
"""
384397
Verify that the database exists.
@@ -415,42 +428,6 @@ def __init__(
415428
assert isinstance(client, AsyncClientV1)
416429
super().__init__(*args, client=client, connection=connection, **kwargs)
417430

418-
async def _api_request(
419-
self,
420-
query: Optional[str] = "",
421-
parameters: Optional[dict[str, Any]] = None,
422-
path: Optional[str] = "",
423-
use_set_parameters: Optional[bool] = True,
424-
) -> Response:
425-
"""
426-
Query API, return Response object.
427-
428-
Args:
429-
query (str): SQL query
430-
parameters (Optional[Sequence[ParameterType]]): A sequence of substitution
431-
parameters. Used to replace '?' placeholders inside a query with
432-
actual values. Note: In order to "output_format" dict value, it
433-
must be an empty string. If no value not specified,
434-
JSON_OUTPUT_FORMAT will be used.
435-
path (str): endpoint suffix, for example "cancel" or "status"
436-
use_set_parameters: Optional[bool]: Some queries will fail if additional
437-
set parameters are sent. Setting this to False will allow
438-
self._set_parameters to be ignored.
439-
"""
440-
parameters = parameters or {}
441-
if use_set_parameters:
442-
parameters = {**(self._set_parameters or {}), **(parameters or {})}
443-
if self.parameters:
444-
parameters = {**self.parameters, **parameters}
445-
return await self._client.request(
446-
url=urljoin(self.engine_url.rstrip("/") + "/", path or ""),
447-
method="POST",
448-
params={
449-
**(parameters or dict()),
450-
},
451-
content=query,
452-
)
453-
454431
async def is_db_available(self, database_name: str) -> bool:
455432
"""
456433
Verify that the database exists.

src/firebolt/common/base_cursor.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ def __post_init__(self) -> None:
112112
setattr(self, field.name, _type(value))
113113

114114

115+
RowSet = Tuple[
116+
int,
117+
Optional[List[Column]],
118+
Optional[Statistics],
119+
Optional[List[List[RawColType]]],
120+
]
121+
122+
115123
def check_not_closed(func: Callable) -> Callable:
116124
"""(Decorator) ensure cursor is not closed before calling method."""
117125

@@ -166,14 +174,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
166174
self._rows: Optional[List[List[RawColType]]] = None
167175
self._descriptions: Optional[List[Column]] = None
168176
self._statistics: Optional[Statistics] = None
169-
self._row_sets: List[
170-
Tuple[
171-
int,
172-
Optional[List[Column]],
173-
Optional[Statistics],
174-
Optional[List[List[RawColType]]],
175-
]
176-
] = []
177+
self._row_sets: List[RowSet] = []
177178
# User-defined set parameters
178179
self._set_parameters: Dict[str, Any] = dict()
179180
# Server-side parameters (user can't change them)
@@ -333,19 +334,12 @@ def engine_name(self) -> str:
333334
return self.parameters["engine"]
334335
return URL(self.engine_url).host.split(".")[0].replace("-", "_")
335336

336-
def _row_set_from_response(
337-
self, response: Response
338-
) -> Tuple[
339-
int,
340-
Optional[List[Column]],
341-
Optional[Statistics],
342-
Optional[List[List[RawColType]]],
343-
]:
337+
def _row_set_from_response(self, response: Response) -> RowSet:
344338
"""Fetch information about executed query from http response."""
345339

346340
# Empty response is returned for insert query
347341
if response.headers.get("content-length", "") == "0":
348-
return (-1, None, None, None)
342+
return -1, None, None, None
349343
try:
350344
# Skip parsing floats to properly parse them later
351345
query_data = response.json(parse_float=str)
@@ -359,18 +353,13 @@ def _row_set_from_response(
359353
statistics = Statistics(**query_data["statistics"])
360354
# Parse data during fetch
361355
rows = query_data["data"]
362-
return (rowcount, descriptions, statistics, rows)
356+
return rowcount, descriptions, statistics, rows
363357
except (KeyError, ValueError) as err:
364358
raise DataError(f"Invalid query data format: {str(err)}")
365359

366360
def _append_row_set(
367361
self,
368-
row_set: Tuple[
369-
int,
370-
Optional[List[Column]],
371-
Optional[Statistics],
372-
Optional[List[List[RawColType]]],
373-
],
362+
row_set: RowSet,
374363
) -> None:
375364
"""Store information about executed query."""
376365
self._row_sets.append(row_set)

0 commit comments

Comments
 (0)