11import base64
22import contextlib
33import json
4- import logging
54import re
65import uuid
76import warnings
8- from typing import Any
7+ import weakref
8+ from typing import TYPE_CHECKING , Any , Optional
99from urllib .parse import quote
1010
1111import google .oauth2 .credentials
1212import numpy as np
1313import requests
14+ import wrapt
1415from cryptography .hazmat .backends import default_backend
1516from cryptography .hazmat .primitives import serialization
1617from google .api_core .client_info import ClientInfo
1718from google .cloud import bigquery
1819from packaging .version import parse as parse_version
1920from pydantic import BaseModel
20- from sqlalchemy .engine import URL , create_engine , make_url
21+ from sqlalchemy .engine import URL , Connection , create_engine , make_url
2122from sqlalchemy .exc import ResourceClosedError
2223
2324from deepnote_core .pydantic_compat_helpers import model_validate_compat
2829 get_project_auth_headers ,
2930)
3031from deepnote_toolkit .ipython_utils import output_sql_metadata
32+ from deepnote_toolkit .logging import LoggerManager
3133from deepnote_toolkit .ocelots .pandas .utils import deduplicate_columns
3234from deepnote_toolkit .sql .duckdb_sql import execute_duckdb_sql
3335from deepnote_toolkit .sql .jinjasql_utils import render_jinja_sql_template
3739from deepnote_toolkit .sql .sql_utils import is_single_select_query
3840from deepnote_toolkit .sql .url_utils import replace_user_pass_in_pg_url
3941
40- logger = logging .getLogger (__name__ )
42+ if TYPE_CHECKING :
43+ try :
44+ from sqlalchemy .engine .interfaces import DBAPIConnection , DBAPICursor
45+ except ImportError :
46+ # Not available in SQLAlchemy < 2.0. We use them only for typing, so replace with Any
47+ DBAPIConnection = Any
48+ DBAPICursor = Any
49+
50+ logger = LoggerManager ().get_logger ()
4151
4252
4353class IntegrationFederatedAuthParams (BaseModel ):
@@ -517,12 +527,97 @@ def _query_data_source(
517527 engine .dispose ()
518528
519529
530+ class CursorTrackingDBAPIConnection (wrapt .ObjectProxy ):
531+ """Wraps DBAPI connection to track cursors as they're created."""
532+
533+ def __init__ (
534+ self ,
535+ wrapped : "DBAPIConnection" ,
536+ cursor_registry : Optional [weakref .WeakSet ["DBAPICursor" ]] = None ,
537+ ) -> None :
538+ super ().__init__ (wrapped )
539+ # Use provided registry or create our own
540+ self ._self_cursor_registry = (
541+ cursor_registry if cursor_registry is not None else weakref .WeakSet ()
542+ )
543+
544+ def cursor (self , * args , ** kwargs ):
545+ cursor = self .__wrapped__ .cursor (* args , ** kwargs )
546+ try :
547+ self ._self_cursor_registry .add (cursor )
548+ except TypeError :
549+ logger .warning (
550+ f"DBAPI Cursor of type { type (cursor )} can't be added to weakset and thus can't be tracked."
551+ )
552+ return cursor
553+
554+ def cancel_all_cursors (self ):
555+ """Cancel all tracked cursors. Best-effort, ignores errors."""
556+ for cursor in self ._self_cursor_registry :
557+ _cancel_cursor (cursor )
558+
559+
560+ class CursorTrackingSQLAlchemyConnection (wrapt .ObjectProxy ):
561+ """A SQLAlchemy connection wrapper that tracks cursors for cancellation.
562+
563+ This wrapper replaces the internal _dbapi_connection with a tracking proxy,
564+ so all cursors created (including by exec_driver_sql) are tracked.
565+ """
566+
567+ def __init__ (self , wrapped : Connection ) -> None :
568+ super ().__init__ (wrapped )
569+ self ._self_cursors : weakref .WeakSet [DBAPICursor ] = weakref .WeakSet ()
570+ self ._install_dbapi_wrapper ()
571+
572+ def _install_dbapi_wrapper (self ):
573+ """Replace SQLAlchemy's internal DBAPI connection with our tracking wrapper."""
574+ try :
575+ # Access the internal DBAPI connection
576+ if hasattr (self .__wrapped__ .connection , "dbapi_connection" ):
577+ dbapi_conn = self .__wrapped__ .connection .dbapi_connection
578+ dbapi_connection_attr_name = "dbapi_connection"
579+ else :
580+ # SQLAlchemy pre v1.4
581+ dbapi_conn = self .__wrapped__ .connection .connection
582+ dbapi_connection_attr_name = "connection"
583+ if dbapi_conn is None :
584+ logger .warning (
585+ f"DBAPI connection is None (connection type { type (self .__wrapped__ )} ), cannot install tracking"
586+ )
587+ return
588+
589+ setattr (
590+ self .__wrapped__ .connection ,
591+ dbapi_connection_attr_name ,
592+ CursorTrackingDBAPIConnection (dbapi_conn , self ._self_cursors ),
593+ )
594+ except Exception as e :
595+ logger .warning (f"Could not install DBAPI wrapper: { e } " )
596+
597+ def cancel_all_cursors (self ):
598+ """Cancel all tracked cursors. Best-effort, ignores errors."""
599+ for cursor in self ._self_cursors :
600+ _cancel_cursor (cursor )
601+
602+
603+ def _cancel_cursor (cursor : "DBAPICursor" ) -> None :
604+ """Best-effort cancel a cursor using available methods."""
605+ try :
606+ if hasattr (cursor , "cancel" ) and callable (cursor .cancel ):
607+ cursor .cancel ()
608+ except (Exception , KeyboardInterrupt ):
609+ pass # Best effort, ignore all errors
610+
611+
520612def _execute_sql_on_engine (engine , query , bind_params ):
521613 """Run *query* on *engine* and return a DataFrame.
522614
523615 Uses pandas.read_sql_query to execute the query with a SQLAlchemy connection.
524616 For pandas 2.2+ and SQLAlchemy < 2.0, which requires a raw DB-API connection with a `.cursor()` attribute,
525617 we use the underlying connection.
618+
619+ On exceptions (including KeyboardInterrupt from cell cancellation), all cursors
620+ created during execution are cancelled to stop running queries on the server.
526621 """
527622
528623 import pandas as pd
@@ -544,26 +639,30 @@ def _execute_sql_on_engine(engine, query, bind_params):
544639 )
545640
546641 with engine .begin () as connection :
547- try :
548- # For pandas 2.2+, use raw connection to avoid 'cursor' AttributeError
549- connection_for_pandas = (
550- connection . connection if needs_raw_connection else connection
551- )
642+ # For pandas 2.2+ with SQLAlchemy < 2.0, use raw DBAPI connection
643+ if needs_raw_connection :
644+ tracking_connection = CursorTrackingDBAPIConnection ( connection . connection )
645+ else :
646+ tracking_connection = CursorTrackingSQLAlchemyConnection ( connection )
552647
648+ try :
553649 # pandas.read_sql_query expects params as tuple (not list) for qmark/format style
554650 params_for_pandas = (
555651 tuple (bind_params ) if isinstance (bind_params , list ) else bind_params
556652 )
557653
558654 return pd .read_sql_query (
559655 query ,
560- con = connection_for_pandas ,
656+ con = tracking_connection ,
561657 params = params_for_pandas ,
562658 coerce_float = coerce_float ,
563659 )
564660 except ResourceClosedError :
565661 # this happens if the query is e.g. UPDATE and pandas tries to create a dataframe from its result
566662 return None
663+ except KeyboardInterrupt :
664+ tracking_connection .cancel_all_cursors ()
665+ raise
567666
568667
569668def _build_params_for_bigquery_oauth (params ):
0 commit comments