Skip to content

Commit b3f56e1

Browse files
committed
fix pre commit
1 parent 4b5080e commit b3f56e1

File tree

3 files changed

+48
-27
lines changed

3 files changed

+48
-27
lines changed

src/firebolt/utils/cache.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -192,20 +192,34 @@ def __hash__(self) -> int:
192192
return hash(self.key)
193193

194194

195-
class FileBasedCache(UtilCache[ConnectionInfo]):
195+
class FileBasedCache:
196196
"""
197197
File-based cache that persists to disk with encryption.
198-
Extends UtilCache to provide persistent storage using encrypted files.
198+
Uses composition to combine in-memory caching with persistent storage
199+
using encrypted files.
199200
"""
200201

201-
def __init__(self, cache_name: str = ""):
202-
super().__init__(cache_name)
202+
def __init__(self, memory_cache: UtilCache[ConnectionInfo], cache_name: str = ""):
203+
self.memory_cache = memory_cache
203204
self._data_dir = user_data_dir(appname=APPNAME) # TODO: change to new dir
204205
makedirs(self._data_dir, exist_ok=True)
206+
# FileBasedCache has its own disabled state, independent of memory cache
207+
cache_env_var = f"FIREBOLT_SDK_DISABLE_CACHE_${cache_name}"
208+
self.disabled = os.getenv("FIREBOLT_SDK_DISABLE_CACHE", False) or os.getenv(
209+
cache_env_var, False
210+
)
211+
212+
def disable(self) -> None:
213+
"""Disable the file-based cache."""
214+
self.disabled = True
215+
216+
def enable(self) -> None:
217+
"""Enable the file-based cache."""
218+
self.disabled = False
205219

206220
def _get_file_path(self, key: SecureCacheKey) -> str:
207221
"""Get the file path for a cache key."""
208-
cache_key = self.create_key(key)
222+
cache_key = self.memory_cache.create_key(key)
209223
encrypted_filename = generate_encrypted_file_name(cache_key, key.encryption_key)
210224
return path.join(self._data_dir, encrypted_filename)
211225

@@ -250,7 +264,7 @@ def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]:
250264
return None
251265

252266
# First try memory cache
253-
memory_result = super().get(key)
267+
memory_result = self.memory_cache.get(key)
254268
if memory_result is not None:
255269
logger.debug("Cache hit in memory")
256270
return memory_result
@@ -265,7 +279,7 @@ def get(self, key: SecureCacheKey) -> Optional[ConnectionInfo]:
265279
data = ConnectionInfo(**raw_data)
266280

267281
# Add to memory cache and return
268-
super().set(key, data)
282+
self.memory_cache.set(key, data)
269283
return data
270284

271285
def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None:
@@ -275,7 +289,7 @@ def set(self, key: SecureCacheKey, value: ConnectionInfo) -> None:
275289

276290
logger.debug("Setting value in cache")
277291
# First set in memory
278-
super().set(key, value)
292+
self.memory_cache.set(key, value)
279293

280294
file_path = self._get_file_path(key)
281295
encrypter = FernetEncrypter(generate_salt(), key.encryption_key)
@@ -289,7 +303,7 @@ def delete(self, key: SecureCacheKey) -> None:
289303
return
290304

291305
# Delete from memory
292-
super().delete(key)
306+
self.memory_cache.delete(key)
293307

294308
# Delete from disk
295309
file_path = self._get_file_path(key)
@@ -303,7 +317,9 @@ def delete(self, key: SecureCacheKey) -> None:
303317
def clear(self) -> None:
304318
# Clear memory only, as deleting every file is not safe
305319
logger.debug("Clearing memory cache")
306-
super().clear()
320+
self.memory_cache.clear()
307321

308322

309-
_firebolt_cache = FileBasedCache(cache_name="connection_info")
323+
_firebolt_cache = FileBasedCache(
324+
UtilCache[ConnectionInfo](cache_name="memory_cache"), cache_name="file_cache"
325+
)

src/firebolt/utils/usage_tracker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from typing import Dict, List, Optional, Tuple
99

1010
from firebolt import __version__
11-
from firebolt.utils.cache import ConnectionInfo, ReprCacheable, _firebolt_cache
11+
from firebolt.utils.cache import (
12+
ConnectionInfo,
13+
SecureCacheKey,
14+
_firebolt_cache,
15+
)
1216

1317

1418
@dataclass
@@ -228,7 +232,7 @@ def get_user_agent_header(
228232

229233

230234
def get_cache_tracking_params(
231-
cache_key: ReprCacheable, conn_id: str
235+
cache_key: SecureCacheKey, conn_id: str
232236
) -> List[Tuple[str, str]]:
233237
ua_parameters = []
234238
ua_parameters.append(("connId", conn_id))

tests/unit/utils/test_cache.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -355,25 +355,25 @@ def test_cache_disable_enable_behavior(
355355

356356
def test_helper_functions():
357357
"""Test the backward compatibility helper functions."""
358-
from tests.unit.test_cache_helpers import cache_token, get_cached_token
359358
from firebolt.utils.cache import _firebolt_cache
360-
359+
from tests.unit.test_cache_helpers import cache_token, get_cached_token
360+
361361
_firebolt_cache.enable()
362362
_firebolt_cache.clear()
363-
363+
364364
# Test caching and retrieving tokens
365365
principal = "test_user"
366366
secret = "test_secret"
367367
token = "test_token"
368368
account_name = "test_account"
369-
369+
370370
# Cache token
371371
cache_token(principal, secret, token, 9999, account_name)
372-
372+
373373
# Retrieve token
374374
cached_token = get_cached_token(principal, secret, account_name)
375375
assert cached_token == token
376-
376+
377377
# Test with None account name
378378
cache_token(principal, secret, token, 9999, None)
379379
cached_token_none = get_cached_token(principal, secret, None)
@@ -385,31 +385,32 @@ def test_connection_info_post_init():
385385
# Test with dictionary inputs that should be converted to dataclasses
386386
engine_dict = {"url": "http://test.com", "params": {"key": "value"}}
387387
db_dict = {"name": "test_db"}
388-
388+
389389
connection_info = ConnectionInfo(
390390
id="test",
391391
system_engine=engine_dict,
392392
databases={"db1": db_dict},
393-
engines={"engine1": engine_dict}
393+
engines={"engine1": engine_dict},
394394
)
395-
395+
396396
# Should convert dicts to dataclasses
397-
from firebolt.utils.cache import EngineInfo, DatabaseInfo
397+
from firebolt.utils.cache import DatabaseInfo, EngineInfo
398+
398399
assert isinstance(connection_info.system_engine, EngineInfo)
399400
assert isinstance(connection_info.databases["db1"], DatabaseInfo)
400401
assert isinstance(connection_info.engines["engine1"], EngineInfo)
401-
402+
402403
# Test with already converted dataclass objects
403404
engine_obj = EngineInfo(url="http://test.com", params={"key": "value"})
404405
db_obj = DatabaseInfo(name="test_db")
405-
406+
406407
connection_info2 = ConnectionInfo(
407408
id="test2",
408409
system_engine=engine_obj,
409410
databases={"db1": db_obj},
410-
engines={"engine1": engine_obj}
411+
engines={"engine1": engine_obj},
411412
)
412-
413+
413414
# Should remain as dataclasses
414415
assert connection_info2.system_engine is engine_obj
415416
assert connection_info2.databases["db1"] is db_obj

0 commit comments

Comments
 (0)