Skip to content

Commit b9d5c22

Browse files
authored
feat: support USE DATABASE query (#328)
1 parent db9e51d commit b9d5c22

File tree

17 files changed

+409
-22
lines changed

17 files changed

+409
-22
lines changed

src/firebolt/async_db/cursor.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ def __init__(
8282
super().__init__(*args, **kwargs)
8383
self._client = client
8484
self.connection = connection
85+
if connection.database:
86+
self.database = connection.database
87+
88+
@property
89+
def database(self) -> Optional[str]:
90+
return self.parameters.get("database")
91+
92+
@database.setter
93+
def database(self, database: str) -> None:
94+
self.parameters["database"] = database
8595

8696
@abstractmethod
8797
async def _api_request(
@@ -100,12 +110,8 @@ async def _raise_if_error(self, resp: Response) -> None:
100110
f"Error executing query:\n{resp.read().decode('utf-8')}"
101111
)
102112
if resp.status_code == codes.FORBIDDEN:
103-
if self.connection.database and not await self.is_db_available(
104-
self.connection.database
105-
):
106-
raise FireboltDatabaseError(
107-
f"Database {self.connection.database} does not exist"
108-
)
113+
if self.database and not await self.is_db_available(self.database):
114+
raise FireboltDatabaseError(f"Database {self.database} does not exist")
109115
raise ProgrammingError(resp.read().decode("utf-8"))
110116
if (
111117
resp.status_code == codes.SERVICE_UNAVAILABLE
@@ -200,6 +206,8 @@ async def _do_execute(
200206
query, {"output_format": JSON_OUTPUT_FORMAT}
201207
)
202208
await self._raise_if_error(resp)
209+
# get parameters from response
210+
self._parse_response_headers(resp.headers)
203211
row_set = self._row_set_from_response(resp)
204212

205213
self._append_row_set(row_set)
@@ -439,8 +447,8 @@ async def _api_request(
439447
parameters = parameters or {}
440448
if use_set_parameters:
441449
parameters = {**(self._set_parameters or {}), **parameters}
442-
if self.connection.database:
443-
parameters["database"] = self.connection.database
450+
if self.parameters:
451+
parameters = {**self.parameters, **parameters}
444452
if self.connection._is_system:
445453
assert isinstance(self._client, AsyncClientV2)
446454
parameters["account_id"] = await self._client.account_id
@@ -543,13 +551,15 @@ async def _api_request(
543551
set parameters are sent. Setting this to False will allow
544552
self._set_parameters to be ignored.
545553
"""
554+
parameters = parameters or {}
546555
if use_set_parameters:
547556
parameters = {**(self._set_parameters or {}), **(parameters or {})}
557+
if self.parameters:
558+
parameters = {**self.parameters, **parameters}
548559
return await self._client.request(
549560
url=f"/{path}",
550561
method="POST",
551562
params={
552-
"database": self.connection.database,
553563
**(parameters or dict()),
554564
},
555565
content=query,

src/firebolt/client/client.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
from firebolt.client.auth import Auth
1313
from firebolt.client.auth.base import AuthRequest
14-
from firebolt.client.constants import DEFAULT_API_URL
14+
from firebolt.client.constants import (
15+
DEFAULT_API_URL,
16+
PROTOCOL_VERSION,
17+
PROTOCOL_VERSION_HEADER_NAME,
18+
)
1519
from firebolt.utils.exception import (
1620
AccountNotFoundError,
1721
FireboltEngineError,
@@ -51,6 +55,11 @@ def __init__(
5155
self._api_endpoint = URL(fix_url_schema(api_endpoint))
5256
self._auth_endpoint = get_auth_endpoint(self._api_endpoint)
5357
super().__init__(*args, auth=auth, **kwargs)
58+
self._set_default_header(PROTOCOL_VERSION_HEADER_NAME, PROTOCOL_VERSION)
59+
60+
def _set_default_header(self, key: str, value: str) -> None:
61+
if key not in self.headers:
62+
self.headers[key] = value
5463

5564
def _build_auth(self, auth: Optional[AuthTypes]) -> Auth:
5665
"""Create Auth object based on auth provided.

src/firebolt/client/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from httpx import CookieConflict, HTTPError, InvalidURL, StreamError
55

66
DEFAULT_API_URL: str = "api.app.firebolt.io"
7+
PROTOCOL_VERSION_HEADER_NAME = "Firebolt-Protocol-Version"
8+
PROTOCOL_VERSION: str = "2.0"
79
_REQUEST_ERRORS: Tuple[Type, ...] = (
810
HTTPError,
911
InvalidURL,

src/firebolt/common/base_cursor.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from types import TracebackType
88
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
99

10-
from httpx import Response
10+
from httpx import Headers, Response
1111

1212
from firebolt.common._types import (
1313
ColType,
@@ -52,6 +52,10 @@ class QueryStatus(Enum):
5252
EXECUTION_ERROR = 8
5353

5454

55+
# known parameters that can be set on the server side
56+
SERVER_SIDE_PARAMETERS = ["database"]
57+
58+
5559
@dataclass
5660
class Statistics:
5761
"""
@@ -109,6 +113,7 @@ def inner(self: BaseCursor, *args: Any, **kwargs: Any) -> Any:
109113
class BaseCursor:
110114
__slots__ = (
111115
"connection",
116+
"parameters",
112117
"_arraysize",
113118
"_client",
114119
"_state",
@@ -140,6 +145,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
140145
]
141146
] = []
142147
self._set_parameters: Dict[str, Any] = dict()
148+
self.parameters: Dict[str, str] = dict()
143149
self._rowcount = -1
144150
self._idx = 0
145151
self._next_set_idx = 0
@@ -243,6 +249,26 @@ def _reset(self) -> None:
243249
self._next_set_idx = 0
244250
self._query_id = ""
245251

252+
def _parse_response_headers(self, headers: Headers) -> None:
253+
"""Parse response and update relevant cursor fields."""
254+
update_parameters = headers.get("Firebolt-Update-Parameters")
255+
# parse update parameters dict and set keys as attributes
256+
if update_parameters:
257+
# parse key1=value1,key2=value2 comma separated string into dict
258+
param_dict = dict(item.split("=") for item in update_parameters.split(","))
259+
# strip whitespace from keys and values
260+
param_dict = {
261+
key.strip(): value.strip() for key, value in param_dict.items()
262+
}
263+
for key, value in param_dict.items():
264+
if key in SERVER_SIDE_PARAMETERS:
265+
self.parameters[key] = value
266+
else:
267+
logger.debug(
268+
f"Unknown parameter {key} returned by the server. "
269+
"It will be ignored."
270+
)
271+
246272
def _row_set_from_response(
247273
self, response: Response
248274
) -> Tuple[

src/firebolt/db/cursor.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ def __init__(
7676
super().__init__(*args, **kwargs)
7777
self._client = client
7878
self.connection = connection
79+
if connection.database:
80+
self.database = connection.database
81+
82+
@property
83+
def database(self) -> Optional[str]:
84+
return self.parameters.get("database")
85+
86+
@database.setter
87+
def database(self, database: str) -> None:
88+
self.parameters["database"] = database
7989

8090
def _raise_if_error(self, resp: Response) -> None:
8191
"""Raise a proper error if any"""
@@ -84,11 +94,9 @@ def _raise_if_error(self, resp: Response) -> None:
8494
f"Error executing query:\n{resp.read().decode('utf-8')}"
8595
)
8696
if resp.status_code == codes.FORBIDDEN:
87-
if self.connection.database and not self.is_db_available(
88-
self.connection.database
89-
):
97+
if self.database and not self.is_db_available(self.database):
9098
raise FireboltDatabaseError(
91-
f"Database {self.connection.database} does not exist"
99+
f"Database {self.parameters['database']} does not exist"
92100
)
93101
raise ProgrammingError(resp.read().decode("utf-8"))
94102
if (
@@ -188,6 +196,8 @@ def _do_execute(
188196
query, {"output_format": JSON_OUTPUT_FORMAT}
189197
)
190198
self._raise_if_error(resp)
199+
# get parameters from response
200+
self._parse_response_headers(resp.headers)
191201
row_set = self._row_set_from_response(resp)
192202

193203
self._append_row_set(row_set)
@@ -379,8 +389,8 @@ def _api_request(
379389
parameters = parameters or {}
380390
if use_set_parameters:
381391
parameters = {**(self._set_parameters or {}), **parameters}
382-
if self.connection.database:
383-
parameters["database"] = self.connection.database
392+
if self.parameters:
393+
parameters = {**self.parameters, **parameters}
384394
if self.connection._is_system:
385395
assert isinstance(self._client, ClientV2) # Type check
386396
parameters["account_id"] = self._client.account_id
@@ -480,13 +490,15 @@ def _api_request(
480490
set parameters are sent. Setting this to False will allow
481491
self._set_parameters to be ignored.
482492
"""
493+
parameters = parameters or {}
483494
if use_set_parameters:
484495
parameters = {**(self._set_parameters or {}), **(parameters or {})}
496+
if self.parameters:
497+
parameters = {**self.parameters, **parameters}
485498
return self._client.request(
486499
url=f"/{path}",
487500
method="POST",
488501
params={
489-
"database": self.connection.database,
490502
**(parameters or dict()),
491503
},
492504
content=query,

tests/integration/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def database_name() -> str:
8484
return must_env(DATABASE_NAME_ENV)
8585

8686

87+
@fixture(scope="session")
88+
def use_db_name(database_name: str):
89+
return f"{database_name}_use_db_test"
90+
91+
8792
@fixture(scope="session")
8893
def account_name() -> str:
8994
return must_env(ACCOUNT_NAME_ENV)

tests/integration/dbapi/async/V1/test_queries_async.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from decimal import Decimal
33
from typing import Any, List
44

5-
from pytest import mark, raises
5+
from pytest import fixture, mark, raises
66

77
from firebolt.async_db import Binary, Connection, Cursor, OperationalError
88
from firebolt.async_db.cursor import QueryStatus
@@ -486,3 +486,41 @@ async def test_bytea_roundtrip(
486486
assert (
487487
bytes_data.decode("utf-8") == data
488488
), "Invalid bytea data returned after roundtrip"
489+
490+
491+
@fixture
492+
async def setup_db(connection_no_engine: Connection, use_db_name: str):
493+
use_db_name = f"{use_db_name}_async"
494+
with connection_no_engine.cursor() as cursor:
495+
await cursor.execute(f"CREATE DATABASE {use_db_name}")
496+
yield
497+
await cursor.execute(f"DROP DATABASE {use_db_name}")
498+
499+
500+
@mark.xfail(reason="USE DATABASE is not yet available in 1.0 Firebolt")
501+
async def test_use_database(
502+
setup_db,
503+
connection_no_engine: Connection,
504+
use_db_name: str,
505+
database_name: str,
506+
) -> None:
507+
test_db_name = f"{use_db_name}_async"
508+
test_table_name = "verify_use_db_async"
509+
"""Use database works as expected."""
510+
with connection_no_engine.cursor() as c:
511+
await c.execute(f"USE DATABASE {test_db_name}")
512+
assert c.database == test_db_name
513+
await c.execute(f"CREATE TABLE {test_table_name} (id int)")
514+
await c.execute(
515+
"SELECT table_name FROM information_schema.tables "
516+
f"WHERE table_name = '{test_table_name}'"
517+
)
518+
assert (await c.fetchone())[0] == test_table_name, "Table was not created"
519+
# Change DB and verify table is not there
520+
await c.execute(f"USE DATABASE {database_name}")
521+
assert c.database == database_name
522+
await c.execute(
523+
"SELECT table_name FROM information_schema.tables "
524+
f"WHERE table_name = '{test_table_name}'"
525+
)
526+
assert (await c.fetchone()) is None, "Database was not changed"

tests/integration/dbapi/async/V2/test_queries_async.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from datetime import date, datetime
22
from decimal import Decimal
3+
from os import environ
34
from typing import List
45

5-
from pytest import mark, raises
6+
from pytest import fixture, mark, raises
67

78
from firebolt.async_db import Binary, Connection, Cursor, OperationalError
89
from firebolt.async_db.cursor import QueryStatus
910
from firebolt.common._types import ColType, Column
11+
from tests.integration.conftest import API_ENDPOINT_ENV
1012
from tests.integration.dbapi.utils import assert_deep_eq
1113

1214
VALS_TO_INSERT_2 = ",".join(
@@ -411,3 +413,41 @@ async def test_bytea_roundtrip(
411413
assert (
412414
bytes_data.decode("utf-8") == data
413415
), "Invalid bytea data returned after roundtrip"
416+
417+
418+
@fixture
419+
async def setup_db(connection_system_engine_no_db: Connection, use_db_name: str):
420+
use_db_name = use_db_name + "_async"
421+
with connection_system_engine_no_db.cursor() as cursor:
422+
await cursor.execute(f"CREATE DATABASE {use_db_name}")
423+
yield
424+
await cursor.execute(f"DROP DATABASE {use_db_name}")
425+
426+
427+
@mark.xfail("dev" not in environ[API_ENDPOINT_ENV], reason="Only works on dev")
428+
async def test_use_database(
429+
setup_db,
430+
connection_system_engine_no_db: Connection,
431+
use_db_name: str,
432+
database_name: str,
433+
) -> None:
434+
test_db_name = use_db_name + "_async"
435+
test_table_name = "verify_use_db_async"
436+
"""Use database works as expected."""
437+
with connection_system_engine_no_db.cursor() as c:
438+
await c.execute(f"USE DATABASE {test_db_name}")
439+
assert c.database == test_db_name
440+
await c.execute(f"CREATE TABLE {test_table_name} (id int)")
441+
await c.execute(
442+
"SELECT table_name FROM information_schema.tables "
443+
f"WHERE table_name = '{test_table_name}'"
444+
)
445+
assert (await c.fetchone())[0] == test_table_name, "Table was not created"
446+
# Change DB and verify table is not there
447+
await c.execute(f"USE DATABASE {database_name}")
448+
assert c.database == database_name
449+
await c.execute(
450+
"SELECT table_name FROM information_schema.tables "
451+
f"WHERE table_name = '{test_table_name}'"
452+
)
453+
assert (await c.fetchone()) is None, "Database was not changed"

0 commit comments

Comments
 (0)