Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions src/databricks_ai_bridge/lakebase_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from pydantic import BaseModel, Field
from datetime import datetime
from .lakebase import LakebaseClient
from typing import Tuple
from psycopg.types.json import Json
from abc import ABC, abstractmethod

Check failure on line 6 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

src/databricks_ai_bridge/lakebase_checkpointer.py:6:22: F401 `abc.abstractmethod` imported but unused

Check failure on line 6 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

src/databricks_ai_bridge/lakebase_checkpointer.py:6:17: F401 `abc.ABC` imported but unused
from typing import Any, Union, Tuple

Check failure on line 7 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

src/databricks_ai_bridge/lakebase_checkpointer.py:7:25: F401 `typing.Union` imported but unused

Check failure on line 7 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (F401)

src/databricks_ai_bridge/lakebase_checkpointer.py:7:20: F401 `typing.Any` imported but unused
from uuid import uuid4

Check failure on line 8 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / ruff check for .

Ruff (I001)

src/databricks_ai_bridge/lakebase_checkpointer.py:1:1: I001 Import block is un-sorted or un-formatted

class GenericCheckpoint(BaseModel):
####################
# CHECKPOINT COLUMNS
####################
id : str = Field(default_factory = lambda: str(uuid4())) # `id` column must exist for locating the checkpoint, even in subclasses.
state : dict = {}
creation_timestamp : datetime = Field(default_factory=datetime.now)
update_timestamp : datetime = Field(default_factory=datetime.now)
####################
# UPDATE ATTRIBUTES
####################
def update(self, **kwargs):
for k,v in kwargs.items():
# Check attribute exists
assert hasattr(self, k), f"Attribute {k} does not exist in {self.__class__.__name__}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of assert, can we raise a ValueError or AttributeError instead? assert statements are stripped when python runs with optimization requested (command line option -O).

https://docs.python.org/3/reference/simple_stmts.html#the-assert-statement

setattr(self, k, v)
self.update_timestamp = datetime.now()
#################################
# SQL GENERATION IMPLEMENTATIONS
#################################
# If subclassing the `GenericCheckpoint` class, implement new methods to handle any changes in attributes.
def generate_insert_sql(self, table_name : str) -> Tuple[str, tuple]:
sql = f"""
INSERT INTO {table_name}
(id, state, creation_timestamp, update_timestamp)
VALUES (%s, %s, %s, %s)
"""
return sql, (self.id, Json(self.state), self.creation_timestamp, self.update_timestamp)

def generate_update_sql(self, table_name : str) -> Tuple[str, tuple]:
sql = f"""
UPDATE {table_name}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

every sql method is using direct string interpolation, which puts us at risk of sql injection: can you take a look at the lakebaseclient implementation: https://github.com/databricks/databricks-ai-bridge/blob/main/src/databricks_ai_bridge/lakebase.py

which uses sql identifiers: https://www.psycopg.org/psycopg3/docs/api/sql.html

SET state = %s, update_timestamp = %s
WHERE id = %s AND creation_timestamp = %s
"""
return sql, (Json(self.state), self.update_timestamp, self.id, self.creation_timestamp)

def generate_init_sql(self, table_name : str) -> Tuple[str, tuple]:
sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
lb_id bigserial PRIMARY KEY,
id text NOT NULL,
state jsonb NOT NULL default '{{}}',
creation_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
update_timestamp timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
UNIQUE(id,creation_timestamp)
)
"""
return sql, None

Check failure on line 58 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (invalid-return-type)

src/databricks_ai_bridge/lakebase_checkpointer.py:58:16: invalid-return-type: Return type does not match returned value: expected `tuple[str, tuple[Unknown, ...]]`, found `tuple[str, None]`

def generate_retrieve_checkpoint_sql(self, table_name : str) -> Tuple[str, tuple]:
sql_all, params = self.generate_retrieval_all_checkpoints_sql(table_name = table_name)
sql = f"""
{sql_all} LIMIT 1
"""
return sql, params

def generate_retrieval_all_checkpoints_sql(self, table_name : str) -> Tuple[str, tuple]:
sql = f"""
SELECT id, state, creation_timestamp, update_timestamp
FROM {table_name}
WHERE id = %s
ORDER BY update_timestamp DESC
"""
return sql, (self.id,)

class LakebaseCheckpointer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we implement context manager support to mirror our existing checkpointsaver class? https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/langchain/src/databricks_langchain/checkpoint.py#L48

def __init__(self,
lakebase_client : LakebaseClient,
sessions_table_name : str,
checkpoint_class : GenericCheckpoint = GenericCheckpoint):

Check failure on line 80 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (invalid-parameter-default)

src/databricks_ai_bridge/lakebase_checkpointer.py:80:18: invalid-parameter-default: Default value of type `<class 'GenericCheckpoint'>` is not assignable to annotated parameter type `GenericCheckpoint`
self.checkpoint_class = checkpoint_class
self.client = lakebase_client
self.table_name = sessions_table_name
self.init_schema()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init_schema is being called automatically which then attempts to create a table each time

can you add a create_tables bool (defaulting to true) to mirror our openai session class where we create tables on init: https://github.com/databricks/databricks-ai-bridge/blob/main/integrations/openai/src/databricks_openai/agents/session.py#L120


def init_schema(self) -> None:
_checkpoint = self.checkpoint_class()

Check failure on line 87 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (call-non-callable)

src/databricks_ai_bridge/lakebase_checkpointer.py:87:23: call-non-callable: Object of type `GenericCheckpoint` is not callable
sql, params = _checkpoint.generate_init_sql(table_name = self.table_name)
response = self.client.execute(sql=sql, params = params)

def get_most_recent_checkpoint(self, id : str) -> GenericCheckpoint | None:
_checkpoint = self.checkpoint_class(id = id)

Check failure on line 92 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (call-non-callable)

src/databricks_ai_bridge/lakebase_checkpointer.py:92:23: call-non-callable: Object of type `GenericCheckpoint` is not callable
# Get the most recently updated checkpoint for this id
sql, params = _checkpoint.generate_retrieve_checkpoint_sql(table_name = self.table_name)
response = self.client.execute(sql=sql, params = params)
if len(response) == 0:

Check failure on line 96 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (invalid-argument-type)

src/databricks_ai_bridge/lakebase_checkpointer.py:96:16: invalid-argument-type: Argument to function `len` is incorrect: Expected `Sized`, found `Unknown | list[Any] | None`
return None
return self.checkpoint_class(**response[0])

Check failure on line 98 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (not-subscriptable)

src/databricks_ai_bridge/lakebase_checkpointer.py:98:40: not-subscriptable: Cannot subscript object of type `None` with no `__getitem__` method

Check failure on line 98 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (call-non-callable)

src/databricks_ai_bridge/lakebase_checkpointer.py:98:16: call-non-callable: Object of type `GenericCheckpoint` is not callable

def get_all_checkpoints(self, id : str) -> list[GenericCheckpoint]:
_checkpoint = self.checkpoint_class(id = id)

Check failure on line 101 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (call-non-callable)

src/databricks_ai_bridge/lakebase_checkpointer.py:101:23: call-non-callable: Object of type `GenericCheckpoint` is not callable
sql, params = _checkpoint.generate_retrieval_all_checkpoints_sql(table_name = self.table_name)
response = self.client.execute(sql=sql, params = params)
return [self.checkpoint_class(**_resp) for _resp in response]

Check failure on line 104 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (not-iterable)

src/databricks_ai_bridge/lakebase_checkpointer.py:104:61: not-iterable: Object of type `Unknown | list[Any] | None` may not be iterable

Check failure on line 104 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (call-non-callable)

src/databricks_ai_bridge/lakebase_checkpointer.py:104:17: call-non-callable: Object of type `GenericCheckpoint` is not callable

def checkpoint_exists(self, id : str) -> bool:
sql = f"""
SELECT COUNT(*) AS count
FROM {self.table_name}
WHERE id = %s
"""
response = self.client.execute(sql=sql, params=(id,))
count = response[0]["count"]
return True if count > 0 else False

def update_most_recent_checkpoint(self, id : str, **checkpoint_kwargs) -> None:
# Get the most recent checkpoint
_checkpoint = self.get_most_recent_checkpoint(id = id)
_checkpoint.update(**checkpoint_kwargs)

Check warning on line 119 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (possibly-missing-attribute)

src/databricks_ai_bridge/lakebase_checkpointer.py:119:9: possibly-missing-attribute: Attribute `update` may be missing on object of type `GenericCheckpoint | None`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_most_recent_checkpoint returns None when no recent checkpoint exists -- can we handle this case?

sql, params = _checkpoint.generate_update_sql(table_name = self.table_name)

Check warning on line 120 in src/databricks_ai_bridge/lakebase_checkpointer.py

View workflow job for this annotation

GitHub Actions / typechecking for .

ty (possibly-missing-attribute)

src/databricks_ai_bridge/lakebase_checkpointer.py:120:23: possibly-missing-attribute: Attribute `generate_update_sql` may be missing on object of type `GenericCheckpoint | None`
self.client.execute(sql = sql, params = params)
return

def insert_checkpoint(self, id : str, **checkpoint_kwargs) -> None:
_checkpoint = self.checkpoint_class(
id = id, **checkpoint_kwargs
)
sql, params = _checkpoint.generate_insert_sql(table_name = self.table_name)
response = self.client.execute(sql=sql, params=params)
return

def save_checkpoint(self, id : str, overwrite : bool = False, **checkpoint_kwargs) -> None:
if overwrite:
checkpoint_exists = self.checkpoint_exists(id = id)
if checkpoint_exists:
self.update_checkpoint(id = id, **checkpoint_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update_checkpoint is a nonexistent function in this class -- is this a typo (should it be self.update_most_recent_checkpoint)?

else:
self.insert_checkpoint(id = id, **checkpoint_kwargs)
else:
# Don't overwrite the most recent checkpoint, so just insert a new one
self.insert_checkpoint(id = id, **checkpoint_kwargs)
Loading