Skip to content

Commit 93e3ef0

Browse files
authored
fix: Add generic mechanism to cancel queries on exception (#64)
1 parent 764c35a commit 93e3ef0

2 files changed

Lines changed: 231 additions & 10 deletions

File tree

deepnote_toolkit/sql/sql_execution.py

Lines changed: 109 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import base64
22
import contextlib
33
import json
4-
import logging
54
import re
65
import uuid
76
import warnings
8-
from typing import Any
7+
import weakref
8+
from typing import TYPE_CHECKING, Any, Optional
99
from urllib.parse import quote
1010

1111
import google.oauth2.credentials
1212
import numpy as np
1313
import requests
14+
import wrapt
1415
from cryptography.hazmat.backends import default_backend
1516
from cryptography.hazmat.primitives import serialization
1617
from google.api_core.client_info import ClientInfo
1718
from google.cloud import bigquery
1819
from packaging.version import parse as parse_version
1920
from pydantic import BaseModel
20-
from sqlalchemy.engine import URL, create_engine, make_url
21+
from sqlalchemy.engine import URL, Connection, create_engine, make_url
2122
from sqlalchemy.exc import ResourceClosedError
2223

2324
from deepnote_core.pydantic_compat_helpers import model_validate_compat
@@ -28,6 +29,7 @@
2829
get_project_auth_headers,
2930
)
3031
from deepnote_toolkit.ipython_utils import output_sql_metadata
32+
from deepnote_toolkit.logging import LoggerManager
3133
from deepnote_toolkit.ocelots.pandas.utils import deduplicate_columns
3234
from deepnote_toolkit.sql.duckdb_sql import execute_duckdb_sql
3335
from deepnote_toolkit.sql.jinjasql_utils import render_jinja_sql_template
@@ -37,7 +39,15 @@
3739
from deepnote_toolkit.sql.sql_utils import is_single_select_query
3840
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url
3941

40-
logger = logging.getLogger(__name__)
42+
if TYPE_CHECKING:
43+
try:
44+
from sqlalchemy.engine.interfaces import DBAPIConnection, DBAPICursor
45+
except ImportError:
46+
# Not available in SQLAlchemy < 2.0. We use them only for typing, so replace with Any
47+
DBAPIConnection = Any
48+
DBAPICursor = Any
49+
50+
logger = LoggerManager().get_logger()
4151

4252

4353
class IntegrationFederatedAuthParams(BaseModel):
@@ -517,12 +527,97 @@ def _query_data_source(
517527
engine.dispose()
518528

519529

530+
class CursorTrackingDBAPIConnection(wrapt.ObjectProxy):
531+
"""Wraps DBAPI connection to track cursors as they're created."""
532+
533+
def __init__(
534+
self,
535+
wrapped: "DBAPIConnection",
536+
cursor_registry: Optional[weakref.WeakSet["DBAPICursor"]] = None,
537+
) -> None:
538+
super().__init__(wrapped)
539+
# Use provided registry or create our own
540+
self._self_cursor_registry = (
541+
cursor_registry if cursor_registry is not None else weakref.WeakSet()
542+
)
543+
544+
def cursor(self, *args, **kwargs):
545+
cursor = self.__wrapped__.cursor(*args, **kwargs)
546+
try:
547+
self._self_cursor_registry.add(cursor)
548+
except TypeError:
549+
logger.warning(
550+
f"DBAPI Cursor of type {type(cursor)} can't be added to weakset and thus can't be tracked."
551+
)
552+
return cursor
553+
554+
def cancel_all_cursors(self):
555+
"""Cancel all tracked cursors. Best-effort, ignores errors."""
556+
for cursor in self._self_cursor_registry:
557+
_cancel_cursor(cursor)
558+
559+
560+
class CursorTrackingSQLAlchemyConnection(wrapt.ObjectProxy):
561+
"""A SQLAlchemy connection wrapper that tracks cursors for cancellation.
562+
563+
This wrapper replaces the internal _dbapi_connection with a tracking proxy,
564+
so all cursors created (including by exec_driver_sql) are tracked.
565+
"""
566+
567+
def __init__(self, wrapped: Connection) -> None:
568+
super().__init__(wrapped)
569+
self._self_cursors: weakref.WeakSet[DBAPICursor] = weakref.WeakSet()
570+
self._install_dbapi_wrapper()
571+
572+
def _install_dbapi_wrapper(self):
573+
"""Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper."""
574+
try:
575+
# Access the internal DBAPI connection
576+
if hasattr(self.__wrapped__.connection, "dbapi_connection"):
577+
dbapi_conn = self.__wrapped__.connection.dbapi_connection
578+
dbapi_connection_attr_name = "dbapi_connection"
579+
else:
580+
# SQLAlchemy pre v1.4
581+
dbapi_conn = self.__wrapped__.connection.connection
582+
dbapi_connection_attr_name = "connection"
583+
if dbapi_conn is None:
584+
logger.warning(
585+
f"DBAPI connection is None (connection type {type(self.__wrapped__)}), cannot install tracking"
586+
)
587+
return
588+
589+
setattr(
590+
self.__wrapped__.connection,
591+
dbapi_connection_attr_name,
592+
CursorTrackingDBAPIConnection(dbapi_conn, self._self_cursors),
593+
)
594+
except Exception as e:
595+
logger.warning(f"Could not install DBAPI wrapper: {e}")
596+
597+
def cancel_all_cursors(self):
598+
"""Cancel all tracked cursors. Best-effort, ignores errors."""
599+
for cursor in self._self_cursors:
600+
_cancel_cursor(cursor)
601+
602+
603+
def _cancel_cursor(cursor: "DBAPICursor") -> None:
604+
"""Best-effort cancel a cursor using available methods."""
605+
try:
606+
if hasattr(cursor, "cancel") and callable(cursor.cancel):
607+
cursor.cancel()
608+
except (Exception, KeyboardInterrupt):
609+
pass # Best effort, ignore all errors
610+
611+
520612
def _execute_sql_on_engine(engine, query, bind_params):
521613
"""Run *query* on *engine* and return a DataFrame.
522614
523615
Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection.
524616
For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute,
525617
we use the underlying connection.
618+
619+
On exceptions (including KeyboardInterrupt from cell cancellation), all cursors
620+
created during execution are cancelled to stop running queries on the server.
526621
"""
527622

528623
import pandas as pd
@@ -544,26 +639,30 @@ def _execute_sql_on_engine(engine, query, bind_params):
544639
)
545640

546641
with engine.begin() as connection:
547-
try:
548-
# For pandas 2.2+, use raw connection to avoid 'cursor' AttributeError
549-
connection_for_pandas = (
550-
connection.connection if needs_raw_connection else connection
551-
)
642+
# For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection
643+
if needs_raw_connection:
644+
tracking_connection = CursorTrackingDBAPIConnection(connection.connection)
645+
else:
646+
tracking_connection = CursorTrackingSQLAlchemyConnection(connection)
552647

648+
try:
553649
# pandas.read_sql_query expects params as tuple (not list) for qmark/format style
554650
params_for_pandas = (
555651
tuple(bind_params) if isinstance(bind_params, list) else bind_params
556652
)
557653

558654
return pd.read_sql_query(
559655
query,
560-
con=connection_for_pandas,
656+
con=tracking_connection,
561657
params=params_for_pandas,
562658
coerce_float=coerce_float,
563659
)
564660
except ResourceClosedError:
565661
# this happens if the query is e.g. UPDATE and pandas tries to create a dataframe from its result
566662
return None
663+
except KeyboardInterrupt:
664+
tracking_connection.cancel_all_cursors()
665+
raise
567666

568667

569668
def _build_params_for_bigquery_oauth(params):

tests/unit/test_sql_execution_internal.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
from typing import Any
23
from unittest import mock
34

45
import numpy as np
@@ -9,6 +10,42 @@
910
from deepnote_toolkit.sql import sql_execution as se
1011

1112

13+
def _setup_mock_engine_with_cursor(mock_cursor: mock.Mock) -> mock.Mock:
14+
"""Helper to set up mock engine and connection with a custom cursor.
15+
16+
Returns mock_engine that can be passed to _execute_sql_on_engine.
17+
"""
18+
import sqlalchemy
19+
20+
mock_dbapi_connection: mock.Mock = mock.Mock()
21+
mock_dbapi_connection.cursor.return_value = mock_cursor
22+
23+
mock_pool_connection = mock.Mock()
24+
mock_pool_connection.dbapi_connection = mock_dbapi_connection
25+
mock_pool_connection.cursor.side_effect = (
26+
lambda: mock_pool_connection.dbapi_connection.cursor()
27+
)
28+
29+
mock_sa_connection = mock.Mock(spec=sqlalchemy.engine.Connection)
30+
mock_sa_connection.connection = mock_pool_connection
31+
mock_sa_connection.in_transaction.return_value = False
32+
33+
def mock_exec_driver_sql(sql: str, *args: Any) -> mock.Mock:
34+
cursor: mock.Mock = mock_sa_connection.connection.cursor()
35+
cursor.execute(sql, *args)
36+
return cursor
37+
38+
mock_sa_connection.exec_driver_sql = mock_exec_driver_sql
39+
40+
mock_engine = mock.Mock()
41+
mock_engine.begin.return_value.__enter__ = mock.Mock(
42+
return_value=mock_sa_connection
43+
)
44+
mock_engine.begin.return_value.__exit__ = mock.Mock(return_value=False)
45+
46+
return mock_engine
47+
48+
1249
def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
1350
import google.cloud.bigquery._job_helpers as _job_helpers
1451

@@ -30,6 +67,91 @@ def test_bigquery_wait_or_cancel_handles_keyboard_interrupt():
3067
mock_job.cancel.assert_called_once_with(retry=None, timeout=30.0)
3168

3269

70+
def test_execute_sql_on_engine_cancels_cursor_on_keyboard_interrupt():
71+
"""Test that _execute_sql_on_engine cancels cursors on KeyboardInterrupt."""
72+
73+
mock_cursor = mock.MagicMock()
74+
mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled")
75+
76+
mock_engine = _setup_mock_engine_with_cursor(mock_cursor)
77+
78+
with pytest.raises(KeyboardInterrupt):
79+
se._execute_sql_on_engine(mock_engine, "SELECT 1", {})
80+
81+
mock_cursor.cancel.assert_called_once()
82+
83+
84+
def test_execute_sql_on_engine_handles_cancel_errors_gracefully():
85+
"""Test that _execute_sql_on_engine handles cancel errors gracefully."""
86+
87+
mock_cursor = mock.MagicMock()
88+
mock_cursor.execute.side_effect = KeyboardInterrupt("Cancelled")
89+
mock_cursor.cancel.side_effect = RuntimeError("Cancel failed")
90+
91+
mock_engine = _setup_mock_engine_with_cursor(mock_cursor)
92+
93+
# Should raise original KeyboardInterrupt, not the cancel error
94+
with pytest.raises(KeyboardInterrupt):
95+
se._execute_sql_on_engine(mock_engine, "SELECT 1", {})
96+
97+
mock_cursor.cancel.assert_called_once()
98+
99+
100+
def test_cursor_tracking_dbapi_connection_cancel_all_cursors():
101+
"""Test that CursorTrackingDBAPIConnection.cancel_all_cursors cancels all tracked cursors."""
102+
mock_wrapped_conn = mock.Mock()
103+
cursor1 = mock.Mock()
104+
cursor2 = mock.Mock()
105+
mock_wrapped_conn.cursor.side_effect = [cursor1, cursor2]
106+
107+
tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn)
108+
109+
# Create two cursors
110+
tracking_conn.cursor()
111+
tracking_conn.cursor()
112+
113+
# Cancel all cursors
114+
tracking_conn.cancel_all_cursors()
115+
116+
cursor1.cancel.assert_called_once()
117+
cursor2.cancel.assert_called_once()
118+
119+
120+
def test_cursor_tracking_dbapi_connection_handles_unhashable_cursor():
121+
"""Test that CursorTrackingDBAPIConnection handles cursors that can't be added to weakset."""
122+
mock_wrapped_conn = mock.Mock()
123+
124+
class UnhashableCursor:
125+
__hash__ = None
126+
127+
unhashable_cursor = UnhashableCursor()
128+
mock_wrapped_conn.cursor.return_value = unhashable_cursor
129+
130+
tracking_conn = se.CursorTrackingDBAPIConnection(mock_wrapped_conn)
131+
132+
with mock.patch.object(se.logger, "warning") as mock_warning:
133+
result = tracking_conn.cursor()
134+
135+
assert result is unhashable_cursor
136+
mock_warning.assert_called_once()
137+
assert "can't be added to weakset" in mock_warning.call_args[0][0]
138+
139+
140+
def test_cursor_tracking_sqlalchemy_connection_handles_none_dbapi_connection():
141+
"""Test that CursorTrackingSQLAlchemyConnection handles None dbapi connection."""
142+
mock_conn_pool = mock.Mock()
143+
mock_conn_pool.dbapi_connection = None
144+
145+
mock_sa_conn = mock.Mock()
146+
mock_sa_conn.connection = mock_conn_pool
147+
148+
with mock.patch.object(se.logger, "warning") as mock_warning:
149+
se.CursorTrackingSQLAlchemyConnection(mock_sa_conn)
150+
151+
mock_warning.assert_called_once()
152+
assert "DBAPI connection is None" in mock_warning.call_args[0][0]
153+
154+
33155
def test_build_params_for_bigquery_oauth_ok():
34156
with mock.patch(
35157
"deepnote_toolkit.sql.sql_execution.bigquery.Client"

0 commit comments

Comments
 (0)