Generic Python Lakebase Checkpointing#327
Generic Python Lakebase Checkpointing#327coltonpeltier-db wants to merge 4 commits intodatabricks:mainfrom
Conversation
… functions more generic to handle future subclasses
There was a problem hiding this comment.
Thanks for getting started with adding a framework-agnostic checkpointing class!
- Generally, "checkpoints" is a term rather specific to langgraph so I wonder if calling it a checkpointer here would cause confusion for users - this is mostly a state persistence class so maybe something like "LakebaseStateSaver" works better? Open to suggestions
- Can we also update the design so that instead of having to create and pass in a LakebaseClient, we directly pass in the lakebase instance name to this class? So it could look something like:
checkpointer = LakebaseCheckpointer(instance_name="my-instance")
this then mirrors our other checkpointing class (Ex: CheckpointSaver(instance_name=LAKEBASE_INSTANCE_NAME))
- reminder to include + test both sync/async methods for this class
- can you include more usage examples/testing when you are finished? More specifically, can you test this out by creating a short term memory agent on apps? You can refer to our existing example using checkpointing with the langgraph framework and make sure the functionality remains the same: https://github.com/databricks/app-templates/tree/main/agent-langgraph-short-term-memory
|
|
||
| def generate_update_sql(self, table_name : str) -> Tuple[str, tuple]: | ||
| sql = f""" | ||
| UPDATE {table_name} |
There was a problem hiding this comment.
every sql method is using direct string interpolation, which puts us at risk of sql injection: can you take a look at the lakebaseclient implementation: https://github.com/databricks/databricks-ai-bridge/blob/main/src/databricks_ai_bridge/lakebase.py
which uses sql identifiers: https://www.psycopg.org/psycopg3/docs/api/sql.html
| if overwrite: | ||
| checkpoint_exists = self.checkpoint_exists(id = id) | ||
| if checkpoint_exists: | ||
| self.update_checkpoint(id = id, **checkpoint_kwargs) |
There was a problem hiding this comment.
update_checkpoint is a nonexistent function in this class -- is this a typo (should it be self.update_most_recent_checkpoint)?
| """ | ||
| return sql, (self.id,) | ||
|
|
||
| class LakebaseCheckpointer: |
There was a problem hiding this comment.
Can we implement context manager support to mirror our existing checkpointsaver class? https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/langchain/src/databricks_langchain/checkpoint.py#L48
| self.checkpoint_class = checkpoint_class | ||
| self.client = lakebase_client | ||
| self.table_name = sessions_table_name | ||
| self.init_schema() |
There was a problem hiding this comment.
init_schema is being called automatically which then attempts to create a table each time
can you add a create_tables bool (defaulting to true) to mirror our openai session class where we create tables on init: https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/openai/src/databricks_openai/agents/session.py#L120
| def update_most_recent_checkpoint(self, id : str, **checkpoint_kwargs) -> None: | ||
| # Get the most recent checkpoint | ||
| _checkpoint = self.get_most_recent_checkpoint(id = id) | ||
| _checkpoint.update(**checkpoint_kwargs) |
There was a problem hiding this comment.
get_most_recent_checkpoint returns None when no recent checkpoint exists -- can we handle this case?
| def update(self, **kwargs): | ||
| for k,v in kwargs.items(): | ||
| # Check attribute exists | ||
| assert hasattr(self, k), f"Attribute {k} does not exist in {self.__class__.__name__}" |
There was a problem hiding this comment.
instead of assert, can we raise a ValueError or AttributeError instead? assert statements are stripped when python runs with optimization requested (command line option -O).
https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement
What
Added classes and functions to handle generic agent checkpointing with lakebase instances.
Why
Existing checkpoint implementations depend on langchain / langgraph
What Changed
Added
lakebase_checkpointer.pyfile which contains two key classes:GenericCheckpoint- class that defines the checkpoints schema. Can be subclassed to change checkpoint schema.Lakebase_Checkpointer- accepts aLakebaseClientand the checkpoint class as input. Implements helper functions for interacting withLakebaseClientwith the checkpoint class (saves, inserts, updates, get).