-
Notifications
You must be signed in to change notification settings - Fork 47
Generic Python Lakebase Checkpointing #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c7fc3b6
6613060
6feb1b3
50d7fc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| from pydantic import BaseModel, Field | ||
| from datetime import datetime | ||
| from .lakebase import LakebaseClient | ||
| from typing import Tuple | ||
| from psycopg.types.json import Json | ||
| from abc import ABC, abstractmethod | ||
|
Check failure on line 6 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| from typing import Any, Union, Tuple | ||
|
Check failure on line 7 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| from uuid import uuid4 | ||
|
|
||
| class GenericCheckpoint(BaseModel): | ||
| #################### | ||
| # CHECKPOINT COLUMNS | ||
| #################### | ||
| id : str = Field(default_factory = lambda: str(uuid4())) # `id` column must exist for locating the checkpoint, even in subclasses. | ||
| state : dict = {} | ||
| creation_timestamp : datetime = Field(default_factory=datetime.now) | ||
| update_timestamp : datetime = Field(default_factory=datetime.now) | ||
| #################### | ||
| # UPDATE ATTRIBUTES | ||
| #################### | ||
| def update(self, **kwargs): | ||
| for k,v in kwargs.items(): | ||
| # Check attribute exists | ||
| assert hasattr(self, k), f"Attribute {k} does not exist in {self.__class__.__name__}" | ||
| setattr(self, k, v) | ||
| self.update_timestamp = datetime.now() | ||
| ################################# | ||
| # SQL GENERATION IMPLEMENTATIONS | ||
| ################################# | ||
| # If subclassing the `GenericCheckpoint` class, implement new methods to handle any changes in attributes. | ||
| def generate_insert_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql = f""" | ||
| INSERT INTO {table_name} | ||
| (id, state, creation_timestamp, update_timestamp) | ||
| VALUES (%s, %s, %s, %s) | ||
| """ | ||
| return sql, (self.id, Json(self.state), self.creation_timestamp, self.update_timestamp) | ||
|
|
||
| def generate_update_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql = f""" | ||
| UPDATE {table_name} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. every sql method is using direct string interpolation, which puts us at risk of sql injection: can you take a look at the lakebaseclient implementation: https://github.com/databricks/databricks-ai-bridge/blob/main/src/databricks_ai_bridge/lakebase.py which uses sql identifiers: https://www.psycopg.org/psycopg3/docs/api/sql.html |
||
| SET state = %s, update_timestamp = %s | ||
| WHERE id = %s AND creation_timestamp = %s | ||
| """ | ||
| return sql, (Json(self.state), self.update_timestamp, self.id, self.creation_timestamp) | ||
|
|
||
| def generate_init_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql = f""" | ||
| CREATE TABLE IF NOT EXISTS {table_name} ( | ||
| lb_id bigserial PRIMARY KEY, | ||
| id text NOT NULL, | ||
| state jsonb NOT NULL default '{{}}', | ||
| creation_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
| update_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, | ||
| UNIQUE(id,creation_timestamp) | ||
| ) | ||
| """ | ||
| return sql, None | ||
|
Check failure on line 58 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
|
|
||
| def generate_retrieve_checkpoint_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql_all, params = self.generate_retrieval_all_checkpoints_sql(table_name = table_name) | ||
| sql = f""" | ||
| {sql_all} LIMIT 1 | ||
| """ | ||
| return sql, params | ||
|
|
||
| def generate_retrieval_all_checkpoints_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql = f""" | ||
| SELECT id, state, creation_timestamp, update_timestamp | ||
| FROM {table_name} | ||
| WHERE id = %s | ||
| ORDER BY update_timestamp DESC | ||
| """ | ||
| return sql, (self.id,) | ||
|
|
||
| class LakebaseCheckpointer: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we implement context manager support to mirror our existing checkpointsaver class? https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/langchain/src/databricks_langchain/checkpoint.py#L48 |
||
| def __init__(self, | ||
| lakebase_client : LakebaseClient, | ||
| sessions_table_name : str, | ||
| checkpoint_class : GenericCheckpoint = GenericCheckpoint): | ||
|
Check failure on line 80 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| self.checkpoint_class = checkpoint_class | ||
| self.client = lakebase_client | ||
| self.table_name = sessions_table_name | ||
| self.init_schema() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can you add a create_tables bool (defaulting to true) to mirror our openai session class where we create tables on init: https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/openai/src/databricks_openai/agents/session.py#L120 |
||
|
|
||
| def init_schema(self) -> None: | ||
| _checkpoint = self.checkpoint_class() | ||
|
Check failure on line 87 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| sql, params = _checkpoint.generate_init_sql(table_name = self.table_name) | ||
| response = self.client.execute(sql=sql, params = params) | ||
|
|
||
| def get_most_recent_checkpoint(self, id : str) -> GenericCheckpoint | None: | ||
| _checkpoint = self.checkpoint_class(id = id) | ||
|
Check failure on line 92 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| # Get the most recently updated checkpoint for this id | ||
| sql, params = _checkpoint.generate_retrieve_checkpoint_sql(table_name = self.table_name) | ||
| response = self.client.execute(sql=sql, params = params) | ||
| if len(response) == 0: | ||
|
Check failure on line 96 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| return None | ||
| return self.checkpoint_class(**response[0]) | ||
|
Check failure on line 98 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
|
|
||
| def get_all_checkpoints(self, id : str) -> list[GenericCheckpoint]: | ||
| _checkpoint = self.checkpoint_class(id = id) | ||
|
Check failure on line 101 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| sql, params = _checkpoint.generate_retrieval_all_checkpoints_sql(table_name = self.table_name) | ||
| response = self.client.execute(sql=sql, params = params) | ||
| return [self.checkpoint_class(**_resp) for _resp in response] | ||
|
Check failure on line 104 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
|
|
||
| def checkpoint_exists(self, id : str) -> bool: | ||
| sql = f""" | ||
| SELECT COUNT(*) AS count | ||
| FROM {self.table_name} | ||
| WHERE id = %s | ||
| """ | ||
| response = self.client.execute(sql=sql, params=(id,)) | ||
| count = response[0]["count"] | ||
| return True if count > 0 else False | ||
|
|
||
| def update_most_recent_checkpoint(self, id : str, **checkpoint_kwargs) -> None: | ||
| # Get the most recent checkpoint | ||
| _checkpoint = self.get_most_recent_checkpoint(id = id) | ||
| _checkpoint.update(**checkpoint_kwargs) | ||
|
Check warning on line 119 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_most_recent_checkpoint returns None when no recent checkpoint exists -- can we handle this case? |
||
| sql, params = _checkpoint.generate_update_sql(table_name = self.table_name) | ||
|
Check warning on line 120 in src/databricks_ai_bridge/lakebase_checkpointer.py
|
||
| self.client.execute(sql = sql, params = params) | ||
| return | ||
|
|
||
| def insert_checkpoint(self, id : str, **checkpoint_kwargs) -> None: | ||
| _checkpoint = self.checkpoint_class( | ||
| id = id, **checkpoint_kwargs | ||
| ) | ||
| sql, params = _checkpoint.generate_insert_sql(table_name = self.table_name) | ||
| response = self.client.execute(sql=sql, params=params) | ||
| return | ||
|
|
||
| def save_checkpoint(self, id : str, overwrite : bool = False, **checkpoint_kwargs) -> None: | ||
| if overwrite: | ||
| checkpoint_exists = self.checkpoint_exists(id = id) | ||
| if checkpoint_exists: | ||
| self.update_checkpoint(id = id, **checkpoint_kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update_checkpoint is a nonexistent function in this class -- is this a typo (should it be self.update_most_recent_checkpoint)? |
||
| else: | ||
| self.insert_checkpoint(id = id, **checkpoint_kwargs) | ||
| else: | ||
| # Don't overwrite the most recent checkpoint, so just insert a new one | ||
| self.insert_checkpoint(id = id, **checkpoint_kwargs) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of assert, can we raise a ValueError or AttributeError instead? assert statements are stripped when python runs with optimization requested (command line option -O).
https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement