Skip to content

Commit 07b7c83

Browse files
committed
Changes after testing
1 parent 6e159c7 commit 07b7c83

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

src/firebolt/async_db/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ async def connect(
237237
if not auth:
238238
raise ConfigurationError("auth is required to connect.")
239239

240-
if account_name:
241-
auth._account_name = account_name
240+
auth.account = account_name
242241

243242
api_endpoint = fix_url_schema(api_endpoint)
244243
# Type checks

src/firebolt/client/auth/base.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import abstractmethod
22
from enum import IntEnum
33
from time import time
4-
from typing import AsyncGenerator, Generator, Optional
4+
from typing import AsyncGenerator, Generator, Optional, Tuple
55

66
from anyio import Lock, get_current_task
77
from httpx import Auth as HttpxAuth
@@ -54,6 +54,17 @@ def __init__(self, use_token_cache: bool = True):
5454
self._expires: Optional[int] = None
5555
self._lock = Lock()
5656

57+
@property
58+
def account(self) -> Optional[str]:
59+
return self._account_name
60+
61+
@account.setter
62+
def account(self, value: str) -> None:
63+
self._account_name = value
64+
# Now we have all the elements to fetch the cached token
65+
if not self._token:
66+
self._token, self._expires = self._get_cached_token()
67+
5768
def copy(self) -> "Auth":
5869
"""Make another auth object with same credentials.
5970
@@ -106,7 +117,7 @@ def expired(self) -> bool:
106117
"""
107118
return self._expires is not None and self._expires <= int(time())
108119

109-
def _get_cached_token(self) -> Optional[str]:
120+
def _get_cached_token(self) -> Tuple[Optional[str], Optional[int]]:
110121
"""If caching is enabled, get token from cache.
111122
112123
If caching is disabled, None is returned.
@@ -115,17 +126,17 @@ def _get_cached_token(self) -> Optional[str]:
115126
Optional[str]: Token if any, and if caching is enabled; None otherwise
116127
"""
117128
if not self._use_token_cache:
118-
return None
129+
return (None, None)
119130

120131
cache_key = SecureCacheKey(
121132
[self.principal, self.secret, self._account_name], self.secret
122133
)
123134
connection_info = _firebolt_cache.get(cache_key)
124135

125136
if connection_info and connection_info.token:
126-
return connection_info.token
137+
return (connection_info.token, connection_info.expiry_time)
127138

128-
return None
139+
return (None, None)
129140

130141
def _cache_token(self) -> None:
131142
"""If caching is enabled, cache token."""

src/firebolt/db/connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def connect(
6666
if not auth:
6767
raise ConfigurationError("auth is required to connect.")
6868

69-
if account_name:
70-
auth._account_name = account_name
69+
auth.account = account_name
7170

7271
api_endpoint = fix_url_schema(api_endpoint)
7372
# Type checks

src/firebolt/utils/cache.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,24 @@ def __post_init__(self) -> None:
7171
"""
7272
if self.system_engine and isinstance(self.system_engine, dict):
7373
self.system_engine = EngineInfo(**self.system_engine)
74-
self.databases = {
75-
k: DatabaseInfo(**v)
76-
for k, v in self.databases.items()
77-
if isinstance(v, dict)
78-
}
79-
self.engines = {
80-
k: EngineInfo(**v) for k, v in self.engines.items() if isinstance(v, dict)
81-
}
74+
75+
# Convert dict values to dataclasses, keep existing dataclass objects
76+
new_databases = {}
77+
for k, db in self.databases.items():
78+
if isinstance(db, dict):
79+
new_databases[k] = DatabaseInfo(**db)
80+
else:
81+
new_databases[k] = db
82+
self.databases = new_databases
83+
84+
# Convert dict values to dataclasses, keep existing dataclass objects
85+
new_engines = {}
86+
for k, engine in self.engines.items():
87+
if isinstance(engine, dict):
88+
new_engines[k] = EngineInfo(**engine)
89+
else:
90+
new_engines[k] = engine
91+
self.engines = new_engines
8292

8393

8494
def noop_if_disabled(func: Callable) -> Callable:

0 commit comments

Comments
 (0)