From 34e0b8944c1656d8978af71019b52550925e1901 Mon Sep 17 00:00:00 2001 From: Joe S Date: Mon, 12 Jan 2026 16:45:55 -0800 Subject: [PATCH 01/40] implement native async client --- .github/workflows/clickhouse_ci.yml | 4 +- .github/workflows/on_push.yml | 8 +- README.md | 10 +- clickhouse_connect/common.py | 2 +- clickhouse_connect/driver/__init__.py | 147 +- clickhouse_connect/driver/aiohttp_client.py | 1759 +++++++++++++++++ clickhouse_connect/driver/asyncclient.py | 727 +++---- clickhouse_connect/driver/asyncqueue.py | 174 ++ clickhouse_connect/driver/client.py | 29 +- clickhouse_connect/driver/common.py | 56 +- clickhouse_connect/driver/httpclient.py | 14 +- clickhouse_connect/driver/query.py | 20 +- clickhouse_connect/driver/streaming.py | 311 +++ clickhouse_connect/driver/tools.py | 36 + pyproject.toml | 1 + setup.py | 1 + tests/integration_tests/conftest.py | 168 +- tests/integration_tests/test_arrow.py | 51 +- .../integration_tests/test_async_features.py | 303 +++ tests/integration_tests/test_client.py | 373 ++-- tests/integration_tests/test_contexts.py | 20 +- tests/integration_tests/test_dynamic.py | 106 +- .../integration_tests/test_error_handling.py | 134 +- tests/integration_tests/test_external_data.py | 80 +- .../test_form_encode_query.py | 57 +- tests/integration_tests/test_formats.py | 14 +- tests/integration_tests/test_geometric.py | 20 +- tests/integration_tests/test_inserts.py | 70 +- tests/integration_tests/test_jwt_auth.py | 148 +- .../integration_tests/test_multithreading.py | 121 +- tests/integration_tests/test_native.py | 102 +- tests/integration_tests/test_native_fuzz.py | 20 +- tests/integration_tests/test_network.py | 76 +- tests/integration_tests/test_numeric.py | 101 +- tests/integration_tests/test_numpy.py | 106 +- tests/integration_tests/test_pandas.py | 172 +- tests/integration_tests/test_pandas_compat.py | 86 +- tests/integration_tests/test_params.py | 48 +- tests/integration_tests/test_polars.py | 44 +- .../test_protocol_version.py | 8 +- tests/integration_tests/test_proxy.py | 51 +- .../test_pyarrow_ddl_integration.py | 44 +- tests/integration_tests/test_raw_insert.py | 18 +- tests/integration_tests/test_session_id.py | 65 +- tests/integration_tests/test_streaming.py | 150 +- tests/integration_tests/test_temporal.py | 453 ++--- tests/integration_tests/test_timezones.py | 82 +- tests/integration_tests/test_tls.py | 36 +- tests/integration_tests/test_tools.py | 45 +- tests/integration_tests/test_vector.py | 100 +- tests/test_requirements.txt | 2 + tests/unit_tests/test_asyncqueue.py | 297 +++ tests/unit_tests/test_streaming_source.py | 443 +++++ 53 files changed, 5558 insertions(+), 1955 deletions(-) create mode 100644 clickhouse_connect/driver/aiohttp_client.py create mode 100644 clickhouse_connect/driver/asyncqueue.py create mode 100644 clickhouse_connect/driver/streaming.py create mode 100644 tests/integration_tests/test_async_features.py create mode 100644 tests/unit_tests/test_asyncqueue.py create mode 100644 tests/unit_tests/test_streaming_source.py diff --git a/.github/workflows/clickhouse_ci.yml b/.github/workflows/clickhouse_ci.yml index a8901243..c816af25 100644 --- a/.github/workflows/clickhouse_ci.yml +++ b/.github/workflows/clickhouse_ci.yml @@ -36,11 +36,11 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests + run: pytest tests/integration_tests -n 4 - name: Run ClickHouse Container (HEAD) run: CLICKHOUSE_CONNECT_TEST_CH_VERSION=head docker compose up -d clickhouse - name: Run HEAD tests - run: pytest tests/integration_tests + run: pytest tests/integration_tests -n 4 - name: remove HEAD container run: docker compose down -v diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index 2114f6d2..8b5a014e 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -104,7 +104,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' CLICKHOUSE_CONNECT_TEST_FUZZ: 50 SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests + run: pytest -n 4 tests pandas-1x-compat-test: runs-on: ubuntu-latest @@ -142,7 +142,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_TLS: 1 CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests/test_pandas_compat.py tests/integration_tests/test_pandas.py + run: pytest -n 4 tests/integration_tests/test_pandas_compat.py tests/integration_tests/test_pandas.py sqlalchemy-1x-compat-test: runs-on: ubuntu-latest @@ -180,7 +180,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_TLS: 1 CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests/test_sqlalchemy + run: pytest -n 4 tests/integration_tests/test_sqlalchemy check-secret: runs-on: ubuntu-latest @@ -234,4 +234,4 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests + run: pytest -n 4 tests/integration_tests diff --git a/README.md b/README.md index 8b7288b9..e1995b72 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ ClickHouse Connect currently uses the ClickHouse HTTP interface for maximum comp pip install clickhouse-connect ``` -ClickHouse Connect requires Python 3.9 or higher. We officially test against Python 3.9 through 3.13. +ClickHouse Connect requires Python 3.9 or higher. We officially test against Python 3.9 through 3.14. ### Superset Connectivity @@ -45,7 +45,13 @@ that rely on full ORM or advanced dialect functionality. ### Asyncio Support -ClickHouse Connect provides an async wrapper, so that it is possible to use the client in an `asyncio` environment. +ClickHouse Connect provides native async support using aiohttp. For the best performance with async applications, +install the optional async dependency: + +``` +pip install clickhouse-connect[async] +``` + See the [run_async example](./examples/run_async.py) for more details. ### Complete Documentation diff --git a/clickhouse_connect/common.py b/clickhouse_connect/common.py index 04f43524..49a3646c 100644 --- a/clickhouse_connect/common.py +++ b/clickhouse_connect/common.py @@ -29,7 +29,7 @@ class CommonSetting: _common_settings: Dict[str, CommonSetting] = {} -def build_client_name(client_name: str) -> str: +def build_client_name(client_name: Optional[str]) -> str: product_name = get_setting('product_name') product_name = product_name.strip() + ' ' if product_name else '' client_name = client_name.strip() + ' ' if client_name else '' diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index e946492e..f5d0eda8 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -1,17 +1,70 @@ import asyncio from concurrent.futures import ThreadPoolExecutor from inspect import signature -from typing import Optional, Union, Dict, Any +from typing import Optional, Union, Dict, Any, Tuple from urllib.parse import urlparse, parse_qs import clickhouse_connect.driver.ctypes from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.common import dict_copy from clickhouse_connect.driver.exceptions import ProgrammingError from clickhouse_connect.driver.httpclient import HttpClient from clickhouse_connect.driver.asyncclient import AsyncClient, DefaultThreadPoolExecutor, NEW_THREAD_POOL_EXECUTOR +def default_port(interface: str, secure: bool) -> int: + """Get default port for the given interface.""" + if interface.startswith("http"): + return 8443 if secure else 8123 + raise ValueError("Unrecognized ClickHouse interface") + + +def _parse_connection_params( + host: Optional[str], + username: Optional[str], + password: str, + port: int, + database: str, + interface: Optional[str], + secure: Union[bool, str], + dsn: Optional[str], + kwargs: Dict[str, Any] +) -> Tuple[str, Optional[str], str, int, str, str]: + """Parse and normalize connection parameters including DSN parsing.""" + if dsn: + parsed = urlparse(dsn) + username = username or parsed.username + password = password or parsed.password + host = host or parsed.hostname + port = port or parsed.port + if parsed.path and (not database or database == "__default__"): + database = parsed.path[1:].split("/")[0] + database = database or parsed.path + for k, v in parse_qs(parsed.query).items(): + kwargs[k] = v[0] + use_tls = str(secure).lower() == "true" or interface == "https" or (not interface and str(port) in ("443", "8443")) + if not host: + host = "localhost" + if not interface: + interface = "https" if use_tls else "http" + port = port or default_port(interface, use_tls) + if username is None and "user" in kwargs: + username = kwargs.pop("user") + if username is None and "user_name" in kwargs: + username = kwargs.pop("user_name") + if password and username is None: + username = "default" + if "compression" in kwargs and "compress" not in kwargs: + kwargs["compress"] = kwargs.pop("compression") + + return host, username, password, port, database, interface + + +def _validate_access_token(access_token: Optional[str], username: Optional[str], password: str) -> None: + """Validate that access_token and username/password are not both provided.""" + if access_token and (username or password != ""): + raise ProgrammingError("Cannot use both access_token and username/password") + + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches def create_client(*, host: Optional[str] = None, @@ -84,33 +137,11 @@ def create_client(*, limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect Client instance """ - if dsn: - parsed = urlparse(dsn) - username = username or parsed.username - password = password or parsed.password - host = host or parsed.hostname - port = port or parsed.port - if parsed.path and (not database or database == '__default__'): - database = parsed.path[1:].split('/')[0] - database = database or parsed.path - for k, v in parse_qs(parsed.query).items(): - kwargs[k] = v[0] - use_tls = str(secure).lower() == 'true' or interface == 'https' or (not interface and str(port) in ('443', '8443')) - if not host: - host = 'localhost' - if not interface: - interface = 'https' if use_tls else 'http' - port = port or default_port(interface, use_tls) - if access_token and (username or password != ''): - raise ProgrammingError('Cannot use both access_token and username/password') - if username is None and 'user' in kwargs: - username = kwargs.pop('user') - if username is None and 'user_name' in kwargs: - username = kwargs.pop('user_name') - if password and username is None: - username = 'default' - if 'compression' in kwargs and 'compress' not in kwargs: - kwargs['compress'] = kwargs.pop('compression') + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs + ) + _validate_access_token(access_token, username, password) + settings = settings or {} if interface.startswith('http'): if generic_args: @@ -130,16 +161,11 @@ def create_client(*, raise ProgrammingError(f'Unrecognized client type {interface}') -def default_port(interface: str, secure: bool): - if interface.startswith('http'): - return 8443 if secure else 8123 - raise ValueError('Unrecognized ClickHouse interface') - - async def create_async_client(*, host: Optional[str] = None, username: Optional[str] = None, password: str = '', + access_token: Optional[str] = None, database: str = '__default__', interface: Optional[str] = None, port: int = 0, @@ -149,6 +175,9 @@ async def create_async_client(*, generic_args: Optional[Dict[str, Any]] = None, executor_threads: int = 0, executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, **kwargs) -> AsyncClient: """ The preferred method to get an async ClickHouse Connect Client instance. @@ -159,6 +188,7 @@ async def create_async_client(*, :param host: The hostname or IP address of the ClickHouse server. If not set, localhost will be used. :param username: The ClickHouse username. If not set, the default ClickHouse user will be used. :param password: The password for username. + :param access_token: JWT access token. :param database: The default database for the connection. If not set, ClickHouse Connect will use the default database for username. :param interface: Must be http or https. Defaults to http, or to https if port is set to 8443 or 443 @@ -170,11 +200,11 @@ async def create_async_client(*, :param settings: ClickHouse server settings to be used with the session/every request :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally - :param executor_threads: 'max_worker' threads used by the client ThreadPoolExecutor. If not set, the default - of 4 + detected CPU cores will be used - :param executor: Optional `ThreadPoolExecutor` to use for async operations. If not set, a new `ThreadPoolExecutor` - will be created with the number of threads specified by `executor_threads`. If set to `None` it will use the - default executor of the event loop. + :param executor_threads: (LEGACY) 'max_worker' threads used by the client ThreadPoolExecutor. + :param executor: (LEGACY) Optional `ThreadPoolExecutor` to use for async operations. + :param connector_limit: Maximum number of allowable connections to the server (native async) + :param connector_limit_per_host: Maximum number of connections per host (native async) + :param keepalive_timeout: Time limit on idle keepalive connections (native async) :param kwargs -- Recognized keyword arguments (used by the HTTP client), see below :param compress: Enable compression for ClickHouse HTTP inserts and query results. True will select the preferred @@ -209,15 +239,34 @@ async def create_async_client(*, :param form_encode_query_params If True, query parameters will be sent as form-encoded data in the request body instead of as URL parameters. This is useful for queries with large parameter sets that might exceed URL length limits. Only available for query operations (not inserts). Default: False - :return: ClickHouse Connect Client instance + :return: ClickHouse Connect AsyncClient instance """ + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs + ) + _validate_access_token(access_token, username, password) + + if executor_threads != 0 or executor is not NEW_THREAD_POOL_EXECUTOR: + # LEGACY PATH: User explicitly requested executor-based client + def _create_client(): + if 'autogenerate_session_id' not in kwargs: + kwargs['autogenerate_session_id'] = False + return create_client(host=host, username=username, password=password, database=database, interface=interface, + port=port, secure=secure, dsn=None, settings=settings, generic_args=generic_args, **kwargs) + + loop = asyncio.get_running_loop() + _client = await loop.run_in_executor(None, _create_client) + return AsyncClient(client=_client, executor_threads=executor_threads, executor=executor) - def _create_client(): - if 'autogenerate_session_id' not in kwargs: - kwargs['autogenerate_session_id'] = False - return create_client(host=host, username=username, password=password, database=database, interface=interface, - port=port, secure=secure, dsn=dsn, settings=settings, generic_args=generic_args, **kwargs) + # NATIVE PATH: Default to true async client + # Set autogenerate_session_id to False by default + if "autogenerate_session_id" not in kwargs: + kwargs["autogenerate_session_id"] = False - loop = asyncio.get_running_loop() - _client = await loop.run_in_executor(None, _create_client) - return AsyncClient(client=_client, executor_threads=executor_threads, executor=executor) + client = AsyncClient(host=host, username=username, password=password, access_token=access_token, + database=database, interface=interface, + port=port, secure=secure, dsn=None, settings=settings, generic_args=generic_args, + connector_limit=connector_limit, connector_limit_per_host=connector_limit_per_host, + keepalive_timeout=keepalive_timeout, **kwargs) + await client._initialize() # pylint: disable=protected-access + return client diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py new file mode 100644 index 00000000..b8c3a8ec --- /dev/null +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -0,0 +1,1759 @@ +# pylint: disable=too-many-lines + +import asyncio +import gzip +import io +import json +import logging +import re +import ssl +import time +import uuid +import pytz +import zlib +from base64 import b64encode +from datetime import tzinfo +from importlib import import_module +from importlib.metadata import version as dist_version +from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Optional, Sequence, Union + +import aiohttp +import lz4.frame +import zstandard + +from clickhouse_connect import common +from clickhouse_connect.datatypes import dynamic as dynamic_module +from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import httputil, tzutil +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import DatabaseError, DataError, OperationalError, ProgrammingError +from clickhouse_connect.driver.external import ExternalData +from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.models import ColumnDef, SettingDef +from clickhouse_connect.driver.options import IS_PANDAS_2, arrow, check_arrow, check_numpy, check_pandas, check_polars, pd, pl +from clickhouse_connect.driver.query import QueryContext, QueryResult, arrow_buffer, to_arrow +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.streaming import StreamingInsertSource +from clickhouse_connect.driver.transform import NativeTransform +from clickhouse_connect.driver.streaming import StreamingResponseSource, StreamingFileAdapter + +logger = logging.getLogger(__name__) +columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) +ex_header = "X-ClickHouse-Exception-Code" + +if "br" in available_compression: + import brotli +else: + brotli = None + +def decompress_response(data: bytes, encoding: Optional[str]) -> bytes: + """Decompress response data based on Content-Encoding header.""" + + if not encoding or encoding == "identity": + return data + + if encoding == "lz4": + lz4_decom = lz4.frame.LZ4FrameDecompressor() + return lz4_decom.decompress(data, len(data)) + if encoding == "zstd": + zstd_decom = zstandard.ZstdDecompressor() + return zstd_decom.stream_reader(io.BytesIO(data)).read() + if encoding == "br": + if brotli is not None: + return brotli.decompress(data) + raise OperationalError("Brotli compression requested but not installed.") + if encoding == "gzip": + return gzip.decompress(data) + if encoding == "deflate": + return zlib.decompress(data) + raise OperationalError(f"Unsupported compression type: '{encoding}'. Supported compression: {', '.join(available_compression)}") + + +class BytesSource: + """Wrapper to make bytes compatible with ResponseBuffer expectations.""" + + def __init__(self, data: bytes): + self.data = data + self.gen = self._make_generator() + + def _make_generator(self): + yield self.data + + def close(self): + """No-op close method for compatibility.""" + +# pylint: disable=invalid-overridden-method, too-many-instance-attributes, too-many-public-methods, broad-exception-caught +class AiohttpAsyncClient(Client): + valid_transport_settings = {"database", "buffer_size", "session_id", + "compress", "decompress", "session_timeout", + "session_check", "query_id", "quota_key", + "wait_end_of_query", "client_protocol_version", + "role"} + optional_transport_settings = {"send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression"} + + # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements + def __init__( + self, + interface: str, + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + access_token: Optional[str] = None, + compress: Union[bool, str] = True, + connect_timeout: int = 10, + send_receive_timeout: int = 300, + client_name: Optional[str] = None, + verify: Union[bool, str] = True, + ca_cert: Optional[str] = None, + client_cert: Optional[str] = None, + client_cert_key: Optional[str] = None, + http_proxy: Optional[str] = None, + https_proxy: Optional[str] = None, + server_host_name: Optional[str] = None, + tls_mode: Optional[str] = None, + proxy_path: str = "", + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, + session_id: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + query_limit: int = 0, + query_retries: int = 2, + apply_server_timezone: Optional[Union[str, bool]] = None, + utc_tz_aware: Optional[bool] = None, + show_clickhouse_errors: Optional[bool] = None, + autogenerate_session_id: Optional[bool] = None, + autogenerate_query_id: Optional[bool] = None, + form_encode_query_params: bool = False, + **kwargs, + ): + """ + Async HTTP Client using aiohttp. Initialization is handled via _initialize(). + """ + proxy_path = proxy_path.lstrip("/") + if proxy_path: + proxy_path = "/" + proxy_path + self.uri = f"{interface}://{host}:{port}{proxy_path}" + self.url = self.uri + self.form_encode_query_params = form_encode_query_params + self._rename_response_column = kwargs.get("rename_response_column") + self._initial_settings = settings + self.headers = {} + + if interface == "https": + if isinstance(verify, str) and verify.lower() == "proxy": + verify = True + tls_mode = tls_mode or "proxy" + + # Priority: access_token > mutual TLS > basic auth + if client_cert and (tls_mode is None or tls_mode == "mutual"): + if not username: + raise ProgrammingError("username parameter is required for Mutual TLS authentication") + self.headers["X-ClickHouse-User"] = username + self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" + elif access_token: + self.headers["Authorization"] = f"Bearer {access_token}" + elif username and (not client_cert or tls_mode in ("strict", "proxy")): + credentials = b64encode(f"{username}:{password}".encode()).decode() + self.headers["Authorization"] = f"Basic {credentials}" + + self.headers["User-Agent"] = common.build_client_name(client_name) + # Prevent aiohttp from automatically requesting compressed responses + # We'll manually set Accept-Encoding when compression is desired + self.headers["Accept-Encoding"] = "identity" + self._send_receive_timeout = send_receive_timeout + + connect_timeout_val = float(connect_timeout) if connect_timeout is not None else None + send_receive_timeout_val = float(send_receive_timeout) if send_receive_timeout is not None else None + + self._timeout = aiohttp.ClientTimeout( + total=None, + connect=connect_timeout_val, + sock_connect=connect_timeout_val, + sock_read=send_receive_timeout_val, + ) + connector_limit_per_host = min(connector_limit_per_host, connector_limit) + + proxy_url = None + if http_proxy: + if not http_proxy.startswith("http://") and not http_proxy.startswith("https://"): + proxy_url = f"http://{http_proxy}" + else: + proxy_url = http_proxy + elif https_proxy: + if not https_proxy.startswith("http://") and not https_proxy.startswith("https://"): + proxy_url = f"http://{https_proxy}" + else: + proxy_url = https_proxy + else: + scheme = "https" if self.url.startswith("https://") else "http" + env_proxy = httputil.check_env_proxy(scheme, host, port) + if env_proxy: + if not env_proxy.startswith("http://") and not env_proxy.startswith("https://"): + proxy_url = f"http://{env_proxy}" + else: + proxy_url = env_proxy + + ssl_context = None + if interface == "https": + ssl_context = ssl.create_default_context() + ssl_verify = verify if isinstance(verify, bool) else coerce_bool(verify) + if not ssl_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ca_cert: + ssl_context.load_verify_locations(ca_cert) + if client_cert: + ssl_context.load_cert_chain(client_cert, client_cert_key) + + self._ssl_context = ssl_context + self._proxy_url = proxy_url + self._connector_kwargs = { + "limit": connector_limit, + "limit_per_host": connector_limit_per_host, + "keepalive_timeout": keepalive_timeout, + "enable_cleanup_closed": True, + "force_close": False, + "ssl": ssl_context, + } + + self._session = None + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._client_settings = {} + self._initialized = False + self._reported_libs = set() + self._last_pool_reset = None + self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") + + # Store aiohttp-specific params for deferred initialization + self._compress_param = compress + self._session_id_param = session_id + self._autogenerate_session_id_param = autogenerate_session_id + self._autogenerate_query_id = ( + common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id + ) + self._active_session = None + self._send_progress = None + self._progress_interval = None + + # Call parent init with autoconnect=False to set up config without blocking I/O + super().__init__( + database=database, + query_limit=query_limit, + uri=self.uri, + query_retries=query_retries, + server_host_name=server_host_name, + apply_server_timezone=apply_server_timezone, + utc_tz_aware=utc_tz_aware, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=False + ) + + async def _initialize(self, apply_server_timezone: Optional[Union[str, bool]] = None): + """ + Async equivalent of Client._init_common_settings. + Fetches server version, timezone, and settings. + """ + if not self._session: + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + + if self._initialized: + return + + try: + if apply_server_timezone is None: + apply_server_timezone = self._deferred_apply_server_timezone + + self.server_tz, dst_safe = pytz.UTC, True + row = await self.command("SELECT version(), timezone()", use_database=False) + self.server_version, server_tz_str = tuple(row) + try: + server_tz = pytz.timezone(server_tz_str) + server_tz, dst_safe = tzutil.normalize_timezone(server_tz) + if apply_server_timezone is None: + apply_server_timezone = dst_safe + self.apply_server_timezone = apply_server_timezone == "always" or coerce_bool(apply_server_timezone) + self.server_tz = server_tz + except pytz.exceptions.UnknownTimeZoneError: + logger.warning("Warning, server is using an unrecognized timezone %s, will use UTC default", server_tz_str) + + if not self.apply_server_timezone and not tzutil.local_tz_dst_safe: + logger.warning("local timezone %s may return unexpected times due to Daylight Savings Time", tzutil.local_tz.tzname(None)) + + readonly = "readonly" + if not self.min_version("19.17"): + readonly = common.get_setting("readonly") + + server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} + + if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): + try: + test_data = await self.raw_query( + "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} + ) + if test_data[8:16] == b"\x01\x01\x05check": + self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD + except Exception: + pass + + if self._initial_settings: + for key, value in self._initial_settings.items(): + self.set_client_setting(key, value) + + compress = self._compress_param + if coerce_bool(compress): + compression = ",".join(available_compression) + self.write_compression = available_compression[0] + elif compress and compress not in ("False", "false", "0"): + if compress not in available_compression: + raise ProgrammingError(f"Unsupported compression method {compress}") + compression = compress + self.write_compression = compress + else: + compression = None + + comp_setting = self._setting_status("enable_http_compression") + self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable # pylint: disable=attribute-defined-outside-init + if comp_setting.is_set or comp_setting.is_writable: + self.compression = compression + + session_id = self._session_id_param + autogenerate_session_id = self._autogenerate_session_id_param + + if autogenerate_session_id is None: + autogenerate_session_id = common.get_setting("autogenerate_session_id") + + if session_id: + self.set_client_setting("session_id", session_id) + elif self.get_client_setting("session_id"): + pass + elif autogenerate_session_id: + self.set_client_setting("session_id", str(uuid.uuid4())) + + send_setting = self._setting_status("send_progress_in_http_headers") + self._send_progress = not send_setting.is_set and send_setting.is_writable + if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: + self._progress_interval = str(min(120000, max(10000, (self._send_receive_timeout - 5) * 1000))) + + if self._setting_status("date_time_input_format").is_writable: + self.set_client_setting("date_time_input_format", "best_effort") + if ( + self._setting_status("allow_experimental_json_type").is_set + and self._setting_status("cast_string_to_dynamic_use_inference").is_writable + ): + self.set_client_setting("cast_string_to_dynamic_use_inference", "1") + if self.min_version("24.8") and not self.min_version("24.10"): + dynamic_module.json_serialization_format = 0 + + self._initialized = True + except Exception: + if self._session and not self._session.closed: + await self._session.close() + self._session = None + raise + + async def __aenter__(self): + """Async context manager entry.""" + if not self._initialized: + await self._initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False + + async def close(self): # type: ignore[override] + if self._session: + await self._session.close() + + async def close_connections(self): # type: ignore[override] + """Close all pooled connections and recreate session""" + if self._session: + await self._session.close() + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + + def set_client_setting(self, key, value): + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is not None: + self._client_settings[key] = str_value + + def get_client_setting(self, key) -> Optional[str]: + return self._client_settings.get(key) + + def set_access_token(self, access_token: str): + auth_header = self.headers.get("Authorization") + if auth_header and not auth_header.startswith("Bearer"): + raise ProgrammingError("Cannot set access token when a different auth type is used") + self.headers["Authorization"] = f"Bearer {access_token}" + if self._session: + self._session.headers["Authorization"] = f"Bearer {access_token}" + + def _prep_query(self, context: QueryContext): + final_query = super()._prep_query(context) + if context.is_insert: + return final_query + fmt = f"\n FORMAT {self._read_format}" + if isinstance(final_query, bytes): + return final_query + fmt.encode() + return final_query + fmt + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + headers = {} + params = {} + if self.database: + params["database"] = self.database + if self.protocol_version: + params["client_protocol_version"] = self.protocol_version + context.block_info = True + params.update(self._validate_settings(context.settings)) + context.rename_response_column = self._rename_response_column + + if not context.is_insert and columns_only_re.search(context.uncommented_query): + fmt_json_query = f"{context.final_query}\n FORMAT JSON" + fields = {"query": fmt_json_query} + fields.update(context.bind_params) + + if self.form_encode_query_params: + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + response = await self._raw_request(None, params, headers, files=files, retries=self.query_retries) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = fmt_json_query + response = await self._raw_request(None, params, headers, files=context.external_data.form_data, retries=self.query_retries) + else: + params.update(context.bind_params) + response = await self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) + + body = await response.read() + encoding = response.headers.get("Content-Encoding") + loop = asyncio.get_running_loop() + + def decompress_and_parse_json(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + return json.loads(decompressed_body) + + # Offload to executor + json_result = await loop.run_in_executor(None, decompress_and_parse_json) + + names: List[str] = [] + types: List[ClickHouseType] = [] + renamer = context.column_renamer + for col in json_result["meta"]: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as e: + logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + + if self.compression: + headers["Accept-Encoding"] = self.compression + if self._send_comp_setting: + params["enable_http_compression"] = "1" + + final_query = self._prep_query(context) + + files = None + data = None + + if self.form_encode_query_params: + fields = {"query": final_query} + fields.update(context.bind_params) + + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = final_query + files = context.external_data.form_data + else: + params.update(context.bind_params) + data = final_query + headers["Content-Type"] = "text/plain; charset=utf-8" + + headers = dict_copy(headers, context.transport_settings) + + response = await self._raw_request(data, params, headers, files=files, + server_wait=not context.streaming, + stream=True, retries=self.query_retries) + encoding = response.headers.get("Content-Encoding") + tz_header = response.headers.get("X-ClickHouse-Timezone") + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding) + await streaming_source.start_producer(loop) + + def parse_streaming(): + """Parse response from streaming queue (runs in executor).""" + # Wrap streaming source with ResponseBuffer. The streaming source provides a + # .gen property that yields decompressed chunks. + byte_source = RespBuffCls(streaming_source) + context.set_response_tz(self._check_tz_change(tz_header)) + result = self._transform.parse_response(byte_source, context) + + # CRITICAL: For non-streaming queries, force full materialization while still in executor thread. + # This prevents the event loop from ever calling blocking queue.sync_q.get() operations + # which would deadlock the entire event loop when backpressure occurs + if not context.streaming: + if context.as_pandas and hasattr(result, 'df_result'): + _ = result.df_result + elif context.use_numpy and hasattr(result, 'np_result'): + _ = result.np_result + elif hasattr(result, 'result_set'): + # Materialize rows (closes the stream) + # Avoid pre-populating result_columns. User can access later if needed + _ = result.result_set + + return result + + # Run parser in executor (pulls from queue, decompresses & parses) + query_result = await loop.run_in_executor(None, parse_streaming) + query_result.summary = self._summary(response) + + # Attach streaming_source to query_result.source to ensure it gets closed + # when the query result is closed (e.g. by StreamContext.__exit__) + query_result.source = streaming_source + + return query_result + + + # pylint: disable=arguments-differ + async def query( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + column_oriented: Optional[bool] = None, + use_numpy: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QueryResult: + """ + Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For + parameters, see the create_query_context method + :return: QueryResult -- data and metadata from response + """ + if query and query.lower().strip().startswith("select __connect_version__"): + return QueryResult( + [[f"ClickHouse Connect v.{common.version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) + ) + if not context: + context = self.create_query_context( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, + transport_settings=transport_settings, + ) + + if context.is_command: + response = await self.command( + query, + parameters=context.parameters, + settings=context.settings, + external_data=context.external_data, + transport_settings=context.transport_settings, + ) + if isinstance(response, QuerySummary): + return response.as_query_result() + return QueryResult([response] if isinstance(response, list) else [[response]]) + + return await self._query_with_context(context) + + async def query_column_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + """ + Async version of query_column_block_stream. + Returns a StreamContext that yields column-oriented blocks. + """ + return (await self._context_query(locals(), use_numpy=False, streaming=True)).column_block_stream + + async def query_row_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + """ + Async version of query_row_block_stream. + Returns a StreamContext that yields row-oriented blocks. + """ + return (await self._context_query(locals(), use_numpy=False, streaming=True)).row_block_stream + + async def query_rows_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + """ + Async version of query_rows_stream. + Returns a StreamContext that yields individual rows. + """ + return (await self._context_query(locals(), use_numpy=False, streaming=True)).rows_stream + + # pylint: disable=unused-argument + async def query_np( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True)).np_result + + async def query_np_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True, streaming=True)).np_stream + + async def query_df( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True)).df_result + + async def query_df_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True)).df_stream + + async def _context_query(self, lcls: dict, **overrides): # type: ignore[override] + """ + Helper method to create query context and execute query. + Matches sync client pattern for consistency. + """ + kwargs = lcls.copy() + kwargs.pop("self") + kwargs.update(overrides) + return await self._query_with_context(self.create_query_context(**kwargs)) + + async def command( # type: ignore[override] + self, + cmd, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + data: Optional[Union[str, bytes]] = None, + settings: Optional[Dict] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> Union[str, int, Sequence[str], QuerySummary]: + """ + See BaseClient doc_string for this method + """ + cmd, bind_params = bind_query(cmd, parameters, self.server_tz) + params = bind_params.copy() + headers = {} + payload = None + files = None + + if external_data: + if data: + raise ProgrammingError("Cannot combine command data with external data") from None + files = external_data.form_data + params.update(external_data.query_params) + elif isinstance(data, str): + headers["Content-Type"] = "text/plain; charset=utf-8" + payload = data.encode() + elif isinstance(data, bytes): + headers["Content-Type"] = "application/octet-stream" + payload = data + + if payload is None and not cmd: + raise ProgrammingError("Command sent without query or recognized data") from None + + if payload or files: + params["query"] = cmd + else: + payload = cmd + + if use_database and self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + method = "POST" if payload or files else "GET" + response = await self._raw_request(payload, params, headers, files=files, method=method, server_wait=False) + body = await response.read() + encoding = response.headers.get("Content-Encoding") + summary = self._summary(response) + + if not body: + return QuerySummary(summary) + + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + try: + result = decompressed_body.decode()[:-1].split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + except UnicodeDecodeError: + return str(decompressed_body) + + return await loop.run_in_executor(None, decompress_and_decode) + + async def ping(self) -> bool: # type: ignore[override] + try: + url = f"{self.url}/ping" + timeout = aiohttp.ClientTimeout(total=3.0) + async with self._session.get(url, timeout=timeout) as response: + return 200 <= response.status < 300 + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.debug("ping failed", exc_info=True) + return False + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> bytes: + """ + See BaseClient doc_string for this method + """ + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(body, params, headers=headers, files=files, retries=self.query_retries) + response_data = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + response_data = await loop.run_in_executor(None, decompress_response, response_data, encoding) + + return response_data + + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) + + async def byte_iterator(): + async for chunk in response.content.iter_any(): + yield chunk + + return StreamContext(response, byte_iterator()) + + def _prep_raw_query(self, query, parameters, settings, fmt, use_database, external_data): + """ + Prepare raw query for execution. + + Note: Unlike sync client which returns (body, params, fields), this async version + returns (body, params, headers, files) because aiohttp requires headers to be + configured before the request() call, while urllib3 can add them during request. + """ + if fmt: + query += f"\n FORMAT {fmt}" + + final_query, bind_params = bind_query(query, parameters, self.server_tz) + params = self._validate_settings(settings or {}) + if use_database and self.database: + params["database"] = self.database + + headers = {} + files = None + body = None + + if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): + raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") + + if self.form_encode_query_params: + files = {} + files["query"] = (None, final_query if isinstance(final_query, str) else final_query.decode()) + for k, v in bind_params.items(): + files[k] = (None, str(v)) + + if external_data: + params.update(external_data.query_params) + files.update(external_data.form_data) + + body = None + elif external_data: + params.update(bind_params) + params["query"] = final_query + params.update(external_data.query_params) + files = external_data.form_data + body = None + else: + params.update(bind_params) + body = final_query.encode() if isinstance(final_query, str) else final_query + + return body, params, headers, files + + async def insert( # type: ignore[override] + self, + table: Optional[str] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + column_names: Union[str, Iterable[str]] = "*", + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments + other than data are ignored + :param table: Target table + :param data: Sequence of sequences of Python data + :param column_names: Ordered list of column names or '*' if column types should be retrieved from the + ClickHouse table definition + :param database: Target database -- will use client default database if not specified. + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param column_oriented: If true the data is already "pivoted" in column form + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + if (context is None or context.empty) and data is None: + raise ProgrammingError("No data specified for insert") from None + if context is None: + context = await self.create_insert_context( + table, + column_names, + database, + column_types, + column_type_names, + column_oriented, + settings, + transport_settings=transport_settings, + ) + if data is not None: + if not context.empty: + raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None + context.data = data + return await self.data_insert(context) + + async def query_arrow( + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + """ + Query method using the ClickHouse Arrow format to return a PyArrow table + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: PyArrow.Table + """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + return to_arrow( + await self.raw_query( + query, + parameters, + settings, + fmt="Arrow", + external_data=external_data, + transport_settings=transport_settings, + ) + ) + + async def query_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + """ + Query method that returns the results as a stream of Arrow record batches. + + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: StreamContext that yields PyArrow RecordBatch objects asynchronously + """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding) + await streaming_source.start_producer(loop) + + def parse_arrow_streaming(): + """Parse Arrow stream incrementally in executor (off event loop).""" + # Wrap streaming source with file-like adapter for PyArrow + file_adapter = StreamingFileAdapter(streaming_source) + reader = arrow.ipc.open_stream(file_adapter) + + batches = [] + for batch in reader: + batches.append(batch) + + return batches + + batches = await loop.run_in_executor(None, parse_arrow_streaming) + + async def arrow_batch_generator(): + """Async generator that yields record batches without blocking event loop.""" + for batch in batches: + yield batch + + return StreamContext(None, arrow_batch_generator()) + + async def query_df_arrow( + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> Union["pd.DataFrame", "pl.DataFrame"]: + """ + Query method using the ClickHouse Arrow format to return a DataFrame + with PyArrow dtype backend. This provides better performance and memory efficiency + compared to the standard query_df method, though fewer output formatting options. + + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") + :return: DataFrame (pandas or polars based on dataframe_library parameter) + """ + check_arrow() + + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + def converter(table: "arrow.Table") -> "pd.DataFrame": + return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) + + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") + + def converter(table: "arrow.Table") -> "pl.DataFrame": + return pl.from_arrow(table) + + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + + arrow_table = await self.query_arrow( + query=query, + parameters=parameters, + settings=settings, + use_strings=use_strings, + external_data=external_data, + transport_settings=transport_settings, + ) + + return converter(arrow_table) + + async def query_df_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> StreamContext: + """ + Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. + Each DataFrame represents a record batch from the ClickHouse response. + + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") + :return: StreamContext that yields DataFrames asynchronously (pandas or polars based on dataframe_library parameter) + """ + check_arrow() + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + def converter(table: "arrow.Table") -> "pd.DataFrame": + return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) + + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") + + def converter(table: "arrow.Table") -> "pl.DataFrame": + return pl.from_arrow(table) + + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding) + await streaming_source.start_producer(loop) + + def parse_and_convert_streaming(): + """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" + file_adapter = StreamingFileAdapter(streaming_source) + + # PyArrow reads incrementally from adapter (which pulls from queue) + reader = arrow.ipc.open_stream(file_adapter) + + dataframes = [] + for batch in reader: + dataframes.append(converter(batch)) + + return dataframes + + dataframes = await loop.run_in_executor(None, parse_and_convert_streaming) + + async def df_generator(): + """Async generator that yields DataFrames without blocking event loop.""" + for df in dataframes: + yield df + + return StreamContext(None, df_generator()) + + async def insert_arrow( # type: ignore[override] + self, + table: str, + arrow_table, + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format + :param table: ClickHouse table + :param arrow_table: PyArrow Table object + :param database: Optional ClickHouse database + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + """ + check_arrow() + self._add_integration_tag("arrow") + full_table = table if "." in table or not database else f"{database}.{table}" + compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None + column_names, insert_block = arrow_buffer(arrow_table, compression) + if hasattr(insert_block, "to_pybytes"): + insert_block = insert_block.to_pybytes() + return await self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) + + async def insert_df_arrow( # type: ignore[override] + self, + table: str, + df: Union["pd.DataFrame", "pl.DataFrame"], + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. + This method is optimized for DataFrames that already use Arrow format, providing + better performance than the standard insert_df method. + + Validation is performed and an exception will be raised if this requirement is not met. + Polars DataFrames are natively Arrow-based and don't require additional validation. + + :param table: ClickHouse table name + :param df: Pandas DataFrame with PyArrow dtype backend or Polars DataFrame + :param database: Optional ClickHouse database name + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + check_arrow() + + if pd is not None and isinstance(df, pd.DataFrame): + df_lib = "pandas" + elif pl is not None and isinstance(df, pl.DataFrame): + df_lib = "polars" + else: + if pd is None and pl is None: + raise ImportError("A DataFrame library (pandas or polars) must be installed to use insert_df_arrow.") + raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") + + if df_lib == "pandas": + if not IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, pd.ArrowDtype)] + if non_arrow_cols: + raise ProgrammingError( + f"insert_df_arrow requires all columns to use PyArrow dtypes. Non-Arrow columns found: [{', '.join(non_arrow_cols)}]. " + ) + try: + arrow_table = arrow.Table.from_pandas(df, preserve_index=False) + except Exception as e: + raise DataError(f"Failed to convert pandas DataFrame to Arrow table: {e}") from e + else: + try: + arrow_table = df.to_arrow() + except Exception as e: + raise DataError(f"Failed to convert polars DataFrame to Arrow table: {e}") from e + + self._add_integration_tag(df_lib) + return await self.insert_arrow( + table=table, + arrow_table=arrow_table, + database=database, + settings=settings, + transport_settings=transport_settings, + ) + + async def create_insert_context( # type: ignore[override] + self, + table: str, + column_names: Optional[Union[str, Sequence[str]]] = None, + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> InsertContext: + """ + Builds a reusable insert context to hold state for a duration of an insert + :param table: Target table + :param database: Target database. If not set, uses the client default database + :param column_names: Optional ordered list of column names. If not set, all columns ('*') will be assumed + in the order specified by the table definition + :param database: Target database -- will use client default database if not specified + :param column_types: ClickHouse column types. Optional Sequence of ClickHouseType objects. If neither column + types nor column type names are set, actual column types will be retrieved from the server. + :param column_type_names: ClickHouse column type names. Specified column types by name string + :param column_oriented: If true the data is already "pivoted" in column form + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param data: Initial dataset for insert + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: Reusable insert context + """ + full_table = table + if "." not in table: + if database: + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" + else: + full_table = quote_identifier(table) + column_defs = [] + if column_types is None and column_type_names is None: + describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] + if column_names is None or isinstance(column_names, str) and column_names == "*": + column_names = [cd.name for cd in column_defs] + column_types = [cd.ch_type for cd in column_defs] + elif isinstance(column_names, str): + column_names = [column_names] + if len(column_names) == 0: + raise ValueError("Column names must be specified for insert") + if not column_types: + if column_type_names: + column_types = [get_from_name(name) for name in column_type_names] + else: + column_map = {d.name: d for d in column_defs} + try: + column_types = [column_map[name].ch_type for name in column_names] + except KeyError as ex: + raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None + if len(column_names) != len(column_types): + raise ProgrammingError("Column names do not match column types") from None + return InsertContext( + full_table, + column_names, + column_types, + column_oriented=column_oriented, + settings=settings, + transport_settings=transport_settings, + data=data, + ) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] + """ + See BaseClient doc_string for this method. + + Uses true streaming via reverse bridge pattern: + - Sync producer (serializer) runs in executor, puts blocks in queue + - Async consumer (network) pulls from queue and yields to aiohttp + - Bounded queue provides backpressure to prevent memory bloat + """ + if context.empty: + logger.debug("No data included in insert, skipping") + return QuerySummary() + + if context.compression is None: + context.compression = self.write_compression + + loop = asyncio.get_running_loop() + + streaming_source = StreamingInsertSource( + transform=self._transform, context=context, loop=loop, maxsize=10 + ) + + streaming_source.start_producer() + + headers = {"Content-Type": "application/octet-stream"} + if context.compression: + headers["Content-Encoding"] = context.compression + + params = {} + if self.database: + params["database"] = self.database + params.update(self._validate_settings(context.settings)) + headers = dict_copy(headers, context.transport_settings) + + try: + response = await self._raw_request( + streaming_source.async_generator(), params, headers=headers, server_wait=False + ) + logger.debug("Context insert response code: %d", response.status) + except Exception: + await streaming_source.close() + + if context.insert_exception: + ex = context.insert_exception + context.insert_exception = None + raise ex from None + raise + finally: + await streaming_source.close() + + context.data = None + return QuerySummary(self._summary(response)) + + async def insert_df( # type: ignore[override] + self, + table: Optional[str] = None, + df=None, + database: Optional[str] = None, + settings: Optional[Dict] = None, + column_names: Optional[Sequence[str]] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored + :param table: ClickHouse table + :param df: two-dimensional pandas dataframe + :param database: Optional ClickHouse database + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names + will be used + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + check_pandas() + self._add_integration_tag("pandas") + if context is None: + if column_names is None: + column_names = df.columns + elif len(column_names) != len(df.columns): + raise ProgrammingError("DataFrame column count does not match insert_columns") from None + return await self.insert( + table, + df, + column_names, + database, + column_types=column_types, + column_type_names=column_type_names, + settings=settings, + transport_settings=transport_settings, + context=context, + ) + + async def raw_insert( # type: ignore[override] + self, + table: Optional[str] = None, + column_names: Optional[Sequence[str]] = None, + insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, + settings: Optional[Dict] = None, + fmt: Optional[str] = None, + compression: Optional[str] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + See BaseClient doc_string for this method + """ + params = {} + headers = {"Content-Type": "application/octet-stream"} + if compression: + headers["Content-Encoding"] = compression + + if table: + cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" + fmt_str = fmt if fmt else self._write_format + query = f"INSERT INTO {table}{cols} FORMAT {fmt_str}" + if not compression and isinstance(insert_block, str): + insert_block = query + "\n" + insert_block + elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): + insert_block = (query + "\n").encode() + insert_block + else: + params["query"] = query + + if self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(insert_block, params, headers, server_wait=False) + logger.debug("Raw insert response code: %d", response.status) + return QuerySummary(self._summary(response)) + + def _add_integration_tag(self, name: str): + """ + Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. + """ + if not common.get_setting("send_integration_tags") or name in self._reported_libs: + return + + try: + ver = "unknown" + try: + ver = dist_version(name) + except Exception: + try: + mod = import_module(name) + ver = getattr(mod, "__version__", "unknown") + except Exception: + pass + + product_info = f"{name}/{ver}" + + ua = self.headers.get("User-Agent", "") + start = ua.find("(") + if start == -1: + return + end = ua.find(")", start + 1) + if end == -1: + return + + details = ua[start + 1 : end].strip() + + if product_info in details: + self._reported_libs.add(name) + return + + new_details = f"{product_info}; {details}" if details else product_info + new_ua = f"{ua[: start + 1]}{new_details}{ua[end:]}" + self.headers["User-Agent"] = new_ua.strip() + if self._session: + self._session.headers["User-Agent"] = new_ua.strip() + + self._reported_libs.add(name) + logger.debug("Added '%s' to User-Agent", product_info) + + except Exception as e: + logger.debug("Problem adding '%s' to User-Agent: %s", name, e) + + async def _error_handler(self, response: aiohttp.ClientResponse, retried: bool = False): + """ + Handles HTTP errors. Tries to be robust and provide maximum context. + """ + try: + body = "" + try: + raw_body = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + decompressed = decompress_response(raw_body, encoding) + return common.format_error(decompressed.decode(errors="backslashreplace")).strip() + + body = await loop.run_in_executor(None, decompress_and_decode) + else: + loop = asyncio.get_running_loop() + body = await loop.run_in_executor( + None, + lambda: common.format_error(raw_body.decode(errors="backslashreplace")).strip() + ) + except Exception: + logger.warning("Failed to read error response body", exc_info=True) + + if self.show_clickhouse_errors: + err_code = response.headers.get(ex_header) + if err_code: + err_str = f"Received ClickHouse exception, code: {err_code}" + else: + err_str = f"HTTP driver received HTTP status {response.status}" + + if body: + err_str = f"{err_str}, server response: {body}" + else: + err_str = "The ClickHouse server returned an error" + + err_str = f"{err_str} (for url {self.url})" + + finally: + response.close() + + raise OperationalError(err_str) if retried else DatabaseError(err_str) from None + + async def _raw_request( + self, + data, + params, + headers=None, + files=None, + method="POST", + stream=False, + server_wait=True, + retries: int = 0, + ) -> aiohttp.ClientResponse: + if self._session is None: + raise ProgrammingError( + "Session not initialized. Use 'async with get_async_client(...)' or call 'await client._initialize()' first." + ) + + reset_seconds = common.get_setting("max_connection_age") + if reset_seconds: + now = time.time() + if self._last_pool_reset is None: + self._last_pool_reset = now + elif self._last_pool_reset < now - reset_seconds: + logger.debug("connection expiration - resetting connection pool") + await self.close_connections() + self._last_pool_reset = now + + final_params = dict_copy(self._client_settings, params) + if server_wait: + final_params.setdefault("wait_end_of_query", "1") + if self._send_progress: + final_params.setdefault("send_progress_in_http_headers", "1") + if self._progress_interval: + final_params.setdefault("http_headers_progress_interval_ms", self._progress_interval) + if self._autogenerate_query_id and "query_id" not in final_params: + final_params["query_id"] = str(uuid.uuid4()) + + req_headers = dict_copy(self.headers, headers) + if self.server_host_name: + req_headers["Host"] = self.server_host_name + query_session = final_params.get("session_id") + attempts = 0 + + # pylint: disable=too-many-nested-blocks + while True: + attempts += 1 + + if query_session: + if query_session == self._active_session: + raise ProgrammingError( + "Attempt to execute concurrent queries within the same session. " + "Please use a separate client instance per concurrent query." + ) + self._active_session = query_session + + try: + # Construct full URL (aiohttp doesn't have base_url) + url = f"{self.url}/" + request_kwargs = {"method": method, "url": url, "params": final_params, "headers": req_headers} + if hasattr(self, "_proxy_url") and self._proxy_url: + request_kwargs["proxy"] = self._proxy_url + if files: + # IMPORTANT: Must set content_type on text fields to force multipart/form-data encoding + # Without content_type, aiohttp uses application/x-www-form-urlencoded + form = aiohttp.FormData() + for field_name, field_value in files.items(): + if isinstance(field_value, tuple): + if field_value[0] is None: + form.add_field(field_name, str(field_value[1]), content_type='text/plain') + else: + filename = field_value[0] + file_data = field_value[1] + content_type = field_value[2] if len(field_value) > 2 else None + form.add_field(field_name, file_data, filename=filename, content_type=content_type) + else: + form.add_field(field_name, field_value, content_type='text/plain') + request_kwargs["data"] = form + elif isinstance(data, dict): + request_kwargs["data"] = data + else: + request_kwargs["data"] = data + + response = await self._session.request(**request_kwargs) + if 200 <= response.status < 300 and not response.headers.get(ex_header): + return response + + if response.status in (429, 503, 504): + if attempts > retries: + await self._error_handler(response, retried=True) + else: + logger.debug("Retrying request with status code %s (attempt %s/%s)", response.status, attempts, retries + 1) + await asyncio.sleep(0.1 * attempts) + response.close() + continue + await self._error_handler(response) + + except aiohttp.ServerConnectionError as e: + if "Connection reset" in str(e) or "Remote end closed" in str(e) or "Cannot connect" in str(e): + if attempts == 1: + logger.debug("Retrying after connection error from remote host") + continue + raise OperationalError(f"Network Error: {str(e)}") from e + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise OperationalError(f"Network Error: {str(e)}") from e + + finally: + if query_session: + self._active_session = None + + @staticmethod + def _summary(response: aiohttp.ClientResponse): + summary = {} + if "X-ClickHouse-Summary" in response.headers: + try: + summary = json.loads(response.headers["X-ClickHouse-Summary"]) + except json.JSONDecodeError: + pass + summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") + return summary diff --git a/clickhouse_connect/driver/asyncclient.py b/clickhouse_connect/driver/asyncclient.py index 41244bbb..fff003e8 100644 --- a/clickhouse_connect/driver/asyncclient.py +++ b/clickhouse_connect/driver/asyncclient.py @@ -2,10 +2,19 @@ import io import logging import os +import warnings + from concurrent.futures.thread import ThreadPoolExecutor from datetime import tzinfo from typing import Optional, Union, Dict, Any, Sequence, Iterable, Generator, BinaryIO +try: + from clickhouse_connect.driver.aiohttp_client import AiohttpAsyncClient + AIOHTTP_AVAILABLE = True +except ImportError: + AiohttpAsyncClient = None + AIOHTTP_AVAILABLE = False + from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.common import StreamContext from clickhouse_connect.driver.httpclient import HttpClient @@ -29,29 +38,53 @@ class DefaultThreadPoolExecutor: # pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,too-many-locals class AsyncClient: """ - AsyncClient is a wrapper around the ClickHouse Client object that allows for async calls to the ClickHouse server. - Internally, each of the methods that uses IO is wrapped in a call to EventLoop.run_in_executor. + Unified async client with backward compatibility. + + This class maintains backward compatibility with the legacy executor-based async client + while also supporting direct instantiation for native async operations (though + get_async_client() is the recommended approach for new code). """ def __init__(self, *, - client: Client, + client: Optional[Client] = None, executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR): - if isinstance(client, HttpClient): - client.headers['User-Agent'] = client.headers['User-Agent'].replace('mode:sync;', 'mode:async;') - self.client = client - if executor_threads == 0: - executor_threads = min(32, (os.cpu_count() or 1) + 4) # Mimic the default behavior - if executor is NEW_THREAD_POOL_EXECUTOR: - self.new_executor = True - self.executor = ThreadPoolExecutor(max_workers=executor_threads) + executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, + **kwargs): + """ + Create async client. + + Args: + client: (LEGACY - DEPRECATED) Sync client to wrap with ThreadPoolExecutor + executor_threads: (LEGACY - DEPRECATED) Thread pool size for legacy mode + executor: (LEGACY - DEPRECATED) Custom ThreadPoolExecutor for legacy mode + **kwargs: Arguments passed to AiohttpAsyncClient (native mode) + """ + if client is not None: + # LEGACY PATH: User passed sync client. use executor-based wrapper + warnings.warn( + "Passing 'client=' to AsyncClient is deprecated. " + "Use create_async_client(host=..., port=...) instead. " + "Legacy executor-based mode may be removed in the future.", + DeprecationWarning, + stacklevel=2 + ) + self._impl = _LegacyAsyncWrapper(client, executor_threads, executor) else: - if executor_threads != 0: - logger.warning('executor_threads parameter is ignored when passing an executor object') - - self.new_executor = False - self.executor = executor + # NATIVE PATH: Create aiohttp client + if not AIOHTTP_AVAILABLE: + raise ImportError( + "Native async support requires aiohttp. " + "Install with: pip install clickhouse-connect[async]\n" + "Alternatively, use the legacy executor-based async by passing a sync client to AsyncClient." + ) + self._impl = AiohttpAsyncClient(**kwargs) + + # Proxy all methods to implementation + # pylint: disable=protected-access + async def _initialize(self): + if hasattr(self._impl, '_initialize'): + await self._impl._initialize() def set_client_setting(self, key, value): """ @@ -61,21 +94,21 @@ def set_client_setting(self, key, value): :param key: ClickHouse setting name :param value: ClickHouse setting value """ - self.client.set_client_setting(key=key, value=value) + self._impl.set_client_setting(key=key, value=value) def get_client_setting(self, key) -> Optional[str]: """ :param key: The setting key :return: The string value of the setting, if it exists, or None """ - return self.client.get_client_setting(key=key) + return self._impl.get_client_setting(key=key) def set_access_token(self, access_token: str): """ Set the ClickHouse access token for the client :param access_token: Access token string """ - self.client.set_access_token(access_token) + return self._impl.set_access_token(access_token) def min_version(self, version_str: str) -> bool: """ @@ -85,16 +118,13 @@ def min_version(self, version_str: str) -> bool: :param version_str: A version string consisting of up to 4 integers delimited by dots :return: True if version_str is greater than the server_version, False if less than """ - return self.client.min_version(version_str) + return self._impl.min_version(version_str) async def close(self): """ Subclass implementation to close the connection to the server/deallocate the client """ - self.client.close() - - if self.new_executor: - await asyncio.to_thread(self.executor.shutdown, True) + return await self._impl.close() async def query(self, query: Optional[str] = None, @@ -107,7 +137,7 @@ async def query(self, column_oriented: Optional[bool] = None, use_numpy: Optional[bool] = None, max_str_len: Optional[int] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, query_tz: Optional[Union[str, tzinfo]] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, @@ -118,18 +148,12 @@ async def query(self, For parameters, see the create_query_context method. :return: QueryResult -- data and metadata from response """ - - def _query(): - return self.client.query(query=query, parameters=parameters, settings=settings, query_formats=query_formats, - column_formats=column_formats, encoding=encoding, use_none=use_none, - column_oriented=column_oriented, use_numpy=use_numpy, max_str_len=max_str_len, - context=context, query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query) - return result + return await self._impl.query(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, column_oriented=column_oriented, + use_numpy=use_numpy, max_str_len=max_str_len, context=context, + query_tz=query_tz, column_tzs=column_tzs, utc_tz_aware=utc_tz_aware, + external_data=external_data, transport_settings=transport_settings) async def query_column_block_stream(self, query: Optional[str] = None, @@ -139,7 +163,7 @@ async def query_column_block_stream(self, column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, encoding: Optional[str] = None, use_none: Optional[bool] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, query_tz: Optional[Union[str, tzinfo]] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, @@ -151,18 +175,12 @@ async def query_column_block_stream(self, For parameters, see the create_query_context method. :return: StreamContext -- Iterable stream context that returns column oriented blocks """ - - def _query_column_block_stream(): - return self.client.query_column_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_column_block_stream) - return result + return await self._impl.query_column_block_stream(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, context=context, + query_tz=query_tz, column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, transport_settings=transport_settings) async def query_row_block_stream(self, query: Optional[str] = None, @@ -172,7 +190,7 @@ async def query_row_block_stream(self, column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, encoding: Optional[str] = None, use_none: Optional[bool] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, query_tz: Optional[Union[str, tzinfo]] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, @@ -183,18 +201,12 @@ async def query_row_block_stream(self, For parameters, see the create_query_context method. :return: StreamContext -- Iterable stream context that returns blocks of rows """ - - def _query_row_block_stream(): - return self.client.query_row_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_row_block_stream) - return result + return await self._impl.query_row_block_stream(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, context=context, + query_tz=query_tz, column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, transport_settings=transport_settings) async def query_rows_stream(self, query: Optional[str] = None, @@ -204,7 +216,7 @@ async def query_rows_stream(self, column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, encoding: Optional[str] = None, use_none: Optional[bool] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, query_tz: Optional[Union[str, tzinfo]] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, @@ -215,24 +227,18 @@ async def query_rows_stream(self, For parameters, see the create_query_context method. :return: StreamContext -- Iterable stream context that returns blocks of rows """ - - def _query_rows_stream(): - return self.client.query_rows_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_rows_stream) - return result + return await self._impl.query_rows_stream(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, context=context, + query_tz=query_tz, column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, transport_settings=transport_settings) async def raw_query(self, query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, settings: Optional[Dict[str, Any]] = None, - fmt: str = None, + fmt: Optional[str] = None, use_database: bool = True, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None) -> bytes: @@ -248,20 +254,15 @@ async def raw_query(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: bytes representing raw ClickHouse return value based on format """ + return await self._impl.raw_query(query=query, parameters=parameters, settings=settings, fmt=fmt, + use_database=use_database, external_data=external_data, + transport_settings=transport_settings) - def _raw_query(): - return self.client.raw_query(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_query) - return result - - async def raw_stream(self, query: str, + async def raw_stream(self, + query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, settings: Optional[Dict[str, Any]] = None, - fmt: str = None, + fmt: Optional[str] = None, use_database: bool = True, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: @@ -277,14 +278,9 @@ async def raw_stream(self, query: str, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: io.IOBase stream/iterator for the result """ - - def _raw_stream(): - return self.client.raw_stream(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_stream) - return result + return await self._impl.raw_stream(query=query, parameters=parameters, settings=settings, fmt=fmt, + use_database=use_database, external_data=external_data, + transport_settings=transport_settings) async def query_np(self, query: Optional[str] = None, @@ -295,7 +291,7 @@ async def query_np(self, encoding: Optional[str] = None, use_none: Optional[bool] = None, max_str_len: Optional[int] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None): """ @@ -303,16 +299,11 @@ async def query_np(self, For parameter values, see the create_query_context method. :return: Numpy array representing the result set """ - - def _query_np(): - return self.client.query_np(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, context=context, - external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np) - return result + return await self._impl.query_np(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, max_str_len=max_str_len, + context=context, external_data=external_data, + transport_settings=transport_settings) async def query_np_stream(self, query: Optional[str] = None, @@ -323,24 +314,19 @@ async def query_np_stream(self, encoding: Optional[str] = None, use_none: Optional[bool] = None, max_str_len: Optional[int] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: """ Query method that returns the results as a stream of numpy arrays. For parameter values, see the create_query_context method. - :return: Generator that yield a numpy array per block representing the result set + :return: Numpy array representing the result set """ - - def _query_np_stream(): - return self.client.query_np_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - context=context, external_data=external_data, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np_stream) - return result + return await self._impl.query_np_stream(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, max_str_len=max_str_len, + context=context, external_data=external_data, + transport_settings=transport_settings) async def query_df(self, query: Optional[str] = None, @@ -355,7 +341,7 @@ async def query_df(self, query_tz: Optional[str] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, external_data: Optional[ExternalData] = None, use_extended_dtypes: Optional[bool] = None, transport_settings: Optional[Dict[str, str]] = None): @@ -364,29 +350,22 @@ async def query_df(self, For parameter values, see the create_query_context method. :return: Pandas dataframe representing the result set """ - - def _query_df(): - return self.client.query_df(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, utc_tz_aware=utc_tz_aware, - context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df) - return result + return await self._impl.query_df(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, max_str_len=max_str_len, + use_na_values=use_na_values, query_tz=query_tz, column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, context=context, external_data=external_data, + use_extended_dtypes=use_extended_dtypes, transport_settings=transport_settings) async def query_df_arrow( - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas", + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", ) -> Union["pd.DataFrame", "pl.DataFrame"]: """ Query method using the ClickHouse Arrow format to return a DataFrame @@ -402,21 +381,10 @@ async def query_df_arrow( :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") :return: DataFrame (pandas or polars based on dataframe_library parameter) """ - - def _query_df_arrow(): - return self.client.query_df_arrow( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow) - return result + return await self._impl.query_df_arrow(query=query, parameters=parameters, settings=settings, + use_strings=use_strings, external_data=external_data, + transport_settings=transport_settings, + dataframe_library=dataframe_library) async def query_df_stream(self, query: Optional[str] = None, @@ -431,7 +399,7 @@ async def query_df_stream(self, query_tz: Optional[str] = None, column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, utc_tz_aware: Optional[bool] = None, - context: QueryContext = None, + context: Optional[QueryContext] = None, external_data: Optional[ExternalData] = None, use_extended_dtypes: Optional[bool] = None, transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: @@ -440,31 +408,21 @@ async def query_df_stream(self, For parameter values, see the create_query_context method. :return: Generator that yields a Pandas dataframe per block representing the result set """ - - def _query_df_stream(): - return self.client.query_df_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_stream) - return result - - async def query_df_arrow_stream( - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas" - ) -> StreamContext: + return await self._impl.query_df_stream(query=query, parameters=parameters, settings=settings, + query_formats=query_formats, column_formats=column_formats, + encoding=encoding, use_none=use_none, max_str_len=max_str_len, + use_na_values=use_na_values, query_tz=query_tz, column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, context=context, external_data=external_data, + use_extended_dtypes=use_extended_dtypes, transport_settings=transport_settings) + + async def query_df_arrow_stream(self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = 'pandas') -> StreamContext: """ Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. Each DataFrame represents a block from the ClickHouse response. @@ -478,42 +436,12 @@ async def query_df_arrow_stream( :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") :return: StreamContext that yields DataFrames (pandas or polars based on dataframe_library parameter) """ + return await self._impl.query_df_arrow_stream(query=query, parameters=parameters, settings=settings, + use_strings=use_strings, external_data=external_data, + transport_settings=transport_settings, + dataframe_library=dataframe_library) - def _query_df_arrow_stream(): - return self.client.query_df_arrow_stream( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow_stream) - return result - - def create_query_context(self, - query: Optional[Union[str, bytes]] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = False, - max_str_len: Optional[int] = 0, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - use_na_values: Optional[bool] = None, - streaming: bool = False, - as_pandas: bool = False, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QueryContext: + def create_query_context(self, *args, **kwargs) -> QueryContext: """ Creates or updates a reusable QueryContext object :param query: Query statement/format string @@ -546,18 +474,7 @@ def create_query_context(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable QueryContext """ - - return self.client.create_query_context(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, max_str_len=max_str_len, context=context, - query_tz=query_tz, column_tzs=column_tzs, - use_na_values=use_na_values, - streaming=streaming, as_pandas=as_pandas, - external_data=external_data, - use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + return self._impl.create_query_context(*args, **kwargs) async def query_arrow(self, query: str, @@ -576,15 +493,9 @@ async def query_arrow(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: PyArrow.Table """ - - def _query_arrow(): - return self.client.query_arrow(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow) - return result + return await self._impl.query_arrow(query=query, parameters=parameters, settings=settings, + use_strings=use_strings, external_data=external_data, + transport_settings=transport_settings) async def query_arrow_stream(self, query: str, @@ -603,21 +514,15 @@ async def query_arrow_stream(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Generator that yields a PyArrow.Table for per block representing the result set """ - - def _query_arrow_stream(): - return self.client.query_arrow_stream(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow_stream) - return result + return await self._impl.query_arrow_stream(query=query, parameters=parameters, settings=settings, + use_strings=use_strings, external_data=external_data, + transport_settings=transport_settings) async def command(self, - cmd: str, + cmd, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Union[str, bytes] = None, - settings: Optional[Dict[str, Any]] = None, + data: Optional[Union[str, bytes]] = None, + settings: Optional[Dict] = None, use_database: bool = True, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: @@ -635,39 +540,27 @@ async def command(self, :return: Decoded response from ClickHouse as either a string, int, or sequence of strings, or QuerySummary if no data returned """ - - def _command(): - return self.client.command(cmd=cmd, parameters=parameters, data=data, settings=settings, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _command) - return result + return await self._impl.command(cmd=cmd, parameters=parameters, data=data, settings=settings, + use_database=use_database, external_data=external_data, + transport_settings=transport_settings) async def ping(self) -> bool: """ Validate the connection, does not throw an Exception (see debug logs) :return: ClickHouse server is up and reachable """ - - def _ping(): - return self.client.ping() - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _ping) - return result + return await self._impl.ping() async def insert(self, table: Optional[str] = None, - data: Sequence[Sequence[Any]] = None, + data: Optional[Sequence[Sequence[Any]]] = None, column_names: Union[str, Iterable[str]] = '*', database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, column_oriented: bool = False, settings: Optional[Dict[str, Any]] = None, - context: InsertContext = None, + context: Optional[InsertContext] = None, transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: """ Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments @@ -688,25 +581,20 @@ async def insert(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ + return await self._impl.insert(table=table, data=data, column_names=column_names, database=database, + column_types=column_types, column_type_names=column_type_names, + column_oriented=column_oriented, settings=settings, context=context, + transport_settings=transport_settings) - def _insert(): - return self.client.insert(table=table, data=data, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, context=context, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert) - return result - - async def insert_df(self, table: str = None, - df=None, + async def insert_df(self, + table: Optional[str] = None, + df = None, database: Optional[str] = None, settings: Optional[Dict] = None, column_names: Optional[Sequence[str]] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - context: InsertContext = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + context: Optional[InsertContext] = None, transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: """ Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored @@ -725,19 +613,15 @@ async def insert_df(self, table: str = None, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ + return await self._impl.insert_df(table=table, df=df, database=database, settings=settings, + column_names=column_names, column_types=column_types, + column_type_names=column_type_names, context=context, + transport_settings=transport_settings) - def _insert_df(): - return self.client.insert_df(table=table, df=df, database=database, settings=settings, - column_names=column_names, - column_types=column_types, column_type_names=column_type_names, - context=context, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df) - return result - - async def insert_arrow(self, table: str, - arrow_table, database: str = None, + async def insert_arrow(self, + table: str, + arrow_table, + database: Optional[str] = None, settings: Optional[Dict] = None, transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: """ @@ -749,31 +633,23 @@ async def insert_arrow(self, table: str, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ + return await self._impl.insert_arrow(table=table, arrow_table=arrow_table, database=database, + settings=settings, transport_settings=transport_settings) - def _insert_arrow(): - return self.client.insert_arrow(table=table, arrow_table=arrow_table, database=database, - settings=settings, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_arrow) - return result - - async def insert_df_arrow( - self, - table: str, - df: Union["pd.DataFrame", "pl.DataFrame"], - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: + async def insert_df_arrow(self, + table: str, + df, + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: """ Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. This method is optimized for DataFrames that already use Arrow format, providing better performance than the standard insert_df method. - + Validation is performed and an exception will be raised if this requirement is not met. Polars DataFrames are natively Arrow-based and don't require additional validation. - + :param table: ClickHouse table name :param df: Pandas DataFrame with PyArrow dtype backend or Polars DataFrame :param database: Optional ClickHouse database name @@ -781,26 +657,15 @@ async def insert_df_arrow( :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ - - def _insert_df_arrow(): - return self.client.insert_df_arrow( - table=table, - df=df, - database=database, - settings=settings, - transport_settings=transport_settings, - ) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df_arrow) - return result + return await self._impl.insert_df_arrow(table=table, df=df, database=database, settings=settings, + transport_settings=transport_settings) async def create_insert_context(self, table: str, column_names: Optional[Union[str, Sequence[str]]] = None, database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, column_oriented: bool = False, settings: Optional[Dict[str, Any]] = None, data: Optional[Sequence[Sequence[Any]]] = None, @@ -821,16 +686,10 @@ async def create_insert_context(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable insert context """ - - def _create_insert_context(): - return self.client.create_insert_context(table=table, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, data=data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _create_insert_context) - return result + return await self._impl.create_insert_context(table=table, column_names=column_names, database=database, + column_types=column_types, column_type_names=column_type_names, + column_oriented=column_oriented, settings=settings, data=data, + transport_settings=transport_settings) async def data_insert(self, context: InsertContext) -> QuerySummary: """ @@ -838,17 +697,12 @@ async def data_insert(self, context: InsertContext) -> QuerySummary: :context: InsertContext parameter object :return: No return, throws an exception if the insert fails """ + return await self._impl.data_insert(context) - def _data_insert(): - return self.client.data_insert(context=context) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _data_insert) - return result - - async def raw_insert(self, table: str, + async def raw_insert(self, + table: Optional[str] = None, column_names: Optional[Sequence[str]] = None, - insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, + insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, settings: Optional[Dict] = None, fmt: Optional[str] = None, compression: Optional[str] = None, @@ -863,18 +717,173 @@ async def raw_insert(self, table: str, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :param fmt: Valid clickhouse format """ - - def _raw_insert(): - return self.client.raw_insert(table=table, column_names=column_names, insert_block=insert_block, - settings=settings, fmt=fmt, compression=compression, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_insert) - return result + return await self._impl.raw_insert(table=table, column_names=column_names, insert_block=insert_block, + settings=settings, fmt=fmt, compression=compression, + transport_settings=transport_settings) async def __aenter__(self) -> "AsyncClient": + if hasattr(self._impl, '_initialize'): + await self._impl._initialize() return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() + + def __getattr__(self, name): + return getattr(self._impl, name) + + def __setattr__(self, name, value): + if name in ("_impl",) or "_impl" not in self.__dict__: + super().__setattr__(name, value) + return + if hasattr(self._impl, name): + setattr(self._impl, name, value) + else: + super().__setattr__(name, value) + + +class _LegacyAsyncWrapper: + """ + Legacy executor-based async wrapper (DEPRECATED). + + This wraps a sync HttpClient and runs all operations in a ThreadPoolExecutor. + Maintained for backward compatibility but may be removed in the future. + """ + + def __init__( + self, + client: Client, + executor_threads: int = 0, + executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, + ): + if isinstance(client, HttpClient): + client.headers["User-Agent"] = client.headers["User-Agent"].replace("mode:sync;", "mode:async;") + self.client = client + if executor_threads == 0: + executor_threads = min(32, (os.cpu_count() or 1) + 4) + if executor is NEW_THREAD_POOL_EXECUTOR: + self.new_executor = True + self.executor = ThreadPoolExecutor(max_workers=executor_threads) + else: + if executor_threads != 0: + logger.warning("executor_threads parameter is ignored when passing an executor object") + self.new_executor = False + self.executor = executor + + if not AIOHTTP_AVAILABLE: + logger.info( + "Using executor-based async (legacy mode). " + "For better performance with true native async, install: pip install clickhouse-connect[async]" + ) + + def set_client_setting(self, key, value): + self.client.set_client_setting(key=key, value=value) + + def get_client_setting(self, key) -> Optional[str]: + return self.client.get_client_setting(key=key) + + def set_access_token(self, access_token: str): + self.client.set_access_token(access_token) + + def min_version(self, version_str: str) -> bool: + return self.client.min_version(version_str) + + async def close(self): + self.client.close() + if self.new_executor: + await asyncio.to_thread(self.executor.shutdown, True) + + async def query(self, *args, **kwargs) -> QueryResult: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query(*args, **kwargs)) + + async def query_column_block_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_column_block_stream(*args, **kwargs)) + + async def query_row_block_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_row_block_stream(*args, **kwargs)) + + async def query_rows_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_rows_stream(*args, **kwargs)) + + async def raw_query(self, *args, **kwargs) -> bytes: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.raw_query(*args, **kwargs)) + + async def raw_stream(self, *args, **kwargs) -> io.IOBase: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.raw_stream(*args, **kwargs)) + + async def query_np(self, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_np(*args, **kwargs)) + + async def query_np_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_np_stream(*args, **kwargs)) + + async def query_df(self, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_df(*args, **kwargs)) + + async def query_df_arrow(self, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_df_arrow(*args, **kwargs)) + + async def query_df_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_df_stream(*args, **kwargs)) + + async def query_df_arrow_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_df_arrow_stream(*args, **kwargs)) + + def create_query_context(self, *args, **kwargs) -> QueryContext: + return self.client.create_query_context(*args, **kwargs) + + async def query_arrow(self, *args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_arrow(*args, **kwargs)) + + async def query_arrow_stream(self, *args, **kwargs) -> StreamContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.query_arrow_stream(*args, **kwargs)) + + async def command(self, *args, **kwargs) -> Union[str, int, Sequence[str], QuerySummary]: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.command(*args, **kwargs)) + + async def ping(self) -> bool: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.ping()) # pylint: disable=unnecessary-lambda + + async def insert(self, *args, **kwargs) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.insert(*args, **kwargs)) + + async def insert_df(self, *args, **kwargs) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.insert_df(*args, **kwargs)) + + async def insert_arrow(self, *args, **kwargs) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.insert_arrow(*args, **kwargs)) + + async def insert_df_arrow(self, *args, **kwargs) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.insert_df_arrow(*args, **kwargs)) + + async def create_insert_context(self, *args, **kwargs) -> InsertContext: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.create_insert_context(*args, **kwargs)) + + async def data_insert(self, context: InsertContext) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.data_insert(context)) + + async def raw_insert(self, *args, **kwargs) -> QuerySummary: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(self.executor, lambda: self.client.raw_insert(*args, **kwargs)) diff --git a/clickhouse_connect/driver/asyncqueue.py b/clickhouse_connect/driver/asyncqueue.py new file mode 100644 index 00000000..9cd84b49 --- /dev/null +++ b/clickhouse_connect/driver/asyncqueue.py @@ -0,0 +1,174 @@ +import asyncio +import threading +from collections import deque +from typing import Deque, Generic, Optional, TypeVar + +__all__ = ["AsyncSyncQueue", "Empty", "Full", "EOF_SENTINEL"] + +T = TypeVar("T") + +EOF_SENTINEL = object() + + +# pylint: disable=too-many-instance-attributes +class AsyncSyncQueue(Generic[T]): + """High-performance bridge between AsyncIO and Threading.""" + + def __init__(self, maxsize: int = 100): + self._maxsize = maxsize + self._queue: Deque[T] = deque() + self._shutdown = False + self._loop: Optional[asyncio.AbstractEventLoop] = None + + self._lock = threading.Lock() + + self._sync_not_empty = threading.Condition(self._lock) + self._sync_not_full = threading.Condition(self._lock) + + self._async_getters: Deque[asyncio.Future] = deque() + self._async_putters: Deque[asyncio.Future] = deque() + + self.sync_q = _SyncQueueInterface(self) + self.async_q = _AsyncQueueInterface(self) + + def _bind_loop(self): + """Lazy-bind to the running loop on first async access.""" + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + pass + + def _wakeup_async_waiter(self, waiter_queue: Deque[asyncio.Future]): + """Helper: Wake up the next async waiter in the queue safely.""" + while waiter_queue: + fut = waiter_queue.popleft() + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + break + + def shutdown(self): + """Terminates the queue. All readers will receive EOF_SENTINEL.""" + with self._lock: + self._shutdown = True + + self._sync_not_empty.notify_all() + self._sync_not_full.notify_all() + + if self._loop and not self._loop.is_closed(): + for fut in list(self._async_getters): + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + for fut in list(self._async_putters): + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + self._async_getters.clear() + self._async_putters.clear() + + @property + def qsize(self) -> int: + with self._lock: + return len(self._queue) + + +# pylint: disable=protected-access +class _SyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + def get(self, block: bool = True, timeout: Optional[float] = None) -> T: + with self._p._lock: + while not self._p._queue and not self._p._shutdown: + if not block: + raise Empty() + + if not self._p._sync_not_empty.wait(timeout): + raise Empty() + + if not self._p._queue and self._p._shutdown: + return EOF_SENTINEL + + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + + return item + + def put(self, item: T, block: bool = True, timeout: Optional[float] = None) -> None: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + while self._p._maxsize > 0 and len(self._p._queue) >= self._p._maxsize: + if not block: + raise Full() + if not self._p._sync_not_full.wait(timeout): + raise Full() + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + self._p._queue.append(item) + + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + + +class _AsyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + async def get(self) -> T: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._queue: + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + return item + + if self._p._shutdown: + return EOF_SENTINEL + + fut = self._p._loop.create_future() + self._p._async_getters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_getters: + self._p._async_getters.remove(fut) + raise + + async def put(self, item: T) -> None: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + if self._p._maxsize <= 0 or len(self._p._queue) < self._p._maxsize: + self._p._queue.append(item) + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + return + + fut = self._p._loop.create_future() + self._p._async_putters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_putters: + self._p._async_putters.remove(fut) + raise + + +class Empty(Exception): + pass + + +class Full(Exception): + pass diff --git a/clickhouse_connect/driver/client.py b/clickhouse_connect/driver/client.py index 1a7a53a6..d85fde60 100644 --- a/clickhouse_connect/driver/client.py +++ b/clickhouse_connect/driver/client.py @@ -36,8 +36,8 @@ class Client(ABC): """ Base ClickHouse Connect client """ - compression: str = None - write_compression: str = None + compression: Optional[str] = None + write_compression: Optional[str] = None protocol_version = 0 valid_transport_settings = set() optional_transport_settings = set() @@ -48,14 +48,15 @@ class Client(ABC): show_clickhouse_errors = True def __init__(self, - database: str, + database: Optional[str], query_limit: int, uri: str, query_retries: int, server_host_name: Optional[str], apply_server_timezone: Optional[Union[str, bool]], utc_tz_aware: Optional[bool], - show_clickhouse_errors: Optional[bool]): + show_clickhouse_errors: Optional[bool], + autoconnect: bool = True): """ Shared initialization of ClickHouse Connect client :param database: database name @@ -63,6 +64,8 @@ def __init__(self, :param uri: uri for error messages :param utc_tz_aware: Default timezone behavior when the active timezone resolves to UTC. If True, timezone-aware UTC datetimes are returned; otherwise legacy naive datetimes are used. + :param autoconnect: If True, immediately connect to server and fetch settings. If False, + defer connection to _connect() method. Used by async clients to avoid blocking I/O in __init__. """ self.query_limit = coerce_int(query_limit) self.query_retries = coerce_int(query_retries) @@ -73,7 +76,17 @@ def __init__(self, self.server_host_name = server_host_name self.uri = uri self.utc_tz_aware = bool(utc_tz_aware) - self._init_common_settings(apply_server_timezone) + + # Initialize attributes that will be set during connection + self.server_version = None + self.server_tz = pytz.UTC + self.server_settings = {} + + if autoconnect: + self._init_common_settings(apply_server_timezone) + else: + # Store for deferred connection + self._deferred_apply_server_timezone = apply_server_timezone def _init_common_settings(self, apply_server_timezone: Optional[Union[str, bool]]): self.server_tz, dst_safe = pytz.UTC, True @@ -334,9 +347,9 @@ def raw_stream(self, query: str, fmt: str = None, use_database: bool = True, external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: + transport_settings: Optional[Dict[str, str]] = None) -> Union[io.IOBase, StreamContext]: """ - Query method that returns the result as an io.IOBase iterator + Query method that returns the result as a stream iterator. :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) @@ -345,7 +358,7 @@ def raw_stream(self, query: str, database context. :param external_data: External data to send with the query. :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result + :return: io.IOBase (sync) or StreamContext (async) - both support iteration over raw bytes """ # pylint: disable=duplicate-code,unused-argument diff --git a/clickhouse_connect/driver/common.py b/clickhouse_connect/driver/common.py index 76ccc7f2..7770f671 100644 --- a/clickhouse_connect/driver/common.py +++ b/clickhouse_connect/driver/common.py @@ -1,6 +1,7 @@ import array import struct import sys +import asyncio from typing import Sequence, MutableSequence, Dict, Optional, Union, Generator, Callable @@ -190,7 +191,8 @@ def __eq__(self, other): class StreamContext: """ Wraps a generator and its "source" in a Context. This ensures that the source will be "closed" even if the - generator is not fully consumed or there is an exception during consumption + generator is not fully consumed or there is an exception during consumption. Supports both synchronous and + asynchronous usage. """ __slots__ = 'source', 'gen', '_in_context' @@ -218,6 +220,58 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.source.close() self.gen = None + def __aiter__(self): + return self + + async def __anext__(self): + if not self._in_context: + raise ProgrammingError("Stream should be used within a context") + try: + if hasattr(self.gen, "__anext__"): + return await self.gen.__anext__() + + def _next_wrapper(): + try: + return True, self.gen.__next__() + except StopIteration: + return False, None + + loop = asyncio.get_running_loop() + has_value, value = await loop.run_in_executor(None, _next_wrapper) + if not has_value: + raise StopAsyncIteration from None + return value + except (StopAsyncIteration, StopIteration): + raise StopAsyncIteration from None + except Exception as ex: + if not isinstance(ex, StreamClosedError): + self._in_context = False + if hasattr(self.source, "close"): + if hasattr(self.source.close, "__await__"): + await self.source.close() + else: + self.source.close() + self.gen = None + raise ex + + async def __aenter__(self): + if not self.gen: + raise StreamClosedError + self._in_context = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._in_context = False + if hasattr(self.source, 'aclose'): + await self.source.aclose() + elif hasattr(self.source, 'close'): + if hasattr(self.source.close, '__await__'): + await self.source.close() + else: + self.source.close() + self.gen = None + + # pylint: disable=too-many-return-statements def get_rename_method(method: Optional[str]) -> Optional[Callable[[str], str]]: def _to_camel(s: str) -> str: diff --git a/clickhouse_connect/driver/httpclient.py b/clickhouse_connect/driver/httpclient.py index 4c58a30d..4bfea703 100644 --- a/clickhouse_connect/driver/httpclient.py +++ b/clickhouse_connect/driver/httpclient.py @@ -185,7 +185,8 @@ def __init__(self, server_host_name=server_host_name, apply_server_timezone=apply_server_timezone, utc_tz_aware=utc_tz_aware, - show_clickhouse_errors=show_clickhouse_errors) + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=True) self.params = dict_copy(self.params, self._validate_settings(ch_settings)) comp_setting = self._setting_status('enable_http_compression') self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable @@ -437,42 +438,31 @@ def _error_handler(self, response: HTTPResponse, retried: bool = False) -> None: """ try: body = "" - # Always try to read the response body for context. try: - # get_response_data reads body and decodes it for the error message raw_body = get_response_data(response) body = common.format_error( raw_body.decode(errors="backslashreplace") ).strip() except Exception: # pylint: disable=broad-except - # If we can't read or decode the body, we'll proceed without it logger.warning("Failed to read error response body", exc_info=True) - # Build the error message if self.show_clickhouse_errors: err_code = response.headers.get(ex_header) if err_code: - # Prioritize the specific ClickHouse exception code if it exists. err_str = f"Received ClickHouse exception, code: {err_code}" else: - # Otherwise, just use the generic HTTP status err_str = f"HTTP driver received HTTP status {response.status}" if body: - # Always append the body if it exists err_str = f"{err_str}, server response: {body}" else: - # Simple message for when detailed errors are disabled err_str = "The ClickHouse server returned an error" - # Add the URL for additional context err_str = f"{err_str} (for url {self.url})" finally: - # Ensure closed response to prevent resource leaks response.close() - # Raise the appropriate exception class raise OperationalError(err_str) if retried else DatabaseError(err_str) from None def _raw_request(self, diff --git a/clickhouse_connect/driver/query.py b/clickhouse_connect/driver/query.py index 8001bd22..fbf297d1 100644 --- a/clickhouse_connect/driver/query.py +++ b/clickhouse_connect/driver/query.py @@ -280,12 +280,20 @@ def result_set(self) -> Matrix: @property def result_columns(self) -> Matrix: if self._result_columns is None: - result = [[] for _ in range(len(self.column_names))] - with self.column_block_stream as stream: - for block in stream: - for base, added in zip(result, block): - base.extend(added) - self._result_columns = result + # If rows are already materialized and stream is closed, transpose from rows + # This happens when async client eagerly materializes result_rows + if self._result_rows is not None and self._block_gen is None: + if self._result_rows: + self._result_columns = list(map(list, zip(*self._result_rows))) + else: + self._result_columns = [[] for _ in range(len(self.column_names))] + else: + result = [[] for _ in range(len(self.column_names))] + with self.column_block_stream as stream: + for block in stream: + for base, added in zip(result, block): + base.extend(added) + self._result_columns = result return self._result_columns @property diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py new file mode 100644 index 00000000..b2fb19dd --- /dev/null +++ b/clickhouse_connect/driver/streaming.py @@ -0,0 +1,311 @@ +import asyncio +import logging +import threading +import zlib +from typing import Iterator, Optional + +import lz4.frame +import zstandard + +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.types import Closable + +logger = logging.getLogger(__name__) + +__all__ = ["StreamingResponseSource", "StreamingFileAdapter", "StreamingInsertSource"] + +if "br" in available_compression: + import brotli +else: + brotli = None + + +# pylint: disable=too-many-instance-attributes, broad-exception-caught +class StreamingResponseSource(Closable): + """Streaming source that feeds chunks from async producer to sync consumer.""" + + READ_BUFFER_SIZE = 512 * 1024 + + def __init__(self, response, encoding: Optional[str] = None): + self.response = response + self.encoding = encoding + + # maxsize=10 means max ~10 socket reads buffered + self.queue = AsyncSyncQueue(maxsize=10) + + self._decompressor = None + self._decompressor_initialized = False + + # Multiple accesses to .gen must return the same generator, not create new ones + self._gen_cache = None + + self._producer_task = None + self._producer_started = threading.Event() + self._producer_error: Optional[Exception] = None + self._producer_completed = False + + async def start_producer(self, loop: asyncio.AbstractEventLoop): + """Start the async producer task. + Must be called from the event loop thread before consuming. + """ + + async def producer(): + """Async producer: reads chunks from response, feeds queue.""" + try: + while True: + chunk = await self.response.content.read(self.READ_BUFFER_SIZE) + if not chunk: + break + await self.queue.async_q.put(chunk) + + await self.queue.async_q.put(EOF_SENTINEL) + self._producer_completed = True + + except Exception as e: + logger.error("Producer error while streaming response: %s", e, exc_info=True) + self._producer_error = e + + try: + await self.queue.async_q.put(e) + except RuntimeError: + pass + + finally: + self.queue.shutdown() + + self._producer_task = loop.create_task(producer()) + self._producer_started.set() + + @property + def gen(self) -> Iterator[bytes]: + """Generator that yields decompressed chunks. + + CRITICAL: Returns cached generator to prevent multiple generators + from competing to read from the same queue. + """ + if self._gen_cache is not None: + return self._gen_cache + + self._gen_cache = self._create_generator() + return self._gen_cache + + # pylint: disable=too-many-branches + def _create_generator(self) -> Iterator[bytes]: + """Creates the actual generator function.""" + if not self._producer_started.wait(timeout=5.0): + raise RuntimeError("Producer failed to start within timeout") + + if self.encoding and not self._decompressor_initialized: + self._decompressor_initialized = True + try: + self._decompressor = self._create_decompressor(self.encoding) + except Exception as e: + logger.error("Failed to create decompressor for %s: %s", self.encoding, e) + raise + + # pylint: disable=too-many-nested-blocks + while True: + chunk = self.queue.sync_q.get() + + if chunk is EOF_SENTINEL: + if self._decompressor: + try: + if hasattr(self._decompressor, "flush"): + final = self._decompressor.flush() + if final: + yield final + except Exception as e: + logger.error("Error flushing decompressor: %s", e, exc_info=True) + raise + break + + if isinstance(chunk, Exception): + raise chunk + + if self._decompressor: + try: + if hasattr(self._decompressor, "decompress"): + decompressed = self._decompressor.decompress(chunk) + else: + decompressed = self._decompressor.process(chunk) + if decompressed: + yield decompressed + except Exception as e: + logger.error("Decompression error: %s", e, exc_info=True) + raise + else: + yield chunk + + @staticmethod + def _create_decompressor(encoding: str): + """Create incremental decompressor for encoding.""" + if encoding == "gzip": + return zlib.decompressobj(16 + zlib.MAX_WBITS) + + if encoding == "deflate": + return zlib.decompressobj() + + if encoding == "br": + if brotli is not None: + return brotli.Decompressor() + raise ImportError("brotli compression requires 'brotli' package. Install with: pip install brotli") + + if encoding == "zstd": + return zstandard.ZstdDecompressor().decompressobj() + + if encoding == "lz4": + return lz4.frame.LZ4FrameDecompressor() + + raise ValueError(f"Unsupported compression encoding: {encoding}") + + async def aclose(self): + """Async cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + try: + await self._producer_task + except asyncio.CancelledError: + pass + except Exception: + pass + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + await asyncio.sleep(0.05) + + def close(self): + """Synchronous cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + + +class StreamingFileAdapter: + """File-like adapter for PyArrow streaming.""" + + def __init__(self, streaming_source): + self.streaming_source = streaming_source + self.gen = streaming_source.gen + self.buffer = b"" + self.closed = False + self.eof = False + + def read(self, size: int = -1) -> bytes: + """Read up to size bytes from stream""" + if self.closed or self.eof: + return b"" + + if size != -1 and len(self.buffer) >= size: + result = self.buffer[:size] + self.buffer = self.buffer[size:] + return result + + chunks = [self.buffer] if self.buffer else [] + current_len = len(self.buffer) + self.buffer = b"" + + while (size == -1 or current_len < size) and not self.eof: + try: + chunk = next(self.gen) + if chunk: + chunks.append(chunk) + current_len += len(chunk) + else: + self.eof = True + break + except StopIteration: + self.eof = True + break + + full_data = b"".join(chunks) + + if size == -1 or len(full_data) <= size: + return full_data + + result = full_data[:size] + self.buffer = full_data[size:] + return result + + def close(self): + self.closed = True + + +class StreamingInsertSource: + """Streaming source for async inserts (reverse bridge)""" + + def __init__(self, transform, context, loop: asyncio.AbstractEventLoop, maxsize: int = 10): + self.transform = transform + self.context = context + self.loop = loop + self.queue = AsyncSyncQueue(maxsize=maxsize) + self._producer_future = None + self._started = False + + def start_producer(self): + if self._started: + raise RuntimeError("Producer already started") + self._started = True + + def producer(): + try: + for block in self.transform.build_insert(self.context): + self.queue.sync_q.put(block) + + self.queue.sync_q.put(EOF_SENTINEL) + + except Exception as e: + logger.error("Insert producer error: %s", e, exc_info=True) + try: + self.queue.sync_q.put(e) + except Exception: + pass + finally: + self.queue.shutdown() + + self._producer_future = self.loop.run_in_executor(None, producer) + + async def async_generator(self): + """Async generator that yields blocks for aiohttp streaming.""" + if not self._started: + raise RuntimeError("Producer not started, call start_producer() first") + + try: + while True: + chunk = await self.queue.async_q.get() + + if chunk is EOF_SENTINEL: + break + + if isinstance(chunk, Exception): + raise chunk + + yield chunk + + except Exception as e: + logger.error("Insert consumer error: %s", e, exc_info=True) + raise + finally: + if self._producer_future and not self._producer_future.done(): + try: + await self._producer_future + except Exception: + pass + + async def close(self): + self.queue.shutdown() + if self._producer_future and not self._producer_future.done(): + try: + await asyncio.wait_for(self._producer_future, timeout=1.0) + except asyncio.TimeoutError: + logger.warning("Insert producer did not finish within timeout") + except Exception: + pass diff --git a/clickhouse_connect/driver/tools.py b/clickhouse_connect/driver/tools.py index 42480858..d26e55a1 100644 --- a/clickhouse_connect/driver/tools.py +++ b/clickhouse_connect/driver/tools.py @@ -1,3 +1,4 @@ +import asyncio from typing import Optional, Sequence, Dict, Any from clickhouse_connect.driver import Client @@ -31,3 +32,38 @@ def insert_file(client: Client, fmt=fmt, settings=settings, compression=compression) + + +async def insert_file_async(client, + table: str, + file_path: str, + fmt: Optional[str] = None, + column_names: Optional[Sequence[str]] = None, + database: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + compression: Optional[str] = None) -> QuerySummary: + + if not database and table[0] not in ('`', "'") and table.find('.') > 0: + full_table = table + elif database: + full_table = f'{quote_identifier(database)}.{quote_identifier(table)}' + else: + full_table = quote_identifier(table) + if not fmt: + fmt = 'CSV' if column_names else 'CSVWithNames' + if compression is None: + if file_path.endswith('.gzip') or file_path.endswith('.gz'): + compression = 'gzip' + + def read_file(): + with open(file_path, 'rb') as file: + return file.read() + + file_data = await asyncio.to_thread(read_file) + + return await client.raw_insert(full_table, + column_names=column_names, + insert_block=file_data, + fmt=fmt, + settings=settings, + compression=compression) diff --git a/pyproject.toml b/pyproject.toml index 7d37b3e0..ce7f096e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,4 @@ log_cli = true log_cli_level = "INFO" env_files = ["test.env"] asyncio_default_fixture_loop_scope = "session" +addopts = "-n 4" diff --git a/setup.py b/setup.py index e5fa4249..4950fa59 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ def run_setup(try_c: bool = True): 'arrow': ['pyarrow>=22.0; python_version>="3.14"', 'pyarrow; python_version<"3.14"'], 'orjson': ['orjson'], 'tzlocal': ['tzlocal>=4.0'], + 'async': ['aiohttp>=3.8.0'], }, tests_require=['pytest'], entry_points={ diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 12598689..d972f4ed 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -1,3 +1,4 @@ +import asyncio import sys import os import random @@ -8,6 +9,7 @@ import pytest_asyncio from pytest import fixture +from clickhouse_connect import get_async_client from clickhouse_connect import common from clickhouse_connect.driver.common import coerce_bool from clickhouse_connect.driver.exceptions import OperationalError @@ -35,6 +37,8 @@ class TestException(BaseException): pass +# pylint: disable=redefined-outer-name + @fixture(scope='session', autouse=True, name='test_config') def test_config_fixture() -> Iterator[TestConfig]: common.set_setting('max_connection_age', 15) # Make sure resetting connections doesn't break stuff @@ -92,6 +96,142 @@ def test_table_engine_fixture() -> Iterator[str]: yield 'MergeTree' +@fixture(scope="module") +def shared_loop(): + """Shared event loop for running async clients in sync test context.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@fixture(params=["sync", "async"]) +def client_mode(request): + return request.param + + +@fixture +def call(client_mode, shared_loop): + """Wrapper to call functions in appropriate sync/async context.""" + if client_mode == "sync": + return lambda fn, *args, **kwargs: fn(*args, **kwargs) + return lambda fn, *args, **kwargs: shared_loop.run_until_complete(fn(*args, **kwargs)) + + +@fixture +def consume_stream(client_mode, call): + """Fixture to consume a stream in either sync or async mode.""" + + def _consume(stream, callback=None): + if client_mode == "sync": + with stream: + for item in stream: + if callback: + callback(item) + else: + + async def runner(): + async with stream: + async for item in stream: + if callback: + callback(item) + + call(runner) + + return _consume + + +@fixture +def client_factory(client_mode, test_config, shared_loop): + """Factory for creating clients with custom configuration in tests.""" + clients = [] + + def factory(**kwargs): + config = { + "host": test_config.host, + "port": test_config.port, + "username": test_config.username, + "password": test_config.password, + "database": test_config.test_database, + "compress": test_config.compress, + **kwargs, + } + + if client_mode == "sync": + client = create_client(**config) + else: + client = shared_loop.run_until_complete(get_async_client(**config)) + + clients.append(client) + return client + + yield factory + + for client in clients: + try: + if client_mode == "sync": + client.close() + else: + shared_loop.run_until_complete(client.close()) + except Exception: # pylint: disable=broad-exception-caught + pass + + +@fixture +def param_client(client_mode, test_config, shared_loop): + """Provides client based on client_mode parameter.""" + if client_mode == "sync": + client = create_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + settings={"allow_suspicious_low_cardinality_types": 1}, + client_name="int_tests/param_sync", + ) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", 1) + if client.min_version("24.8") and (client.min_version("24.12") or not test_config.cloud): + client.set_client_setting("allow_experimental_json_type", 1) + client.set_client_setting("allow_experimental_dynamic_type", 1) + client.set_client_setting("allow_experimental_variant_type", 1) + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", test_config.insert_quorum) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", 1) + + yield client + client.close() + else: + client = shared_loop.run_until_complete( + get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + settings={"allow_suspicious_low_cardinality_types": 1}, + client_name="int_tests/param_async", + ) + ) + + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") + + yield client + shared_loop.run_until_complete(client.close()) + + # pylint: disable=too-many-branches @fixture(scope='session', autouse=True, name='test_client') def test_client_fixture(test_config: TestConfig, test_create_client: Callable) -> Iterator[Client]: @@ -142,12 +282,38 @@ def test_client_fixture(test_config: TestConfig, test_create_client: Callable) - sys.stderr.write('Successfully stopped docker compose') -@pytest_asyncio.fixture(scope='session', autouse=True, name='test_async_client') +@pytest_asyncio.fixture(scope='session', name='test_async_client') async def test_async_client_fixture(test_client: Client) -> AsyncContextManager[AsyncClient]: async with AsyncClient(client=test_client) as client: yield client +@pytest_asyncio.fixture(scope="function", loop_scope="function", name="test_native_async_client") +async def test_native_async_client_fixture(test_config: TestConfig) -> AsyncContextManager: + """Function-scoped fixture for aiohttp async client""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + client_name="int_tests/aiohttp_async", + ) as client: + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") + + yield client + + @fixture(scope='session', name='table_context') def table_context_fixture(test_client: Client, test_table_engine: str): def context(table: str, diff --git a/tests/integration_tests/test_arrow.py b/tests/integration_tests/test_arrow.py index aa681cbe..cb4646e8 100644 --- a/tests/integration_tests/test_arrow.py +++ b/tests/integration_tests/test_arrow.py @@ -8,18 +8,18 @@ from clickhouse_connect.driver.options import arrow -def test_arrow(test_client: Client, table_context: Callable): +def test_arrow(param_client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_insert', ['animal String', 'legs Int64']): n_legs = arrow.array([2, 4, 5, 100] * 50) animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede'] * 50) names = ['legs', 'animal'] insert_table = arrow.Table.from_arrays([n_legs, animals], names=names) - test_client.insert_arrow('test_arrow_insert', insert_table) - result_table = test_client.query_arrow('SELECT * FROM test_arrow_insert', use_strings=False) + call(param_client.insert_arrow, 'test_arrow_insert', insert_table) + result_table = call(param_client.query_arrow, 'SELECT * FROM test_arrow_insert', use_strings=False) arrow_schema = result_table.schema assert arrow_schema.field(0).name == 'animal' assert arrow_schema.field(0).type == arrow.binary() @@ -29,29 +29,34 @@ def test_arrow(test_client: Client, table_context: Callable): assert arrow.compute.sum(result_table['legs']).as_py() == 5550 assert len(result_table.columns) == 2 - arrow_table = test_client.query_arrow('SELECT number from system.numbers LIMIT 500', - settings={'max_block_size': 50}) + arrow_table = call(param_client.query_arrow, 'SELECT number from system.numbers LIMIT 500', + settings={'max_block_size': 50}) arrow_schema = arrow_table.schema assert arrow_schema.field(0).name == 'number' assert arrow_schema.field(0).type.id == 8 assert arrow_table.num_rows == 500 -def test_arrow_stream(test_client: Client, table_context: Callable): +def test_arrow_stream(param_client, call, table_context, consume_stream): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_insert', ['counter Int64', 'letter String']): counter = arrow.array(range(1000000)) alphabet = string.ascii_lowercase letter = arrow.array([alphabet[x % 26] for x in range(1000000)]) names = ['counter', 'letter'] insert_table = arrow.Table.from_arrays([counter, letter], names=names) - test_client.insert_arrow('test_arrow_insert', insert_table) - stream = test_client.query_arrow_stream('SELECT * FROM test_arrow_insert', use_strings=True) - with stream: - result_tables = list(stream) + call(param_client.insert_arrow, 'test_arrow_insert', insert_table) + stream = call(param_client.query_arrow_stream, 'SELECT * FROM test_arrow_insert', use_strings=True) + result_tables = [] + + def process(table): + result_tables.append(table) + + consume_stream(stream, process) + # Hopefully we made the table long enough we got multiple tables in the query assert len(result_tables) > 1 total_rows = 0 @@ -67,20 +72,20 @@ def test_arrow_stream(test_client: Client, table_context: Callable): assert total_rows == 1000000 -def test_arrow_map(test_client: Client, table_context: Callable): +def test_arrow_map(param_client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_map', ['trade_date Date, code String', 'kdj Map(String, Float32)', 'update_time DateTime DEFAULT now()']): data = [[date(2023, 10, 15), 'C1', {'k': 2.5, 'd': 0, 'j': 0}], [date(2023, 10, 16), 'C2', {'k': 3.5, 'd': 0, 'j': -.372}]] - test_client.insert('test_arrow_map', data, column_names=('trade_date', 'code', 'kdj'), - settings={'insert_deduplication_token': '10381'}) - arrow_table = test_client.query_arrow('SELECT * FROM test_arrow_map ORDER BY trade_date', - use_strings=True) + call(param_client.insert, 'test_arrow_map', data, column_names=('trade_date', 'code', 'kdj'), + settings={'insert_deduplication_token': '10381'}) + arrow_table = call(param_client.query_arrow, 'SELECT * FROM test_arrow_map ORDER BY trade_date', + use_strings=True) assert isinstance(arrow_table.schema, arrow.Schema) - test_client.insert_arrow('test_arrow_map', arrow_table, settings={'insert_deduplication_token': '10382'}) - assert 4 == test_client.command('SELECT count() FROM test_arrow_map') + call(param_client.insert_arrow, 'test_arrow_map', arrow_table, settings={'insert_deduplication_token': '10382'}) + assert 4 == call(param_client.command, 'SELECT count() FROM test_arrow_map') diff --git a/tests/integration_tests/test_async_features.py b/tests/integration_tests/test_async_features.py new file mode 100644 index 00000000..ce32b8c2 --- /dev/null +++ b/tests/integration_tests/test_async_features.py @@ -0,0 +1,303 @@ +import asyncio +import time +from typing import Callable + +import pytest + +from clickhouse_connect import get_async_client +from clickhouse_connect.driver.exceptions import OperationalError, ProgrammingError + +# pylint: disable=protected-access + + +@pytest.mark.asyncio +async def test_concurrent_queries(test_config): + """Verify multiple queries execute concurrently (not sequentially).""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + autogenerate_session_id=False, + ) as client: + queries = [client.query(f"SELECT {i}, sleep(0.1)") for i in range(10)] + + start = time.time() + results = await asyncio.gather(*queries) + elapsed = time.time() - start + + assert elapsed < 0.5, f"Took {elapsed}s, queries appear to run sequentially" + assert len(results) == 10 + + for i, result in enumerate(results): + assert result.row_count == 1 + first_row = result.result_rows[0] + assert first_row[0] == i + + +@pytest.mark.asyncio +async def test_stream_cancellation(test_config): + """Test that early exit from async iteration doesn't leak resources.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + ) as client: + stream = await client.query_rows_stream("SELECT number FROM numbers(100000)", settings={"max_block_size": 1000}) + + count = 0 + async with stream: + async for _ in stream: + count += 1 + if count >= 10: + break + + assert count == 10 + + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + +@pytest.mark.asyncio +async def test_concurrent_streams(test_config): + """Verify multiple streams can run in parallel.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + autogenerate_session_id=False, + ) as client: + + async def consume_stream(stream_id: int): + stream = await client.query_rows_stream( + f"SELECT number FROM numbers(1000) WHERE number % 3 = {stream_id}", settings={"max_block_size": 100} + ) + total = 0 + async with stream: + async for row in stream: + total += row[0] + return total + + start = time.time() + results = await asyncio.gather(consume_stream(0), consume_stream(1), consume_stream(2)) + elapsed = time.time() - start + + assert len(results) == 3 + assert all(r > 0 for r in results) + assert elapsed < 5.0 + + +@pytest.mark.asyncio +async def test_context_manager_cleanup(test_config): + """Test proper resource cleanup on context manager exit.""" + client = await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + ) + + assert client._initialized is True + assert client._session is not None + + async with client: + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + assert client._session is None or client._session.closed + + with pytest.raises((RuntimeError, OperationalError)): + await client.query("SELECT 1") + + +@pytest.mark.asyncio +async def test_session_concurrency_protection(test_config): + """Test that concurrent queries in the same session are blocked.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + session_id="test_concurrent_session", + ) as client: + + async def long_query(): + return await client.query("SELECT sleep(0.5), 1") + + async def quick_query(): + await asyncio.sleep(0.1) + return await client.query("SELECT 1") + + with pytest.raises(ProgrammingError) as exc_info: + await asyncio.gather(long_query(), quick_query()) + + assert "concurrent" in str(exc_info.value).lower() or "session" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_timeout_handling(test_config): + """Test that async timeout exceptions propagate correctly.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + send_receive_timeout=1, # 1 second timeout + autogenerate_session_id=False, # No session to avoid session locking after timeout + ) as client: + # This query should timeout (sleep 2 seconds with 1 second timeout) + with pytest.raises((asyncio.TimeoutError, OperationalError)): + await client.query("SELECT sleep(2)") + + # Client should remain functional after timeout + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + +@pytest.mark.asyncio +async def test_connection_pool_reuse(test_config): + """Verify connection pooling works correctly under load.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + connector_limit=10, # Limit pool size + connector_limit_per_host=5, + autogenerate_session_id=False, + ) as client: + # Run more queries in parallel than pool size + queries = [client.query(f"SELECT {i}") for i in range(50)] + + start = time.time() + results = await asyncio.gather(*queries) + elapsed = time.time() - start + + assert len(results) == 50 + for i, result in enumerate(results): + assert result.result_rows[0][0] == i + + assert elapsed < 10.0 + + +@pytest.mark.asyncio +async def test_concurrent_inserts(test_config, table_context: Callable): + """Test multiple inserts can run in parallel.""" + with table_context("test_concurrent_inserts", ["id UInt32", "value String"]) as ctx: + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + autogenerate_session_id=False, + ) as client: + + async def insert_batch(start_id: int, count: int): + data = [[start_id + i, f"value_{start_id + i}"] for i in range(count)] + await client.insert(ctx.table, data) + + await asyncio.gather( + insert_batch(0, 10), + insert_batch(100, 10), + insert_batch(200, 10), + insert_batch(300, 10), + insert_batch(400, 10), + ) + + result = await client.query(f"SELECT count() FROM {ctx.table}") + assert result.result_rows[0][0] == 50 + + +@pytest.mark.asyncio +async def test_error_isolation(test_config): + """Test that one failing query doesn't break other concurrent queries.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + autogenerate_session_id=False, + ) as client: + + async def good_query(n: int): + return await client.query(f"SELECT {n}") + + async def bad_query(): + return await client.query("SELECT invalid_syntax_here!!!") + + results = await asyncio.gather(good_query(1), bad_query(), good_query(2), bad_query(), good_query(3), return_exceptions=True) + + assert results[0].result_rows[0][0] == 1 + assert results[2].result_rows[0][0] == 2 + assert results[4].result_rows[0][0] == 3 + + assert isinstance(results[1], Exception) + assert isinstance(results[3], Exception) + + +@pytest.mark.asyncio +async def test_streaming_early_termination(test_config): + """Verify streaming can be terminated early without issues.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + autogenerate_session_id=False, # Don't use session to avoid locking + ) as client: + stream = await client.query_rows_stream("SELECT number, repeat('x', 10000) FROM numbers(100000)", settings={"max_block_size": 1000}) + + count = 0 + async with stream: + async for row in stream: + count += 1 + if count >= 1000: + break # Early termination + + assert count == 1000 + + # Client should still be functional after early termination + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + stream2 = await client.query_rows_stream("SELECT number FROM numbers(100)", settings={"max_block_size": 10}) + + count2 = 0 + async with stream2: + async for row in stream2: + count2 += 1 + + assert count2 == 100 + + +@pytest.mark.asyncio +async def test_regular_query_streams_then_materializes(test_config): + """Verify regular query() uses streaming internally but materializes result.""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + ) as client: + result = await client.query("SELECT number FROM numbers(10000)") + + assert len(result.result_rows) == 10000 + assert result.result_rows[0][0] == 0 + assert result.result_rows[-1][0] == 9999 + + expected_numbers = list(range(10000)) + actual_numbers = [row[0] for row in result.result_rows] + assert actual_numbers == expected_numbers diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index eb24a6b5..e7ae4d7d 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -29,49 +29,50 @@ def _is_valid_uuid_v4(id_string: str) -> bool: return False -def test_ping(test_client: Client): - assert test_client.ping() is True +def test_ping(param_client, call): + assert call(param_client.ping) is True -def test_query(test_client: Client): - result = test_client.query('SELECT * FROM system.tables') +def test_query(param_client, call): + result = call(param_client.query, 'SELECT * FROM system.tables') assert len(result.result_set) > 0 assert result.row_count > 0 assert result.first_item == next(result.named_results()) -def test_command(test_client: Client): - version = test_client.command('SELECT version()') +def test_command(param_client, call): + version = call(param_client.command, 'SELECT version()') assert int(version.split('.')[0]) >= 19 -def test_client_name(test_client: Client): - user_agent = test_client.headers['User-Agent'] - assert 'test' in user_agent +def test_client_name(param_client, client_mode): + user_agent = param_client.headers['User-Agent'] + assert 'test' in user_agent or 'param' in user_agent assert 'py/' in user_agent + assert f"mode:{client_mode}" in user_agent -def test_transport_settings(test_client: Client): - result = test_client.query('SELECT name,database FROM system.tables', +def test_transport_settings(param_client, call): + result = call(param_client.query, 'SELECT name,database FROM system.tables', transport_settings={'X-Workload': 'ONLINE'}) assert result.column_names == ('name', 'database') assert len(result.result_set) > 0 -def test_none_database(test_client: Client): - old_db = test_client.database - test_db = test_client.command('select currentDatabase()') +def test_none_database(param_client, call): + old_db = param_client.database + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == old_db try: - test_client.database = None - test_client.query('SELECT * FROM system.tables') - test_db = test_client.command('select currentDatabase()') + param_client.database = None + call(param_client.query, 'SELECT * FROM system.tables') + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == 'default' - test_client.database = old_db - test_db = test_client.command('select currentDatabase()') + param_client.database = old_db + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == old_db finally: - test_client.database = old_db + param_client.database = old_db def test_session_params(test_config: TestConfig): @@ -89,15 +90,40 @@ def test_session_params(test_config: TestConfig): if client.min_version('21'): if test_config.host != 'localhost': return # By default, the session log isn't enabled, so we only validate in environments we control - sleep(10) # Allow the log entries to flush to tables - result = client.query( - f"SELECT session_id, user FROM system.session_log WHERE session_id = '{session_id}' AND " + - 'event_time > now() - 30').result_set - assert result[0] == (session_id, test_config.username) - result = client.query( - "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + - 'event_time > now() - 30').result_set - assert result[0] == ('test_session_params', test_config.username) + + def check_session_in_log(): + max_retries = 100 + for _ in range(max_retries): + result = client.query( + f"SELECT session_id, user FROM system.session_log WHERE session_id = '{session_id}' AND " + + 'event_time > now() - 30').result_set + + if len(result) > 0: + assert result[0] == (session_id, test_config.username) + return + + sleep(0.1) + + pytest.fail(f"session_id '{session_id}' did not appear in system.session_log after {max_retries * 0.1}s") + + def check_query_in_log(): + max_retries = 100 + for _ in range(max_retries): + result = client.query( + "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + + 'event_time > now() - 30').result_set + + if len(result) > 0: + assert result[0] == ('test_session_params', test_config.username) + return + + sleep(0.1) + + pytest.fail(f"query_id 'test_session_params' did not appear in system.query_log after {max_retries * 0.1}s") + + # Check both logs with smart retry logic + check_session_in_log() + check_query_in_log() def test_dsn_config(test_config: TestConfig): @@ -105,31 +131,33 @@ def test_dsn_config(test_config: TestConfig): dsn = (f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:{test_config.port}' + f'/{test_config.test_database}?session_id={session_id}&show_clickhouse_errors=false') client = create_client(dsn=dsn) - assert client.get_client_setting('session_id') == session_id - count = client.command('SELECT count() from system.tables') - assert client.database == test_config.test_database - assert count > 0 try: - client.query('SELECT nothing') - except DatabaseError as ex: - assert 'returned an error' in str(ex) - client.close() + assert client.get_client_setting('session_id') == session_id + count = client.command('SELECT count() from system.tables') + assert client.database == test_config.test_database + assert count > 0 + try: + client.query('SELECT nothing') + except DatabaseError as ex: + assert 'returned an error' in str(ex) + finally: + client.close() -def test_no_columns_and_types_when_no_results(test_client: Client): +def test_no_columns_and_types_when_no_results(param_client, call): """ In case of no results, the column names and types are not returned when FORMAT Native is set. This may cause a lot of confusion. Read more: https://github.com/ClickHouse/clickhouse-connect/issues/257 """ - result = test_client.query('SELECT name, database, NOW() as dt FROM system.tables WHERE FALSE') + result = call(param_client.query, 'SELECT name, database, NOW() as dt FROM system.tables WHERE FALSE') assert result.column_names == () assert result.column_types == () assert result.result_set == [] -def test_get_columns_only(test_client: Client): - result = test_client.query('SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') +def test_get_columns_only(param_client, call): + result = call(param_client.query, 'SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') assert result.column_names == ('name', 'database', 'dt') assert len(result.column_types) == 3 assert isinstance(result.column_types[0], datatypes.string.String) @@ -137,28 +165,28 @@ def test_get_columns_only(test_client: Client): assert isinstance(result.column_types[2], datatypes.temporal.DateTime) assert len(result.result_set) == 0 - test_client.query('CREATE TABLE IF NOT EXISTS test_zero_insert (v Int8) ENGINE MergeTree() ORDER BY tuple()') - test_client.query('INSERT INTO test_zero_insert SELECT 1 LIMIT 0') + call(param_client.query, 'CREATE TABLE IF NOT EXISTS test_zero_insert (v Int8) ENGINE MergeTree() ORDER BY tuple()') + call(param_client.query, 'INSERT INTO test_zero_insert SELECT 1 LIMIT 0') -def test_no_limit(test_client: Client): - old_limit = test_client.query_limit - test_client.limit = 0 - result = test_client.query('SELECT name FROM system.databases') +def test_no_limit(param_client, call): + old_limit = param_client.query_limit + param_client.limit = 0 + result = call(param_client.query, 'SELECT name FROM system.databases') assert len(result.result_set) > 0 - test_client.limit = old_limit + param_client.limit = old_limit -def test_multiline_query(test_client: Client): - result = test_client.query(""" +def test_multiline_query(param_client, call): + result = call(param_client.query, """ SELECT * FROM system.tables """) assert len(result.result_set) > 0 -def test_query_with_inline_comment(test_client: Client): - result = test_client.query(""" +def test_query_with_inline_comment(param_client, call): + result = call(param_client.query, """ SELECT * -- This is just a comment FROM system.tables LIMIT 77 @@ -167,8 +195,8 @@ def test_query_with_inline_comment(test_client: Client): assert len(result.result_set) > 0 -def test_query_with_comment(test_client: Client): - result = test_client.query(""" +def test_query_with_comment(param_client, call): + result = call(param_client.query, """ SELECT * /* This is: a multiline comment */ @@ -177,14 +205,14 @@ def test_query_with_comment(test_client: Client): assert len(result.result_set) > 0 -def test_insert_csv_format(test_client: Client, test_table_engine: str): - test_client.command('DROP TABLE IF EXISTS test_csv') - test_client.command( +def test_insert_csv_format(param_client, call, test_table_engine: str): + call(param_client.command, 'DROP TABLE IF EXISTS test_csv') + call(param_client.command, 'CREATE TABLE test_csv ("key" String, "val1" Int32, "val2" Int32) ' + f'ENGINE {test_table_engine} ORDER BY tuple()') sql = f'INSERT INTO test_csv ("key", "val1", "val2") FORMAT CSV {CSV_CONTENT}' - test_client.command(sql) - result = test_client.query('SELECT * from test_csv') + call(param_client.command, sql) + result = call(param_client.query, 'SELECT * from test_csv') def compare_rows(row_1, row_2): return all(c1 == c2 for c1, c2 in zip(row_1, row_2)) @@ -194,35 +222,35 @@ def compare_rows(row_1, row_2): assert compare_rows(result.result_set[4], ['hij', 1, 0]) -def test_non_latin_query(test_client: Client): - result = test_client.query("SELECT database, name FROM system.tables WHERE engine_full IN ('空')") +def test_non_latin_query(param_client, call): + result = call(param_client.query, "SELECT database, name FROM system.tables WHERE engine_full IN ('空')") assert len(result.result_set) == 0 -def test_error_decode(test_client: Client): +def test_error_decode(param_client, call): try: - test_client.query("SELECT database, name FROM system.tables WHERE has_own_data = '空'") + call(param_client.query, "SELECT database, name FROM system.tables WHERE has_own_data = '空'") except DatabaseError as ex: assert '空' in str(ex) -def test_command_as_query(test_client: Client): +def test_command_as_query(param_client, call): # Test that non-SELECT and non-INSERT statements are treated as commands and # just return the QueryResult metadata - result = test_client.query("SET count_distinct_implementation = 'uniq'") + result = call(param_client.query, "SET count_distinct_implementation = 'uniq'") assert 'query_id' in result.first_item -def test_show_create(test_client: Client): - if not test_client.min_version('21'): - pytest.skip(f'Not supported server version {test_client.server_version}') - result = test_client.query('SHOW CREATE TABLE system.tables') +def test_show_create(param_client, call): + if not param_client.min_version('21'): + pytest.skip(f'Not supported server version {param_client.server_version}') + result = call(param_client.query, 'SHOW CREATE TABLE system.tables') result.close() assert 'statement' in result.column_names -def test_empty_result(test_client: Client): - assert len(test_client.query("SELECT * FROM system.tables WHERE name = '_NOT_A THING'").result_rows) == 0 +def test_empty_result(param_client, call): + assert len(call(param_client.query, "SELECT * FROM system.tables WHERE name = '_NOT_A THING'").result_rows) == 0 def test_temporary_tables(test_client: Client, test_config: TestConfig): @@ -250,29 +278,29 @@ def test_temporary_tables(test_client: Client, test_config: TestConfig): test_client.command('DROP TABLE IF EXISTS temp_test_table', settings=session_settings) -def test_str_as_bytes(test_client: Client, table_context: Callable): +def test_str_as_bytes(param_client, call, table_context: Callable): with table_context('test_insert_bytes', ['key UInt32', 'byte_str String', 'n_byte_str Nullable(String)']): - test_client.insert('test_insert_bytes', [[0, 'str_0', 'n_str_0'], [1, 'str_1', 'n_str_0']]) - test_client.insert('test_insert_bytes', [[2, 'str_2'.encode('ascii'), 'n_str_2'.encode()], + call(param_client.insert, 'test_insert_bytes', [[0, 'str_0', 'n_str_0'], [1, 'str_1', 'n_str_0']]) + call(param_client.insert, 'test_insert_bytes', [[2, 'str_2'.encode('ascii'), 'n_str_2'.encode()], [3, b'str_3', b'str_3'], [4, bytearray([5, 120, 24]), bytes([16, 48, 52])], [5, b'', None] ]) - result_set = test_client.query('SELECT * FROM test_insert_bytes ORDER BY key').result_columns + result_set = call(param_client.query, 'SELECT * FROM test_insert_bytes ORDER BY key').result_columns assert result_set[1][0] == 'str_0' assert result_set[1][3] == 'str_3' assert result_set[2][5] is None assert result_set[1][4].encode() == b'\x05\x78\x18' - result_set = test_client.query('SELECT * FROM test_insert_bytes ORDER BY key', + result_set = call(param_client.query, 'SELECT * FROM test_insert_bytes ORDER BY key', query_formats={'String': 'bytes'}).result_columns assert result_set[1][0] == b'str_0' assert result_set[1][4] == b'\x05\x78\x18' assert result_set[2][4] == b'\x10\x30\x34' -def test_embedded_binary(test_client: Client): +def test_embedded_binary(param_client, call): binary_params = {'$xx$': 'col1,col2\n100,700'.encode()} - result = test_client.raw_query( + result = call(param_client.raw_query, 'SELECT col2, col1 FROM format(CSVWithNames, $xx$)', parameters=binary_params) assert result == b'700\t100\n' @@ -280,83 +308,45 @@ def test_embedded_binary(test_client: Client): with open(movies_file, 'rb') as f: # read bytes data = f.read() binary_params = {'$parquet$': data} - result = test_client.query( + result = call(param_client.query, 'SELECT movie, rating FROM format(Parquet, $parquet$) ORDER BY movie', parameters=binary_params) assert result.first_item['movie'] == '12 Angry Men' binary_params = {'$mult$': 'foobar'.encode()} - result = test_client.query("SELECT $mult$ as m1, $mult$ as m2 WHERE m1 = 'foobar'", parameters=binary_params) + result = call(param_client.query, "SELECT $mult$ as m1, $mult$ as m2 WHERE m1 = 'foobar'", parameters=binary_params) assert result.first_item['m2'] == 'foobar' -def test_column_rename_setting_none(test_config: TestConfig): +def test_column_rename_setting_none(client_factory, call): sql = "SELECT 1 as `a.b.c d_e`" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory() + names = call(client.query, sql).column_names assert names[0] == "a.b.c d_e" -def test_column_rename_limit_0_path(test_config: TestConfig): +def test_column_rename_limit_0_path(client_factory, call): + """Test column renaming with LIMIT 0 query (no data returned).""" sql = "SELECT 1 as `a.b.c d_e` LIMIT 0" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="to_camelcase_without_prefix", - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory(rename_response_column="to_camelcase_without_prefix") + names = call(client.query, sql).column_names assert names[0] == "cDE" -def test_column_rename_data_path(test_config: TestConfig): +def test_column_rename_data_path(client_factory, call): + """Test column renaming with data returned.""" sql = "SELECT 1 as `a.b.c d_e`" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="to_camelcase_without_prefix", - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory(rename_response_column="to_camelcase_without_prefix") + names = call(client.query, sql).column_names assert names[0] == "cDE" -def test_column_rename_with_bad_option(test_config: TestConfig): - session_id = "TEST_SESSION_ID_" + test_config.test_database - +def test_column_rename_with_bad_option(client_factory): + """Test that invalid rename option raises ValueError.""" with pytest.raises(ValueError, match="Invalid option"): - create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="not_an_option", - ) + client_factory(rename_response_column="not_an_option") -def test_role_setting_works(test_client: Client, test_config: TestConfig): +def test_role_setting_works(param_client: Client, test_config: TestConfig, client_factory: Callable, call): if test_config.cloud: pytest.skip("Skipping role test in cloud mode - cannot create custom users") @@ -364,82 +354,73 @@ def test_role_setting_works(test_client: Client, test_config: TestConfig): user_limited = 'limit_rows_user' user_password = 'R7m!pZt9qL#x' - test_client.command(f'CREATE ROLE IF NOT EXISTS {role_limited}') - test_client.command(f'CREATE USER IF NOT EXISTS {user_limited} IDENTIFIED BY \'{user_password}\'') - test_client.command(f'GRANT SELECT ON system.numbers TO {user_limited}') - test_client.command(f'GRANT {role_limited} TO {user_limited}') - test_client.command(f'SET DEFAULT ROLE NONE TO {user_limited}') + call(param_client.command, f'CREATE ROLE IF NOT EXISTS {role_limited}') + call(param_client.command, f'CREATE USER IF NOT EXISTS {user_limited} IDENTIFIED BY \'{user_password}\'') + call(param_client.command, f'GRANT SELECT ON system.numbers TO {user_limited}') + call(param_client.command, f'GRANT {role_limited} TO {user_limited}') + call(param_client.command, f'SET DEFAULT ROLE NONE TO {user_limited}') - client = create_client( - host=test_client.server_host_name, - port=test_client.url.rsplit(':', 1)[-1].split('/')[0], + client = client_factory( + host=test_config.host, + port=test_config.port, username=user_limited, password=user_password, ) # the default should not have the role - res = client.query('SELECT currentRoles()') + res = call(client.query, 'SELECT currentRoles()') assert res.result_rows == [([],)] # passing it as a per-query setting should work - res = client.query('SELECT currentRoles()', settings={'role': role_limited}) + res = call(client.query, 'SELECT currentRoles()', settings={'role': role_limited}) assert res.result_rows == [([role_limited],)] # passing it as a per-client setting should work - role_client = create_client( - host=test_client.server_host_name, - port=test_client.url.rsplit(':', 1)[-1].split('/')[0], + role_client = client_factory( + host=test_config.host, + port=test_config.port, username=user_limited, password=user_password, settings={'role': role_limited}, ) - res = role_client.query('SELECT currentRoles()') + res = call(role_client.query, 'SELECT currentRoles()') assert res.result_rows == [([role_limited],)] -def test_query_id_autogeneration(test_client: Client, test_table_engine: str): +def test_query_id_autogeneration(param_client: Client, test_table_engine: str, call): """Test that query_id is auto-generated for query(), command(), and insert() methods""" - result = test_client.query("SELECT 1") + result = call(param_client.query, "SELECT 1") assert _is_valid_uuid_v4(result.query_id) - summary = test_client.command("DROP TABLE IF EXISTS test_query_id_nonexistent") + summary = call(param_client.command, "DROP TABLE IF EXISTS test_query_id_nonexistent") assert _is_valid_uuid_v4(summary.query_id()) - test_client.command("DROP TABLE IF EXISTS test_query_id_insert") - test_client.command(f"CREATE TABLE test_query_id_insert (id UInt32) ENGINE {test_table_engine} ORDER BY id") - summary = test_client.insert("test_query_id_insert", [[1], [2], [3]], column_names=["id"]) + call(param_client.command, "DROP TABLE IF EXISTS test_query_id_insert") + call(param_client.command, f"CREATE TABLE test_query_id_insert (id UInt32) ENGINE {test_table_engine} ORDER BY id") + summary = call(param_client.insert, "test_query_id_insert", [[1], [2], [3]], column_names=["id"]) assert _is_valid_uuid_v4(summary.query_id()) - test_client.command("DROP TABLE test_query_id_insert") + call(param_client.command, "DROP TABLE test_query_id_insert") -def test_query_id_manual_override(test_client: Client): +def test_query_id_manual_override(param_client: Client, call): """Test that manually specified query_id is respected and not overwritten""" manual_query_id = "test_manual_query_id_override" - result = test_client.query("SELECT 1", settings={"query_id": manual_query_id}) + result = call(param_client.query, "SELECT 1", settings={"query_id": manual_query_id}) assert result.query_id == manual_query_id # pylint: disable=protected-access -def test_query_id_disabled(test_config: TestConfig): +def test_query_id_disabled(client_factory, call): """Test that autogenerate_query_id=False works correctly""" - client_no_autogen = create_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - autogenerate_query_id=False, - ) - + client_no_autogen = client_factory(autogenerate_query_id=False) assert client_no_autogen._autogenerate_query_id is False # Even with autogen disabled, server generates a query_id - result = client_no_autogen.query("SELECT 1") + result = call(client_no_autogen.query, "SELECT 1") assert _is_valid_uuid_v4(result.query_id) - client_no_autogen.close() - -def test_query_id_in_query_logs(test_client: Client, test_config: TestConfig): +def test_query_id_in_query_logs(param_client: Client, test_config: TestConfig, call): """Test that query_id appears in ClickHouse's system.query_log for observability""" if test_config.cloud: pytest.skip("Skipping query_log test in cloud environment") @@ -447,7 +428,8 @@ def test_query_id_in_query_logs(test_client: Client, test_config: TestConfig): def check_in_logs(test_query_id): max_retries = 30 for _ in range(max_retries): - log_result = test_client.query( + log_result = call( + param_client.query, "SELECT query_id FROM system.query_log WHERE query_id = {query_id:String} AND event_time > now() - 30 LIMIT 1", parameters={"query_id": test_query_id} ) @@ -462,11 +444,56 @@ def check_in_logs(test_query_id): pytest.fail(f"query_id '{test_query_id}' did not appear in system.query_log after {max_retries * 0.1}s") # Manual override check - test_query_id_manual = "test_query_id_in_logs" - test_client.query("SELECT 1 as num", settings={"query_id": test_query_id_manual}) + test_query_id_manual = f"test_query_id_in_logs_{uuid.uuid4()}" + call(param_client.query, "SELECT 1 as num", settings={"query_id": test_query_id_manual}) check_in_logs(test_query_id_manual) # Autogen check - result = test_client.query("SELECT 2 as num") + result = call(param_client.query, "SELECT 2 as num") test_query_id_auto = result.query_id check_in_logs(test_query_id_auto) + + +def test_compression_enabled(client_factory, call, table_context): + """Test that compression works when enabled.""" + client = client_factory(compress=True) + + assert client.compression is not None + assert client.write_compression is not None + + with table_context("test_compression", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}"] for i in range(100)] + call(client.insert, "test_compression", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_compression") + assert result.result_rows[0][0] == 100 + + +def test_compression_disabled(client_factory, call, table_context): + """Test that compression can be explicitly disabled.""" + client = client_factory(compress=False) + + assert client.compression is None + assert client.write_compression is None + + with table_context("test_no_compression", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}"] for i in range(100)] + call(client.insert, "test_no_compression", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_no_compression") + assert result.result_rows[0][0] == 100 + + +def test_compression_gzip(client_factory, call, table_context): + """Test that gzip compression works.""" + client = client_factory(compress="gzip") + + assert client.compression == "gzip" + assert client.write_compression == "gzip" + + with table_context("test_gzip", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}" * 10] for i in range(50)] + call(client.insert, "test_gzip", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_gzip") + assert result.result_rows[0][0] == 50 diff --git a/tests/integration_tests/test_contexts.py b/tests/integration_tests/test_contexts.py index cf7048ca..27600860 100644 --- a/tests/integration_tests/test_contexts.py +++ b/tests/integration_tests/test_contexts.py @@ -3,27 +3,27 @@ from clickhouse_connect.driver import Client -def test_contexts(test_client: Client, table_context: Callable): +def test_contexts(param_client: Client, call, table_context: Callable): with table_context('test_contexts', ['key Int32', 'value1 String', 'value2 String']) as ctx: data = [[1, 'v1', 'v2'], [2, 'v3', 'v4']] - insert_context = test_client.create_insert_context(table=ctx.table, data=data) - test_client.insert(context=insert_context) - query_context = test_client.create_query_context( + insert_context = call(param_client.create_insert_context, table=ctx.table, data=data) + call(param_client.insert, context=insert_context) + query_context = param_client.create_query_context( query=f'SELECT value1, value2 FROM {ctx.table} WHERE key = {{k:Int32}}', parameters={'k': 2}, column_oriented=True) - result = test_client.query(context=query_context) + result = call(param_client.query, context=query_context) assert result.result_set[1][0] == 'v4' query_context.set_parameter('k', 1) - result = test_client.query(context=query_context) + result = call(param_client.query, context=query_context) assert result.row_count == 1 assert result.result_set[1][0] data = [[1, 'v5', 'v6'], [2, 'v7', 'v8']] - test_client.insert(data=data, context=insert_context) - result = test_client.query(context=query_context) + call(param_client.insert, data=data, context=insert_context) + result = call(param_client.query, context=query_context) assert result.row_count == 2 insert_context.data = [[5, 'v5', 'v6'], [7, 'v7', 'v8']] - test_client.insert(context=insert_context) - assert test_client.command(f'SELECT count() FROM {ctx.table}') == 6 + call(param_client.insert, context=insert_context) + assert call(param_client.command, f'SELECT count() FROM {ctx.table}') == 6 diff --git a/tests/integration_tests/test_dynamic.py b/tests/integration_tests/test_dynamic.py index 513b5a14..b0f1e147 100644 --- a/tests/integration_tests/test_dynamic.py +++ b/tests/integration_tests/test_dynamic.py @@ -11,19 +11,19 @@ from tests.integration_tests.conftest import TestConfig -def type_available(test_client: Client, data_type: str): - if test_client.get_client_setting(f'allow_experimental_{data_type}_type') is None: +def type_available(param_client: Client, data_type: str): + if param_client.get_client_setting(f'allow_experimental_{data_type}_type') is None: return - setting_def = test_client.server_settings.get(f'allow_experimental_{data_type}_value', None) + setting_def = param_client.server_settings.get(f'allow_experimental_{data_type}_value', None) if setting_def is not None and setting_def.value == '1': return - pytest.skip(f'New {data_type.upper()} type not available in this version: {test_client.server_version}') + pytest.skip(f'New {data_type.upper()} type not available in this version: {param_client.server_version}') -def test_variant(test_client: Client, table_context: Callable): +def test_variant(param_client: Client, call, table_context: Callable): pytest.skip('Variant string inserts broken') - type_available(test_client, 'variant') + type_available(param_client, 'variant') with table_context('basic_variants', [ 'key Int32', 'v1 Variant(UInt64, String, Array(UInt64), UUID)', @@ -33,17 +33,17 @@ def test_variant(test_client: Client, table_context: Callable): [3, 'bef56f14-0870-4f82-a35e-9a47eff45a5b', 777.25], [4, [120, 250], 88.2] ] - test_client.insert('basic_variants', data) - result = test_client.query('SELECT * FROM basic_variants ORDER BY key').result_set + call(param_client.insert, 'basic_variants', data) + result = call(param_client.query, 'SELECT * FROM basic_variants ORDER BY key').result_set assert result[2][1] == UUID('bef56f14-0870-4f82-a35e-9a47eff45a5b') assert result[2][2] == 777.25 assert result[3][1] == [120, 250] assert result[3][2] == IPv4Address('243.12.55.44') -def test_nested_variant(test_client: Client, table_context: Callable): +def test_nested_variant(param_client: Client, call, table_context: Callable): pytest.skip('Variant string inserts broken') - type_available(test_client, 'variant') + type_available(param_client, 'variant') with table_context('nested_variants', [ 'key Int32', 'm1 Map(String, Variant(String, UInt128, Bool))', @@ -61,8 +61,8 @@ def test_nested_variant(test_client: Client, table_context: Callable): (), ] ] - test_client.insert('nested_variants', data) - result = test_client.query('SELECT * FROM nested_variants ORDER BY key').result_set + call(param_client.insert, 'nested_variants', data) + result = call(param_client.query, 'SELECT * FROM nested_variants ORDER BY key').result_set assert result[0][1]['k1'] == 'string1' assert result[0][1]['k2'] == 34782477743 assert result[0][2] == (-40, True) @@ -70,19 +70,19 @@ def test_nested_variant(test_client: Client, table_context: Callable): assert result[1][1]['k3'] == 100 -def test_dynamic_nested(test_client: Client, table_context: Callable): - type_available(test_client, 'dynamic') +def test_dynamic_nested(param_client: Client, call, table_context: Callable): + type_available(param_client, 'dynamic') with table_context('nested_dynamics', [ 'm2 Map(String, Dynamic)' ], order_by='()'): data = [({'k4': 'string8', 'k5': 5000},)] - test_client.insert('nested_dynamics', data) - result = test_client.query('SELECT * FROM nested_dynamics').result_set + call(param_client.insert, 'nested_dynamics', data) + result = call(param_client.query, 'SELECT * FROM nested_dynamics').result_set assert result[0][0]['k5'] == '5000' -def test_dynamic(test_client: Client, table_context: Callable): - type_available(test_client, 'dynamic') +def test_dynamic(param_client: Client, call, table_context: Callable): + type_available(param_client, 'dynamic') with table_context('basic_dynamic', [ 'key UInt64', 'v1 Dynamic', @@ -92,15 +92,15 @@ def test_dynamic(test_client: Client, table_context: Callable): [2, 'a string', 55.2], [4, [120, 250], 577.22] ] - test_client.insert('basic_dynamic', data) - result = test_client.query('SELECT * FROM basic_dynamic ORDER BY key').result_set + call(param_client.insert, 'basic_dynamic', data) + result = call(param_client.query, 'SELECT * FROM basic_dynamic ORDER BY key').result_set assert result[2][1] == 'bef56f14-0870-4f82-a35e-9a47eff45a5b' assert result[3][1] == '[120, 250]' assert result[2][2] == '777.25' -def test_basic_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') +def test_basic_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') with table_context('new_json_basic', [ 'key Int32', 'value JSON', @@ -110,12 +110,12 @@ def test_basic_json(test_client: Client, table_context: Callable): jv1 = {'key1': 337, 'value.2': 'vvvv', 'HKD@spéçiäl': 'Special K', 'blank': 'not_really_blank'} njv2 = {'nk1': -302, 'nk2': {'sub1': 372, 'sub2': 'a string'}} njv3 = {'nk1': 5832.44, 'nk2': {'sub1': 47788382, 'sub2': 'sub2val', 'sub3': 'sub3str', 'space key': 'spacey'}} - test_client.insert('new_json_basic', [ + call(param_client.insert, 'new_json_basic', [ [5, jv1, None], [20, None, njv2], [25, jv3, njv3]]) - result = test_client.query('SELECT * FROM new_json_basic ORDER BY key').result_set + result = call(param_client.query, 'SELECT * FROM new_json_basic ORDER BY key').result_set json1 = result[0][1] assert json1['HKD@spéçiäl'] == 'Special K' assert 'key3' not in json1 @@ -132,23 +132,25 @@ def test_basic_json(test_client: Client, table_context: Callable): assert null_json3['nk2']['space key'] == 'spacey' set_write_format('JSON', 'string') - test_client.insert('new_json_basic', [[999, '{"key4": 283, "value.2": "str_value"}', '{"nk1":53}']]) - result = test_client.query('SELECT value.key4, null_value.nk1 FROM new_json_basic ORDER BY key').result_set + call(param_client.insert, 'new_json_basic', [[999, '{"key4": 283, "value.2": "str_value"}', '{"nk1":53}']]) + result = call(param_client.query, 'SELECT value.key4, null_value.nk1 FROM new_json_basic ORDER BY key').result_set assert result[3][0] == 283 assert result[3][1] == 53 -def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callable): - type_available(test_client, "json") - if test_client.server_settings.get("json_type_escape_dots_in_keys") is None: +def test_json_escaped_dots_roundtrip(param_client: Client, call, table_context: Callable): + type_available(param_client, "json") + if param_client.server_settings.get("json_type_escape_dots_in_keys") is None: pytest.skip("json_type_escape_dots_in_keys setting unavailable on this server version") # with escaping enabled dots are preserved in keys - test_client.command("SET json_type_escape_dots_in_keys=1") + if not param_client.get_client_setting("session_id"): + param_client.set_client_setting("session_id", str(UUID(int=0))) + call(param_client.command, "SET json_type_escape_dots_in_keys=1") with table_context("json_dots_escape", ["value JSON"], order_by="()"): payload = {"a.b": 123, "c": {"d.e": 456}} - test_client.insert("json_dots_escape", [[payload]]) - result = test_client.query("SELECT value FROM json_dots_escape").result_set + call(param_client.insert, "json_dots_escape", [[payload]]) + result = call(param_client.query, "SELECT value FROM json_dots_escape").result_set returned = result[0][0] assert "a.b" in returned @@ -158,11 +160,11 @@ def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callabl assert returned["c"]["d.e"] == 456 # with escaping disabled dots create nested structure - test_client.command("SET json_type_escape_dots_in_keys=0") + call(param_client.command, "SET json_type_escape_dots_in_keys=0") with table_context("json_dots_no_escape", ["value JSON"], order_by="()"): payload = {"a.b": 789} - test_client.insert("json_dots_no_escape", [[payload]]) - result = test_client.query("SELECT value FROM json_dots_no_escape").result_set + call(param_client.insert, "json_dots_no_escape", [[payload]]) + result = call(param_client.query, "SELECT value FROM json_dots_no_escape").result_set returned = result[0][0] assert "a" in returned @@ -170,22 +172,22 @@ def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callabl assert returned["a"]["b"] == 789 -def test_typed_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') +def test_typed_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') with table_context('new_json_typed', [ 'key Int32', 'value JSON(max_dynamic_paths=150, `a.b` DateTime64(3), SKIP a.c)' ]): v1 = '{"a":{"b":"2020-10-15T10:15:44.877", "c":"skip_me"}}' - test_client.insert('new_json_typed', [[1, v1]]) - result = test_client.query('SELECT * FROM new_json_typed ORDER BY key') + call(param_client.insert, 'new_json_typed', [[1, v1]]) + result = call(param_client.query, 'SELECT * FROM new_json_typed ORDER BY key') json1 = result.result_set[0][1] assert json1['a']['b'] == datetime.datetime(2020, 10, 15, 10, 15, 44, 877000) -def test_nullable_json(test_client: Client, table_context: Callable): - if not test_client.min_version('25.2'): - pytest.skip(f'Nullable(JSON) type not available in this version: {test_client.server_version}') +def test_nullable_json(param_client: Client, call, table_context: Callable): + if not param_client.min_version('25.2'): + pytest.skip(f'Nullable(JSON) type not available in this version: {param_client.server_version}') with table_context("nullable_json", [ "key Int32", "value_1 Nullable(JSON)", @@ -194,8 +196,8 @@ def test_nullable_json(test_client: Client, table_context: Callable): ]): v1 = {"item_a": 5, "item_b": 10} - test_client.insert("nullable_json", [[1, v1, json.dumps(v1), None], [2, v1, None, None]]) - result = test_client.query('SELECT * FROM nullable_json ORDER BY key') + call(param_client.insert, "nullable_json", [[1, v1, json.dumps(v1), None], [2, v1, None, None]]) + result = call(param_client.query, 'SELECT * FROM nullable_json ORDER BY key') assert result.result_set[0][1] == v1 assert result.result_set[1][1] == v1 assert result.result_set[0][2] == v1 @@ -204,26 +206,26 @@ def test_nullable_json(test_client: Client, table_context: Callable): assert result.result_set[1][3] is None -def test_complex_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') - if not test_client.min_version('24.10'): +def test_complex_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') + if not param_client.min_version('24.10'): pytest.skip('Complex JSON broken before 24.10') with table_context('new_json_complex', [ 'key Int32', 'value Tuple(t JSON)' ]): data = [[100, ({'a': 'qwe123', 'b': 'main', 'c': None},)]] - test_client.insert('new_json_complex', data) - result = test_client.query('SELECT * FROM new_json_complex ORDER BY key') + call(param_client.insert, 'new_json_complex', data) + result = call(param_client.query, 'SELECT * FROM new_json_complex ORDER BY key') json1 = result.result_set[0][1] assert json1['t']['a'] == 'qwe123' -def test_json_str_time(test_client: Client, test_config: TestConfig): +def test_json_str_time(param_client: Client, call, test_config: TestConfig): - if not test_client.min_version('25.1') or test_config.cloud: + if not param_client.min_version('25.1') or test_config.cloud: pytest.skip('JSON string/numbers bug before 25.1, skipping') - result = test_client.query("SELECT '{\"timerange\": \"2025-01-01T00:00:00+0000\"}'::JSON").result_set + result = call(param_client.query, "SELECT '{\"timerange\": \"2025-01-01T00:00:00+0000\"}'::JSON").result_set assert result[0][0]['timerange'] == datetime.datetime(2025, 1, 1) # The following query is broken -- looks like something to do with Nullable(String) in the Tuple diff --git a/tests/integration_tests/test_error_handling.py b/tests/integration_tests/test_error_handling.py index 262e7fa1..4f004d21 100644 --- a/tests/integration_tests/test_error_handling.py +++ b/tests/integration_tests/test_error_handling.py @@ -1,95 +1,67 @@ import logging import pytest -from clickhouse_connect import create_client from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError from tests.integration_tests.conftest import TestConfig -# pylint: disable=attribute-defined-outside-init - -class TestErrorHandling: - """Tests for error handling in the ClickHouse Connect client""" - - @pytest.fixture(autouse=True) - def setup(self, test_config: TestConfig): - self.config = test_config - - def test_wrong_port_error_message(self): - """ - Test that connecting to the wrong port properly propagates - the error message from ClickHouse. - """ - if self.config.cloud: - pytest.skip("Skipping wrong port test in cloud environ.") - wrong_port = 9000 - - with pytest.raises((DatabaseError, OperationalError)) as excinfo: - create_client( - host=self.config.host, - port=wrong_port, - username=self.config.username, - password=self.config.password, - ) +def test_wrong_port_error_message(client_factory, test_config: TestConfig): + """ + Test that connecting to the wrong port properly propagates + the error message from ClickHouse. + """ + if test_config.cloud: + pytest.skip("Skipping wrong port test in cloud environ.") + wrong_port = 9000 + + with pytest.raises((DatabaseError, OperationalError)) as excinfo: + client_factory(port=wrong_port) + + error_message = str(excinfo.value) + assert ( + f"Port {wrong_port} is for clickhouse-client program" in error_message + or "You must use port 8123 for HTTP" in error_message + ) + +def test_connection_refused_error(client_factory, test_config: TestConfig, caplog): + """ + Test that connecting to a port where nothing is listening + produces a clear error message. + """ + if test_config.cloud: + pytest.skip("Skipping connection refused test in cloud environ.") + # Suppress urllib3 and aiohttp connection warnings + urllib3_logger = logging.getLogger("urllib3.connectionpool") + original_urllib3_level = urllib3_logger.level + urllib3_logger.setLevel(logging.CRITICAL) + + # Swallow logging messages to prevent polluting pytest output + caplog.set_level(logging.CRITICAL) + + try: + # Use a port that shouldn't have anything listening + unused_port = 45678 + + # Try connecting to an unused port - should fail with connection refused + with pytest.raises(OperationalError) as excinfo: + client_factory(port=unused_port) error_message = str(excinfo.value) assert ( - f"Port {wrong_port} is for clickhouse-client program" in error_message - or "You must use port 8123 for HTTP" in error_message + "Connection refused" in error_message + or "Failed to establish a new connection" in error_message + or "Cannot connect to host" in error_message ) + finally: + # Restore the original logging level + urllib3_logger.setLevel(original_urllib3_level) - def test_connection_refused_error(self, caplog): - """ - Test that connecting to a port where nothing is listening - produces a clear error message. - """ - if self.config.cloud: - pytest.skip("Skipping connection refused test in cloud environ.") - # Suppress urllib3 connection pool warnings - urllib3_logger = logging.getLogger("urllib3.connectionpool") - original_level = urllib3_logger.level - urllib3_logger.setLevel(logging.CRITICAL) - - # Swallow logging messages to prevent polluting pytest output - caplog.set_level(logging.CRITICAL) - - try: - # Use a port that shouldn't have anything listening - unused_port = 45678 - - # Try connecting to an unused port - should fail with connection refused - with pytest.raises(OperationalError) as excinfo: - create_client( - host=self.config.host, - port=unused_port, - username=self.config.username, - password=self.config.password, - ) - - error_message = str(excinfo.value) - assert ( - "Connection refused" in error_message - or "Failed to establish a new connection" in error_message - ) - finally: - # Restore the original logging level - urllib3_logger.setLevel(original_level) - - def test_successful_connection(self): - """ - Verify that connecting to the correct port works properly. - This serves as a sanity check that the test environment is configured correctly. - """ - # Connect to the correct HTTP port - client = create_client( - host=self.config.host, - port=self.config.port, # Use the port from test config - username=self.config.username, - password=self.config.password, - ) - # Simple query to verify connection works - result = client.command("SELECT 1") - assert result == 1 +def test_successful_connection(client_factory, call): + """Verify that connecting to the correct port works properly.""" + # Connect to the correct HTTP port (uses defaults from test_config) + client = client_factory() - client.close() + # Simple query to verify connection works + result = call(client.command, "SELECT 1") + assert result == 1 diff --git a/tests/integration_tests/test_external_data.py b/tests/integration_tests/test_external_data.py index 4775458a..cf24396e 100644 --- a/tests/integration_tests/test_external_data.py +++ b/tests/integration_tests/test_external_data.py @@ -2,7 +2,6 @@ import pytest -from clickhouse_connect import get_client from clickhouse_connect.driver import Client from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.options import arrow @@ -11,40 +10,40 @@ ext_settings = {'input_format_allow_errors_num': 10, 'input_format_allow_errors_ratio': .2} -def test_external_simple(test_client: Client): +def test_external_simple(param_client: Client, call): data_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(data_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) - result = test_client.query('SELECT * FROM movies ORDER BY movie', + result = call(param_client.query, 'SELECT * FROM movies ORDER BY movie', external_data=data, settings=ext_settings).result_rows assert result[0][0] == '12 Angry Men' -def test_external_arrow(test_client: Client): +def test_external_arrow(param_client: Client, call): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') data_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(data_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) - result = test_client.query_arrow('SELECT * FROM movies ORDER BY movie', + result = call(param_client.query_arrow, 'SELECT * FROM movies ORDER BY movie', external_data=data, settings=ext_settings) assert str(result[0][0]) == '12 Angry Men' -def test_external_multiple(test_client: Client): +def test_external_multiple(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(movies_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) actors_file = f'{Path(__file__).parent}/actors.csv' data.add_file(actors_file, fmt='CSV', types='String,UInt16,String') - result = test_client.query('SELECT * FROM actors;', external_data=data, settings={ + result = call(param_client.query, 'SELECT * FROM actors;', external_data=data, settings={ 'input_format_allow_errors_num': 10, 'input_format_allow_errors_ratio': .2}).result_rows assert result[1][1] == 1940 - result = test_client.query( + result = call(param_client.query, 'SELECT _1, movie FROM actors INNER JOIN movies ON actors._3 = movies.movie AND actors._2 = 1940', external_data=data, settings=ext_settings).result_rows @@ -52,84 +51,77 @@ def test_external_multiple(test_client: Client): assert result[0][1] == 'Scarface' -def test_external_parquet(test_config: TestConfig, test_client: Client): +def test_external_parquet(test_config: TestConfig, param_client: Client, call): if test_config.cloud: pytest.skip('External data join not working in SMT, skipping') movies_file = f'{Path(__file__).parent}/movies.parquet' - test_client.command('DROP TABLE IF EXISTS movies') - test_client.command('DROP TABLE IF EXISTS num') - test_client.command(""" + call(param_client.command, 'DROP TABLE IF EXISTS movies') + call(param_client.command, 'DROP TABLE IF EXISTS num') + call(param_client.command, """ CREATE TABLE IF NOT EXISTS num (number UInt64, t String) ENGINE = MergeTree ORDER BY number""") - test_client.command(""" + call(param_client.command, """ INSERT INTO num SELECT number, concat(toString(number), 'x') as t FROM numbers(2500) WHERE (number > 1950) AND (number < 2025) """) data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.query( + result = call(param_client.query, "SELECT * FROM movies INNER JOIN num ON movies.year = number AND t = '2000x' ORDER BY movie", settings={'output_format_parquet_string_as_string': 1}, external_data=data).result_rows assert len(result) == 5 assert result[2][0] == 'Memento' - test_client.command('DROP TABLE num') + call(param_client.command, 'DROP TABLE num') -def test_external_binary(test_client: Client): - actors = 'Robert Redford\t1936\tThe Sting\nAl Pacino\t1940\tScarface'.encode() +def test_external_binary(param_client: Client, call): + actors = b'Robert Redford\t1936\tThe Sting\nAl Pacino\t1940\tScarface' data = ExternalData(file_name='actors.csv', data=actors, structure='name String, birth_year UInt16, movie String') - result = test_client.query('SELECT * FROM actors ORDER BY birth_year DESC', external_data=data).result_rows + result = call(param_client.query, 'SELECT * FROM actors ORDER BY birth_year DESC', external_data=data).result_rows assert len(result) == 2 assert result[1][2] == 'The Sting' -def test_external_empty_binary(test_client: Client): +def test_external_empty_binary(param_client: Client, call): data = ExternalData(file_name='empty.csv', data=b'', structure='name String') - result = test_client.query('SELECT * FROM empty', external_data=data).result_rows + result = call(param_client.query, 'SELECT * FROM empty', external_data=data).result_rows assert len(result) == 0 -def test_external_raw(test_client: Client): +def test_external_raw(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.parquet' data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.raw_query('SELECT avg(rating) FROM movies', external_data=data) + result = call(param_client.raw_query, 'SELECT avg(rating) FROM movies', external_data=data) assert '8.25' == result.decode()[0:4] -def test_external_command(test_client: Client): +def test_external_command(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.parquet' data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.command('SELECT avg(rating) FROM movies', external_data=data) + result = call(param_client.command, 'SELECT avg(rating) FROM movies', external_data=data) assert '8.25' == result[0:4] - test_client.command('DROP TABLE IF EXISTS movies_ext') - if test_client.min_version('22.8'): - query_result = test_client.query('CREATE TABLE movies_ext ENGINE MergeTree() ORDER BY tuple() EMPTY ' + + call(param_client.command, 'DROP TABLE IF EXISTS movies_ext') + if param_client.min_version('22.8'): + query_result = call(param_client.query, 'CREATE TABLE movies_ext ENGINE MergeTree() ORDER BY tuple() EMPTY ' + 'AS SELECT * FROM movies', external_data=data) assert 'query_id' in query_result.first_item - test_client.raw_query('INSERT INTO movies_ext SELECT * FROM movies', external_data=data) - assert 250 == test_client.command('SELECT COUNT() FROM movies_ext') - - -def test_external_with_form_encode(test_config: TestConfig): - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + call(param_client.raw_query, 'INSERT INTO movies_ext SELECT * FROM movies', external_data=data) + assert 250 == call(param_client.command, 'SELECT COUNT() FROM movies_ext') + + +def test_external_with_form_encode(client_factory, call): + form_client = client_factory(form_encode_query_params=True) movies_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(movies_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) # Test with parameters in the query - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM movies WHERE year > {year:UInt16} ORDER BY rating DESC LIMIT {limit:UInt32}', parameters={'year': 1990, 'limit': 5}, external_data=data, @@ -144,7 +136,7 @@ def test_external_with_form_encode(test_config: TestConfig): assert ratings == sorted(ratings, reverse=True) # Test raw query with external data and form encoding - raw_result = form_client.raw_query( + raw_result = call(form_client.raw_query, 'SELECT COUNT() FROM movies WHERE rating > {min_rating:Decimal32(3)}', parameters={'min_rating': 8.0}, external_data=data, diff --git a/tests/integration_tests/test_form_encode_query.py b/tests/integration_tests/test_form_encode_query.py index 907c1797..05a881d1 100644 --- a/tests/integration_tests/test_form_encode_query.py +++ b/tests/integration_tests/test_form_encode_query.py @@ -1,13 +1,12 @@ from typing import Callable -from clickhouse_connect import get_client from clickhouse_connect.driver import Client from tests.integration_tests.conftest import TestConfig -def test_form_encode_query_basic(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_query_basic(client_factory, call, test_config: TestConfig, table_context: Callable): """Test that form_encode_query sends parameters as form data""" - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -17,19 +16,19 @@ def test_form_encode_query_basic(test_client: Client, test_config: TestConfig, t ) with table_context('test_form_encode', ['id UInt32', 'name String', 'value Float64']): - test_client.insert('test_form_encode', + call(form_client.insert, 'test_form_encode', [[1, 'test1', 10.5], [2, 'test2', 20.3], [3, 'test3', 30.7]]) - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_encode WHERE id = {id:UInt32}', parameters={'id': 2} ) assert result.row_count == 1 assert result.first_row[1] == 'test2' - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_encode WHERE name = {name:String} AND value > {val:Float64}', parameters={'name': 'test3', 'val': 25.0} ) @@ -37,9 +36,9 @@ def test_form_encode_query_basic(test_client: Client, test_config: TestConfig, t assert result.first_row[0] == 3 -def test_form_encode_with_arrays(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_with_arrays(client_factory, call, test_config: TestConfig, table_context: Callable): """Test form_encode_query with array parameters""" - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -49,19 +48,19 @@ def test_form_encode_with_arrays(test_client: Client, test_config: TestConfig, t ) with table_context('test_form_arrays', ['id UInt32', 'tags Array(String)']): - test_client.insert('test_form_arrays', + call(form_client.insert, 'test_form_arrays', [[1, ['tag1', 'tag2']], [2, ['tag2', 'tag3']], [3, ['tag1', 'tag3']]]) - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_arrays WHERE has(tags, {tag:String})', parameters={'tag': 'tag3'} ) assert result.row_count == 2 ids = [1, 3] - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_arrays WHERE id IN {ids:Array(UInt32)}', parameters={'ids': ids} ) @@ -69,9 +68,9 @@ def test_form_encode_with_arrays(test_client: Client, test_config: TestConfig, t assert sorted([row[0] for row in result.result_rows]) == [1, 3] -def test_form_encode_raw_query(test_config: TestConfig): +def test_form_encode_raw_query(client_factory, call, test_config: TestConfig): """Test form_encode_query with raw_query method""" - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -80,7 +79,7 @@ def test_form_encode_raw_query(test_config: TestConfig): form_encode_query_params=True ) - result = form_client.raw_query( + result = call(form_client.raw_query, 'SELECT {a:Int32} + {b:Int32} as sum', parameters={'a': 10, 'b': 20} ) @@ -88,9 +87,9 @@ def test_form_encode_raw_query(test_config: TestConfig): assert b'30' in result -def test_form_encode_vs_regular(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_vs_regular(client_factory, param_client: Client, call, test_config: TestConfig, table_context: Callable): """Verify that form_encode_query produces same results as regular parameter handling""" - regular_client = get_client( + regular_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -99,7 +98,7 @@ def test_form_encode_vs_regular(test_client: Client, test_config: TestConfig, ta form_encode_query_params=False ) - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -109,22 +108,22 @@ def test_form_encode_vs_regular(test_client: Client, test_config: TestConfig, ta ) with table_context('test_comparison', ['id UInt32', 'text String', 'score Float64']): - test_client.insert('test_comparison', + call(param_client.insert, 'test_comparison', [[i, f'text_{i}', i * 1.5] for i in range(1, 11)]) query = 'SELECT * FROM test_comparison WHERE id > {min_id:UInt32} AND score < {max_score:Float64} ORDER BY id' params = {'min_id': 3, 'max_score': 12.0} - regular_result = regular_client.query(query, parameters=params) - form_result = form_client.query(query, parameters=params) + regular_result = call(regular_client.query, query, parameters=params) + form_result = call(form_client.query, query, parameters=params) assert regular_result.result_rows == form_result.result_rows assert regular_result.row_count == form_result.row_count -def test_form_encode_nullable_params(test_config: TestConfig): +def test_form_encode_nullable_params(client_factory, call, test_config: TestConfig): """Test form_encode_query with nullable parameters""" - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -133,22 +132,22 @@ def test_form_encode_nullable_params(test_config: TestConfig): form_encode_query_params=True ) - result = form_client.query( + result = call(form_client.query, 'SELECT {val:Nullable(String)} IS NULL as is_null', parameters={'val': None} ) assert result.first_row[0] == 1 - result = form_client.query( + result = call(form_client.query, 'SELECT {val:Nullable(String)} as value', parameters={'val': 'test_value'} ) assert result.first_row[0] == 'test_value' -def test_form_encode_schema_probe_query(test_config: TestConfig, table_context: Callable): +def test_form_encode_schema_probe_query(client_factory, call, test_config: TestConfig, table_context: Callable): """Test that schema-probe queries (LIMIT 0) work correctly with form_encode_query_params""" - form_client = get_client( + form_client = client_factory( host=test_config.host, port=test_config.port, username=test_config.username, @@ -158,14 +157,14 @@ def test_form_encode_schema_probe_query(test_config: TestConfig, table_context: ) # Test with a simple LIMIT 0 query - result = form_client.query('SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') + result = call(form_client.query, 'SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') assert result.column_names == ('name', 'database', 'dt') assert len(result.column_types) == 3 assert len(result.result_set) == 0 # Test with LIMIT 0 and parameters with table_context('test_schema_probe', ['id UInt32', 'name String', 'value Float64']): - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_schema_probe WHERE id = {id:UInt32} LIMIT 0', parameters={'id': 1} ) @@ -174,7 +173,7 @@ def test_form_encode_schema_probe_query(test_config: TestConfig, table_context: assert len(result.result_set) == 0 # Test with complex query and parameters - result = form_client.query( + result = call(form_client.query, 'SELECT id, name, value * {multiplier:Float64} as adjusted_value ' 'FROM test_schema_probe ' 'WHERE name = {filter_name:String} LIMIT 0', diff --git a/tests/integration_tests/test_formats.py b/tests/integration_tests/test_formats.py index 2e9876c7..0a1895d4 100644 --- a/tests/integration_tests/test_formats.py +++ b/tests/integration_tests/test_formats.py @@ -1,15 +1,11 @@ -from clickhouse_connect.driver import Client, ProgrammingError +from clickhouse_connect.driver import Client -def test_uint64_format(test_client: Client): +def test_uint64_format(param_client: Client, call): # Default should be unsigned - result = test_client.query('SELECT toUInt64(9523372036854775807) as value') + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value") assert result.result_set[0][0] == 9523372036854775807 - result = test_client.query('SELECT toUInt64(9523372036854775807) as value', query_formats={'UInt64': 'signed'}) + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value", query_formats={"UInt64": "signed"}) assert result.result_set[0][0] == -8923372036854775809 - result = test_client.query('SELECT toUInt64(9523372036854775807) as value', query_formats={'UInt64': 'native'}) + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value", query_formats={"UInt64": "native"}) assert result.result_set[0][0] == 9523372036854775807 - try: - test_client.query('SELECT toUInt64(9523372036854775807) as signed', query_formats={'UInt64': 'huh'}) - except ProgrammingError: - pass diff --git a/tests/integration_tests/test_geometric.py b/tests/integration_tests/test_geometric.py index 5e667c78..a6f576ed 100644 --- a/tests/integration_tests/test_geometric.py +++ b/tests/integration_tests/test_geometric.py @@ -3,32 +3,32 @@ from clickhouse_connect.driver import Client -def test_point_column(test_client: Client, table_context: Callable): +def test_point_column(param_client: Client, call, table_context: Callable): with table_context('point_column_test', ['key Int32', 'point Point']): data = [[1, (3.55, 3.55)], [2, (4.55, 4.55)]] - test_client.insert('point_column_test', data) + call(param_client.insert, 'point_column_test', data) - query_result = test_client.query('SELECT * FROM point_column_test ORDER BY key').result_rows + query_result = call(param_client.query, 'SELECT * FROM point_column_test ORDER BY key').result_rows assert len(query_result) == 2 assert query_result[0] == (1, (3.55, 3.55)) assert query_result[1] == (2, (4.55, 4.55)) -def test_ring_column(test_client: Client, table_context: Callable): +def test_ring_column(param_client: Client, call, table_context: Callable): with table_context('ring_column_test', ['key Int32', 'ring Ring']): data = [[1, [(5.522, 58.472),(3.55, 3.55)]], [2, [(4.55, 4.55)]]] - test_client.insert('ring_column_test', data) + call(param_client.insert, 'ring_column_test', data) - query_result = test_client.query('SELECT * FROM ring_column_test ORDER BY key').result_rows + query_result = call(param_client.query, 'SELECT * FROM ring_column_test ORDER BY key').result_rows assert len(query_result) == 2 assert query_result[0] == (1, [(5.522, 58.472),(3.55, 3.55)]) assert query_result[1] == (2, [(4.55, 4.55)]) -def test_polygon_column(test_client: Client, table_context: Callable): +def test_polygon_column(param_client: Client, call, table_context: Callable): with table_context('polygon_column_test', ['key Int32', 'polygon Polygon']): - res = test_client.query("SELECT readWKTPolygon('POLYGON ((-64.8 32.3, -65.5 18.3, -80.3 25.2, -64.8 32.3))') as polygon") + res = call(param_client.query, "SELECT readWKTPolygon('POLYGON ((-64.8 32.3, -65.5 18.3, -80.3 25.2, -64.8 32.3))') as polygon") pg = res.first_row[0] - test_client.insert('polygon_column_test', [(1, pg), (4, pg)]) - query_result = test_client.query('SELECT key, polygon FROM polygon_column_test WHERE key = 4') + call(param_client.insert, 'polygon_column_test', [(1, pg), (4, pg)]) + query_result = call(param_client.query, 'SELECT key, polygon FROM polygon_column_test WHERE key = 4') assert query_result.first_row[1] == pg diff --git a/tests/integration_tests/test_inserts.py b/tests/integration_tests/test_inserts.py index 123cf8d5..5adfa97c 100644 --- a/tests/integration_tests/test_inserts.py +++ b/tests/integration_tests/test_inserts.py @@ -1,91 +1,91 @@ from decimal import Decimal from typing import Callable +import pytest + from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.exceptions import DataError -def test_insert(test_client: Client, test_table_engine: str): - if test_client.min_version('19'): - test_client.command('DROP TABLE IF EXISTS test_system_insert') +def test_insert(param_client: Client, call, test_table_engine: str): + if param_client.min_version('19'): + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert') else: - test_client.command('DROP TABLE IF EXISTS test_system_insert SYNC') - test_client.command(f'CREATE TABLE test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name') - tables_result = test_client.query('SELECT * from system.tables') - test_client.insert(table='test_system_insert', column_names='*', data=tables_result.result_set) - copy_result = test_client.command('SELECT count() from test_system_insert') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert SYNC') + call(param_client.command, f'CREATE TABLE test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name') + tables_result = call(param_client.query, 'SELECT * from system.tables') + call(param_client.insert, table='test_system_insert', column_names='*', data=tables_result.result_set) + copy_result = call(param_client.command, 'SELECT count() from test_system_insert') assert tables_result.row_count == copy_result - test_client.command('DROP TABLE IF EXISTS test_system_insert') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert') -def test_decimal_conv(test_client: Client, table_context: Callable): +def test_decimal_conv(param_client: Client, call, table_context: Callable): with table_context('test_num_conv', ['col1 UInt64', 'col2 Int32', 'f1 Float64']): data = [[Decimal(5), Decimal(-182), Decimal(55.2)], [Decimal(57238478234), Decimal(77), Decimal(-29.5773)]] - test_client.insert('test_num_conv', data) - result = test_client.query('SELECT * FROM test_num_conv').result_set + call(param_client.insert, 'test_num_conv', data) + result = call(param_client.query, 'SELECT * FROM test_num_conv').result_set assert result == [(5, -182, 55.2), (57238478234, 77, -29.5773)] -def test_float_decimal_conv(test_client: Client, table_context: Callable): +def test_float_decimal_conv(param_client: Client, call, table_context: Callable): with table_context('test_float_to_dec_conv', ['col1 Decimal32(6)','col2 Decimal32(6)', 'col3 Decimal128(6)', 'col4 Decimal128(6)']): data = [[0.492917, 0.49291700, 0.492917, 0.49291700]] - test_client.insert('test_float_to_dec_conv', data) - result = test_client.query('SELECT * FROM test_float_to_dec_conv').result_set + call(param_client.insert, 'test_float_to_dec_conv', data) + result = call(param_client.query, 'SELECT * FROM test_float_to_dec_conv').result_set assert result == [(Decimal("0.492917"), Decimal("0.492917"), Decimal("0.492917"), Decimal("0.492917"))] -def test_bad_data_insert(test_client: Client, table_context: Callable): +def test_bad_data_insert(param_client: Client, call, table_context: Callable): with table_context('test_bad_insert', ['key Int32', 'float_col Float64']): data = [[1, 3.22], [2, 'nope']] - try: - test_client.insert('test_bad_insert', data) - except DataError as ex: - assert 'array' in str(ex) + with pytest.raises(DataError, match="array"): + call(param_client.insert, 'test_bad_insert', data) -def test_bad_strings(test_client: Client, table_context: Callable): +def test_bad_strings(param_client: Client, call, table_context: Callable): with table_context('test_bad_strings', 'key Int32, fs FixedString(6), nsf Nullable(FixedString(4))'): try: - test_client.insert('test_bad_strings', [[1, b'\x0535', None]]) + call(param_client.insert, 'test_bad_strings', [[1, b'\x0535', None]]) except DataError as ex: assert 'match' in str(ex) try: - test_client.insert('test_bad_strings', [[1, b'\x0535abc', '😀🙃']]) + call(param_client.insert, 'test_bad_strings', [[1, b'\x0535abc', '😀🙃']]) except DataError as ex: assert 'encoded' in str(ex) -def test_low_card_dictionary_size(test_client: Client, table_context: Callable): +def test_low_card_dictionary_size(param_client: Client, call, table_context: Callable): with table_context('test_low_card_dict', 'key Int32, lc LowCardinality(String)', settings={'index_granularity': 65536 }): data = [[x, str(x)] for x in range(30000)] - test_client.insert('test_low_card_dict', data) - assert 30000 == test_client.command('SELECT count() FROM test_low_card_dict') + call(param_client.insert, 'test_low_card_dict', data) + assert 30000 == call(param_client.command, 'SELECT count() FROM test_low_card_dict') -def test_column_names_spaces(test_client: Client, table_context: Callable): +def test_column_names_spaces(param_client: Client, call, table_context: Callable): with table_context('test_column_spaces', columns=['key 1', 'value 1'], column_types=['Int32', 'String']): data = [[1, 'str 1'], [2, 'str 2']] - test_client.insert('test_column_spaces', data) - result = test_client.query('SELECT * FROM test_column_spaces').result_rows + call(param_client.insert, 'test_column_spaces', data) + result = call(param_client.query, 'SELECT * FROM test_column_spaces').result_rows assert result[0][0] == 1 assert result[1][1] == 'str 2' -def test_numeric_conversion(test_client: Client, table_context: Callable): +def test_numeric_conversion(param_client: Client, call, table_context: Callable): with table_context('test_numeric_convert', columns=['key Int32', 'n_int Nullable(UInt64)', 'n_flt Nullable(Float64)']): data = [[1, None, None], [2, '2', '5.32']] - test_client.insert('test_numeric_convert', data) - result = test_client.query('SELECT * FROM test_numeric_convert').result_rows + call(param_client.insert, 'test_numeric_convert', data) + result = call(param_client.query, 'SELECT * FROM test_numeric_convert').result_rows assert result[1][1] == 2 assert result[1][2] == float('5.32') - test_client.command('TRUNCATE TABLE test_numeric_convert') + call(param_client.command, 'TRUNCATE TABLE test_numeric_convert') data = [[0, '55', '532.48'], [1, None, None], [2, '2', '5.32']] - test_client.insert('test_numeric_convert', data) - result = test_client.query('SELECT * FROM test_numeric_convert').result_rows + call(param_client.insert, 'test_numeric_convert', data) + result = call(param_client.query, 'SELECT * FROM test_numeric_convert').result_rows assert result[0][1] == 55 assert result[0][2] == 532.48 assert result[1][1] is None diff --git a/tests/integration_tests/test_jwt_auth.py b/tests/integration_tests/test_jwt_auth.py index 02249cd5..592b613f 100644 --- a/tests/integration_tests/test_jwt_auth.py +++ b/tests/integration_tests/test_jwt_auth.py @@ -2,157 +2,69 @@ import pytest -from clickhouse_connect.driver import create_client, ProgrammingError, create_async_client +from clickhouse_connect.driver import ProgrammingError from tests.integration_tests.conftest import TestConfig +CHECK_CLOUD_MODE_QUERY = "SELECT value='1' FROM system.settings WHERE name='cloud_mode'" +JWT_SECRET_ENV_KEY = "CLICKHOUSE_CONNECT_TEST_JWT_SECRET" -def test_jwt_auth_sync_client(test_config: TestConfig): - if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') - - access_token = make_access_token() - client = create_client( - host=test_config.host, - port=test_config.port, - access_token=access_token - ) - result = client.query(query=CHECK_CLOUD_MODE_QUERY).result_set - assert result == [(True,)] - - -def test_jwt_auth_sync_client_set_access_token(test_config: TestConfig): - if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') - - access_token = make_access_token() - client = create_client( - host=test_config.host, - port=test_config.port, - access_token=access_token, - ) - - # Should still work after the override - access_token = make_access_token() - client.set_access_token(access_token) - - result = client.query(query=CHECK_CLOUD_MODE_QUERY).result_set - assert result == [(True,)] - - -def test_jwt_auth_sync_client_config_errors(): - with pytest.raises(ProgrammingError): - create_client( - username='bob', - access_token='foobar' - ) - with pytest.raises(ProgrammingError): - create_client( - username='bob', - password='secret', - access_token='foo' - ) - with pytest.raises(ProgrammingError): - create_client( - password='secret', - access_token='foo' - ) - - -def test_jwt_auth_sync_client_set_access_token_errors(test_config: TestConfig): - if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') - - client = create_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - ) - # Can't use JWT with username/password - access_token = make_access_token() - with pytest.raises(ProgrammingError): - client.set_access_token(access_token) +def make_access_token(): + """Get JWT secret from environment for testing.""" + secret = environ.get(JWT_SECRET_ENV_KEY) + if not secret: + raise ValueError(f"{JWT_SECRET_ENV_KEY} environment variable is not set") + return secret -@pytest.mark.asyncio -async def test_jwt_auth_async_client(test_config: TestConfig): +def test_jwt_auth_client(test_config: TestConfig, client_factory, call): + """Test JWT authentication with both sync and async clients.""" if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') + pytest.skip("Skipping JWT test in non-Cloud mode") access_token = make_access_token() - client = await create_async_client( - host=test_config.host, - port=test_config.port, - access_token=access_token - ) - result = (await client.query(query=CHECK_CLOUD_MODE_QUERY)).result_set + client = client_factory(access_token=access_token) + result = call(client.query, CHECK_CLOUD_MODE_QUERY).result_set assert result == [(True,)] -@pytest.mark.asyncio -async def test_jwt_auth_async_client_set_access_token(test_config: TestConfig): +def test_jwt_auth_client_set_access_token(test_config: TestConfig, client_factory, call): + """Test setting JWT access token dynamically with both sync and async clients.""" if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') + pytest.skip("Skipping JWT test in non-Cloud mode") access_token = make_access_token() - client = await create_async_client( - host=test_config.host, - port=test_config.port, - access_token=access_token, - ) + client = client_factory(access_token=access_token) access_token = make_access_token() client.set_access_token(access_token) - result = (await client.query(query=CHECK_CLOUD_MODE_QUERY)).result_set + result = call(client.query, CHECK_CLOUD_MODE_QUERY).result_set assert result == [(True,)] -@pytest.mark.asyncio -async def test_jwt_auth_async_client_config_errors(): +def test_jwt_auth_client_config_errors(client_factory): + """Test JWT configuration validation catches invalid combinations.""" with pytest.raises(ProgrammingError): - await create_async_client( - username='bob', - access_token='foobar' - ) + client_factory(username="bob", access_token="foobar") + with pytest.raises(ProgrammingError): - await create_async_client( - username='bob', - password='secret', - access_token='foo' - ) + client_factory(username="bob", password="secret", access_token="foo") + with pytest.raises(ProgrammingError): - await create_async_client( - password='secret', - access_token='foo' - ) + client_factory(password="secret", access_token="foo") -@pytest.mark.asyncio -async def test_jwt_auth_async_client_set_access_token_errors(test_config: TestConfig): +def test_jwt_auth_client_set_access_token_errors(test_config: TestConfig, client_factory): + """Test that JWT cannot be set when using username/password authentication.""" if not test_config.cloud: - pytest.skip('Skipping JWT test in non-Cloud mode') + pytest.skip("Skipping JWT test in non-Cloud mode") - client = await create_async_client( - host=test_config.host, - port=test_config.port, + client = client_factory( username=test_config.username, password=test_config.password, ) - # Can't use JWT with username/password access_token = make_access_token() with pytest.raises(ProgrammingError): client.set_access_token(access_token) - - -CHECK_CLOUD_MODE_QUERY = "SELECT value='1' FROM system.settings WHERE name='cloud_mode'" -JWT_SECRET_ENV_KEY = 'CLICKHOUSE_CONNECT_TEST_JWT_SECRET' - - -def make_access_token(): - secret = environ.get(JWT_SECRET_ENV_KEY) - if not secret: - raise ValueError(f'{JWT_SECRET_ENV_KEY} environment variable is not set') - return secret diff --git a/tests/integration_tests/test_multithreading.py b/tests/integration_tests/test_multithreading.py index 1b01eaf4..0806f18f 100644 --- a/tests/integration_tests/test_multithreading.py +++ b/tests/integration_tests/test_multithreading.py @@ -1,29 +1,120 @@ +import asyncio import threading +import uuid import pytest -from clickhouse_connect.driver import Client from clickhouse_connect.driver.exceptions import ProgrammingError from tests.integration_tests.conftest import TestConfig -def test_threading_error(test_config: TestConfig, test_client: Client): +def test_sync_client_sequential_thread_access(param_client, client_mode, call, test_config: TestConfig): + """Test that sync clients can handle sequential access from different threads.""" + if client_mode != "sync": + pytest.skip("Only testing sync client behavior") + if test_config.cloud: - pytest.skip('Skipping threading test in ClickHouse Cloud') - thrown = None - - class QueryThread (threading.Thread): - def run(self): - nonlocal thrown - try: - test_client.command('SELECT randomString(512) FROM numbers(1000000)') - except ProgrammingError as ex: - thrown = ex - - threads = [QueryThread(), QueryThread()] + pytest.skip("Skipping threading test in ClickHouse Cloud") + + results = [] + errors = [] + + def run_query(value): + try: + result = param_client.command(f"SELECT {value}") + results.append(result) + except Exception as ex: # pylint: disable=broad-exception-caught + errors.append(ex) + + threads = [threading.Thread(target=run_query, args=(i,)) for i in range(3)] for thread in threads: thread.start() + thread.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 3 + assert results == [0, 1, 2] + + +def test_async_client_threadsafe_submission(param_client, client_mode, call, test_config: TestConfig, shared_loop): + """Test that async clients work correctly with run_coroutine_threadsafe from multiple threads.""" + if client_mode != "async": + pytest.skip("Only testing async client behavior") + + if test_config.cloud: + pytest.skip("Skipping threading test in ClickHouse Cloud") + + results = [] + errors = [] + lock = threading.Lock() + + def run_query_threadsafe(value): + try: + future = asyncio.run_coroutine_threadsafe( + param_client.command(f"SELECT {value}"), + shared_loop + ) + result = future.result(timeout=5) + with lock: + results.append(result) + except Exception as ex: # pylint: disable=broad-exception-caught + with lock: + errors.append(ex) + + + threads = [threading.Thread(target=run_query_threadsafe, args=(i,)) for i in range(3)] + for thread in threads: + thread.start() + + call(asyncio.sleep, 2) + + for thread in threads: + thread.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 3 + assert sorted(results) == [0, 1, 2] + + +def test_concurrent_session_usage_detection(client_mode, call, test_config: TestConfig, client_factory, shared_loop): + """Test that ClickHouse server detects concurrent usage of the same session.""" + if test_config.cloud: + pytest.skip("Skipping session concurrency test in ClickHouse Cloud") + + session_id = str(uuid.uuid4()) + client1 = client_factory(session_id=session_id) + client2 = client_factory(session_id=session_id) + + thrown = [] + + def run_query(client): + try: + if client_mode == "sync": + client.command("SELECT sleep(1)") + else: + future = asyncio.run_coroutine_threadsafe( + client.command("SELECT sleep(1)"), + shared_loop + ) + future.result(timeout=5) + except (ProgrammingError, Exception) as ex: # pylint: disable=broad-exception-caught + thrown.append(ex) + + threads = [ + threading.Thread(target=run_query, args=(client1,)), + threading.Thread(target=run_query, args=(client2,)) + ] + + for thread in threads: + thread.start() + + if client_mode == "async": + call(asyncio.sleep, 2) + for thread in threads: thread.join() - assert 'concurrent' in str(thrown) + # At least one should fail due to concurrent session usage + assert len(thrown) > 0, "Expected ClickHouse to detect concurrent session usage" + assert any("concurrent" in str(ex).lower() or "session" in str(ex).lower() for ex in thrown), \ + f"Expected session concurrency error, got: {thrown}" diff --git a/tests/integration_tests/test_native.py b/tests/integration_tests/test_native.py index 01df2a5a..f9a88140 100644 --- a/tests/integration_tests/test_native.py +++ b/tests/integration_tests/test_native.py @@ -13,52 +13,52 @@ from clickhouse_connect.driver.common import coerce_bool -def test_low_card(test_client: Client, table_context: Callable): +def test_low_card(param_client: Client, call, table_context: Callable): with table_context('native_test', ['key LowCardinality(Int32)', 'value_1 LowCardinality(String)']): - test_client.insert('native_test', [[55, 'TV1'], [-578328, 'TV38882'], [57372, 'Kabc/defXX']]) - result = test_client.query("SELECT * FROM native_test WHERE value_1 LIKE '%abc/def%'") + call(param_client.insert, 'native_test', [[55, 'TV1'], [-578328, 'TV38882'], [57372, 'Kabc/defXX']]) + result = call(param_client.query, "SELECT * FROM native_test WHERE value_1 LIKE '%abc/def%'") assert len(result.result_set) == 1 -def test_low_card_uuid(test_client: Client, table_context: Callable): +def test_low_card_uuid(param_client: Client, call, table_context: Callable): with table_context('low_card_uuid', ['dt Date', 'low_card_uuid LowCardinality(UUID)']): data = ([date(2023, 1, 1), '80397B00E0B248AFAF34AE11A5546A3B'], [date(2024, 1, 1), '70397B00-E0B2-48AF-AF34-AE11A5546A3B']) - test_client.insert('low_card_uuid', data) - result = test_client.query("SELECT * FROM low_card_uuid order by dt").result_set + call(param_client.insert, 'low_card_uuid', data) + result = call(param_client.query, "SELECT * FROM low_card_uuid order by dt").result_set assert len(result) == 2 assert str(result[0][1]) == '80397b00-e0b2-48af-af34-ae11a5546a3b' assert str(result[1][1]) == '70397b00-e0b2-48af-af34-ae11a5546a3b' -def test_bare_datetime64(test_client: Client, table_context: Callable): +def test_bare_datetime64(param_client: Client, call, table_context: Callable): with table_context('bare_datetime64_test', ['key UInt32', 'dt64 DateTime64']): - test_client.insert('bare_datetime64_test', + call(param_client.insert, 'bare_datetime64_test', [[1, datetime(2023, 3, 25, 10, 5, 44, 772402)], [2, datetime.now()], [3, datetime(1965, 10, 15, 12, 0, 0)]]) - result = test_client.query('SELECT * FROM bare_datetime64_test ORDER BY key').result_rows + result = call(param_client.query, 'SELECT * FROM bare_datetime64_test ORDER BY key').result_rows assert result[0][0] == 1 assert result[0][1] == datetime(2023, 3, 25, 10, 5, 44, 772000) assert result[2][1] == datetime(1965, 10, 15, 12, 0, 0) -def test_nulls(test_client: Client, table_context: Callable): +def test_nulls(param_client: Client, call, table_context: Callable): with table_context('nullable_test', ['key UInt32', 'null_str Nullable(String)', 'null_int Nullable(Int64)']): - test_client.insert('nullable_test', [[1, None, None], + call(param_client.insert, 'nullable_test', [[1, None, None], [2, 'nonnull', -57382882345666], [3, None, 5882374747732834], [4, 'nonnull2', None]]) - result = test_client.query('SELECT * FROM nullable_test ORDER BY key', use_none=False).result_rows + result = call(param_client.query, 'SELECT * FROM nullable_test ORDER BY key', use_none=False).result_rows assert result[2] == (3, '', 5882374747732834) assert result[3] == (4, 'nonnull2', 0) - result = test_client.query('SELECT * FROM nullable_test ORDER BY key').result_rows + result = call(param_client.query, 'SELECT * FROM nullable_test ORDER BY key').result_rows assert result[1] == (2, 'nonnull', -57382882345666) assert result[2] == (3, None, 5882374747732834) assert result[3] == (4, 'nonnull2', None) -def test_old_json(test_client: Client, table_context: Callable): +def test_old_json(param_client: Client, call, table_context: Callable): if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_TEST_OLD_JSON_TYPE')): pytest.skip('Deprecated JSON type not tested') with table_context('old_json_test', [ @@ -71,12 +71,12 @@ def test_old_json(test_client: Client, table_context: Callable): jv3 = {'key3': 752, 'value.2': 'v2_rules', 'blank': None} njv2 = {'nk1': -302, 'nk2': {'sub1': 372, 'sub2': 'a string'}} njv3 = {'nk1': 5832.44, 'nk2': {'sub1': 47788382, 'sub2':'sub2val', 'sub3': 'sub3str', 'space key': 'spacey'}} - test_client.insert('old_json_test', [ + call(param_client.insert, 'old_json_test', [ [5, jv1, -44, None], [20, None, 5200, njv2], [25, jv3, 7302, njv3]]) - result = test_client.query('SELECT * FROM old_json_test ORDER BY key') + result = call(param_client.query, 'SELECT * FROM old_json_test ORDER BY key') json1 = result.result_set[0][1] assert json1['HKD@spéçiäl'] == 'Special K' assert json1['key3'] == 0 @@ -93,24 +93,24 @@ def test_old_json(test_client: Client, table_context: Callable): assert null_json3['nk2']['space key'] == 'spacey' set_write_format('JSON', 'string') - test_client.insert('native_json_test', [[999, '{"key4": 283, "value.2": "str_value"}', 77, '{"nk1":53}']]) - result = test_client.query('SELECT value.key4, null_value.nk1 FROM native_json_test ORDER BY key') + call(param_client.insert, 'native_json_test', [[999, '{"key4": 283, "value.2": "str_value"}', 77, '{"nk1":53}']]) + result = call(param_client.query, 'SELECT value.key4, null_value.nk1 FROM native_json_test ORDER BY key') assert result.result_set[3][0] == 283 assert result.result_set[3][1] == 53 -def test_read_formats(test_client: Client, test_table_engine: str): - test_client.command('DROP TABLE IF EXISTS read_format_test') - test_client.command('CREATE TABLE read_format_test (key Int32, uuid UUID, fs FixedString(10), ipv4 IPv4,' + +def test_read_formats(param_client: Client, call, test_table_engine: str): + call(param_client.command, 'DROP TABLE IF EXISTS read_format_test') + call(param_client.command, 'CREATE TABLE read_format_test (key Int32, uuid UUID, fs FixedString(10), ipv4 IPv4,' + 'ip_array Array(IPv6), tup Tuple(u1 UInt64, ip2 IPv4))' + f'Engine {test_table_engine} ORDER BY key') uuid1 = uuid.UUID('23E45688e89B-12D3-3273-426614174000') uuid2 = uuid.UUID('77AA3278-3728-12d3-5372-000377723832') row1 = (1, uuid1, '530055777k', '10.251.30.50', ['2600::', '2001:4860:4860::8844'], (7372, '10.20.30.203')) row2 = (2, uuid2, 'short str', '10.44.75.20', ['74:382::3332', '8700:5200::5782:3992'], (7320, '252.18.4.50')) - test_client.insert('read_format_test', [row1, row2]) + call(param_client.insert, 'read_format_test', [row1, row2]) - result = test_client.query('SELECT * FROM read_format_test;;;').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test;;;').result_set assert result[0][1] == uuid1 assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][2] == b'\x35\x33\x30\x30\x35\x35\x37\x37\x37\x6b' @@ -118,104 +118,104 @@ def test_read_formats(test_client: Client, test_table_engine: str): assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') set_default_formats('uuid', 'string', 'ip*', 'string', 'FixedString', 'string') - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[0][1] == '23e45688-e89b-12d3-3273-426614174000' assert result[1][3] == '10.44.75.20' assert result[0][2] == '530055777k' assert result[0][4][1] == '2001:4860:4860::8844' clear_default_format('ip*') - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[0][1] == '23e45688-e89b-12d3-3273-426614174000' assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][4][1] == IPv6Address('2001:4860:4860::8844') assert result[0][2] == '530055777k' # Test query formats - result = test_client.query('SELECT * FROM read_format_test', query_formats={'IP*': 'string', + result = call(param_client.query, 'SELECT * FROM read_format_test', query_formats={'IP*': 'string', 'tup': 'json'}).result_set assert result[1][3] == '10.44.75.20' assert result[0][5] == b'{"u1":7372,"ip2":"10.20.30.203"}' # Ensure that the query format clears - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') # Test column formats - result = test_client.query('SELECT * FROM read_format_test', column_formats={'ipv4': 'string', + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'ipv4': 'string', 'tup': 'tuple'}).result_set assert result[1][3] == '10.44.75.20' assert result[0][5][1] == IPv4Address('10.20.30.203') # Ensure that the column format clears - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') # Test sub column formats set_read_format('tuple', 'tuple') - result = test_client.query('SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set assert result[0][5][1] == '10.20.30.203' set_read_format('tuple', 'native') - result = test_client.query('SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set assert result[0][5]['ip2'] == '10.20.30.203' -def test_tuple_inserts(test_client: Client, table_context: Callable): +def test_tuple_inserts(param_client: Client, call, table_context: Callable): with table_context('insert_tuple_test', ['key Int32', 'named Tuple(fl Float64, `ns space` Nullable(String))', 'unnamed Tuple(Float64, Nullable(String))']): data = [[1, (3.55, 'str1'), (555, None)], [2, (-43.2, None), (0, 'str2')]] - test_client.insert('insert_tuple_test', data, settings={'insert_deduplication_token': 5772}) + call(param_client.insert, 'insert_tuple_test', data, settings={'insert_deduplication_token': 5772}) data = [[1, {'fl': 3.55, 'ns space': 'str1'}, (555, None)], [2, {'fl': -43.2}, (0, 'str2')]] - test_client.insert('insert_tuple_test', data, settings={'insert_deduplication_token': 5773}) - query_result = test_client.query('SELECT * FROM insert_tuple_test ORDER BY key').result_rows + call(param_client.insert, 'insert_tuple_test', data, settings={'insert_deduplication_token': 5773}) + query_result = call(param_client.query, 'SELECT * FROM insert_tuple_test ORDER BY key').result_rows assert len(query_result) == 4 assert query_result[0] == query_result[1] assert query_result[2] == query_result[3] -def test_agg_function(test_client: Client, table_context: Callable): +def test_agg_function(param_client: Client, call, table_context: Callable): with table_context('agg_func_test', ['key Int32', 'str SimpleAggregateFunction(any, String)', 'lc_str SimpleAggregateFunction(any, LowCardinality(String))'], engine='AggregatingMergeTree'): - test_client.insert('agg_func_test', [(1, 'str', 'lc_str')]) - row = test_client.query('SELECT str, lc_str FROM agg_func_test').first_row + call(param_client.insert, 'agg_func_test', [(1, 'str', 'lc_str')]) + row = call(param_client.query, 'SELECT str, lc_str FROM agg_func_test').first_row assert row[0] == 'str' assert row[1] == 'lc_str' -def test_decimal_rounding(test_client: Client, table_context: Callable): +def test_decimal_rounding(param_client: Client, call, table_context: Callable): test_vals = [732.4, 75.57, 75.49, 40.16] with table_context('test_decimal', ['key Int32, value Decimal(10, 2)']): - test_client.insert('test_decimal', [[ix, x] for ix, x in enumerate(test_vals)]) - values = test_client.query('SELECT value FROM test_decimal').result_columns[0] + call(param_client.insert, 'test_decimal', [[ix, x] for ix, x in enumerate(test_vals)]) + values = call(param_client.query, 'SELECT value FROM test_decimal').result_columns[0] with decimal.localcontext() as dec_ctx: dec_ctx.prec = 10 assert [decimal.Decimal(str(x)) for x in test_vals] == values -def test_empty_maps(test_client: Client): - result = test_client.query("select Cast(([],[]), 'Map(String, Map(String, String))')") +def test_empty_maps(param_client: Client, call): + result = call(param_client.query, "select Cast(([],[]), 'Map(String, Map(String, String))')") assert result.first_row[0] == {} -def test_fixed_str_padding(test_client: Client, table_context: Callable): +def test_fixed_str_padding(param_client: Client, call, table_context: Callable): table = 'test_fixed_str_padding' with table_context(table, 'key Int32, value FixedString(3)'): - test_client.insert(table, [[1, 'abc']]) - test_client.insert(table, [[2, 'a']]) - test_client.insert(table, [[3, '']]) - result = test_client.query(f'select * from {table} ORDER BY key') + call(param_client.insert, table, [[1, 'abc']]) + call(param_client.insert, table, [[2, 'a']]) + call(param_client.insert, table, [[3, '']]) + result = call(param_client.query, f'select * from {table} ORDER BY key') assert result.result_columns[1] == [b'abc', b'a\x00\x00', b'\x00\x00\x00'] -def test_nonstandard_column_names(test_client: Client, table_context: Callable): +def test_nonstandard_column_names(param_client: Client, call, table_context: Callable): table = 'пример_кириллица' with table_context(table, 'колонка String') as t: - test_client.insert(t.table, (('привет',),)) - result = test_client.query(f'SELECT * FROM {t.table}').result_set + call(param_client.insert, t.table, (('привет',),)) + result = call(param_client.query, f'SELECT * FROM {t.table}').result_set assert result[0][0] == 'привет' diff --git a/tests/integration_tests/test_native_fuzz.py b/tests/integration_tests/test_native_fuzz.py index b654e7b8..0d83e4c8 100644 --- a/tests/integration_tests/test_native_fuzz.py +++ b/tests/integration_tests/test_native_fuzz.py @@ -13,28 +13,28 @@ # pylint: disable=duplicate-code -def test_query_fuzz(test_client: Client, test_table_engine: str): - if not test_client.min_version('21'): - pytest.skip(f'flatten_nested setting not supported in this server version {test_client.server_version}') +def test_query_fuzz(param_client: Client, call, test_table_engine: str): + if not param_client.min_version('21'): + pytest.skip(f'flatten_nested setting not supported in this server version {param_client.server_version}') test_runs = int(os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250')) - test_client.apply_server_timezone = True + param_client.apply_server_timezone = True try: for _ in range(test_runs): - test_client.command('DROP TABLE IF EXISTS fuzz_test') + call(param_client.command, 'DROP TABLE IF EXISTS fuzz_test') data_rows = random.randint(0, MAX_DATA_ROWS) col_names, col_types = random_columns(TEST_COLUMNS) - data = random_data(col_types, data_rows, test_client.server_tz) + data = random_data(col_types, data_rows, param_client.server_tz) col_names = ('row_id',) + col_names col_types = (get_from_name('UInt32'),) + col_types col_defs = [TableColumnDef(name, ch_type) for name, ch_type in zip(col_names, col_types)] create_stmt = create_table('fuzz_test', col_defs, test_table_engine, {'order by': 'row_id'}) - test_client.command(create_stmt, settings={'flatten_nested': 0}) - test_client.insert('fuzz_test', data, col_names) + call(param_client.command, create_stmt, settings={'flatten_nested': 0}) + call(param_client.insert, 'fuzz_test', data, col_names) - data_result = test_client.query('SELECT * FROM fuzz_test') + data_result = call(param_client.query, 'SELECT * FROM fuzz_test') if data_rows: assert data_result.column_names == col_names assert data_result.result_set == data finally: - test_client.apply_server_timezone = False + param_client.apply_server_timezone = False diff --git a/tests/integration_tests/test_network.py b/tests/integration_tests/test_network.py index c19b81e5..01842d02 100644 --- a/tests/integration_tests/test_network.py +++ b/tests/integration_tests/test_network.py @@ -1,8 +1,9 @@ from ipaddress import IPv4Address, IPv6Address +from typing import Callable + import pytest from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig # A collection of diverse IPv6 addresses for testing IPV6_TEST_CASES = [ @@ -15,37 +16,13 @@ ] -# pylint: disable=attribute-defined-outside-init -class TestIPv6: - """Integration tests for ClickHouse IPv6 data type handling.""" - - client: Client - table_name: str = "ipv6_integration_test" - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_config: TestConfig, test_client: Client): - """Create the test table before each test and drop it after.""" - self.config = test_config - self.client = test_client - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - self.client.command( - f""" - CREATE TABLE {self.table_name} ( - id UInt32, - ip_addr IPv6, - ip_addr_nullable Nullable(IPv6) - ) ENGINE = MergeTree ORDER BY id - """ - ) - yield - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - - def test_ipv6_round_trip(self): - """Tests that various IPv6 addresses can be inserted as objects and read back correctly.""" +def test_ipv6_round_trip(param_client: Client, call, table_context: Callable): + """Test that various IPv6 addresses can be inserted as objects and read back correctly.""" + with table_context("ipv6_round_trip_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): data = [[i, ip, ip] for i, ip in enumerate(IPV6_TEST_CASES)] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv6_round_trip_test", data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM ipv6_round_trip_test ORDER BY id") assert result.row_count == len(IPV6_TEST_CASES) for i, ip in enumerate(IPV6_TEST_CASES): @@ -53,9 +30,9 @@ def test_ipv6_round_trip(self): assert result.result_rows[i][2] == ip assert isinstance(result.result_rows[i][1], IPv6Address) - def test_ipv4_mapping_and_promotion(self): - """Tests that plain IPv4 strings/objects are correctly promoted to IPv4-mapped - IPv6 addresses on insertion and read back correctly.""" +def test_ipv4_mapping_and_promotion(param_client: Client, call, table_context: Callable): + """Test that plain IPv4 strings/objects are correctly promoted to IPv4-mapped IPv6 addresses.""" + with table_context("ipv4_promotion_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): test_ips = [ "198.51.100.1", IPv4Address("203.0.113.255"), @@ -68,45 +45,44 @@ def test_ipv4_mapping_and_promotion(self): ] data = [[i, ip, None] for i, ip in enumerate(test_ips)] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv4_promotion_test", data) - result = self.client.query( - f"SELECT id, ip_addr FROM {self.table_name} ORDER BY id" - ) + result = call(param_client.query, "SELECT id, ip_addr FROM ipv4_promotion_test ORDER BY id") assert result.row_count == len(test_ips) for i, ip in enumerate(expected_ips): assert isinstance(result.result_rows[i][1], IPv6Address) assert result.result_rows[i][1] == ip - def test_null_handling(self): - """Tests inserting and retrieving NULL values in an IPv6 column.""" +def test_ipv6_null_handling(param_client: Client, call, table_context: Callable): + """Test inserting and retrieving NULL values in an IPv6 column.""" + with table_context("ipv6_null_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): data = [[1, "::1", None], [2, "2001:db8::", "2001:db8::"]] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv6_null_test", data) - result = self.client.query( - f"SELECT id, ip_addr_nullable FROM {self.table_name} ORDER BY id" - ) + result = call(param_client.query, "SELECT id, ip_addr_nullable FROM ipv6_null_test ORDER BY id") assert result.row_count == 2 assert result.result_rows[0][1] is None assert result.result_rows[1][1] == IPv6Address("2001:db8::") - def test_read_as_string(self): - """Tests reading IPv6 values as strings using the toString() function.""" +def test_ipv6_read_as_string(param_client: Client, call, table_context: Callable): + """Test reading IPv6 values as strings using the toString() function.""" + with table_context("ipv6_string_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): ip = IPV6_TEST_CASES[0] - self.client.insert(self.table_name, [[1, ip, None]]) + call(param_client.insert, "ipv6_string_test", [[1, ip, None]]) - result = self.client.query(f"SELECT toString(ip_addr) FROM {self.table_name}") + result = call(param_client.query, "SELECT toString(ip_addr) FROM ipv6_string_test") assert result.row_count == 1 read_val = result.result_rows[0][0] assert isinstance(read_val, str) assert read_val == str(ip) - def test_insert_invalid_ipv6_fails(self): - """Tests that the client correctly rejects an invalid IPv6 string.""" +def test_ipv6_insert_invalid_fails(param_client: Client, call, table_context: Callable): + """Test that the client correctly rejects an invalid IPv6 string.""" + with table_context("ipv6_invalid_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): with pytest.raises(ValueError) as excinfo: - self.client.insert(self.table_name, [[1, "not a valid ip address", None]]) + call(param_client.insert, "ipv6_invalid_test", [[1, "not a valid ip address", None]]) assert "Failed to parse 'not a valid ip address'" in str(excinfo.value) diff --git a/tests/integration_tests/test_numeric.py b/tests/integration_tests/test_numeric.py index 677549e9..c3160ead 100644 --- a/tests/integration_tests/test_numeric.py +++ b/tests/integration_tests/test_numeric.py @@ -1,96 +1,67 @@ +from typing import Callable import pytest -from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig - - -# pylint: disable=duplicate-code -# pylint: disable=attribute-defined-outside-init -class TestBFloat16: - """Integration tests for ClickHouse BFloat16 data type handling.""" - - client: Client - table_name: str = "bf16_integration_test" - - # pylint: disable=no-self-use - @pytest.fixture(scope="class", autouse=True) - def check_version(self, test_client: Client): - """Skips the entire class if the server version is too old.""" - if not test_client.min_version("24.11"): - pytest.skip( - f"BFloat16 type not supported in ClickHouse version {test_client.server_version}" - ) - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_config: TestConfig, test_client: Client): - """Create the test table before each test and drop it after.""" - self.config = test_config - self.client = test_client - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - self.client.command( - f""" - CREATE TABLE {self.table_name} ( - id UInt32, - bfloat16 BFloat16, - bfloat16_nullable Nullable(BFloat16) - ) ENGINE = MergeTree ORDER BY id - """ - ) - yield - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - - def test_bf16_round_trip(self): - """Basic round trip test with precision loss.""" + +def test_bfloat16_round_trip(param_client, call, table_context: Callable): + """Test BFloat16 data type with precision loss on round trip.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): input_data = [[0, 3.141592, -2.71828], [1, 3.141592, -2.71828]] expected = [[0, 3.140625, -2.703125], [1, 3.140625, -2.703125]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row assert isinstance(result_row[1], float) - def test_bf16_nullable_round_trip(self): - """Basic round nullable trip test with precision loss.""" +def test_bfloat16_nullable_round_trip(param_client, call, table_context: Callable): + """Test BFloat16 nullable column with precision loss.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_nullable_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): input_data = [[0, 3.141592, None], [1, 3.141592, -2.71828]] expected = [[0, 3.140625, None], [1, 3.140625, -2.703125]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_nullable_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_nullable_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row assert isinstance(result_row[1], float) - def test_bf16_empty_and_all_null_inserts(self): - """Tests inserting no rows, and inserting rows with all-null columns.""" - self.client.insert(self.table_name, []) - result = self.client.query(f"SELECT count() FROM {self.table_name}") +def test_bfloat16_empty_and_all_null_inserts(param_client, call, table_context: Callable): + """Test BFloat16 with empty inserts and all-null columns.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_empty_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): + # Test empty insert + call(param_client.insert, 'bf16_empty_test', []) + result = call(param_client.query, "SELECT count() FROM bf16_empty_test") assert result.result_rows[0][0] == 0 input_data = [[0, 3.141592, None], [1, -2.71828, None]] expected = [[0, 3.140625, None], [1, -2.703125, None]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_empty_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_empty_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row -class TestSpecialIntervalTypes: - """Integration tests for ClickHouse special interval type handling.""" - - client: Client - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_client: Client): - self.client = test_client - - def test_interval_selects_work(self): - result = self.client.query("SELECT INTERVAL 30 DAY") - assert result.result_rows[0][0] == 30 +def test_interval_selects(param_client, call): + """Test that interval type selects work correctly.""" + result = call(param_client.query, "SELECT INTERVAL 30 DAY") + assert result.result_rows[0][0] == 30 diff --git a/tests/integration_tests/test_numpy.py b/tests/integration_tests/test_numpy.py index 2d4844a9..299d1bb4 100644 --- a/tests/integration_tests/test_numpy.py +++ b/tests/integration_tests/test_numpy.py @@ -18,30 +18,30 @@ pytestmark = pytest.mark.skipif(np is None, reason='Numpy package not installed') -def test_numpy_dates(test_client: Client, table_context: Callable): +def test_numpy_dates(param_client: Client, call, table_context: Callable): np_array = np.array(dt_ds, dtype='datetime64[s]').reshape(-1, 1) source_arr = np_array.copy() with table_context('test_numpy_dates', dt_ds_columns, dt_ds_types): - test_client.insert('test_numpy_dates', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_dates') + call(param_client.insert, 'test_numpy_dates', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_dates') assert np.array_equal(np_array, new_np_array) assert np.array_equal(source_arr, np_array) -def test_invalid_date(test_client): +def test_invalid_date(param_client: Client, call): try: sql = "SELECT cast(now64(), 'DateTime64(1)')" - if not test_client.min_version('20'): + if not param_client.min_version('20'): sql = "SELECT cast(now(), 'DateTime')" - test_client.query_df(sql) + call(param_client.query_df, sql) except ProgrammingError as ex: assert 'milliseconds' in str(ex) -def test_numpy_record_type(test_client: Client, table_context: Callable): +def test_numpy_record_type(param_client: Client, call, table_context: Callable): dt_type = 'datetime64[ns]' ds_types = basic_ds_types - if not test_client.min_version('20'): + if not param_client.min_version('20'): dt_type = 'datetime64[s]' ds_types = basic_ds_types_ver19 @@ -49,18 +49,18 @@ def test_numpy_record_type(test_client: Client, table_context: Callable): source_arr = np_array.copy() np_array.dtype.names = basic_ds_columns with table_context('test_numpy_basic', basic_ds_columns, ds_types): - test_client.insert('test_numpy_basic', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_basic', max_str_len=20) + call(param_client.insert, 'test_numpy_basic', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_basic', max_str_len=20) assert np.array_equal(np_array, new_np_array) - empty_np_array = test_client.query_np("SELECT * FROM test_numpy_basic WHERE key = 'NOT A KEY' ") + empty_np_array = call(param_client.query_np, "SELECT * FROM test_numpy_basic WHERE key = 'NOT A KEY' ") assert len(empty_np_array) == 0 assert np.array_equal(source_arr, np_array) -def test_numpy_object_type(test_client: Client, table_context: Callable): +def test_numpy_object_type(param_client: Client, call, table_context: Callable): dt_type = 'datetime64[ns]' ds_types = basic_ds_types - if not test_client.min_version('20'): + if not param_client.min_version('20'): dt_type = 'datetime64[s]' ds_types = basic_ds_types_ver19 @@ -68,88 +68,114 @@ def test_numpy_object_type(test_client: Client, table_context: Callable): np_array.dtype.names = basic_ds_columns source_arr = np_array.copy() with table_context('test_numpy_basic', basic_ds_columns, ds_types): - test_client.insert('test_numpy_basic', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_basic') + call(param_client.insert, 'test_numpy_basic', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_basic') assert np.array_equal(np_array, new_np_array) assert np.array_equal(source_arr, np_array) -def test_numpy_nulls(test_client: Client, table_context: Callable): +def test_numpy_nulls(param_client: Client, call, table_context: Callable): np_types = [(col_name, 'O') for col_name in null_ds_columns] np_array = np.rec.fromrecords(null_ds, dtype=np_types) source_arr = np_array.copy() with table_context('test_numpy_nulls', null_ds_columns, null_ds_types): - test_client.insert('test_numpy_nulls', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_nulls', use_none=True) + call(param_client.insert, 'test_numpy_nulls', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_nulls', use_none=True) assert list_equal(np_array.tolist(), new_np_array.tolist()) assert list_equal(source_arr.tolist(), np_array.tolist()) -def test_numpy_matrix(test_client: Client, table_context: Callable): +def test_numpy_matrix(param_client: Client, call, table_context: Callable): source = [25000, -37283, 4000, 25770, 40032, 33002, 73086, -403882, 57723, 77382, 1213477, 2, 0, 5777732, 99827616] source_array = np.array(source, dtype='int32') matrix = source_array.reshape((5, 3)) matrix_copy = matrix.copy() with table_context('test_numpy_matrix', ['col1 Int32', 'col2 Int32', 'col3 Int32']): - test_client.insert('test_numpy_matrix', matrix) - py_result = test_client.query('SELECT * FROM test_numpy_matrix').result_set + call(param_client.insert, 'test_numpy_matrix', matrix) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_matrix').result_set assert list(py_result[1]) == [25000, -37283, 4000] - numpy_result = test_client.query_np('SELECT * FROM test_numpy_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_matrix') assert list(numpy_result[1]) == list(py_result[1]) - test_client.command('TRUNCATE TABLE test_numpy_matrix') - numpy_result = test_client.query_np('SELECT * FROM test_numpy_matrix') + call(param_client.command, 'TRUNCATE TABLE test_numpy_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_matrix') assert np.size(numpy_result) == 0 assert np.array_equal(matrix, matrix_copy) -def test_numpy_bigint_matrix(test_client: Client, table_context: Callable): +def test_numpy_bigint_matrix(param_client: Client, call, table_context: Callable): source = [25000, -37283, 4000, 25770, 40032, 33002, 73086, -403882, 57723, 77382, 1213477, 2, 0, 5777732, 99827616] source_array = np.array(source, dtype='int64') matrix = source_array.reshape((5, 3)) matrix_copy = matrix.copy() columns = ['col1 UInt256', 'col2 Int64', 'col3 Int128'] - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['col1 UInt64', 'col2 Int64', 'col3 Int64'] with table_context('test_numpy_bigint_matrix', columns): - test_client.insert('test_numpy_bigint_matrix', matrix) - py_result = test_client.query('SELECT * FROM test_numpy_bigint_matrix').result_set + call(param_client.insert, 'test_numpy_bigint_matrix', matrix) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_bigint_matrix').result_set assert list(py_result[1]) == [25000, -37283, 4000] - numpy_result = test_client.query_np('SELECT * FROM test_numpy_bigint_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_bigint_matrix') assert list(numpy_result[1]) == list(py_result[1]) assert np.array_equal(matrix, matrix_copy) -def test_numpy_bigint_object(test_client: Client, table_context: Callable): +def test_numpy_bigint_object(param_client: Client, call, table_context: Callable): source = [('key1', 347288, datetime.datetime(1999, 10, 15, 12, 3, 44)), ('key2', '348147832478', datetime.datetime.now())] np_array = np.array(source, dtype='O,uint64,datetime64[s]') source_arr = np_array.copy() columns = ['key String', 'big_value UInt256', 'dt DateTime'] - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['key String', 'big_value UInt64', 'dt DateTime'] with table_context('test_numpy_bigint_object', columns): - test_client.insert('test_numpy_bigint_object', np_array) - py_result = test_client.query('SELECT * FROM test_numpy_bigint_object').result_set + call(param_client.insert, 'test_numpy_bigint_object', np_array) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_bigint_object').result_set assert list(py_result[0]) == list(source[0]) - numpy_result = test_client.query_np('SELECT * FROM test_numpy_bigint_object') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_bigint_object') assert list(py_result[1]) == list(numpy_result[1]) assert np.array_equal(source_arr, np_array) -def test_numpy_streams(test_client: Client): - if not test_client.min_version('22'): - pytest.skip(f'generateRandom is not supported in this server version {test_client.server_version}') +def test_numpy_streams(param_client: Client, call, consume_stream): + if not param_client.min_version('22'): + pytest.skip(f'generateRandom is not supported in this server version {param_client.server_version}') runs = os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250') for _ in range(int(runs) // 2): query_rows = random.randint(0, 5000) + 20000 stream_count = 0 row_count = 0 query = random_query(query_rows) - stream = test_client.query_np_stream(query, settings={'max_block_size': 5000}) - with stream: - for np_array in stream: + stream = call(param_client.query_np_stream, query, settings={'max_block_size': 5000}) + + def process(np_array): + nonlocal stream_count, row_count + stream_count += 1 + row_count += np_array.shape[0] + + consume_stream(stream, process) + + assert row_count == query_rows + assert stream_count > 2 + + +@pytest.mark.asyncio +async def test_numpy_streams_async(test_native_async_client): + """Async-only numpy streaming test.""" + if not test_native_async_client.min_version("22"): + pytest.skip(f'generateRandom is not supported in this server version {test_native_async_client.server_version}') + + runs = os.environ.get("CLICKHOUSE_CONNECT_TEST_FUZZ", "250") + for _ in range(int(runs) // 2): + query_rows = random.randint(0, 5000) + 20000 + stream_count = 0 + row_count = 0 + query = random_query(query_rows) + + stream = await test_native_async_client.query_np_stream(query, settings={"max_block_size": 5000}) + async with stream: + async for np_array in stream: stream_count += 1 row_count += np_array.shape[0] assert row_count == query_rows diff --git a/tests/integration_tests/test_pandas.py b/tests/integration_tests/test_pandas.py index 622fcfb1..b136c835 100644 --- a/tests/integration_tests/test_pandas.py +++ b/tests/integration_tests/test_pandas.py @@ -16,23 +16,23 @@ pytestmark = pytest.mark.skipif(pd is None, reason='Pandas package not installed') -def test_pandas_basic(test_client: Client, test_table_engine: str): - df = test_client.query_df('SELECT * FROM system.tables') +def test_pandas_basic(param_client: Client, call, test_table_engine: str): + df = call(param_client.query_df, 'SELECT * FROM system.tables') source_df = df.copy() - test_client.command('DROP TABLE IF EXISTS test_system_insert_pd') - test_client.command(f'CREATE TABLE test_system_insert_pd as system.tables Engine {test_table_engine}' - f' ORDER BY (database, name)') - test_client.insert_df('test_system_insert_pd', df) - new_df = test_client.query_df('SELECT * FROM test_system_insert_pd') - test_client.command('DROP TABLE IF EXISTS test_system_insert_pd') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert_pd') + call(param_client.command, f'CREATE TABLE test_system_insert_pd as system.tables Engine {test_table_engine}' + f' ORDER BY (database, name)') + call(param_client.insert_df, 'test_system_insert_pd', df) + new_df = call(param_client.query_df, 'SELECT * FROM test_system_insert_pd') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert_pd') assert new_df.columns.all() == df.columns.all() assert df.equals(source_df) - df = test_client.query_df("SELECT * FROM system.tables WHERE engine = 'not_a_thing'") + df = call(param_client.query_df, "SELECT * FROM system.tables WHERE engine = 'not_a_thing'") assert len(df) == 0 assert isinstance(df, pd.DataFrame) -def test_pandas_nulls(test_client: Client, table_context: Callable): +def test_pandas_nulls(param_client: Client, call, table_context: Callable): df = pd.DataFrame(null_ds, columns=['key', 'num', 'flt', 'str', 'dt', 'd']) source_df = df.copy() insert_columns = ['key', 'num', 'flt', 'str', 'dt', 'day_col'] @@ -40,13 +40,13 @@ def test_pandas_nulls(test_client: Client, table_context: Callable): 'str String', 'dt DateTime', 'day_col Date']): with pytest.raises(DataError): - test_client.insert_df('test_pandas_nulls_bad', df, column_names=insert_columns) + call(param_client.insert_df, 'test_pandas_nulls_bad', df, column_names=insert_columns) with table_context('test_pandas_nulls_good', ['key String', 'num Nullable(Int32)', 'flt Nullable(Float32)', 'str Nullable(String)', "dt Nullable(DateTime('America/Denver'))", 'day_col Nullable(Date)']): - test_client.insert_df('test_pandas_nulls_good', df, column_names=insert_columns) - result_df = test_client.query_df('SELECT * FROM test_pandas_nulls_good') + call(param_client.insert_df, 'test_pandas_nulls_good', df, column_names=insert_columns) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_nulls_good') assert result_df.iloc[0]['num'] == 1000 assert pd.isna(result_df.iloc[2]['num']) assert result_df.iloc[1]['day_col'] == pd.Timestamp(year=1976, month=5, day=5) @@ -56,18 +56,18 @@ def test_pandas_nulls(test_client: Client, table_context: Callable): assert pd.isna(result_df.iloc[2]['num']) assert pd.isnull(result_df.iloc[3]['flt']) assert result_df['num'].dtype.name == 'Int32' - if test_client.protocol_version: + if param_client.protocol_version: assert isinstance(result_df['dt'].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype) assert result_df.iloc[2]['str'] == 'value3' assert df.equals(source_df) -def test_pandas_all_null_float(test_client: Client): - df = test_client.query_df("SELECT number, cast(NULL, 'Nullable(Float64)') as flt FROM numbers(500)") +def test_pandas_all_null_float(param_client: Client, call): + df = call(param_client.query_df, "SELECT number, cast(NULL, 'Nullable(Float64)') as flt FROM numbers(500)") assert df['flt'].dtype.name == 'float64' -def test_pandas_csv(test_client: Client, table_context: Callable): +def test_pandas_csv(param_client: Client, call, table_context: Callable): csv = """ key,num,flt,str,dt,d key1,555,25.44,string1,2022-11-22 15:00:44,2001-02-14 @@ -78,32 +78,32 @@ def test_pandas_csv(test_client: Client, table_context: Callable): df = df[['num', 'flt']].astype('Float32') source_df = df.copy() with table_context('test_pandas_csv', null_ds_columns, null_ds_types): - test_client.insert_df('test_pandas_csv', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_csv') + call(param_client.insert_df, 'test_pandas_csv', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_csv') assert np.isclose(result_df.iloc[0]['flt'], 25.44) assert pd.isna(result_df.iloc[1]['flt']) - result_df = test_client.query('SELECT * FROM test_pandas_csv') + result_df = call(param_client.query, 'SELECT * FROM test_pandas_csv') assert pd.isna(result_df.result_set[1][2]) assert df.equals(source_df) -def test_pandas_context_inserts(test_client: Client, table_context: Callable): +def test_pandas_context_inserts(param_client: Client, call, table_context: Callable): with table_context('test_pandas_multiple', null_ds_columns, null_ds_types): df = pd.DataFrame(null_ds, columns=null_ds_columns) source_df = df.copy() - insert_context = test_client.create_insert_context('test_pandas_multiple', df.columns) + insert_context = call(param_client.create_insert_context, 'test_pandas_multiple', df.columns) insert_context.data = df - test_client.data_insert(insert_context) - assert test_client.command('SELECT count() FROM test_pandas_multiple') == 4 + call(param_client.data_insert, insert_context) + assert call(param_client.command, 'SELECT count() FROM test_pandas_multiple') == 4 next_df = pd.DataFrame( [['key4', -415, None, 'value4', datetime(2022, 7, 4, 15, 33, 4, 5233), date(1999, 12, 31)]], columns=null_ds_columns) - test_client.insert_df(df=next_df, context=insert_context) - assert test_client.command('SELECT count() FROM test_pandas_multiple') == 5 + call(param_client.insert_df, df=next_df, context=insert_context) + assert call(param_client.command, 'SELECT count() FROM test_pandas_multiple') == 5 assert df.equals(source_df) -def test_pandas_low_card(test_client: Client, table_context: Callable): +def test_pandas_low_card(param_client: Client, call, table_context: Callable): with table_context('test_pandas_low_card', ['key String', 'value LowCardinality(Nullable(String))', 'date_value LowCardinality(Nullable(DateTime))', @@ -117,8 +117,8 @@ def test_pandas_low_card(test_client: Client, table_context: Callable): ], columns=['key', 'value', 'date_value', 'int_value']) source_df = df.copy() - test_client.insert_df('test_pandas_low_card', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_low_card', use_none=True) + call(param_client.insert_df, 'test_pandas_low_card', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_low_card', use_none=True) assert result_df.iloc[0]['value'] == 'test_string_0' assert result_df.iloc[1]['value'] == 'test_string_1' assert result_df.iloc[0]['date_value'] == pd.Timestamp(2022, 10, 15, 4, 25) @@ -129,39 +129,39 @@ def test_pandas_low_card(test_client: Client, table_context: Callable): assert df.equals(source_df) -def test_pandas_large_types(test_client: Client, table_context: Callable): +def test_pandas_large_types(param_client: Client, call, table_context: Callable): columns = ['key String', 'value Int256', 'u_value UInt256' ] key2_value = 30000000000000000000000000000000000 - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['key String', 'value Int64'] key2_value = 3000000000000000000 with table_context('test_pandas_big_int', columns): df = pd.DataFrame([['key1', 2000, 50], ['key2', key2_value, 70], ['key3', -2350, 70]], columns=['key', 'value', 'u_value']) source_df = df.copy() - test_client.insert_df('test_pandas_big_int', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_big_int') + call(param_client.insert_df, 'test_pandas_big_int', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_big_int') assert result_df.iloc[0]['value'] == 2000 assert result_df.iloc[1]['value'] == key2_value assert df.equals(source_df) -def test_pandas_enums(test_client: Client, table_context: Callable): +def test_pandas_enums(param_client: Client, call, table_context: Callable): columns = ['key String', "value Enum8('Moscow' = 0, 'Rostov' = 1, 'Kiev' = 2)", "null_value Nullable(Enum8('red'=0,'blue'=5,'yellow'=10))"] with table_context('test_pandas_enums', columns): df = pd.DataFrame([['key1', 1, 0], ['key2', 0, None]], columns=['key', 'value', 'null_value']) source_df = df.copy() - test_client.insert_df('test_pandas_enums', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_enums ORDER BY key') + call(param_client.insert_df, 'test_pandas_enums', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_enums ORDER BY key') assert result_df.iloc[0]['value'] == 'Rostov' assert result_df.iloc[1]['value'] == 'Moscow' assert result_df.iloc[1]['null_value'] is None assert result_df.iloc[0]['null_value'] == 'red' assert df.equals(source_df) df = pd.DataFrame([['key3', 'Rostov', 'blue'], ['key4', 'Moscow', None]], columns=['key', 'value', 'null_value']) - test_client.insert_df('test_pandas_enums', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_enums ORDER BY key') + call(param_client.insert_df, 'test_pandas_enums', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_enums ORDER BY key') assert result_df.iloc[2]['key'] == 'key3' assert result_df.iloc[2]['value'] == 'Rostov' assert result_df.iloc[3]['value'] == 'Moscow' @@ -169,9 +169,9 @@ def test_pandas_enums(test_client: Client, table_context: Callable): assert result_df.iloc[3]['null_value'] is None -def test_pandas_datetime64(test_client: Client, table_context: Callable): - if not test_client.min_version('20'): - pytest.skip(f'DateTime64 not supported in this server version {test_client.server_version}') +def test_pandas_datetime64(param_client: Client, call, table_context: Callable): + if not param_client.min_version('20'): + pytest.skip(f'DateTime64 not supported in this server version {param_client.server_version}') nano_timestamp = pd.Timestamp(1992, 11, 6, 12, 50, 40, 7420, 44) milli_timestamp = pd.Timestamp(2022, 5, 3, 10, 44, 10, 55000) chicago_timestamp = milli_timestamp.tz_localize('America/Chicago') @@ -184,8 +184,8 @@ def test_pandas_datetime64(test_client: Client, table_context: Callable): ['key2', nano_timestamp, milli_timestamp, chicago_timestamp]], columns=['key', 'nanos', 'millis', 'chicago']) source_df = df.copy() - test_client.insert_df('test_pandas_dt64', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_dt64') + call(param_client.insert_df, 'test_pandas_dt64', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_dt64') assert result_df.iloc[0]['nanos'] == now assert result_df.iloc[1]['nanos'] == nano_timestamp assert result_df.iloc[1]['millis'] == milli_timestamp @@ -194,37 +194,40 @@ def test_pandas_datetime64(test_client: Client, table_context: Callable): test_dt = np.array(['2017-11-22 15:42:58.270000+00:00'][0]) assert df.equals(source_df) df = pd.DataFrame([['key3', pd.to_datetime(test_dt)]], columns=['key', 'nanos']) - test_client.insert_df('test_pandas_dt64', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_dt64 WHERE key = %s', parameters=('key3',)) + call(param_client.insert_df, 'test_pandas_dt64', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_dt64 WHERE key = %s', parameters=('key3',)) assert result_df.iloc[0]['nanos'].second == 58 -def test_pandas_streams(test_client: Client): - if not test_client.min_version('22'): - pytest.skip(f'generateRandom is not supported in this server version {test_client.server_version}') +def test_pandas_streams(param_client: Client, call, consume_stream): + if not param_client.min_version('22'): + pytest.skip(f'generateRandom is not supported in this server version {param_client.server_version}') runs = os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250') for _ in range(int(runs) // 2): query_rows = random.randint(0, 5000) + 20000 stream_count = 0 row_count = 0 query = random_query(query_rows, date32=False) - stream = test_client.query_df_stream(query, settings={'max_block_size': 5000}) - with stream: - for df in stream: - stream_count += 1 - row_count += len(df) + stream = call(param_client.query_df_stream, query, settings={'max_block_size': 5000}) + + def process(df): + nonlocal stream_count, row_count + stream_count += 1 + row_count += len(df) + + consume_stream(stream, process) assert row_count == query_rows assert stream_count > 2 -def test_pandas_date(test_client: Client, table_context:Callable): +def test_pandas_date(param_client: Client, call, table_context: Callable): with table_context('test_pandas_date', ['key UInt32', 'dt Date', 'null_dt Nullable(Date)']): df = pd.DataFrame([[1, pd.Timestamp(1992, 10, 15), pd.Timestamp(2023, 5, 4)], [2, pd.Timestamp(2088, 1, 31), pd.NaT], [3, pd.Timestamp(1971, 4, 15), pd.Timestamp(2101, 12, 31)]], columns=['key', 'dt', 'null_dt']) - test_client.insert_df('test_pandas_date', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_date') + call(param_client.insert_df, 'test_pandas_date', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_date') assert result_df.iloc[0]['dt'] == pd.Timestamp(1992, 10, 15) assert result_df.iloc[1]['dt'] == pd.Timestamp(2088, 1, 31) assert result_df.iloc[0]['null_dt'] == pd.Timestamp(2023, 5, 4) @@ -232,14 +235,14 @@ def test_pandas_date(test_client: Client, table_context:Callable): assert result_df.iloc[2]['null_dt'] == pd.Timestamp(2101, 12, 31) -def test_pandas_date32(test_client: Client, table_context:Callable): +def test_pandas_date32(param_client: Client, call, table_context: Callable): with table_context('test_pandas_date32', ['key UInt32', 'dt Date32', 'null_dt Nullable(Date32)']): df = pd.DataFrame([[1, pd.Timestamp(1992, 10, 15), pd.Timestamp(2023, 5, 4)], [2, pd.Timestamp(2088, 1, 31), pd.NaT], [3, pd.Timestamp(1968, 4, 15), pd.Timestamp(2101, 12, 31)]], columns=['key', 'dt', 'null_dt']) - test_client.insert_df('test_pandas_date32', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_date32') + call(param_client.insert_df, 'test_pandas_date32', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_date32') assert result_df.iloc[1]['dt'] == pd.Timestamp(2088, 1, 31) assert result_df.iloc[0]['dt'] == pd.Timestamp(1992, 10, 15) assert result_df.iloc[0]['null_dt'] == pd.Timestamp(2023, 5, 4) @@ -248,15 +251,15 @@ def test_pandas_date32(test_client: Client, table_context:Callable): assert result_df.iloc[2]['dt'] == pd.Timestamp(1968, 4, 15) -def test_pandas_row_df(test_client: Client, table_context:Callable): +def test_pandas_row_df(param_client: Client, call, table_context: Callable): with table_context('test_pandas_row_df', ['key UInt64', 'dt DateTime64(6)', 'fs FixedString(5)']): df = pd.DataFrame({'key': [1, 2], 'dt': [pd.Timestamp(2023, 5, 4, 10, 20), pd.Timestamp(2023, 10, 15, 14, 50, 2, 4038)], 'fs': ['seven', 'bit']}) df = df.iloc[1:] source_df = df.copy() - test_client.insert_df('test_pandas_row_df', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_row_df', column_formats={'fs': 'string'}) + call(param_client.insert_df, 'test_pandas_row_df', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_row_df', column_formats={'fs': 'string'}) assert str(result_df.dtypes.iloc[2]) == 'string' assert result_df.iloc[0]['key'] == 2 assert result_df.iloc[0]['dt'] == pd.Timestamp(2023, 10, 15, 14, 50, 2, 4038) @@ -265,30 +268,30 @@ def test_pandas_row_df(test_client: Client, table_context:Callable): assert source_df.equals(df) -def test_pandas_null_strings(test_client: Client, table_context:Callable): +def test_pandas_null_strings(param_client: Client, call, table_context: Callable): with table_context('test_pandas_null_strings', ['id String', 'test_col LowCardinality(String)']): row = {'id': 'id', 'test_col': None} df = pd.DataFrame([row]) assert df['test_col'].isnull().values.all() with pytest.raises(DataError): - test_client.insert_df('test_pandas_null_strings', df) + call(param_client.insert_df, 'test_pandas_null_strings', df) row2 = {'id': 'id2', 'test_col': 'val'} df = pd.DataFrame([row, row2]) with pytest.raises(DataError): - test_client.insert_df('test_pandas_null_strings', df) + call(param_client.insert_df, 'test_pandas_null_strings', df) -def test_pandas_small_blocks(test_config: TestConfig, test_client: Client): +def test_pandas_small_blocks(test_config: TestConfig, param_client: Client, call): if test_config.cloud: pytest.skip('Skipping performance test in ClickHouse Cloud') - res = test_client.query_df('SELECT number, randomString(512) FROM numbers(1000000)', + res = call(param_client.query_df, 'SELECT number, randomString(512) FROM numbers(1000000)', settings={'max_block_size': 250}) assert len(res) == 1000000 -def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable): - if not test_client.min_version('25.2'): - pytest.skip(f'Nullable(JSON) type not available in this version: {test_client.server_version}') +def test_pandas_string_to_df_insert(param_client: Client, call, table_context: Callable): + if not param_client.min_version('25.2'): + pytest.skip(f'Nullable(JSON) type not available in this version: {param_client.server_version}') with table_context( "test_pandas_string_to_df_insert", [ @@ -325,8 +328,8 @@ def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable ] df = pd.DataFrame(data) - test_client.insert_df("test_pandas_string_to_df_insert", df) - result_df = test_client.query_df( + call(param_client.insert_df, "test_pandas_string_to_df_insert", df) + result_df = call(param_client.query_df, "SELECT * FROM test_pandas_string_to_df_insert ORDER BY id" ) @@ -336,10 +339,10 @@ def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable def test_pandas_time( - test_config: TestConfig, test_client: Client, table_context: Callable + test_config: TestConfig, param_client: Client, call, table_context: Callable ): """Round trip test for Time types""" - if not test_client.min_version("25.6"): + if not param_client.min_version("25.6"): pytest.skip("Time and types require ClickHouse 25.6+") if test_config.cloud: @@ -348,7 +351,6 @@ def test_pandas_time( ) table_name = "time_tests" - test_client.command("SET enable_time_time64_type = 1") with table_context( table_name, @@ -356,6 +358,7 @@ def test_pandas_time( "t Time", "nt Nullable(Time)", ], + settings={"enable_time_time64_type": 1}, ): test_data = { "t": [timedelta(seconds=1), timedelta(seconds=2)], @@ -363,11 +366,10 @@ def test_pandas_time( } df = pd.DataFrame(test_data) - test_client.insert(table_name, df) + call(param_client.insert, table_name, df) - df_res = test_client.query_df(f"SELECT * FROM {table_name}") - print(df_res.to_string()) - print(df_res.dtypes) + df_res = call(param_client.query_df, f"SELECT * FROM {table_name}") + print(df_res) assert df_res["t"][0] == pd.Timedelta("0 days 00:00:01") assert df_res["t"][1] == pd.Timedelta("0 days 00:00:02") assert df_res["nt"][0] == pd.Timedelta("0 days 00:01:00") @@ -375,10 +377,10 @@ def test_pandas_time( def test_pandas_time64( - test_config: TestConfig, test_client: Client, table_context: Callable + test_config: TestConfig, param_client: Client, call, table_context: Callable ): """Round trip test for Time64 types""" - if not test_client.min_version("25.6"): + if not param_client.min_version("25.6"): pytest.skip("Time64 types require ClickHouse 25.6+") if test_config.cloud: @@ -387,7 +389,6 @@ def test_pandas_time64( ) table_name = "time64_tests" - test_client.command("SET enable_time_time64_type = 1") with table_context( table_name, @@ -399,6 +400,7 @@ def test_pandas_time64( "t64_9 Time64(9)", "nt64_9 Nullable(Time64(9))", ], + settings={"enable_time_time64_type": 1}, ): test_data = { "t64_3": [1, 2], @@ -410,10 +412,10 @@ def test_pandas_time64( } df = pd.DataFrame(test_data) - test_client.insert(table_name, df) + call(param_client.insert, table_name, df) # Make sure the df insert worked correctly - int_res = test_client.query( + int_res = call(param_client.query, f"SELECT * FROM {table_name}", query_formats={"Time": "int", "Time64": "int"}, ) @@ -421,7 +423,7 @@ def test_pandas_time64( assert rows[0] == (1, 45000, 10500000, 60, 1100000000, 30000500000) assert rows[1] == (2, None, 10000000, None, 600000000000, None) - df_res = test_client.query_df(f"SELECT * FROM {table_name}") + df_res = call(param_client.query_df, f"SELECT * FROM {table_name}") expected_row_0 = [ pd.Timedelta(t) for t in [ diff --git a/tests/integration_tests/test_pandas_compat.py b/tests/integration_tests/test_pandas_compat.py index 15ab9a71..5723d498 100644 --- a/tests/integration_tests/test_pandas_compat.py +++ b/tests/integration_tests/test_pandas_compat.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.skipif(pd is None, reason="Pandas package not installed") -def test_pandas_date_compat(test_client: Client, table_context: Callable): +def test_pandas_date_compat(param_client: Client, call, table_context: Callable): table_name = "test_date" with table_context( table_name, @@ -32,15 +32,15 @@ def test_pandas_date_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -51,7 +51,7 @@ def test_pandas_date_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_date32_compat(test_client: Client, table_context: Callable): +def test_pandas_date32_compat(param_client: Client, call, table_context: Callable): table_name = "test_date32" with table_context( table_name, @@ -69,15 +69,15 @@ def test_pandas_date32_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -88,7 +88,7 @@ def test_pandas_date32_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_datetime_compat(test_client: Client, table_context: Callable): +def test_pandas_datetime_compat(param_client: Client, call, table_context: Callable): table_name = "test_datetime" with table_context( table_name, @@ -106,15 +106,15 @@ def test_pandas_datetime_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -125,7 +125,7 @@ def test_pandas_datetime_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): +def test_pandas_datetime64_compat(param_client: Client, call, table_context: Callable): table_name = "test_datetime64" with table_context( table_name, @@ -162,15 +162,15 @@ def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt3", "null_dt3", "dt6", "null_dt6", "dt9", "null_dt9"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: dts = list(result_df.dtypes)[1:] @@ -183,18 +183,13 @@ def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_time_compat( - test_config: TestConfig, - test_client: Client, - table_context: Callable, -): - if not test_client.min_version("25.6"): +def test_pandas_time_compat(test_config: TestConfig, param_client: Client, call, table_context: Callable): + if not param_client.min_version("25.6"): pytest.skip("Time types require ClickHouse 25.6+") if test_config.cloud: pytest.skip("Time types require settings change, but settings are locked in cloud, skipping tests.") - test_client.command("SET enable_time_time64_type = 1") table_name = "test_time" with table_context( table_name, @@ -202,7 +197,7 @@ def test_pandas_time_compat( "key UInt8", "t Time", "null_t Nullable(Time)", - ], + ], settings={"enable_time_time64_type": 1} ): data = ( [1, datetime.timedelta(hours=5), 500], @@ -210,15 +205,15 @@ def test_pandas_time_compat( [3, -datetime.timedelta(minutes=45), 600], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -229,18 +224,13 @@ def test_pandas_time_compat( assert res in str(dt) -def test_pandas_time64_compat( - test_config: TestConfig, - test_client: Client, - table_context: Callable, -): - if not test_client.min_version("25.6"): +def test_pandas_time64_compat(test_config: TestConfig, param_client: Client, call, table_context: Callable): + if not param_client.min_version("25.6"): pytest.skip("Time64 types require ClickHouse 25.6+") if test_config.cloud: pytest.skip("Time types require settings change, but settings are locked in cloud, skipping tests.") - test_client.command("SET enable_time_time64_type = 1") table_name = "test_time64" with table_context( table_name, @@ -252,22 +242,22 @@ def test_pandas_time64_compat( "null_t6 Nullable(Time64(6))", "t9 Time64(9)", "null_t9 Nullable(Time64(9))", - ], + ], settings={"enable_time_time64_type": 1} ): data = ( [1, 1, 2, 3, 4, 5, 6], [2, 10, None, 30, None, 50, None], [3, 100, 200, 300, 400, 500, 600], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: dts = list(result_df.dtypes)[1:] @@ -280,7 +270,7 @@ def test_pandas_time64_compat( assert res in str(dt) -def test_pandas_query_df_arrow(test_client: Client, table_context: Callable): +def test_pandas_query_df_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") @@ -304,17 +294,17 @@ def test_pandas_query_df_arrow(test_client: Client, table_context: Callable): [2, pd.Timestamp(2023, 5, 5), None, -45678912, 8.5555588888, "string 2", None, 1], [3, pd.Timestamp(2023, 5, 6), 30, 789123456, 3.14159, "string 3", None, 1], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) if IS_PANDAS_2: - result_df = test_client.query_df_arrow(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes): assert isinstance(dt, pd.ArrowDtype) else: with pytest.raises(ProgrammingError): - result_df = test_client.query_df_arrow(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") -def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): +def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") @@ -332,12 +322,12 @@ def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): ): if IS_PANDAS_2: df = df.convert_dtypes(dtype_backend="pyarrow") - test_client.insert_df_arrow(table_name, df) - res_df = test_client.query(f"SELECT * from {table_name} ORDER BY i64") + call(param_client.insert_df_arrow, table_name, df) + res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] else: with pytest.raises(ProgrammingError, match="pandas 2.x"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) with table_context( table_name, @@ -351,7 +341,7 @@ def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) df["i64"] = df["i64"].astype(pd.ArrowDtype(arrow.int64())) with pytest.raises(ProgrammingError, match="Non-Arrow columns found"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) else: with pytest.raises(ProgrammingError, match="pandas 2.x"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) diff --git a/tests/integration_tests/test_params.py b/tests/integration_tests/test_params.py index 98c4fabe..bd271cd7 100644 --- a/tests/integration_tests/test_params.py +++ b/tests/integration_tests/test_params.py @@ -5,84 +5,84 @@ from clickhouse_connect.driver.binding import DT64Param -def test_params(test_client: Client, table_context: Callable): - result = test_client.query('SELECT name, database FROM system.tables WHERE database = {db:String}', +def test_params(param_client: Client, call, table_context: Callable): + result = call(param_client.query, 'SELECT name, database FROM system.tables WHERE database = {db:String}', parameters={'db': 'system'}) assert result.first_item['database'] == 'system' - if test_client.min_version('21'): - result = test_client.query('SELECT name, {col:String} FROM system.tables WHERE table ILIKE {t:String}', + if param_client.min_version('21'): + result = call(param_client.query, 'SELECT name, {col:String} FROM system.tables WHERE table ILIKE {t:String}', parameters={'t': '%rr%', 'col': 'database'}) assert 'rr' in result.first_item['name'] first_date = datetime.strptime('Jun 1 2005 1:33PM', '%b %d %Y %I:%M%p') - first_date = test_client.server_tz.localize(first_date) + first_date = param_client.server_tz.localize(first_date) second_date = datetime.strptime('Dec 25 2022 5:00AM', '%b %d %Y %I:%M%p') - second_date = test_client.server_tz.localize(second_date) + second_date = param_client.server_tz.localize(second_date) with table_context('test_bind_params', ['key UInt64', 'dt DateTime', 'value String', 't Tuple(String, String)']): - test_client.insert('test_bind_params', + call(param_client.insert, 'test_bind_params', [[1, first_date, 'v11', ('one', 'two')], [2, second_date, 'v21', ('t1', 't2')], [3, datetime.now(), 'v31', ('str1', 'str2')]]) - result = test_client.query('SELECT * FROM test_bind_params WHERE dt = {dt:DateTime}', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE dt = {dt:DateTime}', parameters={'dt': second_date}) assert result.first_item['key'] == 2 - result = test_client.query('SELECT * FROM test_bind_params WHERE dt = %(dt)s', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE dt = %(dt)s', parameters={'dt': first_date}) assert result.first_item['key'] == 1 - result = test_client.query("SELECT * FROM test_bind_params WHERE value != %(v)s AND value like '%%1'", + result = call(param_client.query, "SELECT * FROM test_bind_params WHERE value != %(v)s AND value like '%%1'", parameters={'v': 'v11'}) assert result.row_count == 2 - result = test_client.query('SELECT * FROM test_bind_params WHERE value IN %(tp)s', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE value IN %(tp)s', parameters={'tp': ('v18', 'v31')}) assert result.first_item['key'] == 3 - result = test_client.query('SELECT number FROM numbers(10) WHERE {n:Nullable(String)} IS NULL', + result = call(param_client.query, 'SELECT number FROM numbers(10) WHERE {n:Nullable(String)} IS NULL', parameters={'n': None}).result_rows assert len(result) == 10 date_params = [date(2023, 6, 1), date(2023, 8, 5)] - result = test_client.query('SELECT {l:Array(Date)}', parameters={'l': date_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(Date)}', parameters={'l': date_params}).first_row assert date_params == result[0] dt_params = [datetime(2023, 6, 1, 7, 40, 2), datetime(2023, 8, 17, 20, 0, 10)] - result = test_client.query('SELECT {l:Array(DateTime)}', parameters={'l': dt_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(DateTime)}', parameters={'l': dt_params}).first_row assert dt_params == result[0] num_array_params = [2.5, 5.3, 7.4] - result = test_client.query('SELECT {l:Array(Float64)}', parameters={'l': num_array_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(Float64)}', parameters={'l': num_array_params}).first_row assert num_array_params == result[0] - result = test_client.query('SELECT %(l)s', parameters={'l': num_array_params}).first_row + result = call(param_client.query, 'SELECT %(l)s', parameters={'l': num_array_params}).first_row assert num_array_params == result[0] tp_params = ('str1', 'str2') - result = test_client.query('SELECT %(tp)s', parameters={'tp': tp_params}).first_row + result = call(param_client.query, 'SELECT %(tp)s', parameters={'tp': tp_params}).first_row assert tp_params == result[0] num_params = {'p_0': 2, 'p_1': 100523.55} - result = test_client.query( + result = call(param_client.query, 'SELECT count() FROM system.tables WHERE total_rows > %(p_0)d and total_rows < %(p_1)f', parameters=num_params) assert result.first_row[0] > 0 -def test_datetime_64_params(test_client: Client): +def test_datetime_64_params(param_client: Client, call): dt_values = [datetime(2023, 6, 1, 7, 40, 2, 250306), datetime(2023, 8, 17, 20, 0, 10, 777722)] dt_params = {f'd{ix}': DT64Param(v) for ix, v in enumerate(dt_values)} - result = test_client.query('SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row + result = call(param_client.query, 'SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row assert result[0] == dt_values[0].replace(microsecond=250000) assert result[1] == dt_values[1] - result = test_client.query('SELECT {a1:Array(DateTime64(6))}', parameters={'a1': [dt_params['d0'], dt_params['d1']]}).first_row + result = call(param_client.query, 'SELECT {a1:Array(DateTime64(6))}', parameters={'a1': [dt_params['d0'], dt_params['d1']]}).first_row assert result[0] == dt_values dt_params = {f'd{ix}_64': v for ix, v in enumerate(dt_values)} - result = test_client.query('SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row + result = call(param_client.query, 'SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row assert result[0] == dt_values[0].replace(microsecond=250000) assert result[1] == dt_values[1] - result = test_client.query('SELECT {a1:Array(DateTime64(6))}', + result = call(param_client.query, 'SELECT {a1:Array(DateTime64(6))}', parameters={'a1_64': dt_values}).first_row assert result[0] == dt_values dt_params = [DT64Param(v) for v in dt_values] - result = test_client.query("SELECT %s as string, toDateTime64(%s,6) as dateTime", parameters = dt_params).first_row + result = call(param_client.query, "SELECT %s as string, toDateTime64(%s,6) as dateTime", parameters = dt_params).first_row assert result == ('2023-06-01 07:40:02.250306', dt_values[1]) diff --git a/tests/integration_tests/test_polars.py b/tests/integration_tests/test_polars.py index 136353e7..63137dca 100644 --- a/tests/integration_tests/test_polars.py +++ b/tests/integration_tests/test_polars.py @@ -13,7 +13,7 @@ ] -def test_polars_insert(test_client: Client, table_context: Callable): +def test_polars_insert(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -35,12 +35,12 @@ def test_polars_insert(test_client: Client, table_context: Callable): "day_col": [date(2025, 7, 1), date(2025, 8, 1), date(2025, 8, 12)], } ) - test_client.insert_df_arrow(ctx.table, df) - res = test_client.query(f"SELECT key FROM {ctx.table}") + call(param_client.insert_df_arrow, ctx.table, df) + res = call(param_client.query, f"SELECT key FROM {ctx.table}") assert [i[0] for i in res.result_rows] == df["key"].to_list() -def test_bad_insert_fails(test_client: Client, table_context: Callable): +def test_bad_insert_fails(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -53,10 +53,10 @@ def test_bad_insert_fails(test_client: Client, table_context: Callable): ], ): with pytest.raises(TypeError, match="got list"): - test_client.insert_df_arrow("test_polars", [[1, 2, 3]]) + call(param_client.insert_df_arrow, "test_polars", [[1, 2, 3]]) -def test_polars_query(test_client: Client, table_context: Callable): +def test_polars_query(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -76,32 +76,44 @@ def test_polars_query(test_client: Client, table_context: Callable): [datetime(2025, 7, 1, 10, 30, 0, 0), datetime(2025, 8, 1, 10, 30, 0, 0), datetime(2025, 8, 12, 10, 30, 1, 0)], [date(2025, 7, 1), date(2025, 8, 1), date(2025, 8, 12)], ] - test_client.insert( + call(param_client.insert, ctx.table, data, column_names=["key", "num", "flt", "str", "dt", "day_col"], column_oriented=True, ) - df = test_client.query_df_arrow(f"SELECT key FROM {ctx.table}", dataframe_library="polars") + df = call(param_client.query_df_arrow, f"SELECT key FROM {ctx.table}", dataframe_library="polars") assert isinstance(df, pl.DataFrame) assert data[0] == df["key"].to_list() -def test_polars_arrow_stream(test_client: Client, table_context: Callable): +# pylint: disable=too-many-locals +def test_polars_arrow_stream(param_client: Client, call, client_mode, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") - if not test_client.min_version("21"): - pytest.skip(f"PyArrow is not supported in this server version {test_client.server_version}") - with table_context("test_arrow_insert", ["counter Int64", "letter String"]): + if not param_client.min_version("21"): + pytest.skip(f"PyArrow is not supported in this server version {param_client.server_version}") + with table_context("test_arrow_insert", ["counter Int64", "letter String"]) as ctx: counter = arrow.array(range(1000000)) alphabet = string.ascii_lowercase letter = arrow.array([alphabet[x % 26] for x in range(1000000)]) names = ["counter", "letter"] insert_table = arrow.Table.from_arrays([counter, letter], names=names) - test_client.insert_arrow("test_arrow_insert", insert_table) - stream = test_client.query_df_arrow_stream("SELECT * FROM test_arrow_insert", dataframe_library="polars") - with stream: - result_dfs = list(stream) + call(param_client.insert_arrow, ctx.table, insert_table) + stream = call(param_client.query_df_arrow_stream, f"SELECT * FROM {ctx.table}", dataframe_library="polars") + + result_dfs = [] + + if client_mode == "sync": + with stream: + for df in stream: + result_dfs.append(df) + else: + async def consume_async(): + async with stream: + async for df in stream: + result_dfs.append(df) + call(consume_async) assert len(result_dfs) > 1 total_rows = 0 diff --git a/tests/integration_tests/test_protocol_version.py b/tests/integration_tests/test_protocol_version.py index bbe56cfa..a9d63635 100644 --- a/tests/integration_tests/test_protocol_version.py +++ b/tests/integration_tests/test_protocol_version.py @@ -1,12 +1,12 @@ from clickhouse_connect.driver import Client -def test_protocol_version(test_client: Client): +def test_protocol_version(param_client: Client, call): query = "select toDateTime(1676369730, 'Asia/Shanghai') as dt FORMAT Native" - raw = test_client.raw_query(query) + raw = call(param_client.raw_query, query) assert raw.hex() == '0101026474084461746554696d65425feb63' - if test_client.min_version('23.3'): - raw = test_client.raw_query(query, settings={'client_protocol_version': 54337}) + if param_client.min_version('23.3'): + raw = call(param_client.raw_query, query, settings={'client_protocol_version': 54337}) ch_type = raw[14:39].decode() assert ch_type == "DateTime('Asia/Shanghai')" diff --git a/tests/integration_tests/test_proxy.py b/tests/integration_tests/test_proxy.py index df5746bf..cf8d2b11 100644 --- a/tests/integration_tests/test_proxy.py +++ b/tests/integration_tests/test_proxy.py @@ -4,63 +4,74 @@ import pytest from urllib3 import ProxyManager -import clickhouse_connect from tests.integration_tests.conftest import TestConfig -def test_proxies(test_config: TestConfig): +# pylint: disable=protected-access +def test_proxies(client_factory, call, test_config: TestConfig): if not test_config.proxy_address: pytest.skip('Proxy address not configured') if test_config.port in (8123, 10723): - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, http_proxy=test_config.proxy_address) - assert '2' in client.command('SELECT version()') - client.close() + assert '2' in call(client.command, 'SELECT version()') try: os.environ['HTTP_PROXY'] = f'http://{test_config.proxy_address}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password) - assert isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert hasattr(client, '_proxy_url') and client._proxy_url is not None + assert '2' in call(client.command, 'SELECT version()') os.environ['no_proxy'] = f'{test_config.host}:{test_config.port}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password) - assert not isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + # Check proxy is NOT configured + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert not isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert not hasattr(client, '_proxy_url') or client._proxy_url is None + assert '2' in call(client.command, 'SELECT version()') finally: os.environ.pop('HTTP_PROXY', None) os.environ.pop('no_proxy', None) else: cert_file = f'{Path(__file__).parent}/proxy_ca_cert.crt' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, ca_cert=cert_file, https_proxy=test_config.proxy_address) - assert '2' in client.command('SELECT version()') - client.close() + assert '2' in call(client.command, 'SELECT version()') try: os.environ['HTTPS_PROXY'] = f'{test_config.proxy_address}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, ca_cert=cert_file) - assert isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert hasattr(client, '_proxy_url') and client._proxy_url is not None + assert '2' in call(client.command, 'SELECT version()') finally: os.environ.pop('HTTPS_PROXY', None) diff --git a/tests/integration_tests/test_pyarrow_ddl_integration.py b/tests/integration_tests/test_pyarrow_ddl_integration.py index cd3106e3..14d9f0cd 100644 --- a/tests/integration_tests/test_pyarrow_ddl_integration.py +++ b/tests/integration_tests/test_pyarrow_ddl_integration.py @@ -13,15 +13,15 @@ pytest.importorskip("pyarrow") -def test_arrow_create_table_and_insert(test_client: Client): - if not test_client.min_version("20"): +def test_arrow_create_table_and_insert(param_client: Client, call): + if not param_client.min_version("20"): pytest.skip( - f"Not supported server version {test_client.server_version}" + f"Not supported server version {param_client.server_version}" ) table_name = "test_arrow_basic_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -38,7 +38,7 @@ def test_arrow_create_table_and_insert(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -50,9 +50,9 @@ def test_arrow_create_table_and_insert(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query( + result = call(param_client.query, f"SELECT id, name, score, flag FROM {table_name} ORDER BY id" ) assert result.result_rows == [ @@ -60,13 +60,13 @@ def test_arrow_create_table_and_insert(test_client: Client): (2, "b", 2.5, False), ] - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") -def test_arrow_schema_to_column_defs(test_client: Client): +def test_arrow_schema_to_column_defs(param_client: Client, call): table_name = "test_arrow_manual_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -84,7 +84,7 @@ def test_arrow_schema_to_column_defs(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -94,26 +94,26 @@ def test_arrow_schema_to_column_defs(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query(f"SELECT id, name FROM {table_name} ORDER BY id") + result = call(param_client.query, f"SELECT id, name FROM {table_name} ORDER BY id") assert result.result_rows == [ (10, "x"), (20, "y"), ] - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") -def test_arrow_datetime_create_and_insert(test_client: Client): - if not test_client.min_version("20"): +def test_arrow_datetime_create_and_insert(param_client: Client, call): + if not param_client.min_version("20"): pytest.skip( - f"Not supported server version {test_client.server_version}" + f"Not supported server version {param_client.server_version}" ) table_name = "test_arrow_datetime_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -130,7 +130,7 @@ def test_arrow_datetime_create_and_insert(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -148,9 +148,9 @@ def test_arrow_datetime_create_and_insert(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query( + result = call(param_client.query, f"SELECT id, event_date, event_ts, event_ts_tz " f"FROM {table_name} ORDER BY id" ) @@ -162,4 +162,4 @@ def test_arrow_datetime_create_and_insert(test_client: Client): assert rows[1][0] == 2 assert str(rows[1][1]) == "2025-01-02" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") diff --git a/tests/integration_tests/test_raw_insert.py b/tests/integration_tests/test_raw_insert.py index 518088ac..daf93c86 100644 --- a/tests/integration_tests/test_raw_insert.py +++ b/tests/integration_tests/test_raw_insert.py @@ -4,31 +4,31 @@ from clickhouse_connect.driver import Client -def test_raw_insert(test_client: Client, table_context: Callable): +def test_raw_insert(param_client: Client, call, table_context: Callable): with table_context('test_raw_insert', ["`weir'd` String", 'value String']): csv = 'value1\nvalue2' - test_client.raw_insert('test_raw_insert', ['"weir\'d"'], csv.encode(), fmt='CSV') - result = test_client.query('SELECT * FROM test_raw_insert') + call(param_client.raw_insert, 'test_raw_insert', ['"weir\'d"'], csv.encode(), fmt='CSV') + result = call(param_client.query, 'SELECT * FROM test_raw_insert') assert result.result_set[1][0] == 'value2' - test_client.command('TRUNCATE TABLE test_raw_insert') + call(param_client.command, 'TRUNCATE TABLE test_raw_insert') tsv = 'weird1\tvalue__`2\nweird2\tvalue77' - test_client.raw_insert('test_raw_insert', ["`weir'd`", 'value'], tsv, fmt='TSV') - result = test_client.query('SELECT * FROM test_raw_insert') + call(param_client.raw_insert, 'test_raw_insert', ["`weir'd`", 'value'], tsv, fmt='TSV') + result = call(param_client.query, 'SELECT * FROM test_raw_insert') assert result.result_set[0][1] == 'value__`2' assert result.result_set[1][1] == 'value77' -def test_raw_insert_compression(test_client: Client, table_context: Callable): +def test_raw_insert_compression(param_client: Client, call, table_context: Callable): data_file = f'{Path(__file__).parent}/movies.csv.gz' with open(data_file, mode='rb') as movies_file: data = movies_file.read() with table_context('test_gzip_movies', ['movie String', 'year UInt16', 'rating Decimal32(3)']): - test_client.raw_insert('test_gzip_movies', None, data, fmt='CSV', compression='gzip', + call(param_client.raw_insert, 'test_gzip_movies', None, data, fmt='CSV', compression='gzip', settings={'input_format_allow_errors_ratio': .2, 'input_format_allow_errors_num': 5} ) - res = test_client.query( + res = call(param_client.query, 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_gzip_movies').first_item assert res['count'] == 248 assert res['year'] == 2022 diff --git a/tests/integration_tests/test_session_id.py b/tests/integration_tests/test_session_id.py index 17ac9112..11a656b5 100644 --- a/tests/integration_tests/test_session_id.py +++ b/tests/integration_tests/test_session_id.py @@ -4,25 +4,22 @@ import pytest -from clickhouse_connect.driver import create_async_client -from tests.integration_tests.conftest import TestConfig - SESSION_KEY = 'session_id' -def test_client_default_session_id(test_create_client: Callable): +def test_client_default_session_id(client_factory: Callable): # by default, the sync client will autogenerate the session id - client = test_create_client() + # for async clients, we need to explicitly enable it + client = client_factory(autogenerate_session_id=True) session_id = client.get_client_setting(SESSION_KEY) try: uuid.UUID(session_id) except ValueError: pytest.fail(f"Invalid session_id: {session_id}") - client.close() -def test_client_autogenerate_session_id(test_create_client: Callable): - client = test_create_client() +def test_client_autogenerate_session_id(client_factory: Callable): + client = client_factory(autogenerate_session_id=True) session_id = client.get_client_setting(SESSION_KEY) try: uuid.UUID(session_id) @@ -30,49 +27,23 @@ def test_client_autogenerate_session_id(test_create_client: Callable): pytest.fail(f"Invalid session_id: {session_id}") -def test_client_custom_session_id(test_create_client: Callable): +def test_client_custom_session_id(client_factory: Callable): session_id = 'custom_session_id' - client = test_create_client(session_id=session_id) + client = client_factory(session_id=session_id) assert client.get_client_setting(SESSION_KEY) == session_id - client.close() -@pytest.mark.asyncio -async def test_async_client_default_session_id(test_config: TestConfig): - # by default, the async client will NOT autogenerate the session id - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password) - assert async_client.get_client_setting(SESSION_KEY) is None - await async_client.close() +def test_explicit_session_id(client_factory: Callable, call): + """Test explicit session_id allows sharing state like temp tables.""" + session_id = f"test_session_{uuid.uuid4()}" + client = client_factory(session_id=session_id) + assert client.get_client_setting("session_id") == session_id -@pytest.mark.asyncio -async def test_async_client_autogenerate_session_id(test_config: TestConfig): - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password, - autogenerate_session_id=True) - session_id = async_client.get_client_setting(SESSION_KEY) - try: - uuid.UUID(session_id) - except ValueError: - pytest.fail(f"Invalid session_id: {session_id}") - await async_client.close() - + call(client.command, "CREATE TEMPORARY TABLE temp_test (id UInt32, val String)") + call(client.command, "INSERT INTO temp_test VALUES (1, 'a'), (2, 'b')") -@pytest.mark.asyncio -async def test_async_client_custom_session_id(test_config: TestConfig): - session_id = 'custom_session_id' - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password, - session_id=session_id) - assert async_client.get_client_setting(SESSION_KEY) == session_id - await async_client.close() + result = call(client.query, "SELECT * FROM temp_test ORDER BY id") + assert result.row_count == 2 + assert result.result_rows[0] == (1, "a") + assert result.result_rows[1] == (2, "b") diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 8be1ed82..43eab850 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -1,81 +1,127 @@ import random import string +import pytest -from clickhouse_connect.driver import Client from clickhouse_connect.driver.exceptions import StreamClosedError, ProgrammingError, StreamFailureError -def test_row_stream(test_client: Client): - row_stream = test_client.query_rows_stream('SELECT number FROM numbers(10000)') +def test_row_stream(param_client, call, consume_stream): + stream = call(param_client.query_rows_stream, 'SELECT number FROM numbers(10000)') total = 0 - with row_stream: - for row in row_stream: - total += row[0] - try: - with row_stream: - pass - except StreamClosedError: - pass + + def process(row): + nonlocal total + total += row[0] + + consume_stream(stream, process) + + # Verify stream is closed by trying to consume it again + # This logic relies on consume_stream handling the context manager which checks state + with pytest.raises(StreamClosedError): + consume_stream(stream, lambda x: None) + assert total == 49995000 -def test_column_block_stream(test_client: Client): +def test_column_block_stream(param_client, call, consume_stream): random_string = 'randomStringUTF8(50)' - if not test_client.min_version('20'): + if not param_client.min_version('20'): random_string = random.choices(string.ascii_lowercase, k=50) - block_stream = test_client.query_column_block_stream(f'SELECT number, {random_string} FROM numbers(10000)', - settings={'max_block_size': 4000}) + stream = call(param_client.query_column_block_stream, + f'SELECT number, {random_string} FROM numbers(10000)', + settings={'max_block_size': 4000}) total = 0 block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - total += sum(block[0]) + + def process(block): + nonlocal total, block_count + block_count += 1 + total += sum(block[0]) + + consume_stream(stream, process) + assert total == 49995000 assert block_count > 1 -def test_row_block_stream(test_client: Client): +def test_row_block_stream(param_client, call, consume_stream): random_string = 'randomStringUTF8(50)' - if not test_client.min_version('20'): + if not param_client.min_version('20'): random_string = random.choices(string.ascii_lowercase, k=50) - block_stream = test_client.query_row_block_stream(f'SELECT number, {random_string} FROM numbers(10000)', - settings={'max_block_size': 4000}) + stream = call(param_client.query_row_block_stream, + f'SELECT number, {random_string} FROM numbers(10000)', + settings={'max_block_size': 4000}) total = 0 block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - for row in block: - total += row[0] + + def process(block): + nonlocal total, block_count + block_count += 1 + for row in block: + total += row[0] + + consume_stream(stream, process) + assert total == 49995000 assert block_count > 1 -def test_stream_errors(test_client: Client): - query_result = test_client.query('SELECT number FROM numbers(100000)') - try: - for _ in query_result.row_block_stream: - pass - except ProgrammingError as ex: - assert 'context' in str(ex) +def test_stream_errors(param_client, call, client_mode, consume_stream): + query_result = call(param_client.query, 'SELECT number FROM numbers(100000)') + + # 1. Test accessing without context manager raises error + if client_mode == 'sync': + with pytest.raises(ProgrammingError, match="context"): + for _ in query_result.row_block_stream: + pass + else: + async def try_iter(): + async for _ in query_result.row_block_stream: + pass + with pytest.raises((ProgrammingError, TypeError)): + call(try_iter) + assert query_result.row_count == 100000 + + # 2. Test that previous access consumed the generator, so next access raises StreamClosedError + with pytest.raises(StreamClosedError): + # Note: query_result.rows_stream creates a NEW StreamContext, but its internal generator + # (self._block_gen) was consumed by the property access in step 1. + consume_stream(query_result.rows_stream) + + +def test_stream_failure(param_client, call, consume_stream): + query = ('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + + ' where intDiv(1,number-300000)>-100000000') + + stream = call(param_client.query_row_block_stream, query) + blocks = 0 + failed = False + + def process(block): # pylint: disable=unused-argument + nonlocal blocks + blocks += 1 + try: - with query_result.rows_stream as stream: - assert sum(row[0] for row in stream) == 3882 - except StreamClosedError: - pass - - -def test_stream_failure(test_client: Client): - with test_client.query_row_block_stream('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + - ' where intDiv(1,number-300000)>-100000000') as stream: - blocks = 0 - failed = False - try: - for _ in stream: - blocks += 1 - except StreamFailureError as ex: - failed = True - assert 'division by zero' in str(ex).lower() + consume_stream(stream, process) + except StreamFailureError as ex: + failed = True + assert 'division by zero' in str(ex).lower() + assert failed + + +def test_raw_stream(param_client, call, consume_stream): + """Test raw_stream for streaming response.""" + chunks = [] + stream = call(param_client.raw_stream, "SELECT number FROM system.numbers LIMIT 1000", fmt="TabSeparated") + + def process(chunk): + nonlocal chunks + chunks.append(chunk) + + consume_stream(stream, process) + + assert len(chunks) > 0 + full_data = b"".join(chunks) + assert len(full_data) > 0 diff --git a/tests/integration_tests/test_temporal.py b/tests/integration_tests/test_temporal.py index 665bd22a..cb1eef1e 100644 --- a/tests/integration_tests/test_temporal.py +++ b/tests/integration_tests/test_temporal.py @@ -1,32 +1,22 @@ -from datetime import timedelta, time -from typing import List, Any +from datetime import time, timedelta +from typing import Any, List + import pytest -from clickhouse_connect.driver import Client from tests.integration_tests.conftest import TestConfig -# pylint: disable=no-self-use - +# Module-level version and cloud checks @pytest.fixture(autouse=True, scope="module") def module_setup_and_checks(test_client, test_config: TestConfig): - """ - Performs all module-level setup: - - Skips if in a cloud environment where settings are locked. - - Skips if the server version is too old for Time/Time64 types. - """ + """Check prerequisites for Time/Time64 type tests.""" if test_config.cloud: - pytest.skip( - "Time/Time64 types require settings change, but settings are locked in cloud, skipping tests.", - allow_module_level=True, - ) + pytest.skip("Time/Time64 types require settings change, but settings are locked in cloud") version_str = test_client.query("SELECT version()").result_rows[0][0] major, minor, *_ = map(int, version_str.split(".")) if (major, minor) < (25, 6): - pytest.skip( - "Time and Time64 types require ClickHouse 25.6+", allow_module_level=True - ) + pytest.skip("Time and Time64 types require ClickHouse 25.6+") class TimeTestData: @@ -77,8 +67,47 @@ class TimeTestData: TIME64_NS_TICKS = [3723123456789, -5500000000, 1] +class ClockTimeData: + """Test data for datetime.time objects.""" + + TIME_OBJS = [ + time(0, 0, 5), + time(1, 2, 3), + time(23, 59, 59), + time(0, 0, 0), + ] + + TIME_DELTAS = [ + timedelta(seconds=5), + timedelta(hours=1, minutes=2, seconds=3), + timedelta(hours=23, minutes=59, seconds=59), + timedelta(0), + ] + + TIME64_US_OBJS = [ + time(1, 2, 3, 123456), + time(0, 0, 0, 1), + time(23, 59, 59, 999999), + ] + + TIME64_NS_OBJS = [ + time(1, 2, 3, 123456), + time(0, 0, 0), + time(23, 59, 59, 123000), + ] + + TABLE_NAME = "temp_time_test" -COLUMN_COUNT = 7 + +STANDARD_TIME_TABLE_SCHEMA = [ + "id UInt32", + "t Time", + "t_nullable Nullable(Time)", + "t64_us Time64(6)", + "t64_us_nullable Nullable(Time64(6))", + "t64_ns Time64(9)", + "t64_ns_nullable Nullable(Time64(9))", +] def create_test_row(row_id: int, time_val: Any) -> List[Any]: @@ -99,106 +128,70 @@ def create_nullable_test_row(row_id: int, **column_values) -> List[Any]: ] -@pytest.fixture(autouse=True) -def setup_time_table(test_client: Client): - """Setup and teardown test table with Time and Time64 columns.""" - client = test_client - - client.command("SET enable_time_time64_type = 1") - - client.command(f"DROP TABLE IF EXISTS {TABLE_NAME}") - client.command( - f""" - CREATE TABLE {TABLE_NAME} ( - id UInt32, - t Time, - t_nullable Nullable(Time), - t64_us Time64(6), - t64_us_nullable Nullable(Time64(6)), - t64_ns Time64(9), - t64_ns_nullable Nullable(Time64(9)) - ) ENGINE = MergeTree ORDER BY id - """ - ) - - yield client - - client.command(f"DROP TABLE IF EXISTS {TABLE_NAME}") - - -def insert_test_data(client: Client, rows: List[List[Any]]) -> None: - """Insert test data into the test table.""" - client.insert(TABLE_NAME, rows) - - -def query_column(client: Client, column: str, **query_formats) -> List[Any]: - """Query a single column ordered by ID with optional format specifications.""" - query_result = client.query( - f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats=query_formats - ) - return [row[0] for row in query_result.result_rows] - - -class TestTimeRoundtrip: - """Test round-trip conversion for Time type.""" - - def test_time_native_format(self, test_client: Client): - """Test Time round-trip with native timedelta format.""" +def test_time_native_format(param_client, call, table_context): + """Test Time round-trip with native timedelta format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, td) for i, td in enumerate(TimeTestData.TIME_DELTAS)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == TimeTestData.TIME_DELTAS - - def test_time_string_format(self, test_client: Client): - """Test Time round-trip with string format.""" - rows = [create_test_row(i, s) for i, s in enumerate(TimeTestData.TIME_STRINGS)] - insert_test_data(test_client, rows) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_DELTAS - result = query_column(test_client, "t", Time="string") - assert result == TimeTestData.TIME_STRINGS - def test_time_int_format(self, test_client: Client): - """Test Time round-trip with integer format.""" - rows = [create_test_row(i, val) for i, val in enumerate(TimeTestData.TIME_INTS)] - insert_test_data(test_client, rows) - - result = query_column(test_client, "t", Time="int") - assert result == TimeTestData.TIME_INTS +def test_time_string_format(param_client, call, table_context): + """Test Time round-trip with string format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): + rows = [create_test_row(i, s) for i, s in enumerate(TimeTestData.TIME_STRINGS)] + call(param_client.insert, TABLE_NAME, rows) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "string"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_STRINGS -class TestTime64Roundtrip: - """Test round-trip conversion for Time64 types.""" - @pytest.mark.parametrize( - "column,strings,deltas,ticks,type_name", - [ - ( - "t64_us", - TimeTestData.TIME64_US_STRINGS, - TimeTestData.TIME64_US_DELTAS, - TimeTestData.TIME64_US_TICKS, - "Time64", - ), - ( - "t64_ns", - TimeTestData.TIME64_NS_STRINGS, - TimeTestData.TIME64_NS_DELTAS, - TimeTestData.TIME64_NS_TICKS, - "Time64", - ), - ], - ) - def test_time64_all_formats( - self, - test_client: Client, - column: str, - strings: List[str], - deltas: List[timedelta], - ticks: List[int], - type_name: str, - ): - """Test Time64 round-trip with all supported formats.""" +def test_time_int_format(param_client, call, table_context): + """Test Time round-trip with integer format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): + rows = [create_test_row(i, val) for i, val in enumerate(TimeTestData.TIME_INTS)] + call(param_client.insert, TABLE_NAME, rows) + + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "int"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_INTS + + +@pytest.mark.parametrize( + "column,strings,deltas,ticks,type_name", + [ + ( + "t64_us", + TimeTestData.TIME64_US_STRINGS, + TimeTestData.TIME64_US_DELTAS, + TimeTestData.TIME64_US_TICKS, + "Time64", + ), + ( + "t64_ns", + TimeTestData.TIME64_NS_STRINGS, + TimeTestData.TIME64_NS_DELTAS, + TimeTestData.TIME64_NS_TICKS, + "Time64", + ), + ], +) +def test_time64_all_formats( + param_client, + call, + table_context, + column: str, + strings: List[str], + deltas: List[timedelta], + ticks: List[int], + type_name: str, +): + """Test Time64 round-trip with all supported formats.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [] for i, string_val in enumerate(strings): row = [i, timedelta(0), None, timedelta(0), None, timedelta(0), None] @@ -208,23 +201,24 @@ def test_time64_all_formats( row[5] = string_val rows.append(row) - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, column) - assert result == deltas + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == deltas - result = query_column(test_client, column, **{type_name: "string"}) - assert result == strings + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={type_name: "string"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == strings - result = query_column(test_client, column, **{type_name: "int"}) - assert result == ticks + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={type_name: "int"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == ticks -class TestNullableColumns: - """Test nullable Time and Time64 columns.""" - - def test_nullable_time_columns(self, test_client: Client): - """Test that nullable columns handle None values correctly.""" +def test_nullable_time_columns(param_client, call, table_context): + """Test that nullable columns handle None values correctly.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [ create_nullable_test_row( 0, @@ -245,163 +239,136 @@ def test_nullable_time_columns(self, test_client: Client): t64_ns_nullable="07:00:00.123000000", ), ] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t_nullable") + result = call(param_client.query, f"SELECT t_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=5)] - assert result == expected + assert result_values == expected - result = query_column(test_client, "t64_us_nullable") + result = call(param_client.query, f"SELECT t64_us_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=6, microseconds=789000)] - assert result == expected + assert result_values == expected - result = query_column(test_client, "t64_ns_nullable") + result = call(param_client.query, f"SELECT t64_ns_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=7, microseconds=123000)] - assert result == expected - - -class TestErrorHandling: - """Test error handling for invalid inputs.""" - - @pytest.mark.parametrize( - "invalid_value", - [ - "1000:00:00", # Out of range string - 3600000, # Out of range int - ], - ) - def test_time_out_of_range_values(self, test_client: Client, invalid_value: Any): - """Test that out-of-range Time values raise ValueError.""" + assert result_values == expected + + +@pytest.mark.parametrize( + "invalid_value", + [ + "1000:00:00", # Out of range string + 3600000, # Out of range int + ], +) +def test_time_out_of_range_values(param_client, call, table_context, invalid_value: Any): + """Test that out-of-range Time values raise ValueError.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): with pytest.raises(ValueError, match="out of range"): rows = [create_test_row(0, invalid_value)] - insert_test_data(test_client, rows) - - @pytest.mark.parametrize( - "time_val,time64_val", - [ - ("1:2:3:4", "1:2:3:4"), # Too many colons - ("10:70:00", "10:70:00"), # Invalid minutes - ("10:00:00.123.456", "10:00:00.123.456"), # Invalid fractional format - ], - ) - def test_invalid_time_formats( - self, test_client: Client, time_val: str, time64_val: str - ): - """Test that invalid time formats raise ValueError.""" + call(param_client.insert, TABLE_NAME, rows) + + +@pytest.mark.parametrize( + "time_val,time64_val", + [ + ("1:2:3:4", "1:2:3:4"), # Too many colons + ("10:70:00", "10:70:00"), # Invalid minutes + ("10:00:00.123.456", "10:00:00.123.456"), # Invalid fractional format + ], +) +def test_invalid_time_formats(param_client, call, table_context, time_val: str, time64_val: str): + """Test that invalid time formats raise ValueError.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): with pytest.raises(ValueError): - rows = [ - create_nullable_test_row( - 0, t=time_val, t64_us=time64_val, t64_ns=time64_val - ) - ] - insert_test_data(test_client, rows) - + rows = [create_nullable_test_row(0, t=time_val, t64_us=time64_val, t64_ns=time64_val)] + call(param_client.insert, TABLE_NAME, rows) -class TestMixedInputTypes: - """Test handling of mixed input types.""" - def test_timedelta_input_conversion(self, test_client: Client): - """Test conversion of timedelta inputs to internal representation.""" +def test_timedelta_input_conversion(param_client, call, table_context): + """Test conversion of timedelta inputs to internal representation.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): test_deltas = TimeTestData.TIME_DELTAS[:3] rows = [create_test_row(i, td) for i, td in enumerate(test_deltas)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == test_deltas - assert all(isinstance(td, timedelta) for td in result) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == test_deltas + assert all(isinstance(td, timedelta) for td in result_values) - def test_integer_input_conversion(self, test_client: Client): - """Test conversion of integer inputs to internal representation.""" + +def test_integer_input_conversion(param_client, call, table_context): + """Test conversion of integer inputs to internal representation.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): test_ints = TimeTestData.TIME_INTS rows = [create_test_row(i, val) for i, val in enumerate(test_ints)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = TimeTestData.TIME_DELTAS - assert result == expected + assert result_values == expected -class ClockTimeData: - TIME_OBJS = [ - time(0, 0, 5), - time(1, 2, 3), - time(23, 59, 59), - time(0, 0, 0), - ] - - TIME_DELTAS = [ - timedelta(seconds=5), - timedelta(hours=1, minutes=2, seconds=3), - timedelta(hours=23, minutes=59, seconds=59), - timedelta(0), - ] - - TIME64_US_OBJS = [ - time(1, 2, 3, 123456), - time(0, 0, 0, 1), - time(23, 59, 59, 999999), - ] - - TIME64_NS_OBJS = [ - time(1, 2, 3, 123456), # 123 456 789 ns -> 123 456 us - time(0, 0, 0), # 1 ns (floor) - time(23, 59, 59, 123000), # 123 000 µs - ] - - -class TestTimeDatetimeTimeRoundtrip: +def test_time_roundtrip_time_format(param_client, call, table_context): """Ensure Time columns accept & return datetime.time when format='time'.""" - - def test_roundtrip_time_format(self, test_client: Client): + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(ClockTimeData.TIME_OBJS)] - insert_test_data(test_client, rows) - result = query_column(test_client, "t", Time="time") - assert result == ClockTimeData.TIME_OBJS + call(param_client.insert, TABLE_NAME, rows) - def test_default_read_from_time_objects(self, test_client: Client): - """Writing time objects + default read still yields timedelta.""" + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "time"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == ClockTimeData.TIME_OBJS + + +def test_time_default_read_from_time_objects(param_client, call, table_context): + """Writing time objects + default read still yields timedelta.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(ClockTimeData.TIME_OBJS)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == ClockTimeData.TIME_DELTAS + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == ClockTimeData.TIME_DELTAS -class TestTime64DatetimeTimeRoundtrip: +@pytest.mark.parametrize( + "column,objects", + [ + ("t64_us", ClockTimeData.TIME64_US_OBJS), + ("t64_ns", ClockTimeData.TIME64_NS_OBJS), + ], +) +def test_time64_time_format(param_client, call, table_context, column: str, objects: List[time]): """Validate Time64(6/9) ⇄ datetime.time conversions.""" - - @pytest.mark.parametrize( - "column,objects", - [ - ("t64_us", ClockTimeData.TIME64_US_OBJS), - ("t64_ns", ClockTimeData.TIME64_NS_OBJS), - ], - ) - def test_time64_time_format( - self, test_client: Client, column: str, objects: List[time] - ): + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(objects)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, column, Time64="time") - assert result == objects + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={"Time64": "time"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == objects -class TestTimeFormatErrorHandling: - """Errors that only show up when requesting format='time'.""" - - def test_negative_value_cannot_be_coerced_to_time(self, test_client: Client): - """Database contains -2 s; asking for format='time' should fail.""" +def test_negative_value_cannot_be_coerced_to_time(param_client, call, table_context): + """Database contains -2 s; asking for format='time' should fail.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(0, timedelta(seconds=-2))] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) with pytest.raises(ValueError, match="outside valid range"): - query_column(test_client, "t", Time="time") + call(param_client.query, f"SELECT t FROM {TABLE_NAME}", query_formats={"Time": "time"}) + - def test_over_24h_value_cannot_be_coerced_to_time(self, test_client: Client): - """30 h is legal for ClickHouse but illegal for datetime.time.""" +def test_over_24h_value_cannot_be_coerced_to_time(param_client, call, table_context): + """30 h is legal for ClickHouse but illegal for datetime.time.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(0, "030:00:00")] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) with pytest.raises(ValueError, match="outside valid range"): - query_column(test_client, "t", Time="time") + call(param_client.query, f"SELECT t FROM {TABLE_NAME}", query_formats={"Time": "time"}) diff --git a/tests/integration_tests/test_timezones.py b/tests/integration_tests/test_timezones.py index 97b623c8..9db02405 100644 --- a/tests/integration_tests/test_timezones.py +++ b/tests/integration_tests/test_timezones.py @@ -11,8 +11,8 @@ chicago_tz = pytz.timezone('America/Chicago').localize(datetime(2020, 8, 8, 10, 5, 5)).tzinfo -def test_basic_timezones(test_client: Client): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + +def test_basic_timezones(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as utc", query_tz='America/Chicago').first_row @@ -24,8 +24,8 @@ def test_basic_timezones(test_client: Client): assert row[1].hour == 10 assert row[1].day == 5 - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime64('2022-10-25 10:55:22.789123', 6, 'America/Chicago')", + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime64('2022-10-25 10:55:22.789123', 6, 'America/Chicago')", query_tz='America/Chicago').first_row assert row[0].tzinfo == chicago_tz assert row[0].hour == 10 @@ -33,14 +33,14 @@ def test_basic_timezones(test_client: Client): assert row[0].microsecond == 789123 -def test_server_timezone(test_client: Client): +def test_server_timezone(param_client: Client, call): # This test is really for manual testing since changing the timezone on the test ClickHouse server # still requires a restart. Other tests will depend on https://github.com/ClickHouse/ClickHouse/pull/44149 - test_client.apply_server_timezone = True + param_client.apply_server_timezone = True test_datetime = datetime(2023, 3, 18, 16, 4, 25) try: - date = test_client.query('SELECT toDateTime(%s) as st', parameters=[test_datetime]).first_row[0] - if test_client.server_tz == pytz.UTC: + date = call(param_client.query, 'SELECT toDateTime(%s) as st', parameters=[test_datetime]).first_row[0] + if param_client.server_tz == pytz.UTC: assert date.tzinfo is None assert date == datetime(2023, 3, 18, 16, 4, 25, tzinfo=None) assert date.timestamp() == 1679155465 @@ -50,15 +50,15 @@ def test_server_timezone(test_client: Client): assert date.tzinfo == den_tz assert date.timestamp() == 1679177065 finally: - test_client.apply_server_timezone = False + param_client.apply_server_timezone = False -def test_column_timezones(test_client: Client): +def test_column_timezones(param_client: Client, call): date_tz64 = "toDateTime64('2023-01-02 15:44:22.7832', 6, 'Asia/Shanghai')" - if not test_client.min_version('20'): + if not param_client.min_version('20'): date_tz64 = "toDateTime('2023-01-02 15:44:22', 'Asia/Shanghai')" column_tzs = {'chicago': 'America/Chicago', 'china': 'Asia/Shanghai'} - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + f'{date_tz64} as china,' + "toDateTime('2023-07-05 15:10:40') as utc", column_tzs=column_tzs).first_row @@ -67,27 +67,27 @@ def test_column_timezones(test_client: Client): assert row[1].tzinfo == china_tz assert row[2].tzinfo is None - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime64('2023-01-02 15:44:22.7832', 6, 'Asia/Shanghai') as china").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo == chicago_tz else: assert row[0].tzinfo is None assert row[1].tzinfo == china_tz # DateTime64 columns work correctly -def test_local_timezones(test_client: Client): +def test_local_timezones(param_client: Client, call): denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.apply_server_timezone = False + param_client.apply_server_timezone = False try: - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22'," + + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22'," + "'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as raw_utc_dst," + "toDateTime('2023-07-05 12:44:22', 'UTC') as forced_utc," + "toDateTime('2023-12-31 17:00:55') as raw_utc_std").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo.tzname(None) == chicago_tz.tzname(None) else: assert row[0].tzinfo.tzname(None) == denver_tz.tzname(None) @@ -96,98 +96,98 @@ def test_local_timezones(test_client: Client): assert row[3].tzinfo.tzname(None) == denver_tz.tzname(None) finally: tzutil.local_tz = pytz.UTC - test_client.apply_server_timezone = True + param_client.apply_server_timezone = True -def test_naive_timezones(test_client: Client): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + +def test_naive_timezones(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as utc").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo == chicago_tz else: assert row[0].tzinfo is None assert row[1].tzinfo is None -def test_timezone_binding_client(test_client: Client): +def test_timezone_binding_client(param_client: Client, call): os.environ['TZ'] = 'America/Denver' time.tzset() denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.apply_server_timezone = False + param_client.apply_server_timezone = False denver_time = datetime(2023, 3, 18, 16, 4, 25, tzinfo=denver_tz) try: - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': denver_time}).first_row[0] assert server_time == denver_time finally: os.environ['TZ'] = 'UTC' tzutil.local_tz = pytz.UTC time.tzset() - test_client.apply_server_timezone = True + param_client.apply_server_timezone = True naive_time = datetime(2023, 3, 18, 16, 4, 25) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': naive_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == naive_time.astimezone(pytz.UTC) utc_time = datetime(2023, 3, 18, 16, 4, 25, tzinfo=pytz.UTC) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': utc_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == utc_time -def test_timezone_binding_server(test_client: Client): +def test_timezone_binding_server(param_client: Client, call): os.environ['TZ'] = 'America/Denver' time.tzset() denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.apply_server_timezone = False + param_client.apply_server_timezone = False denver_time = datetime(2022, 3, 18, 16, 4, 25, tzinfo=denver_tz) try: - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': denver_time}).first_row[0] assert server_time == denver_time finally: os.environ['TZ'] = 'UTC' time.tzset() tzutil.local_tz = pytz.UTC - test_client.apply_server_timezone = True + param_client.apply_server_timezone = True naive_time = datetime(2022, 3, 18, 16, 4, 25) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': naive_time}).first_row[0] assert naive_time.astimezone(pytz.UTC) == server_time.astimezone(pytz.UTC) utc_time = datetime(2020, 3, 18, 16, 4, 25, tzinfo=pytz.UTC) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': utc_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == utc_time -def test_utc_tz_aware(test_client: Client): - row = test_client.query("SELECT toDateTime('2023-07-05 15:10:40') as dt," + +def test_utc_tz_aware(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2023-07-05 15:10:40') as dt," + "toDateTime('2023-07-05 15:10:40', 'UTC') as dt_utc", query_tz='UTC').first_row assert row[0].tzinfo is None assert row[1].tzinfo is None - row = test_client.query("SELECT toDateTime('2023-07-05 15:10:40') as dt," + + row = call(param_client.query, "SELECT toDateTime('2023-07-05 15:10:40') as dt," + "toDateTime('2023-07-05 15:10:40', 'UTC') as dt_utc", query_tz='UTC', utc_tz_aware=True).first_row assert row[0].tzinfo == pytz.UTC assert row[1].tzinfo == pytz.UTC - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + "toDateTime64('2023-07-05 15:10:40.123456', 6, 'UTC') as dt64_utc", query_tz='UTC').first_row assert row[0].tzinfo is None assert row[1].tzinfo is None assert row[0].microsecond == 123456 - row = test_client.query("SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + + row = call(param_client.query, "SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + "toDateTime64('2023-07-05 15:10:40.123456', 6, 'UTC') as dt64_utc", query_tz='UTC', utc_tz_aware=True).first_row assert row[0].tzinfo == pytz.UTC diff --git a/tests/integration_tests/test_tls.py b/tests/integration_tests/test_tls.py index 378c064a..d0c7087c 100644 --- a/tests/integration_tests/test_tls.py +++ b/tests/integration_tests/test_tls.py @@ -1,9 +1,7 @@ import os import pytest -from urllib3.exceptions import SSLError -from clickhouse_connect import get_client from clickhouse_connect.driver.common import coerce_bool from clickhouse_connect.driver.exceptions import OperationalError from tests.helpers import PROJECT_ROOT_DIR @@ -13,39 +11,41 @@ host = 'server1.clickhouse.test' -def test_basic_tls(): +def test_basic_tls(client_factory, call): if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_TEST_TLS', 'False')): pytest.skip('TLS tests not enabled') - client = get_client(interface='https', host=host, port=10843, verify=False) - assert client.command("SELECT 'insecure'") == 'insecure' - client.close_connections() + client = client_factory(interface='https', host=host, port=10843, verify=False, database='default') + assert call(client.command, "SELECT 'insecure'") == 'insecure' - client = get_client(interface='https', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt') - assert client.command("SELECT 'verify_server'") == 'verify_server' - client.close_connections() + client = client_factory(interface='https', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt', database='default') + assert call(client.command, "SELECT 'verify_server'") == 'verify_server' try: - get_client(interface='https', host='localhost', port=10843, ca_cert=f'{cert_dir}ca.crt') + client_factory(interface='https', host='localhost', port=10843, ca_cert=f'{cert_dir}ca.crt', database='default') pytest.fail('Expected TLS exception with a different hostname') except OperationalError as ex: - assert isinstance(ex.__cause__.reason, SSLError) # pylint: disable=no-member - client.close_connections() + # For sync (urllib3): ex.__cause__.reason is SSLError + # For async (aiohttp): ex.__cause__ is ClientConnectorCertificateError + assert ex.__cause__ is not None + assert 'SSL' in str(ex.__cause__) or 'certificate' in str(ex.__cause__).lower() try: - get_client(interface='https', host='localhost', port=10843) + client_factory(interface='https', host='localhost', port=10843, database='default') pytest.fail('Expected TLS exception with a self-signed cert') except OperationalError as ex: - assert isinstance(ex.__cause__.reason, SSLError) # pylint: disable=no-member + assert ex.__cause__ is not None + assert 'SSL' in str(ex.__cause__) or 'certificate' in str(ex.__cause__).lower() -def test_mutual_tls(): +def test_mutual_tls(client_factory, call): if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_TEST_TLS', 'False')): pytest.skip('TLS tests not enabled') - client = get_client(interface='https', + client = client_factory(interface='https', username='cert_user', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt', client_cert=f'{cert_dir}client.crt', - client_cert_key=f'{cert_dir}client.key') - assert client.command('SELECT user()') == 'cert_user' + client_cert_key=f'{cert_dir}client.key', + database='default') + assert call(client.command, 'SELECT user()') == 'cert_user' diff --git a/tests/integration_tests/test_tools.py b/tests/integration_tests/test_tools.py index 6ed546af..adbd0c4d 100644 --- a/tests/integration_tests/test_tools.py +++ b/tests/integration_tests/test_tools.py @@ -2,39 +2,50 @@ from typing import Callable from clickhouse_connect.driver import Client -from clickhouse_connect.driver.tools import insert_file -from tests.integration_tests.conftest import TestConfig +from clickhouse_connect.driver.tools import insert_file, insert_file_async -def test_csv_upload(test_client: Client, table_context: Callable): +def test_csv_upload(param_client: Client, call, table_context: Callable, client_mode): data_file = f'{Path(__file__).parent}/movies.csv.gz' with table_context('test_csv_upload', ['movie String', 'year UInt16', 'rating Decimal32(3)']): - insert_file(test_client, 'test_csv_upload', data_file, - settings={'input_format_allow_errors_ratio': .2, - 'input_format_allow_errors_num': 5}) - res = test_client.query( + # Use appropriate insert_file version based on client type + if client_mode == "async": + call(insert_file_async, param_client, 'test_csv_upload', data_file, + settings={'input_format_allow_errors_ratio': .2, + 'input_format_allow_errors_num': 5}) + else: # Sync client + insert_file(param_client, 'test_csv_upload', data_file, + settings={'input_format_allow_errors_ratio': .2, + 'input_format_allow_errors_num': 5}) + res = call(param_client.query, 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_csv_upload').first_item assert res['count'] == 248 assert res['year'] == 2022 -def test_parquet_upload(test_config: TestConfig, test_client: Client, table_context: Callable): +def test_parquet_upload(param_client: Client, call, client_mode, table_context: Callable): data_file = f'{Path(__file__).parent}/movies.parquet' - full_table = f'{test_config.test_database}.test_parquet_upload' - with table_context(full_table, ['movie String', 'year UInt16', 'rating Float64']): - insert_file(test_client, full_table, data_file, 'Parquet', - settings={'output_format_parquet_string_as_string': 1}) - res = test_client.query( - f'SELECT count() as count, sum(rating) as rating, max(year) as year FROM {full_table}').first_item + with table_context('test_parquet_upload', ['movie String', 'year UInt16', 'rating Float64']): + if client_mode == "async": + call(insert_file_async, param_client, 'test_parquet_upload', data_file, 'Parquet', + settings={'output_format_parquet_string_as_string': 1}) + else: + insert_file(param_client, 'test_parquet_upload', data_file, 'Parquet', + settings={'output_format_parquet_string_as_string': 1}) + res = call(param_client.query, + 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_parquet_upload').first_item assert res['count'] == 250 assert res['year'] == 2022 -def test_json_insert(test_client: Client, table_context: Callable): +def test_json_insert(param_client: Client, call, client_mode, table_context: Callable): data_file = f'{Path(__file__).parent}/json_test.ndjson' with table_context('test_json_upload', ['key UInt16', 'flt_val Float64', 'int_val Int8']): - insert_file(test_client, 'test_json_upload', data_file, 'JSONEachRow') - res = test_client.query('SELECT * FROM test_json_upload ORDER BY key').result_rows + if client_mode == "async": + call(insert_file_async, param_client, 'test_json_upload', data_file, 'JSONEachRow') + else: + insert_file(param_client, 'test_json_upload', data_file, 'JSONEachRow') + res = call(param_client.query, 'SELECT * FROM test_json_upload ORDER BY key').result_rows assert res[1][0] == 17 assert res[1][1] == 5.3 assert res[1][2] == 121 diff --git a/tests/integration_tests/test_vector.py b/tests/integration_tests/test_vector.py index e51b801a..9077d386 100644 --- a/tests/integration_tests/test_vector.py +++ b/tests/integration_tests/test_vector.py @@ -23,7 +23,7 @@ def module_setup_and_checks(test_client: Client, test_config: TestConfig): test_client.command("SET allow_experimental_qbit_type = 1") -def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_float64(param_client: Client, call, table_context: Callable): """Test QBit(Float64) round-trip accuracy with fruit_animal example data""" with table_context("fruit_animal", ["word String", "vec QBit(Float64, 5)"]): @@ -33,13 +33,13 @@ def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): ("orange", [0.93338752, 2.06571317, -0.54612565, -1.51625717, 0.69775337]), ] - test_client.insert("fruit_animal", test_data) - count = test_client.query("SELECT COUNT(*) FROM fruit_animal").result_set[0][0] + call(param_client.insert, "fruit_animal", test_data) + count = call(param_client.query, "SELECT COUNT(*) FROM fruit_animal").result_set[0][0] assert count == 3 for word, original_vec in test_data: - result = test_client.query("SELECT vec FROM fruit_animal WHERE word = %(word)s", parameters={"word": word}) + result = call(param_client.query, "SELECT vec FROM fruit_animal WHERE word = %(word)s", parameters={"word": word}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -47,7 +47,7 @@ def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): assert retrieved_vec == original_vec -def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_float32(param_client: Client, call, table_context: Callable): """Test QBit(Float32) round-trip accuracy""" with table_context("vectors_f32", ["id Int32", "vec QBit(Float32, 8)"]): @@ -57,10 +57,10 @@ def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): (3, [-1.5, -2.5, -3.5, -4.5, -5.5, -6.5, -7.5, -8.5]), ] - test_client.insert("vectors_f32", test_data) + call(param_client.insert, "vectors_f32", test_data) for id_val, original_vec in test_data: - result = test_client.query("SELECT vec FROM vectors_f32 WHERE id = %(id)s", parameters={"id": id_val}) + result = call(param_client.query, "SELECT vec FROM vectors_f32 WHERE id = %(id)s", parameters={"id": id_val}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -68,7 +68,7 @@ def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_bfloat16(param_client: Client, call, table_context: Callable): """Test QBit(BFloat16) round-trip with appropriate tolerance""" with table_context("vectors_bf16", ["id Int32", "vec QBit(BFloat16, 8)"]): @@ -77,10 +77,10 @@ def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): (2, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5]), ] - test_client.insert("vectors_bf16", test_data) + call(param_client.insert, "vectors_bf16", test_data) for id_val, original_vec in test_data: - result = test_client.query("SELECT vec FROM vectors_bf16 WHERE id = %(id)s", parameters={"id": id_val}) + result = call(param_client.query, "SELECT vec FROM vectors_bf16 WHERE id = %(id)s", parameters={"id": id_val}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -88,7 +88,7 @@ def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): assert retrieved_vec == pytest.approx(original_vec, rel=1e-2, abs=1e-2) -def test_qbit_distance_search(test_client: Client, table_context: Callable): +def test_qbit_distance_search(param_client: Client, call, table_context: Callable): """Test L2DistanceTransposed with different precision levels""" with table_context("fruit_animal", ["word String", "vec QBit(Float64, 5)"]): @@ -101,13 +101,13 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): ("horse", [-0.61435682, 0.4851571, 1.21091247, -0.62530446, -1.33082533]), ] - test_client.insert("fruit_animal", test_data) + call(param_client.insert, "fruit_animal", test_data) # Search for "lemon" vector lemon_vector = [-0.88693672, 1.31532824, -0.51182908, -0.99652702, 0.59907770] # Full precision search (64-bit) - full_precision = test_client.query( + full_precision = call(param_client.query, """ SELECT word, L2DistanceTransposed(vec, %(lemon)s, 64) AS distance FROM fruit_animal @@ -121,7 +121,7 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): assert apple_distance == pytest.approx(0.1464, abs=1e-3) # Reduced precision search - reduced_precision = test_client.query( + reduced_precision = call(param_client.query, """ SELECT word, L2DistanceTransposed(vec, %(lemon)s, 12) AS distance FROM fruit_animal @@ -134,7 +134,7 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): assert reduced_precision.result_set[0][1] > 0 -def test_qbit_batch_insert(test_client: Client, table_context: Callable): +def test_qbit_batch_insert(param_client: Client, call, table_context: Callable): """Test batch insert with multiple vectors""" dimension = 16 @@ -147,22 +147,22 @@ def test_qbit_batch_insert(test_client: Client, table_context: Callable): vector = [random.uniform(-1.0, 1.0) for _ in range(dimension)] batch_data.append((i, vector)) - test_client.insert("embeddings", batch_data) + call(param_client.insert, "embeddings", batch_data) - count = test_client.query("SELECT COUNT(*) FROM embeddings").result_set[0][0] + count = call(param_client.query, "SELECT COUNT(*) FROM embeddings").result_set[0][0] assert count == 100 # Spot check a few for test_id in [0, 50, 99]: original_vec = batch_data[test_id][1] - result = test_client.query("SELECT embedding FROM embeddings WHERE id = %(id)s", parameters={"id": test_id}) + result = call(param_client.query, "SELECT embedding FROM embeddings WHERE id = %(id)s", parameters={"id": test_id}) retrieved_vec = result.result_set[0][0] assert len(retrieved_vec) == dimension assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_null_handling(test_client: Client, table_context: Callable): +def test_qbit_null_handling(param_client: Client, call, table_context: Callable): """Test QBit with NULL values using Nullable wrapper""" with table_context("nullable_vecs", ["id Int32", "vec Nullable(QBit(Float32, 4))"]): @@ -172,49 +172,49 @@ def test_qbit_null_handling(test_client: Client, table_context: Callable): (3, [5.0, 6.0, 7.0, 8.0]), ] - test_client.insert("nullable_vecs", test_data) + call(param_client.insert, "nullable_vecs", test_data) - result = test_client.query("SELECT id, vec FROM nullable_vecs ORDER BY id") + result = call(param_client.query, "SELECT id, vec FROM nullable_vecs ORDER BY id") assert result.result_set[0][1] == pytest.approx([1.0, 2.0, 3.0, 4.0]) assert result.result_set[1][1] is None assert result.result_set[2][1] == pytest.approx([5.0, 6.0, 7.0, 8.0]) -def test_qbit_dimension_mismatch_error(test_client: Client, table_context: Callable): +def test_qbit_dimension_mismatch_error(param_client: Client, call, table_context: Callable): """Test that inserting vectors with wrong dimensions raises an error""" with table_context("dim_test", ["id Int32", "vec QBit(Float32, 8)"]): wrong_data = [(1, [1.0, 2.0, 3.0, 4.0, 5.0])] with pytest.raises(ValueError) as exc_info: - test_client.insert("dim_test", wrong_data) + call(param_client.insert, "dim_test", wrong_data) assert "dimension mismatch" in str(exc_info.value).lower() -def test_qbit_empty_insert(test_client: Client, table_context: Callable): +def test_qbit_empty_insert(param_client: Client, call, table_context: Callable): """Test inserting an empty list (no rows)""" with table_context("empty_test", ["id Int32", "vec QBit(Float32, 4)"]): - test_client.insert("empty_test", []) - result = test_client.query("SELECT COUNT(*) FROM empty_test") + call(param_client.insert, "empty_test", []) + result = call(param_client.query, "SELECT COUNT(*) FROM empty_test") assert result.result_set[0][0] == 0 -def test_qbit_single_row(test_client: Client, table_context: Callable): +def test_qbit_single_row(param_client: Client, call, table_context: Callable): """Test inserting a single row""" with table_context("single_row", ["id Int32", "vec QBit(Float32, 4)"]): single_data = [(1, [1.0, 2.0, 3.0, 4.0])] - test_client.insert("single_row", single_data) + call(param_client.insert, "single_row", single_data) - result = test_client.query("SELECT id, vec FROM single_row") + result = call(param_client.query, "SELECT id, vec FROM single_row") assert len(result.result_set) == 1 assert result.result_set[0][0] == 1 assert result.result_set[0][1] == pytest.approx([1.0, 2.0, 3.0, 4.0]) -def test_qbit_special_float_values(test_client: Client, table_context: Callable): +def test_qbit_special_float_values(param_client: Client, call, table_context: Callable): """Test QBit with special float values (inf, -inf, nan)""" with table_context("special_floats", ["id Int32", "vec QBit(Float64, 4)"]): @@ -225,8 +225,8 @@ def test_qbit_special_float_values(test_client: Client, table_context: Callable) (4, [0.0, -0.0, 1.0, -1.0]), ] - test_client.insert("special_floats", test_data) - result = test_client.query("SELECT id, vec FROM special_floats ORDER BY id") + call(param_client.insert, "special_floats", test_data) + result = call(param_client.query, "SELECT id, vec FROM special_floats ORDER BY id") assert result.result_set[0][1][0] == float("inf") assert result.result_set[0][1][1] == pytest.approx(1.0) @@ -241,68 +241,69 @@ def test_qbit_special_float_values(test_client: Client, table_context: Callable) assert result.result_set[3][1][1] == -0.0 or result.result_set[3][1][1] == 0.0 -def test_qbit_edge_case_dimensions(test_client: Client, table_context: Callable): +def test_qbit_edge_case_dimensions(param_client: Client, call, table_context: Callable): """Test QBit with edge case dimensions (1, not multiple of 8)""" with table_context("dim_one", ["id Int32", "vec QBit(Float32, 1)"]): test_data = [(1, [1.0]), (2, [3.14])] - test_client.insert("dim_one", test_data) + call(param_client.insert, "dim_one", test_data) - result = test_client.query("SELECT vec FROM dim_one ORDER BY id") + result = call(param_client.query, "SELECT vec FROM dim_one ORDER BY id") assert result.result_set[0][0] == pytest.approx([1.0]) assert result.result_set[1][0] == pytest.approx([3.14]) with table_context("dim_five", ["id Int32", "vec QBit(Float32, 5)"]): test_data = [(1, [1.0, 2.0, 3.0, 4.0, 5.0])] - test_client.insert("dim_five", test_data) + call(param_client.insert, "dim_five", test_data) - result = test_client.query("SELECT vec FROM dim_five") + result = call(param_client.query, "SELECT vec FROM dim_five") assert result.result_set[0][0] == pytest.approx([1.0, 2.0, 3.0, 4.0, 5.0]) -def test_qbit_very_large_batch(test_client: Client, table_context: Callable): +def test_qbit_very_large_batch(param_client: Client, call, table_context: Callable): """Test inserting a very large batch of vectors (1000 rows)""" with table_context("large_batch", ["id Int32", "vec QBit(Float32, 8)"]): random.seed(1) large_batch = [(i, [random.uniform(-10, 10) for _ in range(8)]) for i in range(1000)] - test_client.insert("large_batch", large_batch) + call(param_client.insert, "large_batch", large_batch) - count = test_client.query("SELECT COUNT(*) FROM large_batch").result_set[0][0] + count = call(param_client.query, "SELECT COUNT(*) FROM large_batch").result_set[0][0] assert count == 1000 for check_id in [0, 500, 999]: original_vec = large_batch[check_id][1] - result = test_client.query("SELECT vec FROM large_batch WHERE id = %(id)s", parameters={"id": check_id}) + result = call(param_client.query, "SELECT vec FROM large_batch WHERE id = %(id)s", parameters={"id": check_id}) retrieved_vec = result.result_set[0][0] assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_all_nulls(test_client: Client, table_context: Callable): +def test_qbit_all_nulls(param_client: Client, call, table_context: Callable): """Test QBit nullable column with all NULL values""" with table_context("all_nulls", ["id Int32", "vec Nullable(QBit(Float32, 4))"]): test_data = [(1, None), (2, None), (3, None)] - test_client.insert("all_nulls", test_data) + call(param_client.insert, "all_nulls", test_data) - result = test_client.query("SELECT vec FROM all_nulls ORDER BY id") + result = call(param_client.query, "SELECT vec FROM all_nulls ORDER BY id") assert all(row[0] is None for row in result.result_set) -def test_qbit_all_zeros(test_client: Client, table_context: Callable): +def test_qbit_all_zeros(param_client: Client, call, table_context: Callable): """Test QBit with all zero vectors""" with table_context("all_zeros", ["id Int32", "vec QBit(Float32, 4)"]): test_data = [(1, [0.0, 0.0, 0.0, 0.0]), (2, [0.0, 0.0, 0.0, 0.0])] - test_client.insert("all_zeros", test_data) + call(param_client.insert, "all_zeros", test_data) - result = test_client.query("SELECT vec FROM all_zeros ORDER BY id") + result = call(param_client.query, "SELECT vec FROM all_zeros ORDER BY id") assert result.result_set[0][0] == pytest.approx([0.0, 0.0, 0.0, 0.0]) assert result.result_set[1][0] == pytest.approx([0.0, 0.0, 0.0, 0.0]) -def test_invalid_dimension(table_context: Callable): +# pylint: disable=unused-argument +def test_invalid_dimension(param_client: Client, call, table_context: Callable): """Try creating a column with a negative dimension.""" with pytest.raises(DatabaseError): @@ -310,7 +311,8 @@ def test_invalid_dimension(table_context: Callable): pass -def test_invalid_element_type(table_context: Callable): +# pylint: disable=unused-argument +def test_invalid_element_type(param_client: Client, call, table_context: Callable): """Try creating a column with an invalid element type.""" with pytest.raises(DatabaseError): diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 31f1f695..ef0e8bdd 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -2,6 +2,7 @@ pytz urllib3>=1.26 setuptools certifi +aiohttp>=3.8.0 sqlalchemy>=2.0,<3.0 cython==3.0.11 pyarrow @@ -10,6 +11,7 @@ pytest-asyncio pytest-mock pytest-dotenv pytest-cov +pytest-xdist numpy~=1.22.0; python_version >= '3.8' and python_version <= '3.10' numpy~=1.26.0; python_version >= '3.11' and python_version <= '3.12' numpy~=2.1.0; python_version >= '3.13' diff --git a/tests/unit_tests/test_asyncqueue.py b/tests/unit_tests/test_asyncqueue.py new file mode 100644 index 00000000..b675b000 --- /dev/null +++ b/tests/unit_tests/test_asyncqueue.py @@ -0,0 +1,297 @@ +import asyncio +import threading +import time + +import pytest + +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue, Empty + +# pylint: disable=broad-exception-caught + + +def test_async_put_sync_get(): + """Test async producer putting items, sync consumer getting them.""" + queue = AsyncSyncQueue(maxsize=5) + items_received = [] + + async def async_producer(): + """Put items from async context.""" + for i in range(10): + await queue.async_q.put(f"item_{i}") + await queue.async_q.put(EOF_SENTINEL) + + def sync_consumer(): + """Get items from sync context.""" + while True: + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + items_received.append(item) + + async def run_test(): + consumer_thread = threading.Thread(target=sync_consumer) + consumer_thread.start() + + await async_producer() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive(), "Consumer thread hung" + + asyncio.run(run_test()) + + assert len(items_received) == 10 + assert items_received == [f"item_{i}" for i in range(10)] + + +def test_sync_put_async_get(): + """Test sync producer putting items, async consumer getting them.""" + queue = AsyncSyncQueue(maxsize=5) + items_received = [] + + def sync_producer(): + """Put items from sync context.""" + for i in range(10): + queue.sync_q.put(f"item_{i}") + queue.sync_q.put(EOF_SENTINEL) + + async def async_consumer(): + """Get items from async context.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + items_received.append(item) + + async def run_test(): + producer_thread = threading.Thread(target=sync_producer) + producer_thread.start() + + await async_consumer() + + producer_thread.join(timeout=5.0) + assert not producer_thread.is_alive(), "Producer thread hung" + + asyncio.run(run_test()) + + assert len(items_received) == 10 + assert items_received == [f"item_{i}" for i in range(10)] + + +def test_backpressure_async_producer(): + """Test that bounded queue provides backpressure to async producer.""" + queue = AsyncSyncQueue(maxsize=3) + produced = [] + consumed = [] + + async def fast_producer(): + """Producer that tries to produce faster than consumer.""" + for i in range(10): + produced.append(f"before_put_{i}") + await queue.async_q.put(f"item_{i}") + produced.append(f"after_put_{i}") + await queue.async_q.put(EOF_SENTINEL) + + def slow_consumer(): + """Consumer that's slower than producer.""" + while True: + time.sleep(0.01) + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + consumed.append(item) + + async def run_test(): + consumer_thread = threading.Thread(target=slow_consumer) + consumer_thread.start() + + await fast_producer() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(consumed) == 10 + assert consumed == [f"item_{i}" for i in range(10)] + + +def test_backpressure_sync_producer(): + """Test that bounded queue provides backpressure to sync producer.""" + queue = AsyncSyncQueue(maxsize=3) + produced = [] + consumed = [] + + def fast_producer(): + """Producer that tries to produce faster than consumer.""" + for i in range(10): + produced.append(f"before_put_{i}") + queue.sync_q.put(f"item_{i}") + produced.append(f"after_put_{i}") + queue.sync_q.put(EOF_SENTINEL) + + async def slow_consumer(): + """Consumer that's slower than producer.""" + while True: + await asyncio.sleep(0.01) + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + consumed.append(item) + + async def run_test(): + producer_thread = threading.Thread(target=fast_producer) + producer_thread.start() + + await slow_consumer() + + producer_thread.join(timeout=5.0) + assert not producer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(consumed) == 10 + assert consumed == [f"item_{i}" for i in range(10)] + + +def test_shutdown_unblocks_consumer(): + """Test that shutdown() unblocks a consumer waiting on an empty queue.""" + queue = AsyncSyncQueue(maxsize=2) + consumer_unblocked = threading.Event() + + def blocking_consumer(): + """Consumer that will block waiting for items.""" + try: + item = queue.sync_q.get(timeout=2.0) + if item is EOF_SENTINEL: + consumer_unblocked.set() + except Exception: + pass + + async def run_test(): + consumer_thread = threading.Thread(target=blocking_consumer) + consumer_thread.start() + + await asyncio.sleep(0.1) + + queue.shutdown() + + consumer_thread.join(timeout=2.0) + assert consumer_unblocked.is_set(), "Consumer was not unblocked by shutdown" + + asyncio.run(run_test()) + + +def test_shutdown_unblocks_producer(): + """Test that shutdown() unblocks a producer waiting on a full queue.""" + queue = AsyncSyncQueue(maxsize=2) + producer_unblocked = threading.Event() + + async def blocking_producer(): + """Producer that will block when queue is full.""" + try: + await queue.async_q.put("item1") + await queue.async_q.put("item2") + + await asyncio.wait_for(queue.async_q.put("item3"), timeout=2.0) + except (RuntimeError, asyncio.TimeoutError): + producer_unblocked.set() + except Exception as e: + print(f"Producer caught unexpected exception: {e}") + + async def run_test(): + producer_task = asyncio.create_task(blocking_producer()) + + await asyncio.sleep(0.1) + + queue.shutdown() + + await producer_task + assert producer_unblocked.is_set(), "Producer was not unblocked by shutdown" + + asyncio.run(run_test()) + + +def test_multiple_producers_single_consumer(): + """Test multiple async producers with single sync consumer.""" + queue = AsyncSyncQueue(maxsize=10) + items_received = [] + + async def producer(producer_id, count): + """Producer that sends count items.""" + for i in range(count): + await queue.async_q.put(f"p{producer_id}_item{i}") + + def consumer(): + """Consumer that reads until getting 30 items (3 producers × 10 items).""" + received = 0 + while received < 30: + item = queue.sync_q.get(timeout=5.0) + items_received.append(item) + received += 1 + + async def run_test(): + consumer_thread = threading.Thread(target=consumer) + consumer_thread.start() + + await asyncio.gather(producer(0, 10), producer(1, 10), producer(2, 10)) + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(items_received) == 30 + assert len(set(items_received)) == 30 + + +def test_exception_propagation(): + """Test that exceptions can be passed through the queue.""" + queue = AsyncSyncQueue(maxsize=5) + exception_received = [] + + async def producer_with_error(): + """Producer that sends an exception.""" + await queue.async_q.put("item1") + await queue.async_q.put("item2") + await queue.async_q.put(ValueError("test error")) + await queue.async_q.put(EOF_SENTINEL) + + def consumer(): + """Consumer that should receive the exception.""" + items = [] + while True: + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + exception_received.append(item) + else: + items.append(item) + return items + + async def run_test(): + consumer_thread = threading.Thread(target=consumer) + consumer_thread.start() + + await producer_with_error() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(exception_received) == 1 + assert isinstance(exception_received[0], ValueError) + assert str(exception_received[0]) == "test error" + + +def test_empty_exception_on_non_blocking_get(): + """Test that non-blocking get raises Empty when queue is empty.""" + queue = AsyncSyncQueue(maxsize=5) + + with pytest.raises(Empty): + queue.sync_q.get(block=False) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit_tests/test_streaming_source.py b/tests/unit_tests/test_streaming_source.py new file mode 100644 index 00000000..6794a0a2 --- /dev/null +++ b/tests/unit_tests/test_streaming_source.py @@ -0,0 +1,443 @@ +import asyncio +import gzip +import time +import zlib +from unittest.mock import Mock + +import lz4.frame +import pytest +import zstandard + +from clickhouse_connect.driver.streaming import ( + StreamingInsertSource, + StreamingResponseSource, +) + + +class MockAsyncIterator: + """Mock async iterator for simulating aiohttp response content.""" + + def __init__(self, chunks): + self.chunks = chunks + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.chunks): + raise StopAsyncIteration + chunk = self.chunks[self.index] + self.index += 1 + return chunk + + +class MockContent: + """Mock aiohttp StreamReader content.""" + + def __init__(self, chunks): + self.chunks = chunks + self.index = 0 + + async def read(self, n=-1): # pylint: disable=unused-argument + """Mock read method that returns chunks sequentially.""" + if self.index >= len(self.chunks): + return b"" + chunk = self.chunks[self.index] + self.index += 1 + return chunk + + +class MockResponse: + """Mock aiohttp ClientResponse.""" + + def __init__(self, chunks, encoding=None): + self.content = MockContent(chunks) + self.headers = {"Content-Encoding": encoding} if encoding else {} + self.status = 200 + self.closed = False + + def close(self): + self.closed = True + + +@pytest.mark.asyncio +async def test_basic_streaming_no_compression(): + """Test basic streaming without compression.""" + chunks = [b"hello ", b"world", b"!"] + response = MockResponse(chunks) + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert result == chunks + assert b"".join(result) == b"hello world!" + + +@pytest.mark.asyncio +async def test_streaming_with_gzip_compression(): + """Test streaming with gzip decompression.""" + original_data = b"hello world! " * 1000 + compressed = gzip.compress(original_data) + chunk_size = 100 + chunks = [compressed[i : i + chunk_size] for i in range(0, len(compressed), chunk_size)] + + response = MockResponse(chunks, encoding="gzip") + source = StreamingResponseSource(response, encoding="gzip") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_deflate_compression(): + """Test streaming with deflate decompression.""" + original_data = b"test data " * 500 + compressed = zlib.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="deflate") + source = StreamingResponseSource(response, encoding="deflate") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_zstd_compression(): + """Test streaming with zstd decompression.""" + original_data = b"zstd test data " * 500 + compressor = zstandard.ZstdCompressor() + compressed = compressor.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="zstd") + source = StreamingResponseSource(response, encoding="zstd") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_lz4_compression(): + """Test streaming with lz4 decompression.""" + original_data = b"lz4 test data " * 500 + compressed = lz4.frame.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="lz4") + source = StreamingResponseSource(response, encoding="lz4") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_empty_stream(): + """Test streaming with empty response.""" + response = MockResponse([]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert result == [] + + +@pytest.mark.asyncio +async def test_single_large_chunk(): + """Test streaming with single large chunk.""" + large_chunk = b"x" * 1000000 + response = MockResponse([large_chunk]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert len(result) == 1 + assert result[0] == large_chunk + + +@pytest.mark.asyncio +async def test_many_small_chunks(): + """Test streaming with many small chunks.""" + chunks = [f"chunk{i}".encode() for i in range(1000)] + response = MockResponse(chunks) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert len(result) == 1000 + assert result == chunks + + +@pytest.mark.asyncio +async def test_generator_caching(): + """Test that .gen property returns cached generator.""" + response = MockResponse([b"test"]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + # Access .gen multiple times - should return same generator + gen1 = source.gen + gen2 = source.gen + + assert gen1 is gen2, "Generator should be cached" + + +@pytest.mark.asyncio +async def test_producer_error_propagation(): + """Test that producer errors are propagated to consumer.""" + + class FailingContent: + @staticmethod + async def read(n=-1): + raise ValueError("Producer error!") + + response = Mock() + response.content = FailingContent() + response.headers = {} + response.closed = False + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + try: + for _ in source.gen: + pass + except ValueError as e: + return str(e) + return "No error raised!" + + error_msg = await loop.run_in_executor(None, consume) + + assert error_msg == "Producer error!" + + +@pytest.mark.asyncio +async def test_gzip_with_incremental_decompression(): + """Test that gzip decompression works incrementally with streaming.""" + original_data = b"The quick brown fox jumps over the lazy dog. " * 100 + compressed = gzip.compress(original_data) + + # Split compressed data into very small chunks to force incremental decompression + chunks = [compressed[i : i + 10] for i in range(0, len(compressed), 10)] + + response = MockResponse(chunks, encoding="gzip") + source = StreamingResponseSource(response, encoding="gzip") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + """Consume and verify we get multiple decompressed chunks.""" + chunks_received = [] + for chunk in source.gen: + chunks_received.append(chunk) + return chunks_received, b"".join(chunks_received) + + chunks_received, decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + assert len([c for c in chunks_received if c]) > 0 + + +@pytest.mark.asyncio +async def test_backpressure_with_bounded_queue(): + """Test that bounded queue provides backpressure.""" + # Create many chunks to test backpressure + chunks = [f"chunk{i}".encode() for i in range(100)] + response = MockResponse(chunks) + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + # Slow consumer + def slow_consume(): + result = [] + for chunk in source.gen: + time.sleep(0.001) + result.append(chunk) + return result + + result = await loop.run_in_executor(None, slow_consume) + + # All chunks should still be received despite slow consumer + assert len(result) == 100 + assert result == chunks + + +class MockTransform: + """Mock NativeTransform.""" + + def __init__(self, chunks=None): + self.chunks = chunks or [b"chunk1", b"chunk2"] + + def build_insert(self, context): # pylint: disable=unused-argument + yield from self.chunks + + +class FailingTransform: + """Mock NativeTransform that raises error.""" + + @staticmethod + def build_insert(context): # pylint: disable=unused-argument + yield b"chunk1" + raise ValueError("Serialization error") + + +class MockContext: + """Mock InsertContext.""" + + +@pytest.mark.asyncio +async def test_streaming_insert_basic(): + """Test basic streaming insert (reverse bridge).""" + transform = MockTransform() + context = MockContext() + loop = asyncio.get_running_loop() + + source = StreamingInsertSource(transform, context, loop) + source.start_producer() + + chunks = [] + async for chunk in source.async_generator(): + chunks.append(chunk) + + await source.close() + + assert chunks == [b"chunk1", b"chunk2"] + + +@pytest.mark.asyncio +async def test_streaming_insert_error_propagation(): + """Test that insert producer errors are propagated to async consumer.""" + transform = FailingTransform() + context = MockContext() + loop = asyncio.get_running_loop() + + source = StreamingInsertSource(transform, context, loop) + source.start_producer() + + chunks = [] + with pytest.raises(ValueError, match="Serialization error"): + async for chunk in source.async_generator(): + chunks.append(chunk) + + await source.close() + + # Should have received first chunk before error + assert chunks == [b"chunk1"] + + +@pytest.mark.asyncio +async def test_streaming_insert_backpressure(): + """Test backpressure in streaming insert.""" + chunks = [f"chunk{i}".encode() for i in range(100)] + transform = MockTransform(chunks) + context = MockContext() + loop = asyncio.get_running_loop() + + # Small queue size to force backpressure + source = StreamingInsertSource(transform, context, loop, maxsize=2) + source.start_producer() + + received = [] + async for chunk in source.async_generator(): + received.append(chunk) + # Yield to allow producer to run (since we're in same loop/process) + await asyncio.sleep(0.001) + + await source.close() + + assert len(received) == 100 + assert received == chunks + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 397125457805b907bacb632ef92de674f9817743 Mon Sep 17 00:00:00 2001 From: Joe S Date: Mon, 12 Jan 2026 16:52:14 -0800 Subject: [PATCH 02/40] add pylint exceptions --- pyproject.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ce7f096e..cb28f806 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,13 @@ log_cli_level = "INFO" env_files = ["test.env"] asyncio_default_fixture_loop_scope = "session" addopts = "-n 4" + +[tool.pylint.messages_control] +disable = [ + "duplicate-code", # Expected duplication between async/sync implementations +] + +[tool.pylint.typecheck] +ignored-modules = [ + "brotli", # Optional compression library +] From d1f262d00cea4a1ea92ed71beed7f7e9622acbc9 Mon Sep 17 00:00:00 2001 From: Joe S Date: Mon, 12 Jan 2026 22:40:32 -0800 Subject: [PATCH 03/40] linting --- clickhouse_connect/driver/aiohttp_client.py | 2 +- clickhouse_connect/driver/streaming.py | 2 ++ pyproject.toml | 10 ---------- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index b8c3a8ec..aae46496 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -1,4 +1,4 @@ -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,duplicate-code,import-error import asyncio import gzip diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index b2fb19dd..99742d79 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -1,3 +1,5 @@ +# pylint: disable=import-error + import asyncio import logging import threading diff --git a/pyproject.toml b/pyproject.toml index cb28f806..ce7f096e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,13 +12,3 @@ log_cli_level = "INFO" env_files = ["test.env"] asyncio_default_fixture_loop_scope = "session" addopts = "-n 4" - -[tool.pylint.messages_control] -disable = [ - "duplicate-code", # Expected duplication between async/sync implementations -] - -[tool.pylint.typecheck] -ignored-modules = [ - "brotli", # Optional compression library -] From 83ad8866c8e9f3a41f484bf1617a9351f06266f5 Mon Sep 17 00:00:00 2001 From: Joe S Date: Mon, 12 Jan 2026 23:04:55 -0800 Subject: [PATCH 04/40] improve test config setup --- tests/integration_tests/conftest.py | 15 +++ tests/integration_tests/test_arrow.py | 6 +- .../integration_tests/test_async_features.py | 95 +++---------------- .../test_form_encode_query.py | 76 +++------------ .../integration_tests/test_multithreading.py | 93 ++++++++++++------ 5 files changed, 109 insertions(+), 176 deletions(-) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index d972f4ed..2219f12d 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -1,3 +1,5 @@ +# pylint: disable=duplicate-code + import asyncio import sys import os @@ -37,6 +39,19 @@ class TestException(BaseException): pass +def make_client_config(test_config: TestConfig, **kwargs): + """Helper to build client config dict from test_config with optional overrides.""" + return { + "host": test_config.host, + "port": test_config.port, + "username": test_config.username, + "password": test_config.password, + "database": test_config.test_database, + "compress": test_config.compress, + **kwargs, + } + + # pylint: disable=redefined-outer-name @fixture(scope='session', autouse=True, name='test_config') diff --git a/tests/integration_tests/test_arrow.py b/tests/integration_tests/test_arrow.py index cb4646e8..25387ed2 100644 --- a/tests/integration_tests/test_arrow.py +++ b/tests/integration_tests/test_arrow.py @@ -8,7 +8,7 @@ from clickhouse_connect.driver.options import arrow -def test_arrow(param_client, call, table_context: Callable): +def test_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') if not param_client.min_version('21'): @@ -37,7 +37,7 @@ def test_arrow(param_client, call, table_context: Callable): assert arrow_table.num_rows == 500 -def test_arrow_stream(param_client, call, table_context, consume_stream): +def test_arrow_stream(param_client: Client, call, table_context, consume_stream): if not arrow: pytest.skip('PyArrow package not available') if not param_client.min_version('21'): @@ -72,7 +72,7 @@ def process(table): assert total_rows == 1000000 -def test_arrow_map(param_client, call, table_context: Callable): +def test_arrow_map(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') if not param_client.min_version('21'): diff --git a/tests/integration_tests/test_async_features.py b/tests/integration_tests/test_async_features.py index ce32b8c2..abc70b97 100644 --- a/tests/integration_tests/test_async_features.py +++ b/tests/integration_tests/test_async_features.py @@ -6,6 +6,7 @@ from clickhouse_connect import get_async_client from clickhouse_connect.driver.exceptions import OperationalError, ProgrammingError +from tests.integration_tests.conftest import make_client_config # pylint: disable=protected-access @@ -13,14 +14,7 @@ @pytest.mark.asyncio async def test_concurrent_queries(test_config): """Verify multiple queries execute concurrently (not sequentially).""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - autogenerate_session_id=False, - ) as client: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: queries = [client.query(f"SELECT {i}, sleep(0.1)") for i in range(10)] start = time.time() @@ -39,13 +33,7 @@ async def test_concurrent_queries(test_config): @pytest.mark.asyncio async def test_stream_cancellation(test_config): """Test that early exit from async iteration doesn't leak resources.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - ) as client: + async with await get_async_client(**make_client_config(test_config)) as client: stream = await client.query_rows_stream("SELECT number FROM numbers(100000)", settings={"max_block_size": 1000}) count = 0 @@ -64,14 +52,7 @@ async def test_stream_cancellation(test_config): @pytest.mark.asyncio async def test_concurrent_streams(test_config): """Verify multiple streams can run in parallel.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - autogenerate_session_id=False, - ) as client: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: async def consume_stream(stream_id: int): stream = await client.query_rows_stream( @@ -95,13 +76,7 @@ async def consume_stream(stream_id: int): @pytest.mark.asyncio async def test_context_manager_cleanup(test_config): """Test proper resource cleanup on context manager exit.""" - client = await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - ) + client = await get_async_client(**make_client_config(test_config)) assert client._initialized is True assert client._session is not None @@ -119,14 +94,7 @@ async def test_context_manager_cleanup(test_config): @pytest.mark.asyncio async def test_session_concurrency_protection(test_config): """Test that concurrent queries in the same session are blocked.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - session_id="test_concurrent_session", - ) as client: + async with await get_async_client(**make_client_config(test_config, session_id="test_concurrent_session")) as client: async def long_query(): return await client.query("SELECT sleep(0.5), 1") @@ -145,11 +113,7 @@ async def quick_query(): async def test_timeout_handling(test_config): """Test that async timeout exceptions propagate correctly.""" async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, + **make_client_config(test_config), send_receive_timeout=1, # 1 second timeout autogenerate_session_id=False, # No session to avoid session locking after timeout ) as client: @@ -166,11 +130,7 @@ async def test_timeout_handling(test_config): async def test_connection_pool_reuse(test_config): """Verify connection pooling works correctly under load.""" async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, + **make_client_config(test_config), connector_limit=10, # Limit pool size connector_limit_per_host=5, autogenerate_session_id=False, @@ -193,14 +153,7 @@ async def test_connection_pool_reuse(test_config): async def test_concurrent_inserts(test_config, table_context: Callable): """Test multiple inserts can run in parallel.""" with table_context("test_concurrent_inserts", ["id UInt32", "value String"]) as ctx: - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - autogenerate_session_id=False, - ) as client: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: async def insert_batch(start_id: int, count: int): data = [[start_id + i, f"value_{start_id + i}"] for i in range(count)] @@ -221,14 +174,7 @@ async def insert_batch(start_id: int, count: int): @pytest.mark.asyncio async def test_error_isolation(test_config): """Test that one failing query doesn't break other concurrent queries.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - autogenerate_session_id=False, - ) as client: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: async def good_query(n: int): return await client.query(f"SELECT {n}") @@ -249,19 +195,12 @@ async def bad_query(): @pytest.mark.asyncio async def test_streaming_early_termination(test_config): """Verify streaming can be terminated early without issues.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - autogenerate_session_id=False, # Don't use session to avoid locking - ) as client: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: stream = await client.query_rows_stream("SELECT number, repeat('x', 10000) FROM numbers(100000)", settings={"max_block_size": 1000}) count = 0 async with stream: - async for row in stream: + async for _ in stream: count += 1 if count >= 1000: break # Early termination @@ -276,7 +215,7 @@ async def test_streaming_early_termination(test_config): count2 = 0 async with stream2: - async for row in stream2: + async for _ in stream2: count2 += 1 assert count2 == 100 @@ -285,13 +224,7 @@ async def test_streaming_early_termination(test_config): @pytest.mark.asyncio async def test_regular_query_streams_then_materializes(test_config): """Verify regular query() uses streaming internally but materializes result.""" - async with await get_async_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - ) as client: + async with await get_async_client(**make_client_config(test_config)) as client: result = await client.query("SELECT number FROM numbers(10000)") assert len(result.result_rows) == 10000 diff --git a/tests/integration_tests/test_form_encode_query.py b/tests/integration_tests/test_form_encode_query.py index 05a881d1..3b80d70f 100644 --- a/tests/integration_tests/test_form_encode_query.py +++ b/tests/integration_tests/test_form_encode_query.py @@ -1,19 +1,11 @@ from typing import Callable from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig -def test_form_encode_query_basic(client_factory, call, test_config: TestConfig, table_context: Callable): +def test_form_encode_query_basic(client_factory, call, table_context: Callable): """Test that form_encode_query sends parameters as form data""" - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_form_encode', ['id UInt32', 'name String', 'value Float64']): call(form_client.insert, 'test_form_encode', @@ -36,16 +28,9 @@ def test_form_encode_query_basic(client_factory, call, test_config: TestConfig, assert result.first_row[0] == 3 -def test_form_encode_with_arrays(client_factory, call, test_config: TestConfig, table_context: Callable): +def test_form_encode_with_arrays(client_factory, call, table_context: Callable): """Test form_encode_query with array parameters""" - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_form_arrays', ['id UInt32', 'tags Array(String)']): call(form_client.insert, 'test_form_arrays', @@ -68,16 +53,9 @@ def test_form_encode_with_arrays(client_factory, call, test_config: TestConfig, assert sorted([row[0] for row in result.result_rows]) == [1, 3] -def test_form_encode_raw_query(client_factory, call, test_config: TestConfig): +def test_form_encode_raw_query(client_factory, call): """Test form_encode_query with raw_query method""" - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) result = call(form_client.raw_query, 'SELECT {a:Int32} + {b:Int32} as sum', @@ -87,25 +65,11 @@ def test_form_encode_raw_query(client_factory, call, test_config: TestConfig): assert b'30' in result -def test_form_encode_vs_regular(client_factory, param_client: Client, call, test_config: TestConfig, table_context: Callable): +def test_form_encode_vs_regular(client_factory, param_client: Client, call, table_context: Callable): """Verify that form_encode_query produces same results as regular parameter handling""" - regular_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=False - ) + regular_client = client_factory(form_encode_query_params=False) - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_comparison', ['id UInt32', 'text String', 'score Float64']): call(param_client.insert, 'test_comparison', @@ -121,16 +85,9 @@ def test_form_encode_vs_regular(client_factory, param_client: Client, call, test assert regular_result.row_count == form_result.row_count -def test_form_encode_nullable_params(client_factory, call, test_config: TestConfig): +def test_form_encode_nullable_params(client_factory, call): """Test form_encode_query with nullable parameters""" - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) result = call(form_client.query, 'SELECT {val:Nullable(String)} IS NULL as is_null', @@ -145,16 +102,9 @@ def test_form_encode_nullable_params(client_factory, call, test_config: TestConf assert result.first_row[0] == 'test_value' -def test_form_encode_schema_probe_query(client_factory, call, test_config: TestConfig, table_context: Callable): +def test_form_encode_schema_probe_query(client_factory, call, table_context: Callable): """Test that schema-probe queries (LIMIT 0) work correctly with form_encode_query_params""" - form_client = client_factory( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) # Test with a simple LIMIT 0 query result = call(form_client.query, 'SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') diff --git a/tests/integration_tests/test_multithreading.py b/tests/integration_tests/test_multithreading.py index 0806f18f..f18ccbc0 100644 --- a/tests/integration_tests/test_multithreading.py +++ b/tests/integration_tests/test_multithreading.py @@ -4,15 +4,13 @@ import pytest +from clickhouse_connect import create_client, get_async_client from clickhouse_connect.driver.exceptions import ProgrammingError -from tests.integration_tests.conftest import TestConfig +from tests.integration_tests.conftest import TestConfig, make_client_config -def test_sync_client_sequential_thread_access(param_client, client_mode, call, test_config: TestConfig): +def test_sync_client_sequential_thread_access(test_client, test_config: TestConfig): """Test that sync clients can handle sequential access from different threads.""" - if client_mode != "sync": - pytest.skip("Only testing sync client behavior") - if test_config.cloud: pytest.skip("Skipping threading test in ClickHouse Cloud") @@ -21,7 +19,7 @@ def test_sync_client_sequential_thread_access(param_client, client_mode, call, t def run_query(value): try: - result = param_client.command(f"SELECT {value}") + result = test_client.command(f"SELECT {value}") results.append(result) except Exception as ex: # pylint: disable=broad-exception-caught errors.append(ex) @@ -36,14 +34,13 @@ def run_query(value): assert results == [0, 1, 2] -def test_async_client_threadsafe_submission(param_client, client_mode, call, test_config: TestConfig, shared_loop): +@pytest.mark.asyncio +async def test_async_client_threadsafe_submission(test_native_async_client, test_config: TestConfig): """Test that async clients work correctly with run_coroutine_threadsafe from multiple threads.""" - if client_mode != "async": - pytest.skip("Only testing async client behavior") - if test_config.cloud: pytest.skip("Skipping threading test in ClickHouse Cloud") + loop = asyncio.get_running_loop() results = [] errors = [] lock = threading.Lock() @@ -51,8 +48,8 @@ def test_async_client_threadsafe_submission(param_client, client_mode, call, tes def run_query_threadsafe(value): try: future = asyncio.run_coroutine_threadsafe( - param_client.command(f"SELECT {value}"), - shared_loop + test_native_async_client.command(f"SELECT {value}"), + loop ) result = future.result(timeout=5) with lock: @@ -61,12 +58,11 @@ def run_query_threadsafe(value): with lock: errors.append(ex) - threads = [threading.Thread(target=run_query_threadsafe, args=(i,)) for i in range(3)] for thread in threads: thread.start() - call(asyncio.sleep, 2) + await asyncio.sleep(2) for thread in threads: thread.join() @@ -76,27 +72,20 @@ def run_query_threadsafe(value): assert sorted(results) == [0, 1, 2] -def test_concurrent_session_usage_detection(client_mode, call, test_config: TestConfig, client_factory, shared_loop): - """Test that ClickHouse server detects concurrent usage of the same session.""" +def test_sync_concurrent_session_usage_detection(test_config: TestConfig): + """Test that ClickHouse server detects concurrent usage of the same session (sync client).""" if test_config.cloud: pytest.skip("Skipping session concurrency test in ClickHouse Cloud") session_id = str(uuid.uuid4()) - client1 = client_factory(session_id=session_id) - client2 = client_factory(session_id=session_id) + client1 = create_client(**make_client_config(test_config, session_id=session_id)) + client2 = create_client(**make_client_config(test_config, session_id=session_id)) thrown = [] def run_query(client): try: - if client_mode == "sync": - client.command("SELECT sleep(1)") - else: - future = asyncio.run_coroutine_threadsafe( - client.command("SELECT sleep(1)"), - shared_loop - ) - future.result(timeout=5) + client.command("SELECT sleep(1)") except (ProgrammingError, Exception) as ex: # pylint: disable=broad-exception-caught thrown.append(ex) @@ -108,13 +97,59 @@ def run_query(client): for thread in threads: thread.start() - if client_mode == "async": - call(asyncio.sleep, 2) - for thread in threads: thread.join() + try: + client1.close() + client2.close() + except Exception: # pylint: disable=broad-exception-caught + pass + # At least one should fail due to concurrent session usage assert len(thrown) > 0, "Expected ClickHouse to detect concurrent session usage" assert any("concurrent" in str(ex).lower() or "session" in str(ex).lower() for ex in thrown), \ f"Expected session concurrency error, got: {thrown}" + + +@pytest.mark.asyncio +async def test_async_concurrent_session_usage_detection(test_config: TestConfig): + """Test that ClickHouse server detects concurrent usage of the same session (async client).""" + if test_config.cloud: + pytest.skip("Skipping session concurrency test in ClickHouse Cloud") + + session_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + + async with await get_async_client(**make_client_config(test_config, session_id=session_id)) as client1, \ + await get_async_client(**make_client_config(test_config, session_id=session_id)) as client2: + + thrown = [] + + def run_query(client): + try: + future = asyncio.run_coroutine_threadsafe( + client.command("SELECT sleep(1)"), + loop + ) + future.result(timeout=5) + except (ProgrammingError, Exception) as ex: # pylint: disable=broad-exception-caught + thrown.append(ex) + + threads = [ + threading.Thread(target=run_query, args=(client1,)), + threading.Thread(target=run_query, args=(client2,)) + ] + + for thread in threads: + thread.start() + + await asyncio.sleep(2) + + for thread in threads: + thread.join() + + # At least one should fail due to concurrent session usage + assert len(thrown) > 0, "Expected ClickHouse to detect concurrent session usage" + assert any("concurrent" in str(ex).lower() or "session" in str(ex).lower() for ex in thrown), \ + f"Expected session concurrency error, got: {thrown}" From 4aff17d207b37440b2aa8a16c95adcaa3d2b2b70 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 12:22:45 -0800 Subject: [PATCH 05/40] fix some tests --- CHANGELOG.md | 1 + tests/integration_tests/test_client.py | 17 ++++++----------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 79c47184..fb100174 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The supported method of passing ClickHouse server settings is to prefix such arg - Fix issue with DROP table in client temp table test. ### Improvements +- Implement a native async client. Closes [#141](https://github.com/ClickHouse/clickhouse-connect/issues/141) - Add support for QBit data type. Closes [#570](https://github.com/ClickHouse/clickhouse-connect/issues/570) - Add the ability to create table from PyArrow objects. Addresses [#588](https://github.com/ClickHouse/clickhouse-connect/issues/588) - Always generate query_id from the client side as a UUID4 if it is not explicitly set. Closes [#596](https://github.com/ClickHouse/clickhouse-connect/issues/596) diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index e7ae4d7d..df3e5e15 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -75,16 +75,11 @@ def test_none_database(param_client, call): param_client.database = old_db -def test_session_params(test_config: TestConfig): +def test_session_params(test_config: TestConfig, client_factory, call): session_id = 'TEST_SESSION_ID_' + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password) - result = client.query('SELECT number FROM system.numbers LIMIT 5', - settings={'query_id': 'test_session_params'}).result_set + client = client_factory(session_id=session_id) + result = call(client.query, 'SELECT number FROM system.numbers LIMIT 5', + settings={'query_id': 'test_session_params'}).result_set assert len(result) == 5 if client.min_version('21'): @@ -94,7 +89,7 @@ def test_session_params(test_config: TestConfig): def check_session_in_log(): max_retries = 100 for _ in range(max_retries): - result = client.query( + result = call(client.query, f"SELECT session_id, user FROM system.session_log WHERE session_id = '{session_id}' AND " + 'event_time > now() - 30').result_set @@ -109,7 +104,7 @@ def check_session_in_log(): def check_query_in_log(): max_retries = 100 for _ in range(max_retries): - result = client.query( + result = call(client.query, "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + 'event_time > now() - 30').result_set From 7392e2bd06e4dd5a144c2dda68f9ddf1e92c2b48 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 13:16:45 -0800 Subject: [PATCH 06/40] enable log tables on startup --- .docker/clickhouse/single_node/config.xml | 2 ++ .docker/clickhouse/single_node_tls/config.xml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.docker/clickhouse/single_node/config.xml b/.docker/clickhouse/single_node/config.xml index 076ed6ce..d42d2637 100644 --- a/.docker/clickhouse/single_node/config.xml +++ b/.docker/clickhouse/single_node/config.xml @@ -21,6 +21,8 @@ 1 + 1 + system query_log
diff --git a/.docker/clickhouse/single_node_tls/config.xml b/.docker/clickhouse/single_node_tls/config.xml index 4ff6cc01..bc710317 100644 --- a/.docker/clickhouse/single_node_tls/config.xml +++ b/.docker/clickhouse/single_node_tls/config.xml @@ -21,6 +21,8 @@ 1 + 1 + /etc/clickhouse-server/certs/server.crt From d96d1d503773cea521d68af53340951c13b60b40 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 13:30:45 -0800 Subject: [PATCH 07/40] improve error handling --- clickhouse_connect/driver/streaming.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index 99742d79..a4a32e94 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -6,6 +6,7 @@ import zlib from typing import Iterator, Optional +from aiohttp.client_exceptions import ClientPayloadError import lz4.frame import zstandard @@ -65,6 +66,17 @@ async def producer(): self._producer_completed = True except Exception as e: + # Swallowing ClientPayloadError to match sync client behavior (which swallows incomplete read errors) + # and allow partial data to be processed (which might contain the error message from ClickHouse) + if isinstance(e, ClientPayloadError): + logger.warning("Payload error while streaming response: %s", e) + try: + await self.queue.async_q.put(EOF_SENTINEL) + except RuntimeError: + pass + self._producer_completed = True + return + logger.error("Producer error while streaming response: %s", e, exc_info=True) self._producer_error = e @@ -310,4 +322,4 @@ async def close(self): except asyncio.TimeoutError: logger.warning("Insert producer did not finish within timeout") except Exception: - pass + pass \ No newline at end of file From 9019bc26b04e4abbe4084085e092a45dbab63395 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 13:37:48 -0800 Subject: [PATCH 08/40] linting --- clickhouse_connect/driver/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index a4a32e94..b0cde9ba 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -322,4 +322,4 @@ async def close(self): except asyncio.TimeoutError: logger.warning("Insert producer did not finish within timeout") except Exception: - pass \ No newline at end of file + pass From 759c7366120ebd9bdeac17bc50f0e28d5800e722 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 14:10:28 -0800 Subject: [PATCH 09/40] fix deprecation warning in aiohttp --- clickhouse_connect/driver/aiohttp_client.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index aae46496..822d079b 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -7,6 +7,7 @@ import logging import re import ssl +import sys import time import uuid import pytz @@ -222,10 +223,13 @@ def __init__( "limit": connector_limit, "limit_per_host": connector_limit_per_host, "keepalive_timeout": keepalive_timeout, - "enable_cleanup_closed": True, "force_close": False, "ssl": ssl_context, } + # enable_cleanup_closed is only needed for Python < 3.14 (cpython issue fixed in 3.14) + # https://github.com/python/cpython/pull/118960 + if sys.version_info < (3, 13, 4): + self._connector_kwargs["enable_cleanup_closed"] = True self._session = None self._read_format = "Native" From 4e050ab9a17eacba0f7a09834d30440f208bd736 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 14:33:32 -0800 Subject: [PATCH 10/40] error handling --- clickhouse_connect/driver/streaming.py | 10 ++-------- clickhouse_connect/driver/transform.py | 4 +++- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index b0cde9ba..ce29c48c 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -66,16 +66,10 @@ async def producer(): self._producer_completed = True except Exception as e: - # Swallowing ClientPayloadError to match sync client behavior (which swallows incomplete read errors) - # and allow partial data to be processed (which might contain the error message from ClickHouse) if isinstance(e, ClientPayloadError): logger.warning("Payload error while streaming response: %s", e) - try: - await self.queue.async_q.put(EOF_SENTINEL) - except RuntimeError: - pass - self._producer_completed = True - return + # Don't send exception - let consumer discover incomplete stream naturally + return # Queue shutdown happens in finally block logger.error("Producer error while streaming response: %s", e, exc_info=True) self._producer_error = e diff --git a/clickhouse_connect/driver/transform.py b/clickhouse_connect/driver/transform.py index 9edee9e1..96865865 100644 --- a/clickhouse_connect/driver/transform.py +++ b/clickhouse_connect/driver/transform.py @@ -56,8 +56,10 @@ def get_block(): if isinstance(ex, StreamCompleteException): # We ran out of data before it was expected, this could be ClickHouse reporting an error # in the response - if source.last_message: + if source.last_message and b'Code: ' in source.last_message: raise StreamFailureError(extract_error_message(source.last_message)) from None + # If there's no ClickHouse error in the buffer, raise generic stream failure + raise StreamFailureError("Stream ended unexpectedly (connection closed by server)") from ex raise block_num += 1 return result_block From 841c5d8de30da8bc6a6ea81eb65d0af91557804f Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 14:33:57 -0800 Subject: [PATCH 11/40] separate sync/async stream failure tests --- tests/integration_tests/test_streaming.py | 67 +++++++++++++++-------- 1 file changed, 45 insertions(+), 22 deletions(-) diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 43eab850..3a485156 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -66,44 +66,67 @@ def process(block): assert block_count > 1 -def test_stream_errors(param_client, call, client_mode, consume_stream): - query_result = call(param_client.query, 'SELECT number FROM numbers(100000)') +def test_stream_errors_sync(test_client): + query_result = test_client.query('SELECT number FROM numbers(100000)') # 1. Test accessing without context manager raises error - if client_mode == 'sync': - with pytest.raises(ProgrammingError, match="context"): - for _ in query_result.row_block_stream: - pass - else: - async def try_iter(): - async for _ in query_result.row_block_stream: - pass - with pytest.raises((ProgrammingError, TypeError)): - call(try_iter) + with pytest.raises(ProgrammingError, match="context"): + for _ in query_result.row_block_stream: + pass assert query_result.row_count == 100000 # 2. Test that previous access consumed the generator, so next access raises StreamClosedError with pytest.raises(StreamClosedError): - # Note: query_result.rows_stream creates a NEW StreamContext, but its internal generator - # (self._block_gen) was consumed by the property access in step 1. - consume_stream(query_result.rows_stream) + with query_result.rows_stream as stream: + for _ in stream: + pass -def test_stream_failure(param_client, call, consume_stream): +@pytest.mark.asyncio +async def test_stream_errors_async(test_native_async_client): + stream = await test_native_async_client.query_row_block_stream('SELECT number FROM numbers(100)') + async with stream: + async for _ in stream: + pass + + # Try to reuse + with pytest.raises(StreamClosedError): + async with stream: + async for _ in stream: + pass + + +def test_stream_failure_sync(test_client): query = ('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + ' where intDiv(1,number-300000)>-100000000') - stream = call(param_client.query_row_block_stream, query) - blocks = 0 + stream = test_client.query_row_block_stream(query) failed = False - def process(block): # pylint: disable=unused-argument - nonlocal blocks - blocks += 1 + try: + with stream: + for _ in stream: + pass + except StreamFailureError as ex: + failed = True + assert 'division by zero' in str(ex).lower() + + assert failed + + +@pytest.mark.asyncio +async def test_stream_failure_async(test_native_async_client): + query = ('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + + ' where intDiv(1,number-300000)>-100000000') + + stream = await test_native_async_client.query_row_block_stream(query) + failed = False try: - consume_stream(stream, process) + async with stream: + async for _ in stream: + pass except StreamFailureError as ex: failed = True assert 'division by zero' in str(ex).lower() From fb31ccb53f59273678ff25f469c894dce891894f Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 14:54:34 -0800 Subject: [PATCH 12/40] more error handling issues --- clickhouse_connect/driver/streaming.py | 5 ----- clickhouse_connect/driver/transform.py | 8 ++++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index ce29c48c..4624487c 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -66,11 +66,6 @@ async def producer(): self._producer_completed = True except Exception as e: - if isinstance(e, ClientPayloadError): - logger.warning("Payload error while streaming response: %s", e) - # Don't send exception - let consumer discover incomplete stream naturally - return # Queue shutdown happens in finally block - logger.error("Producer error while streaming response: %s", e, exc_info=True) self._producer_error = e diff --git a/clickhouse_connect/driver/transform.py b/clickhouse_connect/driver/transform.py index 96865865..89153178 100644 --- a/clickhouse_connect/driver/transform.py +++ b/clickhouse_connect/driver/transform.py @@ -60,6 +60,14 @@ def get_block(): raise StreamFailureError(extract_error_message(source.last_message)) from None # If there's no ClickHouse error in the buffer, raise generic stream failure raise StreamFailureError("Stream ended unexpectedly (connection closed by server)") from ex + + # Handle async streaming errors (ClientPayloadError from aiohttp) + if ex.__class__.__name__ == 'ClientPayloadError': + # Check if ClickHouse sent an error message before closing the connection + if source.last_message and b'Code: ' in source.last_message: + raise StreamFailureError(extract_error_message(source.last_message)) from None + raise StreamFailureError("Stream failed during read (connection closed by server)") from ex + raise block_num += 1 return result_block From b61a1d22f1bcb438b23488cd1d6152cbf4e395c4 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 14:56:34 -0800 Subject: [PATCH 13/40] linting --- clickhouse_connect/driver/streaming.py | 1 - 1 file changed, 1 deletion(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index 4624487c..99742d79 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -6,7 +6,6 @@ import zlib from typing import Iterator, Optional -from aiohttp.client_exceptions import ClientPayloadError import lz4.frame import zstandard From 0f14167638823324320d1c364d673bf9b14c7ee8 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 16:28:45 -0800 Subject: [PATCH 14/40] test updates and improvements --- tests/integration_tests/conftest.py | 5 +++++ tests/integration_tests/test_async_features.py | 2 +- tests/integration_tests/test_sqlalchemy/test_delete.py | 3 +++ tests/test_requirements.txt | 3 ++- 4 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 2219f12d..5b1e6a97 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -171,6 +171,11 @@ def factory(**kwargs): **kwargs, } + # Clear username/password if access_token is provided + if "access_token" in kwargs: + config["username"] = None + config["password"] = "" + if client_mode == "sync": client = create_client(**config) else: diff --git a/tests/integration_tests/test_async_features.py b/tests/integration_tests/test_async_features.py index abc70b97..8bc6bb7b 100644 --- a/tests/integration_tests/test_async_features.py +++ b/tests/integration_tests/test_async_features.py @@ -157,7 +157,7 @@ async def test_concurrent_inserts(test_config, table_context: Callable): async def insert_batch(start_id: int, count: int): data = [[start_id + i, f"value_{start_id + i}"] for i in range(count)] - await client.insert(ctx.table, data) + await client.insert(ctx.table, data, settings={"wait_for_async_insert": 1}) await asyncio.gather( insert_batch(0, 10), diff --git a/tests/integration_tests/test_sqlalchemy/test_delete.py b/tests/integration_tests/test_sqlalchemy/test_delete.py index 0bf9e189..f07f2add 100644 --- a/tests/integration_tests/test_sqlalchemy/test_delete.py +++ b/tests/integration_tests/test_sqlalchemy/test_delete.py @@ -6,6 +6,7 @@ from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import String, UInt64 from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import engine_map +from tests.integration_tests.test_sqlalchemy.conftest import verify_tables_ready def test_delete_with_table_object(test_engine: Engine, test_db: str, test_table_engine: str): @@ -179,6 +180,8 @@ def test_explicit_delete(test_engine: Engine, test_table_engine: str): conn.execute(db.insert(test_table).values({"id": 1, "name": "hello world"})) conn.execute(db.insert(test_table).values({"id": 2, "name": "test data"})) conn.execute(db.insert(test_table).values({"id": 3, "name": "hello test"})) + # Wait for inserts to complete in cloud environments + verify_tables_ready(conn, {"delete_explicit_test": 3}) starting = conn.execute(db.select(test_table).order_by(test_table.c.id)).fetchall() assert len(starting) == 3 assert [row.id for row in starting] == [1, 2, 3] diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index ef0e8bdd..79a85386 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -14,7 +14,8 @@ pytest-cov pytest-xdist numpy~=1.22.0; python_version >= '3.8' and python_version <= '3.10' numpy~=1.26.0; python_version >= '3.11' and python_version <= '3.12' -numpy~=2.1.0; python_version >= '3.13' +numpy~=2.1.0; python_version == '3.13' +numpy>=2.1.0; python_version >= '3.14' pandas>=2.0,<3.0 polars>=1.0 zstandard; python_version < "3.14" From 74f573739b42a134a1872097ab5e2ccbfa459af5 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 13 Jan 2026 16:49:39 -0800 Subject: [PATCH 15/40] test adjustments --- tests/integration_tests/conftest.py | 5 ----- tests/integration_tests/test_jwt_auth.py | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 5b1e6a97..2219f12d 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -171,11 +171,6 @@ def factory(**kwargs): **kwargs, } - # Clear username/password if access_token is provided - if "access_token" in kwargs: - config["username"] = None - config["password"] = "" - if client_mode == "sync": client = create_client(**config) else: diff --git a/tests/integration_tests/test_jwt_auth.py b/tests/integration_tests/test_jwt_auth.py index 592b613f..1b9d273f 100644 --- a/tests/integration_tests/test_jwt_auth.py +++ b/tests/integration_tests/test_jwt_auth.py @@ -23,7 +23,7 @@ def test_jwt_auth_client(test_config: TestConfig, client_factory, call): pytest.skip("Skipping JWT test in non-Cloud mode") access_token = make_access_token() - client = client_factory(access_token=access_token) + client = client_factory(username=None, password="", access_token=access_token) result = call(client.query, CHECK_CLOUD_MODE_QUERY).result_set assert result == [(True,)] @@ -34,7 +34,7 @@ def test_jwt_auth_client_set_access_token(test_config: TestConfig, client_factor pytest.skip("Skipping JWT test in non-Cloud mode") access_token = make_access_token() - client = client_factory(access_token=access_token) + client = client_factory(username=None, password="", access_token=access_token) access_token = make_access_token() client.set_access_token(access_token) From 5781c52fcd70013bec4fa6ef12501edc0473359d Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 11:24:33 -0800 Subject: [PATCH 16/40] increase async stream buffer size --- clickhouse_connect/driver/streaming.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index 99742d79..f21f6a5e 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -27,7 +27,7 @@ class StreamingResponseSource(Closable): """Streaming source that feeds chunks from async producer to sync consumer.""" - READ_BUFFER_SIZE = 512 * 1024 + READ_BUFFER_SIZE = 1024 * 1024 def __init__(self, response, encoding: Optional[str] = None): self.response = response From 6334be8183f5b85ddf4e21ad426eeaecaeda8f5f Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 11:33:01 -0800 Subject: [PATCH 17/40] don't parallelize cloud tests --- .github/workflows/on_push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index 8b5a014e..d9661e7c 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -234,4 +234,4 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest -n 4 tests/integration_tests + run: pytest -n 1 tests/integration_tests From e060dfd369d1ea9bb8918868778a3ee7c97b3c86 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 12:13:39 -0800 Subject: [PATCH 18/40] apply consistency settings to client factory fixture --- tests/integration_tests/conftest.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 2219f12d..99c972af 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -173,8 +173,28 @@ def factory(**kwargs): if client_mode == "sync": client = create_client(**config) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", 1) + if client.min_version("24.8") and (client.min_version("24.12") or not test_config.cloud): + client.set_client_setting("allow_experimental_json_type", 1) + client.set_client_setting("allow_experimental_dynamic_type", 1) + client.set_client_setting("allow_experimental_variant_type", 1) + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", test_config.insert_quorum) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", 1) else: client = shared_loop.run_until_complete(get_async_client(**config)) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") clients.append(client) return client From 590fd12f386eb1368b71bfcb0fbeee2486571365 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 12:13:55 -0800 Subject: [PATCH 19/40] up cloud par to 4 --- .github/workflows/on_push.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index d9661e7c..8b5a014e 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -234,4 +234,4 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest -n 1 tests/integration_tests + run: pytest -n 4 tests/integration_tests From 1b2578ecc302335d52e31af1efa135c7be477b06 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 12:51:47 -0800 Subject: [PATCH 20/40] enforce sequential consistency in helper --- tests/integration_tests/conftest.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 99c972af..f3498e84 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -41,6 +41,12 @@ class TestException(BaseException): def make_client_config(test_config: TestConfig, **kwargs): """Helper to build client config dict from test_config with optional overrides.""" + settings = kwargs.pop("settings", {}).copy() + if test_config.insert_quorum: + settings["insert_quorum"] = test_config.insert_quorum + elif test_config.cloud: + settings["select_sequential_consistency"] = 1 + return { "host": test_config.host, "port": test_config.port, @@ -48,6 +54,7 @@ def make_client_config(test_config: TestConfig, **kwargs): "password": test_config.password, "database": test_config.test_database, "compress": test_config.compress, + "settings": settings, **kwargs, } From 08abfee3ae3d429dea013dc90e7cccafd8788123 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 14:02:53 -0800 Subject: [PATCH 21/40] update enable_cleanup_closed vers --- clickhouse_connect/driver/aiohttp_client.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index 822d079b..97a24f8d 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -226,9 +226,10 @@ def __init__( "force_close": False, "ssl": ssl_context, } - # enable_cleanup_closed is only needed for Python < 3.14 (cpython issue fixed in 3.14) + # enable_cleanup_closed is only needed for Python < 3.12.7 or == 3.13.0 + # The underlying SSL connection leak was fixed in 3.12.7 and 3.13.1+ # https://github.com/python/cpython/pull/118960 - if sys.version_info < (3, 13, 4): + if sys.version_info < (3, 12, 7) or sys.version_info[:3] == (3, 13, 0): self._connector_kwargs["enable_cleanup_closed"] = True self._session = None From 85f086020d647efb663316455adeee6e852934c3 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 14:03:37 -0800 Subject: [PATCH 22/40] add conn abort to wrong port assertion --- tests/integration_tests/test_error_handling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integration_tests/test_error_handling.py b/tests/integration_tests/test_error_handling.py index 4f004d21..2509ca11 100644 --- a/tests/integration_tests/test_error_handling.py +++ b/tests/integration_tests/test_error_handling.py @@ -51,6 +51,7 @@ def test_connection_refused_error(client_factory, test_config: TestConfig, caplo "Connection refused" in error_message or "Failed to establish a new connection" in error_message or "Cannot connect to host" in error_message + or "Connection aborted" in error_message # Port occasionally occupied in CI, apparently ) finally: # Restore the original logging level From 1f1650018daf4c8911b8f44a01d33c8fabd6386f Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 15:24:37 -0800 Subject: [PATCH 23/40] small test refactor --- clickhouse_connect/driver/transform.py | 2 ++ tests/integration_tests/test_streaming.py | 20 ++++++++------------ 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/clickhouse_connect/driver/transform.py b/clickhouse_connect/driver/transform.py index 89153178..8f3a24f6 100644 --- a/clickhouse_connect/driver/transform.py +++ b/clickhouse_connect/driver/transform.py @@ -62,6 +62,8 @@ def get_block(): raise StreamFailureError("Stream ended unexpectedly (connection closed by server)") from ex # Handle async streaming errors (ClientPayloadError from aiohttp) + # Note: Error message extraction is best-effort here. aiohttp's HTTP parser may detect + # protocol errors before the error message is buffered, due to streaming race conditions if ex.__class__.__name__ == 'ClientPayloadError': # Check if ClickHouse sent an error message before closing the connection if source.last_message and b'Code: ' in source.last_message: diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 3a485156..654b83d4 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -102,17 +102,15 @@ def test_stream_failure_sync(test_client): ' where intDiv(1,number-300000)>-100000000') stream = test_client.query_row_block_stream(query) - failed = False - try: + with pytest.raises(StreamFailureError) as excinfo: with stream: for _ in stream: pass - except StreamFailureError as ex: - failed = True - assert 'division by zero' in str(ex).lower() - assert failed + error_msg = str(excinfo.value).lower() + # Race condition: may get actual ClickHouse error or generic connection closed + assert 'division by zero' in error_msg or 'connection closed' in error_msg @pytest.mark.asyncio @@ -121,17 +119,15 @@ async def test_stream_failure_async(test_native_async_client): ' where intDiv(1,number-300000)>-100000000') stream = await test_native_async_client.query_row_block_stream(query) - failed = False - try: + with pytest.raises(StreamFailureError) as excinfo: async with stream: async for _ in stream: pass - except StreamFailureError as ex: - failed = True - assert 'division by zero' in str(ex).lower() - assert failed + error_msg = str(excinfo.value).lower() + # Race condition: may get actual ClickHouse error or generic connection closed + assert 'division by zero' in error_msg or 'connection closed' in error_msg def test_raw_stream(param_client, call, consume_stream): From e95f61b988495610fdd423a29eb89e5755d2f3c8 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 15 Jan 2026 15:56:14 -0800 Subject: [PATCH 24/40] more consistency --- tests/integration_tests/test_sqlalchemy/conftest.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/integration_tests/test_sqlalchemy/conftest.py b/tests/integration_tests/test_sqlalchemy/conftest.py index 7cf76abe..3fbb180d 100644 --- a/tests/integration_tests/test_sqlalchemy/conftest.py +++ b/tests/integration_tests/test_sqlalchemy/conftest.py @@ -11,11 +11,17 @@ @fixture(scope='module', name='test_engine') def test_engine_fixture(test_config: TestConfig) -> Iterator[Engine]: - test_engine: Engine = create_engine( - f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:' + - f'{test_config.port}/{test_config.test_database}?ch_http_max_field_name_size=99999' + + conn_str = ( + f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:' + f'{test_config.port}/{test_config.test_database}?ch_http_max_field_name_size=99999' '&use_skip_indexes=0&ca_cert=certifi&query_limit=2333&compression=zstd' ) + if test_config.cloud: + conn_str += '&select_sequential_consistency=1' + if test_config.insert_quorum: + conn_str += f'&insert_quorum={test_config.insert_quorum}' + + test_engine: Engine = create_engine(conn_str) yield test_engine test_engine.dispose() From ceb672aa607f02ddc17222046bcc2ccadcbebdd4 Mon Sep 17 00:00:00 2001 From: Joe S Date: Fri, 16 Jan 2026 15:30:20 -0800 Subject: [PATCH 25/40] add deadlock check in async queue --- clickhouse_connect/driver/aiohttp_client.py | 11 ++++------- clickhouse_connect/driver/asyncqueue.py | 21 +++++++++++++++++++++ tests/integration_tests/test_native_fuzz.py | 20 ++++++++++++++++---- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index 97a24f8d..d3d5a656 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -545,18 +545,15 @@ def parse_streaming(): context.set_response_tz(self._check_tz_change(tz_header)) result = self._transform.parse_response(byte_source, context) - # CRITICAL: For non-streaming queries, force full materialization while still in executor thread. - # This prevents the event loop from ever calling blocking queue.sync_q.get() operations - # which would deadlock the entire event loop when backpressure occurs + # For Pandas/Numpy, we must materialize in the executor because the resulting objects + # (DataFrame, Array) are fully in-memory structures. + # For standard queries, we return a lazy QueryResult. Accessing .result_set on the event loop + # will raise a ProgrammingError (deadlock check), encouraging usage of .rows_stream. if not context.streaming: if context.as_pandas and hasattr(result, 'df_result'): _ = result.df_result elif context.use_numpy and hasattr(result, 'np_result'): _ = result.np_result - elif hasattr(result, 'result_set'): - # Materialize rows (closes the stream) - # Avoid pre-populating result_columns. User can access later if needed - _ = result.result_set return result diff --git a/clickhouse_connect/driver/asyncqueue.py b/clickhouse_connect/driver/asyncqueue.py index 9cd84b49..e052a16f 100644 --- a/clickhouse_connect/driver/asyncqueue.py +++ b/clickhouse_connect/driver/asyncqueue.py @@ -3,6 +3,8 @@ from collections import deque from typing import Deque, Generic, Optional, TypeVar +from clickhouse_connect.driver.exceptions import ProgrammingError + __all__ = ["AsyncSyncQueue", "Empty", "Full", "EOF_SENTINEL"] T = TypeVar("T") @@ -39,6 +41,22 @@ def _bind_loop(self): except RuntimeError: pass + def _check_deadlock(self): + """Check if blocking would cause a deadlock on the event loop.""" + if self._loop is None: + return + + try: + current_loop = asyncio.get_running_loop() + if current_loop is self._loop: + raise ProgrammingError( + "Deadlock detected: Synchronous blocking operation called on event loop thread. " + "This usually happens when iterating a stream synchronously (e.g., 'for row in result') " + "instead of asynchronously ('async for row in result') inside an async function." + ) + except RuntimeError: + pass + def _wakeup_async_waiter(self, waiter_queue: Deque[asyncio.Future]): """Helper: Wake up the next async waiter in the queue safely.""" while waiter_queue: @@ -82,6 +100,7 @@ def get(self, block: bool = True, timeout: Optional[float] = None) -> T: if not block: raise Empty() + self._p._check_deadlock() if not self._p._sync_not_empty.wait(timeout): raise Empty() @@ -102,6 +121,8 @@ def put(self, item: T, block: bool = True, timeout: Optional[float] = None) -> N while self._p._maxsize > 0 and len(self._p._queue) >= self._p._maxsize: if not block: raise Full() + + self._p._check_deadlock() if not self._p._sync_not_full.wait(timeout): raise Full() if self._p._shutdown: diff --git a/tests/integration_tests/test_native_fuzz.py b/tests/integration_tests/test_native_fuzz.py index 0d83e4c8..191e74d9 100644 --- a/tests/integration_tests/test_native_fuzz.py +++ b/tests/integration_tests/test_native_fuzz.py @@ -1,3 +1,4 @@ +import asyncio import os import random @@ -13,7 +14,7 @@ # pylint: disable=duplicate-code -def test_query_fuzz(param_client: Client, call, test_table_engine: str): +def test_query_fuzz(param_client: Client, call, test_table_engine: str, client_mode: str): if not param_client.min_version('21'): pytest.skip(f'flatten_nested setting not supported in this server version {param_client.server_version}') test_runs = int(os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250')) @@ -32,9 +33,20 @@ def test_query_fuzz(param_client: Client, call, test_table_engine: str): call(param_client.command, create_stmt, settings={'flatten_nested': 0}) call(param_client.insert, 'fuzz_test', data, col_names) - data_result = call(param_client.query, 'SELECT * FROM fuzz_test') + if client_mode == 'async': + async def get_results(): + result = await param_client.query('SELECT * FROM fuzz_test') + loop = asyncio.get_running_loop() + rows = await loop.run_in_executor(None, lambda: list(result.result_set)) + return rows, result.column_names + result_rows, result_cols = call(get_results) + else: + data_result = call(param_client.query, 'SELECT * FROM fuzz_test') + result_rows = data_result.result_set + result_cols = data_result.column_names + if data_rows: - assert data_result.column_names == col_names - assert data_result.result_set == data + assert result_cols == col_names + assert result_rows == data finally: param_client.apply_server_timezone = False From 038dfb5b59618d33757a09082a6e4898b68ae3a1 Mon Sep 17 00:00:00 2001 From: Joe S Date: Fri, 16 Jan 2026 22:43:50 -0800 Subject: [PATCH 26/40] enforce consistent use of async iter stream consumption --- clickhouse_connect/driver/aiohttp_client.py | 15 +++++++++++---- tests/integration_tests/test_pandas_compat.py | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index d3d5a656..b2bdd0d0 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -310,7 +310,12 @@ async def _initialize(self, apply_server_timezone: Optional[Union[str, bool]] = readonly = common.get_setting("readonly") server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") - self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} + settings_map = {} + async with server_settings.rows_stream as stream: + async for row in stream: + row_dict = dict(zip(server_settings.column_names, row)) + settings_map[row_dict["name"]] = SettingDef(**row_dict) + self.server_settings = settings_map if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): try: @@ -1378,9 +1383,11 @@ async def create_insert_context( # type: ignore[override] column_defs = [] if column_types is None and column_type_names is None: describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) - column_defs = [ - ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") - ] + async with describe_result.rows_stream as stream: + async for row in stream: + row_dict = dict(zip(describe_result.column_names, row)) + if row_dict["default_type"] not in ("ALIAS", "MATERIALIZED"): + column_defs.append(ColumnDef(**row_dict)) if column_names is None or isinstance(column_names, str) and column_names == "*": column_names = [cd.name for cd in column_defs] column_types = [cd.ch_type for cd in column_defs] diff --git a/tests/integration_tests/test_pandas_compat.py b/tests/integration_tests/test_pandas_compat.py index 5723d498..1782e298 100644 --- a/tests/integration_tests/test_pandas_compat.py +++ b/tests/integration_tests/test_pandas_compat.py @@ -304,7 +304,7 @@ def test_pandas_query_df_arrow(param_client: Client, call, table_context: Callab result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") -def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable): +def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable, client_mode: str): if not arrow: pytest.skip("PyArrow package not available") @@ -323,8 +323,21 @@ def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Calla if IS_PANDAS_2: df = df.convert_dtypes(dtype_backend="pyarrow") call(param_client.insert_df_arrow, table_name, df) - res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") - assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] + + if client_mode == 'async': + async def get_rows(): + result = await param_client.query(f"SELECT * from {table_name} ORDER BY i64") + rows = [] + async with result.rows_stream as stream: + async for row in stream: + rows.append(row) + return rows + rows = call(get_rows) + else: + res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") + rows = res_df.result_rows + + assert rows == [(51, 421, "b"), (78, None, "a")] else: with pytest.raises(ProgrammingError, match="pandas 2.x"): call(param_client.insert_df_arrow, table_name, df) From beaa28a79c80172e599799654c8b0cb985160a69 Mon Sep 17 00:00:00 2001 From: Joe S Date: Sat, 17 Jan 2026 12:56:07 -0800 Subject: [PATCH 27/40] linting --- tests/integration_tests/test_pandas_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests/test_pandas_compat.py b/tests/integration_tests/test_pandas_compat.py index 1782e298..4650914e 100644 --- a/tests/integration_tests/test_pandas_compat.py +++ b/tests/integration_tests/test_pandas_compat.py @@ -323,7 +323,7 @@ def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Calla if IS_PANDAS_2: df = df.convert_dtypes(dtype_backend="pyarrow") call(param_client.insert_df_arrow, table_name, df) - + if client_mode == 'async': async def get_rows(): result = await param_client.query(f"SELECT * from {table_name} ORDER BY i64") @@ -336,7 +336,7 @@ async def get_rows(): else: res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") rows = res_df.result_rows - + assert rows == [(51, 421, "b"), (78, None, "a")] else: with pytest.raises(ProgrammingError, match="pandas 2.x"): From 45c96c1db5c01257f4564670e847d0b4b3435c38 Mon Sep 17 00:00:00 2001 From: Joe S Date: Sat, 17 Jan 2026 22:46:34 -0800 Subject: [PATCH 28/40] DO materialize non-streaming queries --- clickhouse_connect/driver/aiohttp_client.py | 17 ++++++----------- tests/integration_tests/test_arrow.py | 6 +++--- tests/integration_tests/test_pandas_compat.py | 19 +++---------------- 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index b2bdd0d0..ad9ca62f 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -310,12 +310,7 @@ async def _initialize(self, apply_server_timezone: Optional[Union[str, bool]] = readonly = common.get_setting("readonly") server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") - settings_map = {} - async with server_settings.rows_stream as stream: - async for row in stream: - row_dict = dict(zip(server_settings.column_names, row)) - settings_map[row_dict["name"]] = SettingDef(**row_dict) - self.server_settings = settings_map + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): try: @@ -559,6 +554,8 @@ def parse_streaming(): _ = result.df_result elif context.use_numpy and hasattr(result, 'np_result'): _ = result.np_result + elif isinstance(result, QueryResult): + _ = result.result_set return result @@ -1383,11 +1380,9 @@ async def create_insert_context( # type: ignore[override] column_defs = [] if column_types is None and column_type_names is None: describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) - async with describe_result.rows_stream as stream: - async for row in stream: - row_dict = dict(zip(describe_result.column_names, row)) - if row_dict["default_type"] not in ("ALIAS", "MATERIALIZED"): - column_defs.append(ColumnDef(**row_dict)) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] if column_names is None or isinstance(column_names, str) and column_names == "*": column_names = [cd.name for cd in column_defs] column_types = [cd.ch_type for cd in column_defs] diff --git a/tests/integration_tests/test_arrow.py b/tests/integration_tests/test_arrow.py index 25387ed2..676240d0 100644 --- a/tests/integration_tests/test_arrow.py +++ b/tests/integration_tests/test_arrow.py @@ -30,7 +30,7 @@ def test_arrow(param_client: Client, call, table_context: Callable): assert len(result_table.columns) == 2 arrow_table = call(param_client.query_arrow, 'SELECT number from system.numbers LIMIT 500', - settings={'max_block_size': 50}) + settings={'max_block_size': 50}) arrow_schema = arrow_table.schema assert arrow_schema.field(0).name == 'number' assert arrow_schema.field(0).type.id == 8 @@ -83,9 +83,9 @@ def test_arrow_map(param_client: Client, call, table_context: Callable): data = [[date(2023, 10, 15), 'C1', {'k': 2.5, 'd': 0, 'j': 0}], [date(2023, 10, 16), 'C2', {'k': 3.5, 'd': 0, 'j': -.372}]] call(param_client.insert, 'test_arrow_map', data, column_names=('trade_date', 'code', 'kdj'), - settings={'insert_deduplication_token': '10381'}) + settings={'insert_deduplication_token': '10381'}) arrow_table = call(param_client.query_arrow, 'SELECT * FROM test_arrow_map ORDER BY trade_date', - use_strings=True) + use_strings=True) assert isinstance(arrow_table.schema, arrow.Schema) call(param_client.insert_arrow, 'test_arrow_map', arrow_table, settings={'insert_deduplication_token': '10382'}) assert 4 == call(param_client.command, 'SELECT count() FROM test_arrow_map') diff --git a/tests/integration_tests/test_pandas_compat.py b/tests/integration_tests/test_pandas_compat.py index 4650914e..5723d498 100644 --- a/tests/integration_tests/test_pandas_compat.py +++ b/tests/integration_tests/test_pandas_compat.py @@ -304,7 +304,7 @@ def test_pandas_query_df_arrow(param_client: Client, call, table_context: Callab result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") -def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable, client_mode: str): +def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") @@ -323,21 +323,8 @@ def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Calla if IS_PANDAS_2: df = df.convert_dtypes(dtype_backend="pyarrow") call(param_client.insert_df_arrow, table_name, df) - - if client_mode == 'async': - async def get_rows(): - result = await param_client.query(f"SELECT * from {table_name} ORDER BY i64") - rows = [] - async with result.rows_stream as stream: - async for row in stream: - rows.append(row) - return rows - rows = call(get_rows) - else: - res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") - rows = res_df.result_rows - - assert rows == [(51, 421, "b"), (78, None, "a")] + res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") + assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] else: with pytest.raises(ProgrammingError, match="pandas 2.x"): call(param_client.insert_df_arrow, table_name, df) From 66438d81936a7f29f0f0ac26e6e1bdf7a382ba28 Mon Sep 17 00:00:00 2001 From: Joe S Date: Sun, 18 Jan 2026 21:45:53 -0800 Subject: [PATCH 29/40] accept either concurrent session error --- tests/integration_tests/test_async_features.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/integration_tests/test_async_features.py b/tests/integration_tests/test_async_features.py index 8bc6bb7b..86b8ee03 100644 --- a/tests/integration_tests/test_async_features.py +++ b/tests/integration_tests/test_async_features.py @@ -5,7 +5,7 @@ import pytest from clickhouse_connect import get_async_client -from clickhouse_connect.driver.exceptions import OperationalError, ProgrammingError +from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError from tests.integration_tests.conftest import make_client_config # pylint: disable=protected-access @@ -103,10 +103,17 @@ async def quick_query(): await asyncio.sleep(0.1) return await client.query("SELECT 1") - with pytest.raises(ProgrammingError) as exc_info: + # This can raise either: + # - ProgrammingError (client-side detection - best effort) + # - DatabaseError code 373 (server-side SESSION_IS_LOCKED - when client check is too slow) + # Both are valid ways to detect the concurrent session violation. + with pytest.raises((ProgrammingError, DatabaseError)) as exc_info: await asyncio.gather(long_query(), quick_query()) - assert "concurrent" in str(exc_info.value).lower() or "session" in str(exc_info.value).lower() + # Verify it's the right kind of error (concurrent session access) + error_msg = str(exc_info.value).lower() + assert ("concurrent" in error_msg or "session" in error_msg or "locked" in error_msg), \ + f"Expected session concurrency error, got: {exc_info.value}" @pytest.mark.asyncio From 95d30d8c8dcc774fd3082893b14d1cc22b9349e0 Mon Sep 17 00:00:00 2001 From: Joe S Date: Tue, 20 Jan 2026 09:21:08 -0800 Subject: [PATCH 30/40] update async client to always reset context --- clickhouse_connect/driver/aiohttp_client.py | 2 +- tests/integration_tests/test_contexts.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index ad9ca62f..e4c02b5a 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -1460,8 +1460,8 @@ async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ig raise finally: await streaming_source.close() + context.data = None - context.data = None return QuerySummary(self._summary(response)) async def insert_df( # type: ignore[override] diff --git a/tests/integration_tests/test_contexts.py b/tests/integration_tests/test_contexts.py index a3874ff4..c25332e0 100644 --- a/tests/integration_tests/test_contexts.py +++ b/tests/integration_tests/test_contexts.py @@ -29,14 +29,14 @@ def test_contexts(param_client: Client, call, table_context: Callable): call(param_client.insert, context=insert_context) assert call(param_client.command, f'SELECT count() FROM {ctx.table}') == 6 -def test_insert_context_data_cleared_on_failure(test_client: Client, table_context: Callable): +def test_insert_context_data_cleared_on_failure(param_client: Client, call, table_context: Callable): with table_context('test_contexts', ['key Int32', 'value1 String', 'value2 String']) as ctx: data = [[1, "v1", "v2"], [2, "v3", "v4"]] - insert_context = test_client.create_insert_context(table=ctx.table, data=data) + insert_context = call(param_client.create_insert_context, table=ctx.table, data=data) insert_context.table = f"{ctx.table}__does_not_exist" with pytest.raises(Exception): - test_client.insert(context=insert_context) + call(param_client.insert, context=insert_context) assert insert_context.data is None From f4e0010a552ae75c055f9617e2b030f63a218e78 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 29 Jan 2026 16:23:55 -0800 Subject: [PATCH 31/40] add pipelining to arrow methods --- clickhouse_connect/driver/aiohttp_client.py | 149 +++++++++++++++----- 1 file changed, 116 insertions(+), 33 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index e4c02b5a..65683b04 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -38,8 +38,9 @@ from clickhouse_connect.driver.insert import InsertContext from clickhouse_connect.driver.models import ColumnDef, SettingDef from clickhouse_connect.driver.options import IS_PANDAS_2, arrow, check_arrow, check_numpy, check_pandas, check_polars, pd, pl -from clickhouse_connect.driver.query import QueryContext, QueryResult, arrow_buffer, to_arrow +from clickhouse_connect.driver.query import QueryContext, QueryResult, arrow_buffer from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.asyncqueue import AsyncSyncQueue, EOF_SENTINEL from clickhouse_connect.driver.streaming import StreamingInsertSource from clickhouse_connect.driver.transform import NativeTransform from clickhouse_connect.driver.streaming import StreamingResponseSource, StreamingFileAdapter @@ -1044,16 +1045,33 @@ async def query_arrow( check_arrow() self._add_integration_tag("arrow") settings = self._update_arrow_settings(settings, use_strings) - return to_arrow( - await self.raw_query( - query, - parameters, - settings, - fmt="Arrow", - external_data=external_data, - transport_settings=transport_settings, - ) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries ) + encoding = response.headers.get("Content-Encoding") + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding) + await streaming_source.start_producer(loop) + + def parse_arrow_stream(): + file_adapter = StreamingFileAdapter(streaming_source) + reader = arrow.ipc.open_stream(file_adapter) + return reader.read_all() + + try: + return await loop.run_in_executor(None, parse_arrow_stream) + finally: + await streaming_source.aclose() async def query_arrow_stream( # type: ignore[override] self, @@ -1096,26 +1114,58 @@ async def query_arrow_stream( # type: ignore[override] streaming_source = StreamingResponseSource(response, encoding=encoding) await streaming_source.start_producer(loop) + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + def parse_arrow_streaming(): """Parse Arrow stream incrementally in executor (off event loop).""" - # Wrap streaming source with file-like adapter for PyArrow - file_adapter = StreamingFileAdapter(streaming_source) - reader = arrow.ipc.open_stream(file_adapter) + try: + file_adapter = StreamingFileAdapter(streaming_source) + reader = arrow.ipc.open_stream(file_adapter) - batches = [] - for batch in reader: - batches.append(batch) + for batch in reader: + try: + queue.sync_q.put(batch) + except RuntimeError: + return - return batches + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() - batches = await loop.run_in_executor(None, parse_arrow_streaming) + loop.run_in_executor(None, parse_arrow_streaming) async def arrow_batch_generator(): """Async generator that yields record batches without blocking event loop.""" - for batch in batches: - yield batch + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item - return StreamContext(None, arrow_batch_generator()) + return StreamContext(_ArrowStreamSource(streaming_source, queue), arrow_batch_generator()) async def query_df_arrow( self, @@ -1234,27 +1284,60 @@ def converter(table: "arrow.Table") -> "pl.DataFrame": streaming_source = StreamingResponseSource(response, encoding=encoding) await streaming_source.start_producer(loop) + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowDFStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + def parse_and_convert_streaming(): """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" - file_adapter = StreamingFileAdapter(streaming_source) + try: + file_adapter = StreamingFileAdapter(streaming_source) - # PyArrow reads incrementally from adapter (which pulls from queue) - reader = arrow.ipc.open_stream(file_adapter) + # PyArrow reads incrementally from adapter (which pulls from queue) + reader = arrow.ipc.open_stream(file_adapter) - dataframes = [] - for batch in reader: - dataframes.append(converter(batch)) + for batch in reader: + try: + queue.sync_q.put(converter(batch)) + except RuntimeError: + return - return dataframes + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() - dataframes = await loop.run_in_executor(None, parse_and_convert_streaming) + loop.run_in_executor(None, parse_and_convert_streaming) async def df_generator(): """Async generator that yields DataFrames without blocking event loop.""" - for df in dataframes: - yield df - - return StreamContext(None, df_generator()) + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowDFStreamSource(streaming_source, queue), df_generator()) async def insert_arrow( # type: ignore[override] self, From f65502af865b90e495b96b90744dce06d3d17438 Mon Sep 17 00:00:00 2001 From: Joe S Date: Wed, 11 Feb 2026 15:19:01 -0800 Subject: [PATCH 32/40] feature parity with recent sync client changes --- clickhouse_connect/driver/aiohttp_client.py | 28 ++++++++++++++--- clickhouse_connect/driver/streaming.py | 8 ++++- .../test_mid_stream_exception.py | 31 +++++++++++++++++++ tests/unit_tests/test_streaming_source.py | 5 +-- 4 files changed, 64 insertions(+), 8 deletions(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index 65683b04..41a6ce70 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -28,7 +28,7 @@ from clickhouse_connect.datatypes.registry import get_from_name from clickhouse_connect.driver import httputil, tzutil from clickhouse_connect.driver.binding import bind_query, quote_identifier -from clickhouse_connect.driver.client import Client +from clickhouse_connect.driver.client import Client, _strip_utc_timezone_from_arrow from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy from clickhouse_connect.driver.compression import available_compression from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD @@ -48,6 +48,7 @@ logger = logging.getLogger(__name__) columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) ex_header = "X-ClickHouse-Exception-Code" +ex_tag_header = "X-ClickHouse-Exception-Tag" if "br" in available_compression: import brotli @@ -323,6 +324,11 @@ async def _initialize(self, apply_server_timezone: Optional[Union[str, bool]] = except Exception: pass + cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") + if cancel_setting.is_writable and not cancel_setting.is_set and \ + "cancel_http_readonly_queries_on_client_close" not in (self._initial_settings or {}): + self._client_settings["cancel_http_readonly_queries_on_client_close"] = "1" + if self._initial_settings: for key, value in self._initial_settings.items(): self.set_client_setting(key, value) @@ -533,9 +539,10 @@ def decompress_and_parse_json(): stream=True, retries=self.query_retries) encoding = response.headers.get("Content-Encoding") tz_header = response.headers.get("X-ClickHouse-Timezone") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding) + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) await streaming_source.start_producer(loop) def parse_streaming(): @@ -1058,9 +1065,10 @@ async def query_arrow( stream=True, server_wait=False, retries=self.query_retries ) encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding) + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) await streaming_source.start_producer(loop) def parse_arrow_stream(): @@ -1109,9 +1117,10 @@ async def query_arrow_stream( # type: ignore[override] stream=True, server_wait=False, retries=self.query_retries ) encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding) + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) await streaming_source.start_producer(loop) queue = AsyncSyncQueue(maxsize=10) @@ -1200,6 +1209,8 @@ async def query_df_arrow( raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") def converter(table: "arrow.Table") -> "pd.DataFrame": + if not self.utc_tz_aware: + table = _strip_utc_timezone_from_arrow(table) return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) elif dataframe_library == "polars": @@ -1207,6 +1218,8 @@ def converter(table: "arrow.Table") -> "pd.DataFrame": self._add_integration_tag("polars") def converter(table: "arrow.Table") -> "pl.DataFrame": + if not self.utc_tz_aware: + table = _strip_utc_timezone_from_arrow(table) return pl.from_arrow(table) else: @@ -1254,6 +1267,8 @@ async def query_df_arrow_stream( # type: ignore[override] raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") def converter(table: "arrow.Table") -> "pd.DataFrame": + if not self.utc_tz_aware: + table = _strip_utc_timezone_from_arrow(table) return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) elif dataframe_library == "polars": @@ -1261,6 +1276,8 @@ def converter(table: "arrow.Table") -> "pd.DataFrame": self._add_integration_tag("polars") def converter(table: "arrow.Table") -> "pl.DataFrame": + if not self.utc_tz_aware: + table = _strip_utc_timezone_from_arrow(table) return pl.from_arrow(table) else: @@ -1279,9 +1296,10 @@ def converter(table: "arrow.Table") -> "pl.DataFrame": stream=True, server_wait=False, retries=self.query_retries ) encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding) + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) await streaming_source.start_producer(loop) queue = AsyncSyncQueue(maxsize=10) diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py index f21f6a5e..476083e7 100644 --- a/clickhouse_connect/driver/streaming.py +++ b/clickhouse_connect/driver/streaming.py @@ -11,6 +11,7 @@ from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.exceptions import OperationalError from clickhouse_connect.driver.types import Closable logger = logging.getLogger(__name__) @@ -29,9 +30,10 @@ class StreamingResponseSource(Closable): READ_BUFFER_SIZE = 1024 * 1024 - def __init__(self, response, encoding: Optional[str] = None): + def __init__(self, response, encoding: Optional[str] = None, exception_tag: Optional[str] = None): self.response = response self.encoding = encoding + self.exception_tag = exception_tag # maxsize=10 means max ~10 socket reads buffered self.queue = AsyncSyncQueue(maxsize=10) @@ -54,11 +56,13 @@ async def start_producer(self, loop: asyncio.AbstractEventLoop): async def producer(): """Async producer: reads chunks from response, feeds queue.""" + data_sent = False try: while True: chunk = await self.response.content.read(self.READ_BUFFER_SIZE) if not chunk: break + data_sent = True await self.queue.async_q.put(chunk) await self.queue.async_q.put(EOF_SENTINEL) @@ -66,6 +70,8 @@ async def producer(): except Exception as e: logger.error("Producer error while streaming response: %s", e, exc_info=True) + if not data_sent: + e = OperationalError("Failed to read response data from server") self._producer_error = e try: diff --git a/tests/integration_tests/test_mid_stream_exception.py b/tests/integration_tests/test_mid_stream_exception.py index 174f9d3a..848716bf 100644 --- a/tests/integration_tests/test_mid_stream_exception.py +++ b/tests/integration_tests/test_mid_stream_exception.py @@ -29,3 +29,34 @@ def test_mid_stream_exception_streaming(test_client: Client): error_msg = str(exc_info.value) assert "Value passed to 'throwIf' function is non-zero" in error_msg assert test_client.command("SELECT 1") == 1 + + +@pytest.mark.asyncio +async def test_mid_stream_exception_async(test_native_async_client): + """Test that mid-stream exceptions are properly detected and raised (async).""" + query = "SELECT sleepEachRow(0.01), throwIf(number=100) FROM numbers(200)" + + with pytest.raises(StreamFailureError) as exc_info: + result = await test_native_async_client.query(query, settings={"max_block_size": 1, "wait_end_of_query": 0}) + _ = result.result_set + + error_msg = str(exc_info.value) + assert "Value passed to 'throwIf' function is non-zero" in error_msg + assert await test_native_async_client.command("SELECT 1") == 1 + + +@pytest.mark.asyncio +async def test_mid_stream_exception_streaming_async(test_native_async_client): + """Test that mid-stream exceptions are properly detected in streaming mode (async).""" + query = "SELECT sleepEachRow(0.01), throwIf(number=100) FROM numbers(200)" + + with pytest.raises(StreamFailureError) as exc_info: + async with await test_native_async_client.query_rows_stream( + query, settings={"max_block_size": 1, "wait_end_of_query": 0} + ) as stream: + async for _ in stream: + pass + + error_msg = str(exc_info.value) + assert "Value passed to 'throwIf' function is non-zero" in error_msg + assert await test_native_async_client.command("SELECT 1") == 1 diff --git a/tests/unit_tests/test_streaming_source.py b/tests/unit_tests/test_streaming_source.py index 6794a0a2..06c62ec5 100644 --- a/tests/unit_tests/test_streaming_source.py +++ b/tests/unit_tests/test_streaming_source.py @@ -8,6 +8,7 @@ import pytest import zstandard +from clickhouse_connect.driver.exceptions import OperationalError from clickhouse_connect.driver.streaming import ( StreamingInsertSource, StreamingResponseSource, @@ -288,13 +289,13 @@ def consume(): try: for _ in source.gen: pass - except ValueError as e: + except OperationalError as e: return str(e) return "No error raised!" error_msg = await loop.run_in_executor(None, consume) - assert error_msg == "Producer error!" + assert error_msg == "Failed to read response data from server" @pytest.mark.asyncio From 46bc5975c98d7df2b9e25f6fec90ab25d45af44c Mon Sep 17 00:00:00 2001 From: Joe S Date: Wed, 11 Feb 2026 15:32:59 -0800 Subject: [PATCH 33/40] fix jwt test after bad merge conflict resolution --- tests/integration_tests/test_jwt_auth.py | 118 +++++++++++++++++------ 1 file changed, 91 insertions(+), 27 deletions(-) diff --git a/tests/integration_tests/test_jwt_auth.py b/tests/integration_tests/test_jwt_auth.py index 6bd01782..9e2e69ad 100644 --- a/tests/integration_tests/test_jwt_auth.py +++ b/tests/integration_tests/test_jwt_auth.py @@ -1,24 +1,30 @@ +from os import environ + import pytest -from clickhouse_connect.driver import ProgrammingError +from clickhouse_connect.driver import create_client, ProgrammingError, create_async_client from tests.integration_tests.conftest import TestConfig -pytest.skip("JWT tests are not yet configured", allow_module_level=True) +pytest.skip('JWT tests are not yet configured', allow_module_level=True) def test_jwt_auth_sync_client(test_config: TestConfig): if not test_config.cloud: - pytest.skip("Skipping JWT test in non-Cloud mode") + pytest.skip('Skipping JWT test in non-Cloud mode') access_token = make_access_token() - client = create_client(host=test_config.host, port=test_config.port, access_token=access_token) + client = create_client( + host=test_config.host, + port=test_config.port, + access_token=access_token + ) result = client.query(query=CHECK_CLOUD_MODE_QUERY).result_set assert result == [(True,)] def test_jwt_auth_sync_client_set_access_token(test_config: TestConfig): if not test_config.cloud: - pytest.skip("Skipping JWT test in non-Cloud mode") + pytest.skip('Skipping JWT test in non-Cloud mode') access_token = make_access_token() client = create_client( @@ -37,60 +43,118 @@ def test_jwt_auth_sync_client_set_access_token(test_config: TestConfig): def test_jwt_auth_sync_client_config_errors(): with pytest.raises(ProgrammingError): - create_client(username="bob", access_token="foobar") + create_client( + username='bob', + access_token='foobar' + ) with pytest.raises(ProgrammingError): - create_client(username="bob", password="secret", access_token="foo") + create_client( + username='bob', + password='secret', + access_token='foo' + ) with pytest.raises(ProgrammingError): - create_client(password="secret", access_token="foo") + create_client( + password='secret', + access_token='foo' + ) def test_jwt_auth_sync_client_set_access_token_errors(test_config: TestConfig): if not test_config.cloud: - pytest.skip("Skipping JWT test in non-Cloud mode") + pytest.skip('Skipping JWT test in non-Cloud mode') + + client = create_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + ) + + # Can't use JWT with username/password + access_token = make_access_token() + with pytest.raises(ProgrammingError): + client.set_access_token(access_token) + + +@pytest.mark.asyncio +async def test_jwt_auth_async_client(test_config: TestConfig): + if not test_config.cloud: + pytest.skip('Skipping JWT test in non-Cloud mode') access_token = make_access_token() - client = client_factory(username=None, password="", access_token=access_token) - result = call(client.query, CHECK_CLOUD_MODE_QUERY).result_set + client = await create_async_client( + host=test_config.host, + port=test_config.port, + access_token=access_token + ) + result = (await client.query(query=CHECK_CLOUD_MODE_QUERY)).result_set assert result == [(True,)] -def test_jwt_auth_client_set_access_token(test_config: TestConfig, client_factory, call): - """Test setting JWT access token dynamically with both sync and async clients.""" +@pytest.mark.asyncio +async def test_jwt_auth_async_client_set_access_token(test_config: TestConfig): if not test_config.cloud: - pytest.skip("Skipping JWT test in non-Cloud mode") + pytest.skip('Skipping JWT test in non-Cloud mode') access_token = make_access_token() - client = client_factory(username=None, password="", access_token=access_token) + client = await create_async_client( + host=test_config.host, + port=test_config.port, + access_token=access_token, + ) access_token = make_access_token() client.set_access_token(access_token) - result = call(client.query, CHECK_CLOUD_MODE_QUERY).result_set + result = (await client.query(query=CHECK_CLOUD_MODE_QUERY)).result_set assert result == [(True,)] -def test_jwt_auth_client_config_errors(client_factory): - """Test JWT configuration validation catches invalid combinations.""" +@pytest.mark.asyncio +async def test_jwt_auth_async_client_config_errors(): with pytest.raises(ProgrammingError): - client_factory(username="bob", access_token="foobar") - + await create_async_client( + username='bob', + access_token='foobar' + ) with pytest.raises(ProgrammingError): - client_factory(username="bob", password="secret", access_token="foo") - + await create_async_client( + username='bob', + password='secret', + access_token='foo' + ) with pytest.raises(ProgrammingError): - client_factory(password="secret", access_token="foo") + await create_async_client( + password='secret', + access_token='foo' + ) -def test_jwt_auth_client_set_access_token_errors(test_config: TestConfig, client_factory): - """Test that JWT cannot be set when using username/password authentication.""" +@pytest.mark.asyncio +async def test_jwt_auth_async_client_set_access_token_errors(test_config: TestConfig): if not test_config.cloud: - pytest.skip("Skipping JWT test in non-Cloud mode") + pytest.skip('Skipping JWT test in non-Cloud mode') - client = client_factory( + client = await create_async_client( + host=test_config.host, + port=test_config.port, username=test_config.username, password=test_config.password, ) + # Can't use JWT with username/password access_token = make_access_token() with pytest.raises(ProgrammingError): client.set_access_token(access_token) + + +CHECK_CLOUD_MODE_QUERY = "SELECT value='1' FROM system.settings WHERE name='cloud_mode'" +JWT_SECRET_ENV_KEY = 'CLICKHOUSE_CONNECT_TEST_JWT_SECRET' + + +def make_access_token(): + secret = environ.get(JWT_SECRET_ENV_KEY) + if not secret: + raise ValueError(f'{JWT_SECRET_ENV_KEY} environment variable is not set') + return secret From 6af89ee270f8ad4a5f3103489199102ddf6836de Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 12 Feb 2026 09:46:40 -0800 Subject: [PATCH 34/40] 0.12.0rc1 release prep --- CHANGELOG.md | 2 ++ clickhouse_connect/__version__.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44667b63..4df443b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ instead of being passed as ClickHouse server settings. This is in conjunction wi The supported method of passing ClickHouse server settings is to prefix such arguments/query parameters with`ch_`. ## UNRELEASED + +## 0.12.0rc1, 2026-02-11 - Implement a native async client. Closes [#141](https://github.com/ClickHouse/clickhouse-connect/issues/141) ## 0.11.0, 2026-02-10 diff --git a/clickhouse_connect/__version__.py b/clickhouse_connect/__version__.py index ea402e72..9d35b3ed 100644 --- a/clickhouse_connect/__version__.py +++ b/clickhouse_connect/__version__.py @@ -1 +1 @@ -version = "0.11.0" +version = "0.12.0rc1" From 7a62f64761d0a1aa38742aa3d189ff32fde6176e Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 12 Feb 2026 10:19:04 -0800 Subject: [PATCH 35/40] fix assertion in racy test --- tests/integration_tests/test_streaming.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 654b83d4..a6f10f15 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -120,15 +120,11 @@ async def test_stream_failure_async(test_native_async_client): stream = await test_native_async_client.query_row_block_stream(query) - with pytest.raises(StreamFailureError) as excinfo: + with pytest.raises(StreamFailureError): async with stream: async for _ in stream: pass - error_msg = str(excinfo.value).lower() - # Race condition: may get actual ClickHouse error or generic connection closed - assert 'division by zero' in error_msg or 'connection closed' in error_msg - def test_raw_stream(param_client, call, consume_stream): """Test raw_stream for streaming response.""" From cb0346646df8f685285a0cf5cd7e245e1dd7f209 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 12 Feb 2026 10:29:02 -0800 Subject: [PATCH 36/40] clean up streaming source on exception --- clickhouse_connect/driver/aiohttp_client.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index 41a6ce70..196a3212 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -568,7 +568,11 @@ def parse_streaming(): return result # Run parser in executor (pulls from queue, decompresses & parses) - query_result = await loop.run_in_executor(None, parse_streaming) + try: + query_result = await loop.run_in_executor(None, parse_streaming) + except Exception: + await streaming_source.aclose() + raise query_result.summary = self._summary(response) # Attach streaming_source to query_result.source to ensure it gets closed From 4a2e74d1a3046a37aac2a94ab79d12d74109b8ec Mon Sep 17 00:00:00 2001 From: Joe S Date: Wed, 25 Mar 2026 14:50:01 -0700 Subject: [PATCH 37/40] linting --- clickhouse_connect/driver/aiohttp_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py index 288924f4..77ef2348 100644 --- a/clickhouse_connect/driver/aiohttp_client.py +++ b/clickhouse_connect/driver/aiohttp_client.py @@ -272,6 +272,7 @@ def __init__( autoconnect=False ) + # pylint: disable=attribute-defined-outside-init async def _initialize(self): """ Async equivalent of Client._init_common_settings. From dd983aacc057f39796febd8d7bb7804a811a494d Mon Sep 17 00:00:00 2001 From: Joe S Date: Wed, 25 Mar 2026 14:59:44 -0700 Subject: [PATCH 38/40] more linting --- tests/integration_tests/test_timezones.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/test_timezones.py b/tests/integration_tests/test_timezones.py index 3e409961..a4cfd2f7 100644 --- a/tests/integration_tests/test_timezones.py +++ b/tests/integration_tests/test_timezones.py @@ -199,7 +199,7 @@ def test_tz_mode(param_client: Client, call): assert row[0].microsecond == 123456 -def test_apply_server_timezone_setter_deprecated(param_client: Client, call): +def test_apply_server_timezone_setter_deprecated(param_client: Client): """Setting client.apply_server_timezone should emit a DeprecationWarning and update state.""" try: with warnings.catch_warnings(record=True) as w: @@ -224,7 +224,7 @@ def test_apply_server_timezone_setter_deprecated(param_client: Client, call): param_client.tz_source = "auto" -def test_apply_server_timezone_getter_deprecated(param_client: Client, call): +def test_apply_server_timezone_getter_deprecated(param_client: Client): """Reading client.apply_server_timezone should emit a DeprecationWarning.""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -233,13 +233,13 @@ def test_apply_server_timezone_getter_deprecated(param_client: Client, call): assert issubclass(w[0].category, DeprecationWarning) -def test_tz_source_setter_validates(param_client: Client, call): +def test_tz_source_setter_validates(param_client: Client): """Setting client.tz_source to an invalid value should raise ProgrammingError.""" with pytest.raises(ProgrammingError, match='tz_source must be'): param_client.tz_source = "serer" -def test_tz_source_setter_auto_restores_dst_safe(param_client: Client, call): +def test_tz_source_setter_auto_restores_dst_safe(param_client: Client): """Setting tz_source back to 'auto' should re-resolve based on server DST safety.""" original = param_client._apply_server_tz try: From 6885950ef3c6d821636b88514c11995f249e94e1 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 26 Mar 2026 15:25:19 -0700 Subject: [PATCH 39/40] deprecate legacy async client and replace with async native --- README.md | 10 +- clickhouse_connect/driver/__init__.py | 94 +- clickhouse_connect/driver/aiohttp_client.py | 1878 ------------- clickhouse_connect/driver/asyncclient.py | 2520 ++++++++++++------ examples/run_async.py | 58 +- tests/integration_tests/conftest.py | 12 +- tests/integration_tests/test_async_client.py | 302 --- 7 files changed, 1837 insertions(+), 3037 deletions(-) delete mode 100644 clickhouse_connect/driver/aiohttp_client.py delete mode 100644 tests/integration_tests/test_async_client.py diff --git a/README.md b/README.md index a5d77813..175b42d1 100644 --- a/README.md +++ b/README.md @@ -53,19 +53,15 @@ are not implemented. The dialect is best suited for SQLAlchemy Core usage and Su ### Asyncio Support -ClickHouse Connect provides native async support using aiohttp. For the best performance with async applications, +ClickHouse Connect provides native async support using aiohttp. To use the async client, install the optional async dependency: ``` pip install clickhouse-connect[async] ``` -See the [run_async example](./examples/run_async.py) for more details. - -The current `AsyncClient` is a thread-pool executor wrapper around the synchronous client and is deprecated. -In 1.0.0 it will be replaced by a fully native async implementation. The API surface is the same, -with one difference: you will no longer be able to create a sync client first and pass it to the -`AsyncClient` constructor. Instead, use `clickhouse_connect.get_async_client()` directly. +Then create a client with `clickhouse_connect.get_async_client()`. See the +[run_async example](./examples/run_async.py) for more details. ### Complete Documentation diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index e72d8055..79caa705 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -1,18 +1,34 @@ -import asyncio -from concurrent.futures import ThreadPoolExecutor +from __future__ import annotations + from inspect import signature -from typing import Optional, Union, Dict, Any, Tuple +from typing import Optional, Union, Dict, Any, Tuple, TYPE_CHECKING from urllib.parse import urlparse, parse_qs import clickhouse_connect.driver.ctypes # noqa: F401 -- side-effect import from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.exceptions import ProgrammingError from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.asyncclient import AsyncClient, DefaultThreadPoolExecutor, NEW_THREAD_POOL_EXECUTOR + +if TYPE_CHECKING: + from clickhouse_connect.driver.asyncclient import AsyncClient __all__ = ['Client', 'AsyncClient', 'create_client', 'create_async_client'] +def __getattr__(name): + if name == "AsyncClient": + try: + from clickhouse_connect.driver.asyncclient import AsyncClient # pylint: disable=import-outside-toplevel + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): # pylint: disable=no-member + raise ImportError( + "Async support requires aiohttp. Install with: pip install clickhouse-connect[async]" + ) from ex + raise + return AsyncClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + def default_port(interface: str, secure: bool) -> int: """Get default port for the given interface.""" if interface.startswith("http"): @@ -182,14 +198,14 @@ async def create_async_client(*, dsn: Optional[str] = None, settings: Optional[Dict[str, Any]] = None, generic_args: Optional[Dict[str, Any]] = None, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, connector_limit: int = 100, connector_limit_per_host: int = 20, keepalive_timeout: float = 30.0, **kwargs) -> AsyncClient: """ The preferred method to get an async ClickHouse Connect Client instance. + Requires the async extra: pip install clickhouse-connect[async] + For sync version, see create_client. Unlike sync version, the 'autogenerate_session_id' setting by default is False. @@ -209,12 +225,10 @@ async def create_async_client(*, :param settings: ClickHouse server settings to be used with the session/every request :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally - :param executor_threads: (LEGACY) 'max_worker' threads used by the client ThreadPoolExecutor. - :param executor: (LEGACY) Optional `ThreadPoolExecutor` to use for async operations. - :param connector_limit: Maximum number of allowable connections to the server (native async) - :param connector_limit_per_host: Maximum number of connections per host (native async) - :param keepalive_timeout: Time limit on idle keepalive connections (native async) - :param kwargs -- Recognized keyword arguments (used by the HTTP client), see below + :param connector_limit: Maximum number of allowable connections to the server + :param connector_limit_per_host: Maximum number of connections per host + :param keepalive_timeout: Time limit on idle keepalive connections + :param kwargs -- Recognized keyword arguments (used by the async HTTP client), see below :param compress: Enable compression for ClickHouse HTTP inserts and query results. True will select the preferred compression method (lz4). A str of 'lz4', 'zstd', 'brotli', or 'gzip' can be used to use a specific compression type @@ -223,7 +237,6 @@ async def create_async_client(*, :param send_receive_timeout: Read timeout in seconds for http connection :param client_name: client_name prepended to the HTTP User Agent header. Set this to track client queries in the ClickHouse system.query_log. - :param send_progress: Deprecated, has no effect. Previous functionality is now automatically determined :param verify: Verify the server certificate in secure/https mode :param ca_cert: If verify is True, the file path to Certificate Authority root to validate ClickHouse server certificate, in .pem format. Ignored if verify is False. This is not necessary if the ClickHouse server @@ -235,8 +248,6 @@ async def create_async_client(*, is not included the Client Certificate key file :param session_id ClickHouse session id. If not specified and the common setting 'autogenerate_session_id' is True, the client will generate a UUID1 session id - :param pool_mgr Optional urllib3 PoolManager for this client. Useful for creating separate connection - pools for multiple client endpoints for applications with many clients :param http_proxy http proxy address. Equivalent to setting the HTTP_PROXY environment variable :param https_proxy https proxy address. Equivalent to setting the HTTPS_PROXY environment variable :param server_host_name This is the server host name that will be checked against a TLS certificate for @@ -257,32 +268,47 @@ async def create_async_client(*, limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect AsyncClient instance """ + try: + from clickhouse_connect.driver.asyncclient import AsyncClient as _AsyncClient # pylint: disable=import-outside-toplevel + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): # pylint: disable=no-member + raise ImportError( + "Async support requires aiohttp. Install with: pip install clickhouse-connect[async]" + ) from ex + raise + + if "pool_mgr" in kwargs: + raise ProgrammingError( + "pool_mgr is not supported by the async client. " + "Use connector_limit and connector_limit_per_host to configure connection pooling." + ) + host, username, password, port, database, interface = _parse_connection_params( host, username, password, port, database, interface, secure, dsn, kwargs ) _validate_access_token(access_token, username, password) - if executor_threads != 0 or executor is not NEW_THREAD_POOL_EXECUTOR: - # LEGACY PATH: User explicitly requested executor-based client - def _create_client(): - if 'autogenerate_session_id' not in kwargs: - kwargs['autogenerate_session_id'] = False - return create_client(host=host, username=username, password=password, database=database, interface=interface, - port=port, secure=secure, dsn=None, settings=settings, generic_args=generic_args, **kwargs) - - loop = asyncio.get_running_loop() - _client = await loop.run_in_executor(None, _create_client) - return AsyncClient(client=_client, executor_threads=executor_threads, executor=executor) + settings = settings or {} + if generic_args: + client_params = signature(_AsyncClient).parameters + for name, value in generic_args.items(): + if name in client_params: + kwargs[name] = value + elif name == "compression": + if "compress" not in kwargs: + kwargs["compress"] = value + else: + if name.startswith("ch_"): + name = name[3:] + settings[name] = value - # NATIVE PATH: Default to true async client - # Set autogenerate_session_id to False by default if "autogenerate_session_id" not in kwargs: kwargs["autogenerate_session_id"] = False - client = AsyncClient(host=host, username=username, password=password, access_token=access_token, - database=database, interface=interface, - port=port, secure=secure, dsn=None, settings=settings, generic_args=generic_args, - connector_limit=connector_limit, connector_limit_per_host=connector_limit_per_host, - keepalive_timeout=keepalive_timeout, **kwargs) - await client._initialize() # pylint: disable=protected-access + client = _AsyncClient(interface=interface, host=host, port=port, username=username, password=password, + database=database, access_token=access_token, + settings=settings, + connector_limit=connector_limit, connector_limit_per_host=connector_limit_per_host, + keepalive_timeout=keepalive_timeout, **kwargs) + await client._initialize() # pylint: disable=protected-access return client diff --git a/clickhouse_connect/driver/aiohttp_client.py b/clickhouse_connect/driver/aiohttp_client.py deleted file mode 100644 index 77ef2348..00000000 --- a/clickhouse_connect/driver/aiohttp_client.py +++ /dev/null @@ -1,1878 +0,0 @@ -# pylint: disable=too-many-lines,duplicate-code,import-error - -import asyncio -import gzip -import io -import json -import logging -import re -import ssl -import sys -import time -import uuid -import pytz -import zlib -from base64 import b64encode -from datetime import tzinfo -from importlib import import_module -from importlib.metadata import version as dist_version -from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Union - -import aiohttp -import lz4.frame -import zstandard - -from clickhouse_connect import common -from clickhouse_connect.datatypes import dynamic as dynamic_module -from clickhouse_connect.datatypes.base import ClickHouseType -from clickhouse_connect.datatypes.registry import get_from_name -from clickhouse_connect.driver import httputil, tzutil -from clickhouse_connect.driver.binding import bind_query, quote_identifier -from clickhouse_connect.driver.client import Client, _apply_arrow_tz_policy -from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy -from clickhouse_connect.driver.compression import available_compression -from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD -from clickhouse_connect.driver.ctypes import RespBuffCls -from clickhouse_connect.driver.exceptions import DatabaseError, DataError, OperationalError, ProgrammingError -from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.insert import InsertContext -from clickhouse_connect.driver.models import ColumnDef, SettingDef -from clickhouse_connect.driver.options import IS_PANDAS_2, arrow, check_arrow, check_numpy, check_pandas, check_polars, pd, pl -from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource, arrow_buffer -from clickhouse_connect.driver.summary import QuerySummary -from clickhouse_connect.driver.asyncqueue import AsyncSyncQueue, EOF_SENTINEL -from clickhouse_connect.driver.streaming import StreamingInsertSource -from clickhouse_connect.driver.transform import NativeTransform -from clickhouse_connect.driver.streaming import StreamingResponseSource, StreamingFileAdapter - -logger = logging.getLogger(__name__) -columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) -ex_header = "X-ClickHouse-Exception-Code" -ex_tag_header = "X-ClickHouse-Exception-Tag" - -if "br" in available_compression: - import brotli -else: - brotli = None - -def decompress_response(data: bytes, encoding: Optional[str]) -> bytes: - """Decompress response data based on Content-Encoding header.""" - - if not encoding or encoding == "identity": - return data - - if encoding == "lz4": - lz4_decom = lz4.frame.LZ4FrameDecompressor() - return lz4_decom.decompress(data, len(data)) - if encoding == "zstd": - zstd_decom = zstandard.ZstdDecompressor() - return zstd_decom.stream_reader(io.BytesIO(data)).read() - if encoding == "br": - if brotli is not None: - return brotli.decompress(data) - raise OperationalError("Brotli compression requested but not installed.") - if encoding == "gzip": - return gzip.decompress(data) - if encoding == "deflate": - return zlib.decompress(data) - raise OperationalError(f"Unsupported compression type: '{encoding}'. Supported compression: {', '.join(available_compression)}") - - -class BytesSource: - """Wrapper to make bytes compatible with ResponseBuffer expectations.""" - - def __init__(self, data: bytes): - self.data = data - self.gen = self._make_generator() - - def _make_generator(self): - yield self.data - - def close(self): - """No-op close method for compatibility.""" - -# pylint: disable=invalid-overridden-method, too-many-instance-attributes, too-many-public-methods, broad-exception-caught -class AiohttpAsyncClient(Client): - valid_transport_settings = {"database", "buffer_size", "session_id", - "compress", "decompress", "session_timeout", - "session_check", "query_id", "quota_key", - "wait_end_of_query", "client_protocol_version", - "role"} - optional_transport_settings = {"send_progress_in_http_headers", - "http_headers_progress_interval_ms", - "enable_http_compression"} - - # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements - def __init__( - self, - interface: str, - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - database: Optional[str] = None, - access_token: Optional[str] = None, - compress: Union[bool, str] = True, - connect_timeout: int = 10, - send_receive_timeout: int = 300, - client_name: Optional[str] = None, - verify: Union[bool, str] = True, - ca_cert: Optional[str] = None, - client_cert: Optional[str] = None, - client_cert_key: Optional[str] = None, - http_proxy: Optional[str] = None, - https_proxy: Optional[str] = None, - server_host_name: Optional[str] = None, - tls_mode: Optional[str] = None, - proxy_path: str = "", - connector_limit: int = 100, - connector_limit_per_host: int = 20, - keepalive_timeout: float = 30.0, - session_id: Optional[str] = None, - settings: Optional[Dict[str, Any]] = None, - query_limit: int = 0, - query_retries: int = 2, - tz_source: Optional[TzSource] = None, - tz_mode: Optional[TzMode] = None, - apply_server_timezone: Optional[Union[str, bool]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - show_clickhouse_errors: Optional[bool] = None, - autogenerate_session_id: Optional[bool] = None, - autogenerate_query_id: Optional[bool] = None, - form_encode_query_params: bool = False, - **kwargs, - ): - """ - Async HTTP Client using aiohttp. Initialization is handled via _initialize(). - """ - proxy_path = proxy_path.lstrip("/") - if proxy_path: - proxy_path = "/" + proxy_path - self.uri = f"{interface}://{host}:{port}{proxy_path}" - self.url = self.uri - self.form_encode_query_params = form_encode_query_params - self._rename_response_column = kwargs.get("rename_response_column") - self._initial_settings = settings - self.headers = {} - - if interface == "https": - if isinstance(verify, str) and verify.lower() == "proxy": - verify = True - tls_mode = tls_mode or "proxy" - - # Priority: access_token > mutual TLS > basic auth - if client_cert and (tls_mode is None or tls_mode == "mutual"): - if not username: - raise ProgrammingError("username parameter is required for Mutual TLS authentication") - self.headers["X-ClickHouse-User"] = username - self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" - elif access_token: - self.headers["Authorization"] = f"Bearer {access_token}" - elif username and (not client_cert or tls_mode in ("strict", "proxy")): - credentials = b64encode(f"{username}:{password}".encode()).decode() - self.headers["Authorization"] = f"Basic {credentials}" - - self.headers["User-Agent"] = common.build_client_name(client_name) - # Prevent aiohttp from automatically requesting compressed responses - # We'll manually set Accept-Encoding when compression is desired - self.headers["Accept-Encoding"] = "identity" - self._send_receive_timeout = send_receive_timeout - - connect_timeout_val = float(connect_timeout) if connect_timeout is not None else None - send_receive_timeout_val = float(send_receive_timeout) if send_receive_timeout is not None else None - - self._timeout = aiohttp.ClientTimeout( - total=None, - connect=connect_timeout_val, - sock_connect=connect_timeout_val, - sock_read=send_receive_timeout_val, - ) - connector_limit_per_host = min(connector_limit_per_host, connector_limit) - - proxy_url = None - if http_proxy: - if not http_proxy.startswith("http://") and not http_proxy.startswith("https://"): - proxy_url = f"http://{http_proxy}" - else: - proxy_url = http_proxy - elif https_proxy: - if not https_proxy.startswith("http://") and not https_proxy.startswith("https://"): - proxy_url = f"http://{https_proxy}" - else: - proxy_url = https_proxy - else: - scheme = "https" if self.url.startswith("https://") else "http" - env_proxy = httputil.check_env_proxy(scheme, host, port) - if env_proxy: - if not env_proxy.startswith("http://") and not env_proxy.startswith("https://"): - proxy_url = f"http://{env_proxy}" - else: - proxy_url = env_proxy - - ssl_context = None - if interface == "https": - ssl_context = ssl.create_default_context() - ssl_verify = verify if isinstance(verify, bool) else coerce_bool(verify) - if not ssl_verify: - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - elif ca_cert: - ssl_context.load_verify_locations(ca_cert) - if client_cert: - ssl_context.load_cert_chain(client_cert, client_cert_key) - - self._ssl_context = ssl_context - self._proxy_url = proxy_url - self._connector_kwargs = { - "limit": connector_limit, - "limit_per_host": connector_limit_per_host, - "keepalive_timeout": keepalive_timeout, - "force_close": False, - "ssl": ssl_context, - } - # enable_cleanup_closed is only needed for Python < 3.12.7 or == 3.13.0 - # The underlying SSL connection leak was fixed in 3.12.7 and 3.13.1+ - # https://github.com/python/cpython/pull/118960 - if sys.version_info < (3, 12, 7) or sys.version_info[:3] == (3, 13, 0): - self._connector_kwargs["enable_cleanup_closed"] = True - - self._session = None - self._read_format = "Native" - self._write_format = "Native" - self._transform = NativeTransform() - self._client_settings = {} - self._initialized = False - self._reported_libs = set() - self._last_pool_reset = None - self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") - - # Store aiohttp-specific params for deferred initialization - self._compress_param = compress - self._session_id_param = session_id - self._autogenerate_session_id_param = autogenerate_session_id - self._autogenerate_query_id = ( - common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id - ) - self._active_session = None - self._send_progress = None - self._progress_interval = None - - # Call parent init with autoconnect=False to set up config without blocking I/O - super().__init__( - database=database, - query_limit=query_limit, - uri=self.uri, - query_retries=query_retries, - server_host_name=server_host_name, - tz_source=tz_source, - tz_mode=tz_mode, - utc_tz_aware=utc_tz_aware, - apply_server_timezone=apply_server_timezone, - show_clickhouse_errors=show_clickhouse_errors, - autoconnect=False - ) - - # pylint: disable=attribute-defined-outside-init - async def _initialize(self): - """ - Async equivalent of Client._init_common_settings. - Fetches server version, timezone, and settings. - """ - if not self._session: - connector = aiohttp.TCPConnector(**self._connector_kwargs) - self._session = aiohttp.ClientSession( - connector=connector, - timeout=self._timeout, - headers=self.headers, - trust_env=False, - auto_decompress=False, - skip_auto_headers={"Accept-Encoding"}, - ) - - if self._initialized: - return - - try: - tz_source = self._deferred_tz_source - - self.server_tz, self._dst_safe = pytz.UTC, True - row = await self.command("SELECT version(), timezone()", use_database=False) - self.server_version, server_tz_str = tuple(row) - try: - server_tz = pytz.timezone(server_tz_str) - server_tz, self._dst_safe = tzutil.normalize_timezone(server_tz) - if tz_source == "auto": - self._apply_server_tz = self._dst_safe - else: - self._apply_server_tz = tz_source == "server" - self.server_tz = server_tz - except pytz.exceptions.UnknownTimeZoneError: - logger.warning("Warning, server is using an unrecognized timezone %s, will use UTC default", server_tz_str) - - if not self._apply_server_tz and not tzutil.local_tz_dst_safe: - logger.warning("local timezone %s may return unexpected times due to Daylight Savings Time", tzutil.local_tz.tzname(None)) - - readonly = "readonly" - if not self.min_version("19.17"): - readonly = common.get_setting("readonly") - - server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") - self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} - - if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): - try: - test_data = await self.raw_query( - "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} - ) - if test_data[8:16] == b"\x01\x01\x05check": - self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD - except Exception: - pass - - cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") - if cancel_setting.is_writable and not cancel_setting.is_set and \ - "cancel_http_readonly_queries_on_client_close" not in (self._initial_settings or {}): - self._client_settings["cancel_http_readonly_queries_on_client_close"] = "1" - - if self._initial_settings: - for key, value in self._initial_settings.items(): - self.set_client_setting(key, value) - - compress = self._compress_param - if coerce_bool(compress): - compression = ",".join(available_compression) - self.write_compression = available_compression[0] - elif compress and compress not in ("False", "false", "0"): - if compress not in available_compression: - raise ProgrammingError(f"Unsupported compression method {compress}") - compression = compress - self.write_compression = compress - else: - compression = None - - comp_setting = self._setting_status("enable_http_compression") - self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable # pylint: disable=attribute-defined-outside-init - if comp_setting.is_set or comp_setting.is_writable: - self.compression = compression - - session_id = self._session_id_param - autogenerate_session_id = self._autogenerate_session_id_param - - if autogenerate_session_id is None: - autogenerate_session_id = common.get_setting("autogenerate_session_id") - - if session_id: - self.set_client_setting("session_id", session_id) - elif self.get_client_setting("session_id"): - pass - elif autogenerate_session_id: - self.set_client_setting("session_id", str(uuid.uuid4())) - - send_setting = self._setting_status("send_progress_in_http_headers") - self._send_progress = not send_setting.is_set and send_setting.is_writable - if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: - self._progress_interval = str(min(120000, max(10000, (self._send_receive_timeout - 5) * 1000))) - - if self._setting_status("date_time_input_format").is_writable: - self.set_client_setting("date_time_input_format", "best_effort") - if ( - self._setting_status("allow_experimental_json_type").is_set - and self._setting_status("cast_string_to_dynamic_use_inference").is_writable - ): - self.set_client_setting("cast_string_to_dynamic_use_inference", "1") - if self.min_version("24.8") and not self.min_version("24.10"): - dynamic_module.json_serialization_format = 0 - - self._initialized = True - except Exception: - if self._session and not self._session.closed: - await self._session.close() - self._session = None - raise - - async def __aenter__(self): - """Async context manager entry.""" - if not self._initialized: - await self._initialize() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Async context manager exit.""" - await self.close() - return False - - async def close(self): # type: ignore[override] - if self._session: - await self._session.close() - - async def close_connections(self): # type: ignore[override] - """Close all pooled connections and recreate session""" - if self._session: - await self._session.close() - connector = aiohttp.TCPConnector(**self._connector_kwargs) - self._session = aiohttp.ClientSession( - connector=connector, - timeout=self._timeout, - headers=self.headers, - trust_env=False, - auto_decompress=False, - skip_auto_headers={"Accept-Encoding"}, - ) - - def set_client_setting(self, key, value): - str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) - if str_value is not None: - self._client_settings[key] = str_value - - def get_client_setting(self, key) -> Optional[str]: - return self._client_settings.get(key) - - def set_access_token(self, access_token: str): - auth_header = self.headers.get("Authorization") - if auth_header and not auth_header.startswith("Bearer"): - raise ProgrammingError("Cannot set access token when a different auth type is used") - self.headers["Authorization"] = f"Bearer {access_token}" - if self._session: - self._session.headers["Authorization"] = f"Bearer {access_token}" - - def _prep_query(self, context: QueryContext): - final_query = super()._prep_query(context) - if context.is_insert: - return final_query - fmt = f"\n FORMAT {self._read_format}" - if isinstance(final_query, bytes): - return final_query + fmt.encode() - return final_query + fmt - - async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] - headers = {} - params = {} - if self.database: - params["database"] = self.database - if self.protocol_version: - params["client_protocol_version"] = self.protocol_version - context.block_info = True - params.update(self._validate_settings(context.settings)) - context.rename_response_column = self._rename_response_column - - if not context.is_insert and columns_only_re.search(context.uncommented_query): - fmt_json_query = f"{context.final_query}\n FORMAT JSON" - fields = {"query": fmt_json_query} - fields.update(context.bind_params) - - if self.form_encode_query_params: - files = {} - if context.external_data: - params.update(context.external_data.query_params) - files.update(context.external_data.form_data) - - for k, v in fields.items(): - files[k] = (None, str(v)) - response = await self._raw_request(None, params, headers, files=files, retries=self.query_retries) - elif context.external_data: - params.update(context.bind_params) - params.update(context.external_data.query_params) - params["query"] = fmt_json_query - response = await self._raw_request(None, params, headers, files=context.external_data.form_data, retries=self.query_retries) - else: - params.update(context.bind_params) - response = await self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) - - body = await response.read() - encoding = response.headers.get("Content-Encoding") - loop = asyncio.get_running_loop() - - def decompress_and_parse_json(): - if encoding: - decompressed_body = decompress_response(body, encoding) - else: - decompressed_body = body - return json.loads(decompressed_body) - - # Offload to executor - json_result = await loop.run_in_executor(None, decompress_and_parse_json) - - names: List[str] = [] - types: List[ClickHouseType] = [] - renamer = context.column_renamer - for col in json_result["meta"]: - name = col["name"] - if renamer is not None: - try: - name = renamer(name) - except Exception as e: - logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) - names.append(name) - types.append(get_from_name(col["type"])) - return QueryResult([], None, tuple(names), tuple(types)) - - if self.compression: - headers["Accept-Encoding"] = self.compression - if self._send_comp_setting: - params["enable_http_compression"] = "1" - - final_query = self._prep_query(context) - - files = None - data = None - - if self.form_encode_query_params: - fields = {"query": final_query} - fields.update(context.bind_params) - - files = {} - if context.external_data: - params.update(context.external_data.query_params) - files.update(context.external_data.form_data) - - for k, v in fields.items(): - files[k] = (None, str(v)) - elif context.external_data: - params.update(context.bind_params) - params.update(context.external_data.query_params) - params["query"] = final_query - files = context.external_data.form_data - else: - params.update(context.bind_params) - data = final_query - headers["Content-Type"] = "text/plain; charset=utf-8" - - headers = dict_copy(headers, context.transport_settings) - - response = await self._raw_request(data, params, headers, files=files, - server_wait=not context.streaming, - stream=True, retries=self.query_retries) - encoding = response.headers.get("Content-Encoding") - tz_header = response.headers.get("X-ClickHouse-Timezone") - exception_tag = response.headers.get(ex_tag_header) - - loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) - await streaming_source.start_producer(loop) - - def parse_streaming(): - """Parse response from streaming queue (runs in executor).""" - # Wrap streaming source with ResponseBuffer. The streaming source provides a - # .gen property that yields decompressed chunks. - byte_source = RespBuffCls(streaming_source) - context.set_response_tz(self._check_tz_change(tz_header)) - result = self._transform.parse_response(byte_source, context) - - # For Pandas/Numpy, we must materialize in the executor because the resulting objects - # (DataFrame, Array) are fully in-memory structures. - # For standard queries, we return a lazy QueryResult. Accessing .result_set on the event loop - # will raise a ProgrammingError (deadlock check), encouraging usage of .rows_stream. - if not context.streaming: - if context.as_pandas and hasattr(result, 'df_result'): - _ = result.df_result - elif context.use_numpy and hasattr(result, 'np_result'): - _ = result.np_result - elif isinstance(result, QueryResult): - _ = result.result_set - - return result - - # Run parser in executor (pulls from queue, decompresses & parses) - try: - query_result = await loop.run_in_executor(None, parse_streaming) - except Exception: - await streaming_source.aclose() - raise - query_result.summary = self._summary(response) - - # Attach streaming_source to query_result.source to ensure it gets closed - # when the query result is closed (e.g. by StreamContext.__exit__) - query_result.source = streaming_source - - return query_result - - - # pylint: disable=arguments-differ - async def query( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> QueryResult: - """ - Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For - parameters, see the create_query_context method - :return: QueryResult -- data and metadata from response - """ - if query and query.lower().strip().startswith("select __connect_version__"): - return QueryResult( - [[f"ClickHouse Connect v.{common.version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) - ) - if not context: - context = self.create_query_context( - query=query, - parameters=parameters, - settings=settings, - query_formats=query_formats, - column_formats=column_formats, - encoding=encoding, - use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, - max_str_len=max_str_len, - query_tz=query_tz, - column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, - transport_settings=transport_settings, - tz_mode=tz_mode, - ) - - if context.is_command: - response = await self.command( - query, - parameters=context.parameters, - settings=context.settings, - external_data=context.external_data, - transport_settings=context.transport_settings, - ) - if isinstance(response, QuerySummary): - return response.as_query_result() - return QueryResult([response] if isinstance(response, list) else [[response]]) - - return await self._query_with_context(context) - - async def query_column_block_stream( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: - """ - Async version of query_column_block_stream. - Returns a StreamContext that yields column-oriented blocks. - """ - return (await self._context_query(locals(), use_numpy=False, streaming=True)).column_block_stream - - async def query_row_block_stream( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: - """ - Async version of query_row_block_stream. - Returns a StreamContext that yields row-oriented blocks. - """ - return (await self._context_query(locals(), use_numpy=False, streaming=True)).row_block_stream - - async def query_rows_stream( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: - """ - Async version of query_rows_stream. - Returns a StreamContext that yields individual rows. - """ - return (await self._context_query(locals(), use_numpy=False, streaming=True)).rows_stream - - # pylint: disable=unused-argument - async def query_np( - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ): - check_numpy() - self._add_integration_tag("numpy") - return (await self._context_query(locals(), use_numpy=True)).np_result - - async def query_np_stream( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> StreamContext: - check_numpy() - self._add_integration_tag("numpy") - return (await self._context_query(locals(), use_numpy=True, streaming=True)).np_stream - - async def query_df( - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ): - check_pandas() - self._add_integration_tag("pandas") - return (await self._context_query(locals(), use_numpy=True, as_pandas=True)).df_result - - async def query_df_stream( # type: ignore[override] - self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: - check_pandas() - self._add_integration_tag("pandas") - return (await self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True)).df_stream - - async def _context_query(self, lcls: dict, **overrides): # type: ignore[override] - """ - Helper method to create query context and execute query. - Matches sync client pattern for consistency. - """ - kwargs = lcls.copy() - kwargs.pop("self") - kwargs.update(overrides) - return await self._query_with_context(self.create_query_context(**kwargs)) - - async def command( # type: ignore[override] - self, - cmd, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Optional[Union[str, bytes]] = None, - settings: Optional[Dict] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> Union[str, int, Sequence[str], QuerySummary]: - """ - See BaseClient doc_string for this method - """ - cmd, bind_params = bind_query(cmd, parameters, self.server_tz) - params = bind_params.copy() - headers = {} - payload = None - files = None - - if external_data: - if data: - raise ProgrammingError("Cannot combine command data with external data") from None - files = external_data.form_data - params.update(external_data.query_params) - elif isinstance(data, str): - headers["Content-Type"] = "text/plain; charset=utf-8" - payload = data.encode() - elif isinstance(data, bytes): - headers["Content-Type"] = "application/octet-stream" - payload = data - - if payload is None and not cmd: - raise ProgrammingError("Command sent without query or recognized data") from None - - if payload or files: - params["query"] = cmd - else: - payload = cmd - - if use_database and self.database: - params["database"] = self.database - params.update(self._validate_settings(settings or {})) - headers = dict_copy(headers, transport_settings) - method = "POST" if payload or files else "GET" - response = await self._raw_request(payload, params, headers, files=files, method=method, server_wait=False) - body = await response.read() - encoding = response.headers.get("Content-Encoding") - summary = self._summary(response) - - if not body: - return QuerySummary(summary) - - loop = asyncio.get_running_loop() - - def decompress_and_decode(): - if encoding: - decompressed_body = decompress_response(body, encoding) - else: - decompressed_body = body - try: - result = decompressed_body.decode()[:-1].split("\t") - if len(result) == 1: - try: - return int(result[0]) - except ValueError: - return result[0] - return result - except UnicodeDecodeError: - return str(decompressed_body) - - return await loop.run_in_executor(None, decompress_and_decode) - - async def ping(self) -> bool: # type: ignore[override] - try: - url = f"{self.url}/ping" - timeout = aiohttp.ClientTimeout(total=3.0) - async with self._session.get(url, timeout=timeout) as response: - return 200 <= response.status < 300 - except (aiohttp.ClientError, asyncio.TimeoutError): - logger.debug("ping failed", exc_info=True) - return False - - async def raw_query( # type: ignore[override] - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: Optional[str] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> bytes: - """ - See BaseClient doc_string for this method - """ - body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) - if transport_settings: - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request(body, params, headers=headers, files=files, retries=self.query_retries) - response_data = await response.read() - encoding = response.headers.get("Content-Encoding") - - if encoding: - loop = asyncio.get_running_loop() - response_data = await loop.run_in_executor(None, decompress_response, response_data, encoding) - - return response_data - - async def raw_stream( # type: ignore[override] - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: Optional[str] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> StreamContext: - - body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) - if transport_settings: - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request( - body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries - ) - - async def byte_iterator(): - async for chunk in response.content.iter_any(): - yield chunk - - return StreamContext(response, byte_iterator()) - - def _prep_raw_query(self, query, parameters, settings, fmt, use_database, external_data): - """ - Prepare raw query for execution. - - Note: Unlike sync client which returns (body, params, fields), this async version - returns (body, params, headers, files) because aiohttp requires headers to be - configured before the request() call, while urllib3 can add them during request. - """ - if fmt: - query += f"\n FORMAT {fmt}" - - final_query, bind_params = bind_query(query, parameters, self.server_tz) - params = self._validate_settings(settings or {}) - if use_database and self.database: - params["database"] = self.database - - headers = {} - files = None - body = None - - if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): - raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") - - if self.form_encode_query_params: - files = {} - files["query"] = (None, final_query if isinstance(final_query, str) else final_query.decode()) - for k, v in bind_params.items(): - files[k] = (None, str(v)) - - if external_data: - params.update(external_data.query_params) - files.update(external_data.form_data) - - body = None - elif external_data: - params.update(bind_params) - params["query"] = final_query - params.update(external_data.query_params) - files = external_data.form_data - body = None - else: - params.update(bind_params) - body = final_query.encode() if isinstance(final_query, str) else final_query - - return body, params, headers, files - - async def insert( # type: ignore[override] - self, - table: Optional[str] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - column_names: Union[str, Iterable[str]] = "*", - database: Optional[str] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - context: Optional[InsertContext] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: - """ - Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments - other than data are ignored - :param table: Target table - :param data: Sequence of sequences of Python data - :param column_names: Ordered list of column names or '*' if column types should be retrieved from the - ClickHouse table definition - :param database: Target database -- will use client default database if not specified. - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param column_oriented: If true the data is already "pivoted" in column form - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - if (context is None or context.empty) and data is None: - raise ProgrammingError("No data specified for insert") from None - if context is None: - context = await self.create_insert_context( - table, - column_names, - database, - column_types, - column_type_names, - column_oriented, - settings, - transport_settings=transport_settings, - ) - if data is not None: - if not context.empty: - raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None - context.data = data - return await self.data_insert(context) - - async def query_arrow( - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ): - """ - Query method using the ClickHouse Arrow format to return a PyArrow table - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data: ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: PyArrow.Table - """ - check_arrow() - self._add_integration_tag("arrow") - settings = self._update_arrow_settings(settings, use_strings) - - body, params, headers, files = self._prep_raw_query( - query, parameters, settings, fmt="ArrowStream", - use_database=True, external_data=external_data - ) - if transport_settings: - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request( - body, params, headers=headers, files=files, - stream=True, server_wait=False, retries=self.query_retries - ) - encoding = response.headers.get("Content-Encoding") - exception_tag = response.headers.get(ex_tag_header) - - loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) - await streaming_source.start_producer(loop) - - def parse_arrow_stream(): - file_adapter = StreamingFileAdapter(streaming_source) - reader = arrow.ipc.open_stream(file_adapter) - table = reader.read_all() - return _apply_arrow_tz_policy(table, self.tz_mode) - - try: - return await loop.run_in_executor(None, parse_arrow_stream) - finally: - await streaming_source.aclose() - - async def query_arrow_stream( # type: ignore[override] - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> StreamContext: - """ - Query method that returns the results as a stream of Arrow record batches. - - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data: ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: StreamContext that yields PyArrow RecordBatch objects asynchronously - """ - check_arrow() - self._add_integration_tag("arrow") - settings = self._update_arrow_settings(settings, use_strings) - - body, params, headers, files = self._prep_raw_query( - query, parameters, settings, fmt="ArrowStream", - use_database=True, external_data=external_data - ) - if transport_settings: - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request( - body, params, headers=headers, files=files, - stream=True, server_wait=False, retries=self.query_retries - ) - encoding = response.headers.get("Content-Encoding") - exception_tag = response.headers.get(ex_tag_header) - - loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) - await streaming_source.start_producer(loop) - - queue = AsyncSyncQueue(maxsize=10) - - class _ArrowStreamSource: - def __init__(self, source, q): - self._source = source - self._queue = q - - async def aclose(self): - self._queue.shutdown() - await self._source.aclose() - - def close(self): - self._queue.shutdown() - self._source.close() - - def parse_arrow_streaming(): - """Parse Arrow stream incrementally in executor (off event loop).""" - try: - file_adapter = StreamingFileAdapter(streaming_source) - reader = arrow.ipc.open_stream(file_adapter) - - for batch in reader: - try: - batch = _apply_arrow_tz_policy(batch, self.tz_mode) - queue.sync_q.put(batch) - except RuntimeError: - return - - try: - queue.sync_q.put(EOF_SENTINEL) - except RuntimeError: - return - except Exception as e: - try: - queue.sync_q.put(e) - except Exception: - pass - finally: - queue.shutdown() - - loop.run_in_executor(None, parse_arrow_streaming) - - async def arrow_batch_generator(): - """Async generator that yields record batches without blocking event loop.""" - while True: - item = await queue.async_q.get() - if item is EOF_SENTINEL: - break - if isinstance(item, Exception): - raise item - yield item - - return StreamContext(_ArrowStreamSource(streaming_source, queue), arrow_batch_generator()) - - async def query_df_arrow( - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas", - ) -> Union["pd.DataFrame", "pl.DataFrame"]: - """ - Query method using the ClickHouse Arrow format to return a DataFrame - with PyArrow dtype backend. This provides better performance and memory efficiency - compared to the standard query_df method, though fewer output formatting options. - - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data: ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: DataFrame (pandas or polars based on dataframe_library parameter) - """ - check_arrow() - - if dataframe_library == "pandas": - check_pandas() - self._add_integration_tag("pandas") - if not IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - - def converter(table: "arrow.Table") -> "pd.DataFrame": - table = _apply_arrow_tz_policy(table, self.tz_mode) - return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) - - elif dataframe_library == "polars": - check_polars() - self._add_integration_tag("polars") - - def converter(table: "arrow.Table") -> "pl.DataFrame": - table = _apply_arrow_tz_policy(table, self.tz_mode) - return pl.from_arrow(table) - - else: - raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") - - arrow_table = await self.query_arrow( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - ) - - return converter(arrow_table) - - async def query_df_arrow_stream( # type: ignore[override] - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas", - ) -> StreamContext: - """ - Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. - Each DataFrame represents a record batch from the ClickHouse response. - - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data: ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: StreamContext that yields DataFrames asynchronously (pandas or polars based on dataframe_library parameter) - """ - check_arrow() - if dataframe_library == "pandas": - check_pandas() - self._add_integration_tag("pandas") - if not IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - - def converter(table: "arrow.Table") -> "pd.DataFrame": - table = _apply_arrow_tz_policy(table, self.tz_mode) - return table.to_pandas(types_mapper=pd.ArrowDtype, safe=False) - - elif dataframe_library == "polars": - check_polars() - self._add_integration_tag("polars") - - def converter(table: "arrow.Table") -> "pl.DataFrame": - table = _apply_arrow_tz_policy(table, self.tz_mode) - return pl.from_arrow(table) - - else: - raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") - settings = self._update_arrow_settings(settings, use_strings) - - body, params, headers, files = self._prep_raw_query( - query, parameters, settings, fmt="ArrowStream", - use_database=True, external_data=external_data - ) - if transport_settings: - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request( - body, params, headers=headers, files=files, - stream=True, server_wait=False, retries=self.query_retries - ) - encoding = response.headers.get("Content-Encoding") - exception_tag = response.headers.get(ex_tag_header) - - loop = asyncio.get_running_loop() - streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) - await streaming_source.start_producer(loop) - - queue = AsyncSyncQueue(maxsize=10) - - class _ArrowDFStreamSource: - def __init__(self, source, q): - self._source = source - self._queue = q - - async def aclose(self): - self._queue.shutdown() - await self._source.aclose() - - def close(self): - self._queue.shutdown() - self._source.close() - - def parse_and_convert_streaming(): - """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" - try: - file_adapter = StreamingFileAdapter(streaming_source) - - # PyArrow reads incrementally from adapter (which pulls from queue) - reader = arrow.ipc.open_stream(file_adapter) - - for batch in reader: - try: - queue.sync_q.put(converter(batch)) - except RuntimeError: - return - - try: - queue.sync_q.put(EOF_SENTINEL) - except RuntimeError: - return - except Exception as e: - try: - queue.sync_q.put(e) - except Exception: - pass - finally: - queue.shutdown() - - loop.run_in_executor(None, parse_and_convert_streaming) - - async def df_generator(): - """Async generator that yields DataFrames without blocking event loop.""" - while True: - item = await queue.async_q.get() - if item is EOF_SENTINEL: - break - if isinstance(item, Exception): - raise item - yield item - - return StreamContext(_ArrowDFStreamSource(streaming_source, queue), df_generator()) - - async def insert_arrow( # type: ignore[override] - self, - table: str, - arrow_table, - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: - """ - Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format - :param table: ClickHouse table - :param arrow_table: PyArrow Table object - :param database: Optional ClickHouse database - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - """ - check_arrow() - self._add_integration_tag("arrow") - full_table = table if "." in table or not database else f"{database}.{table}" - compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None - column_names, insert_block = arrow_buffer(arrow_table, compression) - if hasattr(insert_block, "to_pybytes"): - insert_block = insert_block.to_pybytes() - return await self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) - - async def insert_df_arrow( # type: ignore[override] - self, - table: str, - df: Union["pd.DataFrame", "pl.DataFrame"], - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: - """ - Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. - This method is optimized for DataFrames that already use Arrow format, providing - better performance than the standard insert_df method. - - Validation is performed and an exception will be raised if this requirement is not met. - Polars DataFrames are natively Arrow-based and don't require additional validation. - - :param table: ClickHouse table name - :param df: Pandas DataFrame with PyArrow dtype backend or Polars DataFrame - :param database: Optional ClickHouse database name - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - check_arrow() - - if pd is not None and isinstance(df, pd.DataFrame): - df_lib = "pandas" - elif pl is not None and isinstance(df, pl.DataFrame): - df_lib = "polars" - else: - if pd is None and pl is None: - raise ImportError("A DataFrame library (pandas or polars) must be installed to use insert_df_arrow.") - raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") - - if df_lib == "pandas": - if not IS_PANDAS_2: - raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - - non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, pd.ArrowDtype)] - if non_arrow_cols: - raise ProgrammingError( - f"insert_df_arrow requires all columns to use PyArrow dtypes. Non-Arrow columns found: [{', '.join(non_arrow_cols)}]. " - ) - try: - arrow_table = arrow.Table.from_pandas(df, preserve_index=False) - except Exception as e: - raise DataError(f"Failed to convert pandas DataFrame to Arrow table: {e}") from e - else: - try: - arrow_table = df.to_arrow() - except Exception as e: - raise DataError(f"Failed to convert polars DataFrame to Arrow table: {e}") from e - - self._add_integration_tag(df_lib) - return await self.insert_arrow( - table=table, - arrow_table=arrow_table, - database=database, - settings=settings, - transport_settings=transport_settings, - ) - - async def create_insert_context( # type: ignore[override] - self, - table: str, - column_names: Optional[Union[str, Sequence[str]]] = None, - database: Optional[str] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> InsertContext: - """ - Builds a reusable insert context to hold state for a duration of an insert - :param table: Target table - :param database: Target database. If not set, uses the client default database - :param column_names: Optional ordered list of column names. If not set, all columns ('*') will be assumed - in the order specified by the table definition - :param database: Target database -- will use client default database if not specified - :param column_types: ClickHouse column types. Optional Sequence of ClickHouseType objects. If neither column - types nor column type names are set, actual column types will be retrieved from the server. - :param column_type_names: ClickHouse column type names. Specified column types by name string - :param column_oriented: If true the data is already "pivoted" in column form - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param data: Initial dataset for insert - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Reusable insert context - """ - full_table = table - if "." not in table: - if database: - full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" - else: - full_table = quote_identifier(table) - column_defs = [] - if column_types is None and column_type_names is None: - describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) - column_defs = [ - ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") - ] - if column_names is None or isinstance(column_names, str) and column_names == "*": - column_names = [cd.name for cd in column_defs] - column_types = [cd.ch_type for cd in column_defs] - elif isinstance(column_names, str): - column_names = [column_names] - if len(column_names) == 0: - raise ValueError("Column names must be specified for insert") - if not column_types: - if column_type_names: - column_types = [get_from_name(name) for name in column_type_names] - else: - column_map = {d.name: d for d in column_defs} - try: - column_types = [column_map[name].ch_type for name in column_names] - except KeyError as ex: - raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None - if len(column_names) != len(column_types): - raise ProgrammingError("Column names do not match column types") from None - return InsertContext( - full_table, - column_names, - column_types, - column_oriented=column_oriented, - settings=settings, - transport_settings=transport_settings, - data=data, - ) - - async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] - """ - See BaseClient doc_string for this method. - - Uses true streaming via reverse bridge pattern: - - Sync producer (serializer) runs in executor, puts blocks in queue - - Async consumer (network) pulls from queue and yields to aiohttp - - Bounded queue provides backpressure to prevent memory bloat - """ - if context.empty: - logger.debug("No data included in insert, skipping") - return QuerySummary() - - if context.compression is None: - context.compression = self.write_compression - - loop = asyncio.get_running_loop() - - streaming_source = StreamingInsertSource( - transform=self._transform, context=context, loop=loop, maxsize=10 - ) - - streaming_source.start_producer() - - headers = {"Content-Type": "application/octet-stream"} - if context.compression: - headers["Content-Encoding"] = context.compression - - params = {} - if self.database: - params["database"] = self.database - params.update(self._validate_settings(context.settings)) - headers = dict_copy(headers, context.transport_settings) - - try: - response = await self._raw_request( - streaming_source.async_generator(), params, headers=headers, server_wait=False - ) - logger.debug("Context insert response code: %d", response.status) - except Exception: - await streaming_source.close() - - if context.insert_exception: - ex = context.insert_exception - context.insert_exception = None - raise ex from None - raise - finally: - await streaming_source.close() - context.data = None - - return QuerySummary(self._summary(response)) - - async def insert_df( # type: ignore[override] - self, - table: Optional[str] = None, - df=None, - database: Optional[str] = None, - settings: Optional[Dict] = None, - column_names: Optional[Sequence[str]] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - context: Optional[InsertContext] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: - """ - Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored - :param table: ClickHouse table - :param df: two-dimensional pandas dataframe - :param database: Optional ClickHouse database - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names - will be used - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - check_pandas() - self._add_integration_tag("pandas") - if context is None: - if column_names is None: - column_names = df.columns - elif len(column_names) != len(df.columns): - raise ProgrammingError("DataFrame column count does not match insert_columns") from None - return await self.insert( - table, - df, - column_names, - database, - column_types=column_types, - column_type_names=column_type_names, - settings=settings, - transport_settings=transport_settings, - context=context, - ) - - async def raw_insert( # type: ignore[override] - self, - table: Optional[str] = None, - column_names: Optional[Sequence[str]] = None, - insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None, - ) -> QuerySummary: - """ - See BaseClient doc_string for this method - """ - params = {} - headers = {"Content-Type": "application/octet-stream"} - if compression: - headers["Content-Encoding"] = compression - - if table: - cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" - fmt_str = fmt if fmt else self._write_format - query = f"INSERT INTO {table}{cols} FORMAT {fmt_str}" - if not compression and isinstance(insert_block, str): - insert_block = query + "\n" + insert_block - elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): - insert_block = (query + "\n").encode() + insert_block - else: - params["query"] = query - - if self.database: - params["database"] = self.database - params.update(self._validate_settings(settings or {})) - headers = dict_copy(headers, transport_settings) - - response = await self._raw_request(insert_block, params, headers, server_wait=False) - logger.debug("Raw insert response code: %d", response.status) - return QuerySummary(self._summary(response)) - - def _add_integration_tag(self, name: str): - """ - Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. - """ - if not common.get_setting("send_integration_tags") or name in self._reported_libs: - return - - try: - ver = "unknown" - try: - ver = dist_version(name) - except Exception: - try: - mod = import_module(name) - ver = getattr(mod, "__version__", "unknown") - except Exception: - pass - - product_info = f"{name}/{ver}" - - ua = self.headers.get("User-Agent", "") - start = ua.find("(") - if start == -1: - return - end = ua.find(")", start + 1) - if end == -1: - return - - details = ua[start + 1 : end].strip() - - if product_info in details: - self._reported_libs.add(name) - return - - new_details = f"{product_info}; {details}" if details else product_info - new_ua = f"{ua[: start + 1]}{new_details}{ua[end:]}" - self.headers["User-Agent"] = new_ua.strip() - if self._session: - self._session.headers["User-Agent"] = new_ua.strip() - - self._reported_libs.add(name) - logger.debug("Added '%s' to User-Agent", product_info) - - except Exception as e: - logger.debug("Problem adding '%s' to User-Agent: %s", name, e) - - async def _error_handler(self, response: aiohttp.ClientResponse, retried: bool = False): - """ - Handles HTTP errors. Tries to be robust and provide maximum context. - """ - try: - body = "" - try: - raw_body = await response.read() - encoding = response.headers.get("Content-Encoding") - - if encoding: - loop = asyncio.get_running_loop() - - def decompress_and_decode(): - decompressed = decompress_response(raw_body, encoding) - return common.format_error(decompressed.decode(errors="backslashreplace")).strip() - - body = await loop.run_in_executor(None, decompress_and_decode) - else: - loop = asyncio.get_running_loop() - body = await loop.run_in_executor( - None, - lambda: common.format_error(raw_body.decode(errors="backslashreplace")).strip() - ) - except Exception: - logger.warning("Failed to read error response body", exc_info=True) - - if self.show_clickhouse_errors: - err_code = response.headers.get(ex_header) - if err_code: - err_str = f"Received ClickHouse exception, code: {err_code}" - else: - err_str = f"HTTP driver received HTTP status {response.status}" - - if body: - err_str = f"{err_str}, server response: {body}" - else: - err_str = "The ClickHouse server returned an error" - - err_str = f"{err_str} (for url {self.url})" - - finally: - response.close() - - raise OperationalError(err_str) if retried else DatabaseError(err_str) from None - - async def _raw_request( - self, - data, - params, - headers=None, - files=None, - method="POST", - stream=False, - server_wait=True, - retries: int = 0, - ) -> aiohttp.ClientResponse: - if self._session is None: - raise ProgrammingError( - "Session not initialized. Use 'async with get_async_client(...)' or call 'await client._initialize()' first." - ) - - reset_seconds = common.get_setting("max_connection_age") - if reset_seconds: - now = time.time() - if self._last_pool_reset is None: - self._last_pool_reset = now - elif self._last_pool_reset < now - reset_seconds: - logger.debug("connection expiration - resetting connection pool") - await self.close_connections() - self._last_pool_reset = now - - final_params = dict_copy(self._client_settings, params) - if server_wait: - final_params.setdefault("wait_end_of_query", "1") - if self._send_progress: - final_params.setdefault("send_progress_in_http_headers", "1") - if self._progress_interval: - final_params.setdefault("http_headers_progress_interval_ms", self._progress_interval) - if self._autogenerate_query_id and "query_id" not in final_params: - final_params["query_id"] = str(uuid.uuid4()) - - req_headers = dict_copy(self.headers, headers) - if self.server_host_name: - req_headers["Host"] = self.server_host_name - query_session = final_params.get("session_id") - attempts = 0 - - # pylint: disable=too-many-nested-blocks - while True: - attempts += 1 - - if query_session: - if query_session == self._active_session: - raise ProgrammingError( - "Attempt to execute concurrent queries within the same session. " - "Please use a separate client instance per concurrent query." - ) - self._active_session = query_session - - try: - # Construct full URL (aiohttp doesn't have base_url) - url = f"{self.url}/" - request_kwargs = {"method": method, "url": url, "params": final_params, "headers": req_headers} - if hasattr(self, "_proxy_url") and self._proxy_url: - request_kwargs["proxy"] = self._proxy_url - if files: - # IMPORTANT: Must set content_type on text fields to force multipart/form-data encoding - # Without content_type, aiohttp uses application/x-www-form-urlencoded - form = aiohttp.FormData() - for field_name, field_value in files.items(): - if isinstance(field_value, tuple): - if field_value[0] is None: - form.add_field(field_name, str(field_value[1]), content_type='text/plain') - else: - filename = field_value[0] - file_data = field_value[1] - content_type = field_value[2] if len(field_value) > 2 else None - form.add_field(field_name, file_data, filename=filename, content_type=content_type) - else: - form.add_field(field_name, field_value, content_type='text/plain') - request_kwargs["data"] = form - elif isinstance(data, dict): - request_kwargs["data"] = data - else: - request_kwargs["data"] = data - - response = await self._session.request(**request_kwargs) - if 200 <= response.status < 300 and not response.headers.get(ex_header): - return response - - if response.status in (429, 503, 504): - if attempts > retries: - await self._error_handler(response, retried=True) - else: - logger.debug("Retrying request with status code %s (attempt %s/%s)", response.status, attempts, retries + 1) - await asyncio.sleep(0.1 * attempts) - response.close() - continue - await self._error_handler(response) - - except aiohttp.ServerConnectionError as e: - if "Connection reset" in str(e) or "Remote end closed" in str(e) or "Cannot connect" in str(e): - if attempts == 1: - logger.debug("Retrying after connection error from remote host") - continue - raise OperationalError(f"Network Error: {str(e)}") from e - - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - raise OperationalError(f"Network Error: {str(e)}") from e - - finally: - if query_session: - self._active_session = None - - @staticmethod - def _summary(response: aiohttp.ClientResponse): - summary = {} - if "X-ClickHouse-Summary" in response.headers: - try: - summary = json.loads(response.headers["X-ClickHouse-Summary"]) - except json.JSONDecodeError: - pass - summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") - return summary diff --git a/clickhouse_connect/driver/asyncclient.py b/clickhouse_connect/driver/asyncclient.py index 7706bd8a..5fc11464 100644 --- a/clickhouse_connect/driver/asyncclient.py +++ b/clickhouse_connect/driver/asyncclient.py @@ -1,382 +1,1210 @@ +# pylint: disable=too-many-lines,duplicate-code,import-error + import asyncio +import gzip import io +import json import logging -import os -import warnings - -from concurrent.futures.thread import ThreadPoolExecutor +import re +import ssl +import sys +import time +import uuid +import pytz +import zlib +from base64 import b64encode from datetime import tzinfo -from typing import Optional, Union, Dict, Any, Sequence, Iterable, Generator, BinaryIO - -try: - from clickhouse_connect.driver.aiohttp_client import AiohttpAsyncClient - AIOHTTP_AVAILABLE = True -except ImportError: - AiohttpAsyncClient = None - AIOHTTP_AVAILABLE = False - -from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.common import StreamContext -from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.query import QueryContext, QueryResult -from clickhouse_connect.driver.summary import QuerySummary +from importlib import import_module +from importlib.metadata import version as dist_version +from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Union + +import aiohttp +import lz4.frame +import zstandard + +from clickhouse_connect import common +from clickhouse_connect.datatypes import dynamic as dynamic_module from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import httputil, tzutil +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client, _apply_arrow_tz_policy +from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import DatabaseError, DataError, OperationalError, ProgrammingError +from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.insert import InsertContext +from clickhouse_connect.driver.models import ColumnDef, SettingDef +from clickhouse_connect.driver import options +from clickhouse_connect.driver.options import check_arrow, check_numpy, check_pandas, check_polars +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource, arrow_buffer +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.asyncqueue import AsyncSyncQueue, EOF_SENTINEL +from clickhouse_connect.driver.streaming import StreamingInsertSource +from clickhouse_connect.driver.transform import NativeTransform +from clickhouse_connect.driver.streaming import StreamingResponseSource, StreamingFileAdapter logger = logging.getLogger(__name__) +columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) +ex_header = "X-ClickHouse-Exception-Code" +ex_tag_header = "X-ClickHouse-Exception-Tag" + +if "br" in available_compression: + import brotli +else: + brotli = None + +def decompress_response(data: bytes, encoding: Optional[str]) -> bytes: + """Decompress response data based on Content-Encoding header.""" + + if not encoding or encoding == "identity": + return data + + if encoding == "lz4": + lz4_decom = lz4.frame.LZ4FrameDecompressor() + return lz4_decom.decompress(data, len(data)) + if encoding == "zstd": + zstd_decom = zstandard.ZstdDecompressor() + return zstd_decom.stream_reader(io.BytesIO(data)).read() + if encoding == "br": + if brotli is not None: + return brotli.decompress(data) + raise OperationalError("Brotli compression requested but not installed.") + if encoding == "gzip": + return gzip.decompress(data) + if encoding == "deflate": + return zlib.decompress(data) + raise OperationalError(f"Unsupported compression type: '{encoding}'. Supported compression: {', '.join(available_compression)}") + + +class BytesSource: + """Wrapper to make bytes compatible with ResponseBuffer expectations.""" + + def __init__(self, data: bytes): + self.data = data + self.gen = self._make_generator() + + def _make_generator(self): + yield self.data + + def close(self): + """No-op close method for compatibility.""" + +# pylint: disable=invalid-overridden-method, too-many-instance-attributes, too-many-public-methods, broad-exception-caught +class AsyncClient(Client): + valid_transport_settings = {"database", "buffer_size", "session_id", + "compress", "decompress", "session_timeout", + "session_check", "query_id", "quota_key", + "wait_end_of_query", "client_protocol_version", + "role"} + optional_transport_settings = {"send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression"} + + # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements + def __init__( + self, + interface: str, + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + access_token: Optional[str] = None, + compress: Union[bool, str] = True, + connect_timeout: int = 10, + send_receive_timeout: int = 300, + client_name: Optional[str] = None, + verify: Union[bool, str] = True, + ca_cert: Optional[str] = None, + client_cert: Optional[str] = None, + client_cert_key: Optional[str] = None, + http_proxy: Optional[str] = None, + https_proxy: Optional[str] = None, + server_host_name: Optional[str] = None, + tls_mode: Optional[str] = None, + proxy_path: str = "", + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, + session_id: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + query_limit: int = 0, + query_retries: int = 2, + tz_source: Optional[TzSource] = None, + tz_mode: Optional[TzMode] = None, + apply_server_timezone: Optional[Union[str, bool]] = None, + utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, + show_clickhouse_errors: Optional[bool] = None, + autogenerate_session_id: Optional[bool] = None, + autogenerate_query_id: Optional[bool] = None, + form_encode_query_params: bool = False, + **kwargs, + ): + """ + Async HTTP Client using aiohttp. Initialization is handled via _initialize(). + """ + proxy_path = proxy_path.lstrip("/") + if proxy_path: + proxy_path = "/" + proxy_path + self.uri = f"{interface}://{host}:{port}{proxy_path}" + self.url = self.uri + self.form_encode_query_params = form_encode_query_params + self._rename_response_column = kwargs.get("rename_response_column") + self._initial_settings = settings + self.headers = {} + + if interface == "https": + if isinstance(verify, str) and verify.lower() == "proxy": + verify = True + tls_mode = tls_mode or "proxy" + + # Priority: access_token > mutual TLS > basic auth + if client_cert and (tls_mode is None or tls_mode == "mutual"): + if not username: + raise ProgrammingError("username parameter is required for Mutual TLS authentication") + self.headers["X-ClickHouse-User"] = username + self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" + elif access_token: + self.headers["Authorization"] = f"Bearer {access_token}" + elif username and (not client_cert or tls_mode in ("strict", "proxy")): + credentials = b64encode(f"{username}:{password}".encode()).decode() + self.headers["Authorization"] = f"Basic {credentials}" + + self.headers["User-Agent"] = common.build_client_name(client_name) + # Prevent aiohttp from automatically requesting compressed responses + # We'll manually set Accept-Encoding when compression is desired + self.headers["Accept-Encoding"] = "identity" + self._send_receive_timeout = send_receive_timeout + + connect_timeout_val = float(connect_timeout) if connect_timeout is not None else None + send_receive_timeout_val = float(send_receive_timeout) if send_receive_timeout is not None else None + + self._timeout = aiohttp.ClientTimeout( + total=None, + connect=connect_timeout_val, + sock_connect=connect_timeout_val, + sock_read=send_receive_timeout_val, + ) + connector_limit_per_host = min(connector_limit_per_host, connector_limit) + + proxy_url = None + if http_proxy: + if not http_proxy.startswith("http://") and not http_proxy.startswith("https://"): + proxy_url = f"http://{http_proxy}" + else: + proxy_url = http_proxy + elif https_proxy: + if not https_proxy.startswith("http://") and not https_proxy.startswith("https://"): + proxy_url = f"http://{https_proxy}" + else: + proxy_url = https_proxy + else: + scheme = "https" if self.url.startswith("https://") else "http" + env_proxy = httputil.check_env_proxy(scheme, host, port) + if env_proxy: + if not env_proxy.startswith("http://") and not env_proxy.startswith("https://"): + proxy_url = f"http://{env_proxy}" + else: + proxy_url = env_proxy + + ssl_context = None + if interface == "https": + ssl_context = ssl.create_default_context() + ssl_verify = verify if isinstance(verify, bool) else coerce_bool(verify) + if not ssl_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ca_cert: + ssl_context.load_verify_locations(ca_cert) + if client_cert: + ssl_context.load_cert_chain(client_cert, client_cert_key) + + self._ssl_context = ssl_context + self._proxy_url = proxy_url + self._connector_kwargs = { + "limit": connector_limit, + "limit_per_host": connector_limit_per_host, + "keepalive_timeout": keepalive_timeout, + "force_close": False, + "ssl": ssl_context, + } + # enable_cleanup_closed is only needed for Python < 3.12.7 or == 3.13.0 + # The underlying SSL connection leak was fixed in 3.12.7 and 3.13.1+ + # https://github.com/python/cpython/pull/118960 + if sys.version_info < (3, 12, 7) or sys.version_info[:3] == (3, 13, 0): + self._connector_kwargs["enable_cleanup_closed"] = True + + self._session = None + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._client_settings = {} + self._initialized = False + self._reported_libs = set() + self._last_pool_reset = None + self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") + + # Store aiohttp-specific params for deferred initialization + self._compress_param = compress + self._session_id_param = session_id + self._autogenerate_session_id_param = autogenerate_session_id + self._autogenerate_query_id = ( + common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id + ) + self._active_session = None + self._send_progress = None + self._progress_interval = None + + # Call parent init with autoconnect=False to set up config without blocking I/O + super().__init__( + database=database, + query_limit=query_limit, + uri=self.uri, + query_retries=query_retries, + server_host_name=server_host_name, + tz_source=tz_source, + tz_mode=tz_mode, + utc_tz_aware=utc_tz_aware, + apply_server_timezone=apply_server_timezone, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=False + ) + + # pylint: disable=attribute-defined-outside-init + async def _initialize(self): + """ + Async equivalent of Client._init_common_settings. + Fetches server version, timezone, and settings. + """ + if not self._session: + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + if self._initialized: + return -class DefaultThreadPoolExecutor: - pass + try: + tz_source = self._deferred_tz_source + + self.server_tz, self._dst_safe = pytz.UTC, True + row = await self.command("SELECT version(), timezone()", use_database=False) + self.server_version, server_tz_str = tuple(row) + try: + server_tz = pytz.timezone(server_tz_str) + server_tz, self._dst_safe = tzutil.normalize_timezone(server_tz) + if tz_source == "auto": + self._apply_server_tz = self._dst_safe + else: + self._apply_server_tz = tz_source == "server" + self.server_tz = server_tz + except pytz.exceptions.UnknownTimeZoneError: + logger.warning("Warning, server is using an unrecognized timezone %s, will use UTC default", server_tz_str) + + if not self._apply_server_tz and not tzutil.local_tz_dst_safe: + logger.warning("local timezone %s may return unexpected times due to Daylight Savings Time", tzutil.local_tz.tzname(None)) + + readonly = "readonly" + if not self.min_version("19.17"): + readonly = common.get_setting("readonly") + + server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} + + if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): + try: + test_data = await self.raw_query( + "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} + ) + if test_data[8:16] == b"\x01\x01\x05check": + self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD + except Exception: + pass + + cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") + if cancel_setting.is_writable and not cancel_setting.is_set and \ + "cancel_http_readonly_queries_on_client_close" not in (self._initial_settings or {}): + self._client_settings["cancel_http_readonly_queries_on_client_close"] = "1" + + if self._initial_settings: + for key, value in self._initial_settings.items(): + self.set_client_setting(key, value) + + compress = self._compress_param + if coerce_bool(compress): + compression = ",".join(available_compression) + self.write_compression = available_compression[0] + elif compress and compress not in ("False", "false", "0"): + if compress not in available_compression: + raise ProgrammingError(f"Unsupported compression method {compress}") + compression = compress + self.write_compression = compress + else: + compression = None + + comp_setting = self._setting_status("enable_http_compression") + self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable # pylint: disable=attribute-defined-outside-init + if comp_setting.is_set or comp_setting.is_writable: + self.compression = compression + + session_id = self._session_id_param + autogenerate_session_id = self._autogenerate_session_id_param + + if autogenerate_session_id is None: + autogenerate_session_id = common.get_setting("autogenerate_session_id") + + if session_id: + self.set_client_setting("session_id", session_id) + elif self.get_client_setting("session_id"): + pass + elif autogenerate_session_id: + self.set_client_setting("session_id", str(uuid.uuid4())) + + send_setting = self._setting_status("send_progress_in_http_headers") + self._send_progress = not send_setting.is_set and send_setting.is_writable + if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: + self._progress_interval = str(min(120000, max(10000, (self._send_receive_timeout - 5) * 1000))) + + if self._setting_status("date_time_input_format").is_writable: + self.set_client_setting("date_time_input_format", "best_effort") + if ( + self._setting_status("allow_experimental_json_type").is_set + and self._setting_status("cast_string_to_dynamic_use_inference").is_writable + ): + self.set_client_setting("cast_string_to_dynamic_use_inference", "1") + if self.min_version("24.8") and not self.min_version("24.10"): + dynamic_module.json_serialization_format = 0 + + self._initialized = True + except Exception: + if self._session and not self._session.closed: + await self._session.close() + self._session = None + raise + + async def __aenter__(self): + """Async context manager entry.""" + if not self._initialized: + await self._initialize() + return self + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False + + async def close(self): # type: ignore[override] + if self._session: + await self._session.close() + + async def close_connections(self): # type: ignore[override] + """Close all pooled connections and recreate session""" + if self._session: + await self._session.close() + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) -# Sentinel value to preserve default behavior and also allow passing `None` -NEW_THREAD_POOL_EXECUTOR = DefaultThreadPoolExecutor() + def set_client_setting(self, key, value): + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is not None: + self._client_settings[key] = str_value + def get_client_setting(self, key) -> Optional[str]: + return self._client_settings.get(key) -# pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,too-many-locals -class AsyncClient: - """ - Unified async client with backward compatibility. + def set_access_token(self, access_token: str): + auth_header = self.headers.get("Authorization") + if auth_header and not auth_header.startswith("Bearer"): + raise ProgrammingError("Cannot set access token when a different auth type is used") + self.headers["Authorization"] = f"Bearer {access_token}" + if self._session: + self._session.headers["Authorization"] = f"Bearer {access_token}" + + def _prep_query(self, context: QueryContext): + final_query = super()._prep_query(context) + if context.is_insert: + return final_query + fmt = f"\n FORMAT {self._read_format}" + if isinstance(final_query, bytes): + return final_query + fmt.encode() + return final_query + fmt + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + headers = {} + params = {} + if self.database: + params["database"] = self.database + if self.protocol_version: + params["client_protocol_version"] = self.protocol_version + context.block_info = True + params.update(self._validate_settings(context.settings)) + context.rename_response_column = self._rename_response_column + + if not context.is_insert and columns_only_re.search(context.uncommented_query): + fmt_json_query = f"{context.final_query}\n FORMAT JSON" + fields = {"query": fmt_json_query} + fields.update(context.bind_params) + + if self.form_encode_query_params: + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + response = await self._raw_request(None, params, headers, files=files, retries=self.query_retries) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = fmt_json_query + response = await self._raw_request(None, params, headers, files=context.external_data.form_data, retries=self.query_retries) + else: + params.update(context.bind_params) + response = await self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) + + body = await response.read() + encoding = response.headers.get("Content-Encoding") + loop = asyncio.get_running_loop() + + def decompress_and_parse_json(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + return json.loads(decompressed_body) + + # Offload to executor + json_result = await loop.run_in_executor(None, decompress_and_parse_json) + + names: List[str] = [] + types: List[ClickHouseType] = [] + renamer = context.column_renamer + for col in json_result["meta"]: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as e: + logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + + if self.compression: + headers["Accept-Encoding"] = self.compression + if self._send_comp_setting: + params["enable_http_compression"] = "1" + + final_query = self._prep_query(context) + + files = None + data = None + + if self.form_encode_query_params: + fields = {"query": final_query} + fields.update(context.bind_params) + + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = final_query + files = context.external_data.form_data + else: + params.update(context.bind_params) + data = final_query + headers["Content-Type"] = "text/plain; charset=utf-8" - This class maintains backward compatibility with the legacy executor-based async client - while also supporting direct instantiation for native async operations (though - get_async_client() is the recommended approach for new code). - """ + headers = dict_copy(headers, context.transport_settings) - def __init__(self, - *, - client: Optional[Client] = None, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, - **kwargs): - """ - Create async client. + response = await self._raw_request(data, params, headers, files=files, + server_wait=not context.streaming, + stream=True, retries=self.query_retries) + encoding = response.headers.get("Content-Encoding") + tz_header = response.headers.get("X-ClickHouse-Timezone") + exception_tag = response.headers.get(ex_tag_header) - Args: - client: (LEGACY - DEPRECATED) Sync client to wrap with ThreadPoolExecutor - executor_threads: (LEGACY - DEPRECATED) Thread pool size for legacy mode - executor: (LEGACY - DEPRECATED) Custom ThreadPoolExecutor for legacy mode - **kwargs: Arguments passed to AiohttpAsyncClient (native mode) + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + def parse_streaming(): + """Parse response from streaming queue (runs in executor).""" + # Wrap streaming source with ResponseBuffer. The streaming source provides a + # .gen property that yields decompressed chunks. + byte_source = RespBuffCls(streaming_source) + context.set_response_tz(self._check_tz_change(tz_header)) + result = self._transform.parse_response(byte_source, context) + + # For Pandas/Numpy, we must materialize in the executor because the resulting objects + # (DataFrame, Array) are fully in-memory structures. + # For standard queries, we return a lazy QueryResult. Accessing .result_set on the event loop + # will raise a ProgrammingError (deadlock check), encouraging usage of .rows_stream. + if not context.streaming: + if context.as_pandas and hasattr(result, 'df_result'): + _ = result.df_result + elif context.use_numpy and hasattr(result, 'np_result'): + _ = result.np_result + elif isinstance(result, QueryResult): + _ = result.result_set + + return result + + # Run parser in executor (pulls from queue, decompresses & parses) + try: + query_result = await loop.run_in_executor(None, parse_streaming) + except Exception: + await streaming_source.aclose() + raise + query_result.summary = self._summary(response) + + # Attach streaming_source to query_result.source to ensure it gets closed + # when the query result is closed (e.g. by StreamContext.__exit__) + query_result.source = streaming_source + + return query_result + + + # pylint: disable=arguments-differ + async def query( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + column_oriented: Optional[bool] = None, + use_numpy: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> QueryResult: + """ + Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For + parameters, see the create_query_context method + :return: QueryResult -- data and metadata from response """ - if client is not None: - # LEGACY PATH: User passed sync client. use executor-based wrapper - warnings.warn( - "Passing 'client=' to AsyncClient is deprecated. " - "Use create_async_client(host=..., port=...) instead. " - "Legacy executor-based mode may be removed in the future.", - DeprecationWarning, - stacklevel=2 + if query and query.lower().strip().startswith("select __connect_version__"): + return QueryResult( + [[f"ClickHouse Connect v.{common.version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) + ) + if not context: + context = self.create_query_context( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, + transport_settings=transport_settings, + tz_mode=tz_mode, ) - self._impl = _LegacyAsyncWrapper(client, executor_threads, executor) - else: - # NATIVE PATH: Create aiohttp client - if not AIOHTTP_AVAILABLE: - raise ImportError( - "Native async support requires aiohttp. " - "Install with: pip install clickhouse-connect[async]\n" - "Alternatively, use the legacy executor-based async by passing a sync client to AsyncClient." - ) - self._impl = AiohttpAsyncClient(**kwargs) - # Proxy all methods to implementation - # pylint: disable=protected-access - async def _initialize(self): - if hasattr(self._impl, '_initialize'): - await self._impl._initialize() + if context.is_command: + response = await self.command( + query, + parameters=context.parameters, + settings=context.settings, + external_data=context.external_data, + transport_settings=context.transport_settings, + ) + if isinstance(response, QuerySummary): + return response.as_query_result() + return QueryResult([response] if isinstance(response, list) else [[response]]) - def set_client_setting(self, key, value): + return await self._query_with_context(context) + + async def query_column_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: """ - Set a clickhouse setting for the client after initialization. If a setting is not recognized by ClickHouse, - or the setting is identified as "read_only", this call will either throw a Programming exception or attempt - to send the setting anyway based on the common setting 'invalid_setting_action'. - :param key: ClickHouse setting name - :param value: ClickHouse setting value + Async version of query_column_block_stream. + Returns a StreamContext that yields column-oriented blocks. """ - self._impl.set_client_setting(key=key, value=value) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).column_block_stream - def get_client_setting(self, key) -> Optional[str]: + async def query_row_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: """ - :param key: The setting key - :return: The string value of the setting, if it exists, or None + Async version of query_row_block_stream. + Returns a StreamContext that yields row-oriented blocks. """ - return self._impl.get_client_setting(key=key) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).row_block_stream - def set_access_token(self, access_token: str): + async def query_rows_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: + """ + Async version of query_rows_stream. + Returns a StreamContext that yields individual rows. + """ + return (await self._context_query(locals(), use_numpy=False, streaming=True)).rows_stream + + # pylint: disable=unused-argument + async def query_np( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True)).np_result + + async def query_np_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True, streaming=True)).np_stream + + async def query_df( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ): + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True)).df_result + + async def query_df_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True)).df_stream + + async def _context_query(self, lcls: dict, **overrides): # type: ignore[override] """ - Set the ClickHouse access token for the client - :param access_token: Access token string + Helper method to create query context and execute query. + Matches sync client pattern for consistency. """ - return self._impl.set_access_token(access_token) + kwargs = lcls.copy() + kwargs.pop("self") + kwargs.update(overrides) + return await self._query_with_context(self.create_query_context(**kwargs)) - def min_version(self, version_str: str) -> bool: + async def command( # type: ignore[override] + self, + cmd, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + data: Optional[Union[str, bytes]] = None, + settings: Optional[Dict] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> Union[str, int, Sequence[str], QuerySummary]: """ - Determine whether the connected server is at least the submitted version - For Altinity Stable versions like 22.8.15.25.altinitystable - the last condition in the first list comprehension expression is added - :param version_str: A version string consisting of up to 4 integers delimited by dots - :return: True if version_str is greater than the server_version, False if less than + See BaseClient doc_string for this method """ - return self._impl.min_version(version_str) + cmd, bind_params = bind_query(cmd, parameters, self.server_tz) + params = bind_params.copy() + headers = {} + payload = None + files = None + + if external_data: + if data: + raise ProgrammingError("Cannot combine command data with external data") from None + files = external_data.form_data + params.update(external_data.query_params) + elif isinstance(data, str): + headers["Content-Type"] = "text/plain; charset=utf-8" + payload = data.encode() + elif isinstance(data, bytes): + headers["Content-Type"] = "application/octet-stream" + payload = data + + if payload is None and not cmd: + raise ProgrammingError("Command sent without query or recognized data") from None + + if payload or files: + params["query"] = cmd + else: + payload = cmd + + if use_database and self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + method = "POST" if payload or files else "GET" + response = await self._raw_request(payload, params, headers, files=files, method=method, server_wait=False) + body = await response.read() + encoding = response.headers.get("Content-Encoding") + summary = self._summary(response) + + if not body: + return QuerySummary(summary) + + loop = asyncio.get_running_loop() - async def close(self) -> None: + def decompress_and_decode(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + try: + result = decompressed_body.decode()[:-1].split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + except UnicodeDecodeError: + return str(decompressed_body) + + return await loop.run_in_executor(None, decompress_and_decode) + + async def ping(self) -> bool: # type: ignore[override] + try: + url = f"{self.url}/ping" + timeout = aiohttp.ClientTimeout(total=3.0) + async with self._session.get(url, timeout=timeout) as response: + return 200 <= response.status < 300 + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.debug("ping failed", exc_info=True) + return False + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> bytes: + """ + See BaseClient doc_string for this method """ - Subclass implementation to close the connection to the server/deallocate the client + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(body, params, headers=headers, files=files, retries=self.query_retries) + response_data = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + response_data = await loop.run_in_executor(None, decompress_response, response_data, encoding) + + return response_data + + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) + + async def byte_iterator(): + async for chunk in response.content.iter_any(): + yield chunk + + return StreamContext(response, byte_iterator()) + + def _prep_raw_query(self, query, parameters, settings, fmt, use_database, external_data): """ - return await self._impl.close() + Prepare raw query for execution. - async def query(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None) -> QueryResult: + Note: Unlike sync client which returns (body, params, fields), this async version + returns (body, params, headers, files) because aiohttp requires headers to be + configured before the request() call, while urllib3 can add them during request. """ - Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. - For parameters, see the create_query_context method. - :return: QueryResult -- data and metadata from response + if fmt: + query += f"\n FORMAT {fmt}" + + final_query, bind_params = bind_query(query, parameters, self.server_tz) + params = self._validate_settings(settings or {}) + if use_database and self.database: + params["database"] = self.database + + headers = {} + files = None + body = None + + if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): + raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") + + if self.form_encode_query_params: + files = {} + files["query"] = (None, final_query if isinstance(final_query, str) else final_query.decode()) + for k, v in bind_params.items(): + files[k] = (None, str(v)) + + if external_data: + params.update(external_data.query_params) + files.update(external_data.form_data) + + body = None + elif external_data: + params.update(bind_params) + params["query"] = final_query + params.update(external_data.query_params) + files = external_data.form_data + body = None + else: + params.update(bind_params) + body = final_query.encode() if isinstance(final_query, str) else final_query + + return body, params, headers, files + + async def insert( # type: ignore[override] + self, + table: Optional[str] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + column_names: Union[str, Iterable[str]] = "*", + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: """ - return await self._impl.query(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, column_oriented=column_oriented, - use_numpy=use_numpy, max_str_len=max_str_len, context=context, - query_tz=query_tz, column_tzs=column_tzs, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings, - tz_mode=tz_mode) - - async def query_column_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None, - ) -> StreamContext: - """ - Variation of main query method that returns a stream of column oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns column oriented blocks - """ - return await self._impl.query_column_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings, - tz_mode=tz_mode) - - async def query_row_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None) -> StreamContext: - """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows - """ - return await self._impl.query_row_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings, - tz_mode=tz_mode) - - async def query_rows_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None) -> StreamContext: - """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows - """ - return await self._impl.query_rows_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings, - tz_mode=tz_mode) - - async def raw_query(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: Optional[str] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> bytes: - """ - Query method that simply returns the raw ClickHouse format bytes. + Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments + other than data are ignored + :param table: Target table + :param data: Sequence of sequences of Python data + :param column_names: Ordered list of column names or '*' if column types should be retrieved from the + ClickHouse table definition + :param database: Target database -- will use client default database if not specified. + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param column_oriented: If true the data is already "pivoted" in column form + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + if (context is None or context.empty) and data is None: + raise ProgrammingError("No data specified for insert") from None + if context is None: + context = await self.create_insert_context( + table, + column_names, + database, + column_types, + column_type_names, + column_oriented, + settings, + transport_settings=transport_settings, + ) + if data is not None: + if not context.empty: + raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None + context.data = data + return await self.data_insert(context) + + async def query_arrow( + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + """ + Query method using the ClickHouse Arrow format to return a PyArrow table :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: bytes representing raw ClickHouse return value based on format - """ - return await self._impl.raw_query(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) - - async def raw_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: Optional[str] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: - """ - Query method that returns the result as an io.IOBase iterator. + :return: PyArrow.Table + """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + def parse_arrow_stream(): + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) + table = reader.read_all() + return _apply_arrow_tz_policy(table, self.tz_mode) + + try: + return await loop.run_in_executor(None, parse_arrow_stream) + finally: + await streaming_source.aclose() + + async def query_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + """ + Query method that returns the results as a stream of Arrow record batches. + :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result - """ - return await self._impl.raw_stream(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) - - async def query_np(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None): - """ - Query method that returns the results as a numpy array. - For parameter values, see the create_query_context method. - :return: Numpy array representing the result set - """ - return await self._impl.query_np(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - context=context, external_data=external_data, - transport_settings=transport_settings) - - async def query_np_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of numpy arrays. - For parameter values, see the create_query_context method. - :return: Numpy array representing the result set - """ - return await self._impl.query_np_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - context=context, external_data=external_data, - transport_settings=transport_settings) - - async def query_df(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None): - """ - Query method that results the results as a pandas dataframe. - For parameter values, see the create_query_context method. - :return: Pandas dataframe representing the result set - """ - return await self._impl.query_df(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - use_na_values=use_na_values, query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, context=context, external_data=external_data, - use_extended_dtypes=use_extended_dtypes, transport_settings=transport_settings, - tz_mode=tz_mode) + :return: StreamContext that yields PyArrow RecordBatch objects asynchronously + """ + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) + + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + + def parse_arrow_streaming(): + """Parse Arrow stream incrementally in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) + + for batch in reader: + try: + batch = _apply_arrow_tz_policy(batch, self.tz_mode) + queue.sync_q.put(batch) + except RuntimeError: + return + + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() + + loop.run_in_executor(None, parse_arrow_streaming) + + async def arrow_batch_generator(): + """Async generator that yields record batches without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowStreamSource(streaming_source, queue), arrow_batch_generator()) async def query_df_arrow( - self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas", - ) -> Union["pd.DataFrame", "pl.DataFrame"]: + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> Union["options.pd.DataFrame", "options.pl.DataFrame"]: """ Query method using the ClickHouse Arrow format to return a DataFrame with PyArrow dtype backend. This provides better performance and memory efficiency @@ -391,53 +1219,53 @@ async def query_df_arrow( :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") :return: DataFrame (pandas or polars based on dataframe_library parameter) """ - return await self._impl.query_df_arrow(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library) - - async def query_df_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[bool] = None, - context: Optional[QueryContext] = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[str] = None) -> StreamContext: - """ - Query method that returns the results as a StreamContext. - For parameter values, see the create_query_context method. - :return: Generator that yields a Pandas dataframe per block representing the result set - """ - return await self._impl.query_df_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - use_na_values=use_na_values, query_tz=query_tz, column_tzs=column_tzs, - utc_tz_aware=utc_tz_aware, context=context, external_data=external_data, - use_extended_dtypes=use_extended_dtypes, transport_settings=transport_settings, - tz_mode=tz_mode) - - async def query_df_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = 'pandas') -> StreamContext: + check_arrow() + + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + def converter(table: "options.arrow.Table") -> "options.pd.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) + + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") + + def converter(table: "options.arrow.Table") -> "options.pl.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) + + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + + arrow_table = await self.query_arrow( + query=query, + parameters=parameters, + settings=settings, + use_strings=use_strings, + external_data=external_data, + transport_settings=transport_settings, + ) + + return converter(arrow_table) + + async def query_df_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> StreamContext: """ Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. - Each DataFrame represents a block from the ClickHouse response. + Each DataFrame represents a record batch from the ClickHouse response. :param query: Query statement/format string :param parameters: Optional dictionary used to format the query @@ -446,196 +1274,112 @@ async def query_df_arrow_stream(self, :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: StreamContext that yields DataFrames (pandas or polars based on dataframe_library parameter) + :return: StreamContext that yields DataFrames asynchronously (pandas or polars based on dataframe_library parameter) """ - return await self._impl.query_df_arrow_stream(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library) + check_arrow() + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - def create_query_context(self, *args, **kwargs) -> QueryContext: - """ - Creates or updates a reusable QueryContext object - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param query_formats: See QueryContext __init__ docstring - :param column_formats: See QueryContext __init__ docstring - :param encoding: See QueryContext __init__ docstring - :param use_none: Use None for ClickHouse NULL instead of default values. Note that using None in Numpy - arrays will force the numpy array dtype to 'object', which is often inefficient. This effect also - will impact the performance of Pandas dataframes. - :param column_oriented: Deprecated. Controls orientation of the QueryResult result_set property - :param use_numpy: Return QueryResult columns as one-dimensional numpy arrays - :param max_str_len: Limit returned ClickHouse String values to this length, which allows a Numpy - structured array even with ClickHouse variable length String columns. If 0, Numpy arrays for - String columns will always be object arrays - :param context: An existing QueryContext to be updated with any provided parameter values - :param query_tz Either a string or a pytz tzinfo object. (Strings will be converted to tzinfo objects). - Values for any DateTime or DateTime64 column in the query will be converted to Python datetime.datetime - objects with the selected timezone - :param column_tzs A dictionary of column names to tzinfo objects (or strings that will be converted to - tzinfo objects). The timezone will be applied to datetime objects returned in the query - :param use_na_values: Deprecated alias for use_advanced_dtypes - :param as_pandas Return the result columns as pandas.Series objects - :param streaming Marker used to correctly configure streaming queries - :param external_data ClickHouse "external data" to send with query - :param use_extended_dtypes: Only relevant to Pandas Dataframe queries. Use Pandas "missing types", such as - pandas.NA and pandas.NaT for ClickHouse NULL values, as well as extended Pandas dtypes such as IntegerArray - and StringArray. Defaulted to True for query_df methods - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Reusable QueryContext - """ - return self._impl.create_query_context(*args, **kwargs) + def converter(table: "options.arrow.Table") -> "options.pd.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) - async def query_arrow(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None): - """ - Query method using the ClickHouse Arrow format to return a PyArrow table - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: PyArrow.Table - """ - return await self._impl.query_arrow(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) - - async def query_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of Arrow tables - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Generator that yields a PyArrow.Table for per block representing the result set - """ - return await self._impl.query_arrow_stream(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) - - async def command(self, - cmd, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Optional[Union[str, bytes]] = None, - settings: Optional[Dict] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: - """ - Client method that returns a single value instead of a result set - :param cmd: ClickHouse query/command as a python format string - :param parameters: Optional dictionary of key/values pairs to be formatted - :param data: Optional 'data' for the command (for INSERT INTO in particular) - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_database: Send the database parameter to ClickHouse so the command will be executed in the client - database context. Otherwise, no database will be specified with the command. This is useful for determining - the default user database - :param external_data ClickHouse "external data" to send with command/query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Decoded response from ClickHouse as either a string, int, or sequence of strings, or QuerySummary - if no data returned - """ - return await self._impl.command(cmd=cmd, parameters=parameters, data=data, settings=settings, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") - async def ping(self) -> bool: - """ - Validate the connection, does not throw an Exception (see debug logs) - :return: ClickHouse server is up and reachable - """ - return await self._impl.ping() + def converter(table: "options.arrow.Table") -> "options.pl.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) - async def insert(self, - table: Optional[str] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - column_names: Union[str, Iterable[str]] = '*', - database: Optional[str] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - context: Optional[InsertContext] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments - other than data are ignored - :param table: Target table - :param data: Sequence of sequences of Python data - :param column_names: Ordered list of column names or '*' if column types should be retrieved from the - ClickHouse table definition - :param database: Target database -- will use client default database if not specified. - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param column_oriented: If true the data is already "pivoted" in column form - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - return await self._impl.insert(table=table, data=data, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, context=context, - transport_settings=transport_settings) - - async def insert_df(self, - table: Optional[str] = None, - df = None, - database: Optional[str] = None, - settings: Optional[Dict] = None, - column_names: Optional[Sequence[str]] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - context: Optional[InsertContext] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored - :param table: ClickHouse table - :param df: two-dimensional pandas dataframe - :param database: Optional ClickHouse database - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names - will be used - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ - return await self._impl.insert_df(table=table, df=df, database=database, settings=settings, - column_names=column_names, column_types=column_types, - column_type_names=column_type_names, context=context, - transport_settings=transport_settings) + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) - async def insert_arrow(self, - table: str, - arrow_table, - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowDFStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + + def parse_and_convert_streaming(): + """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) + + # PyArrow reads incrementally from adapter (which pulls from queue) + reader = options.arrow.ipc.open_stream(file_adapter) + + for batch in reader: + try: + queue.sync_q.put(converter(batch)) + except RuntimeError: + return + + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() + + loop.run_in_executor(None, parse_and_convert_streaming) + + async def df_generator(): + """Async generator that yields DataFrames without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowDFStreamSource(streaming_source, queue), df_generator()) + + async def insert_arrow( # type: ignore[override] + self, + table: str, + arrow_table, + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: """ Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format :param table: ClickHouse table @@ -643,17 +1387,24 @@ async def insert_arrow(self, :param database: Optional ClickHouse database :param settings: Optional dictionary of ClickHouse settings (key/string values) :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails """ - return await self._impl.insert_arrow(table=table, arrow_table=arrow_table, database=database, - settings=settings, transport_settings=transport_settings) - - async def insert_df_arrow(self, - table: str, - df, - database: Optional[str] = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + check_arrow() + self._add_integration_tag("arrow") + full_table = table if "." in table or not database else f"{database}.{table}" + compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None + column_names, insert_block = arrow_buffer(arrow_table, compression) + if hasattr(insert_block, "to_pybytes"): + insert_block = insert_block.to_pybytes() + return await self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) + + async def insert_df_arrow( # type: ignore[override] + self, + table: str, + df: Union["options.pd.DataFrame", "options.pl.DataFrame"], + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: """ Insert a pandas DataFrame with PyArrow backend or a polars DataFrame into ClickHouse using Arrow format. This method is optimized for DataFrames that already use Arrow format, providing @@ -669,19 +1420,57 @@ async def insert_df_arrow(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ - return await self._impl.insert_df_arrow(table=table, df=df, database=database, settings=settings, - transport_settings=transport_settings) + check_arrow() - async def create_insert_context(self, - table: str, - column_names: Optional[Union[str, Sequence[str]]] = None, - database: Optional[str] = None, - column_types: Optional[Sequence[ClickHouseType]] = None, - column_type_names: Optional[Sequence[str]] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - transport_settings: Optional[Dict[str, str]] = None) -> InsertContext: + if options.pd is not None and isinstance(df, options.pd.DataFrame): + df_lib = "pandas" + elif options.pl is not None and isinstance(df, options.pl.DataFrame): + df_lib = "polars" + else: + if options.pd is None and options.pl is None: + raise ImportError("A DataFrame library (pandas or polars) must be installed to use insert_df_arrow.") + raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") + + if df_lib == "pandas": + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, options.pd.ArrowDtype)] + if non_arrow_cols: + raise ProgrammingError( + f"insert_df_arrow requires all columns to use PyArrow dtypes. Non-Arrow columns found: [{', '.join(non_arrow_cols)}]. " + ) + try: + arrow_table = options.arrow.Table.from_pandas(df, preserve_index=False) + except Exception as e: + raise DataError(f"Failed to convert pandas DataFrame to Arrow table: {e}") from e + else: + try: + arrow_table = df.to_arrow() + except Exception as e: + raise DataError(f"Failed to convert polars DataFrame to Arrow table: {e}") from e + + self._add_integration_tag(df_lib) + return await self.insert_arrow( + table=table, + arrow_table=arrow_table, + database=database, + settings=settings, + transport_settings=transport_settings, + ) + + async def create_insert_context( # type: ignore[override] + self, + table: str, + column_names: Optional[Union[str, Sequence[str]]] = None, + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> InsertContext: """ Builds a reusable insert context to hold state for a duration of an insert :param table: Target table @@ -698,204 +1487,393 @@ async def create_insert_context(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable insert context """ - return await self._impl.create_insert_context(table=table, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, data=data, - transport_settings=transport_settings) - - async def data_insert(self, context: InsertContext) -> QuerySummary: - """ - Subclass implementation of the data insert - :context: InsertContext parameter object - :return: No return, throws an exception if the insert fails - """ - return await self._impl.data_insert(context) - - async def raw_insert(self, - table: Optional[str] = None, - column_names: Optional[Sequence[str]] = None, - insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Insert data already formatted in a bytes object - :param table: Table name (whether qualified with the database name or not) - :param column_names: Sequence of column names - :param insert_block: Binary or string data already in a recognized ClickHouse format - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param compression: Recognized ClickHouse `Accept-Encoding` header compression value - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param fmt: Valid clickhouse format + full_table = table + if "." not in table: + if database: + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" + else: + full_table = quote_identifier(table) + column_defs = [] + if column_types is None and column_type_names is None: + describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] + if column_names is None or isinstance(column_names, str) and column_names == "*": + column_names = [cd.name for cd in column_defs] + column_types = [cd.ch_type for cd in column_defs] + elif isinstance(column_names, str): + column_names = [column_names] + if len(column_names) == 0: + raise ValueError("Column names must be specified for insert") + if not column_types: + if column_type_names: + column_types = [get_from_name(name) for name in column_type_names] + else: + column_map = {d.name: d for d in column_defs} + try: + column_types = [column_map[name].ch_type for name in column_names] + except KeyError as ex: + raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None + if len(column_names) != len(column_types): + raise ProgrammingError("Column names do not match column types") from None + return InsertContext( + full_table, + column_names, + column_types, + column_oriented=column_oriented, + settings=settings, + transport_settings=transport_settings, + data=data, + ) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] """ - return await self._impl.raw_insert(table=table, column_names=column_names, insert_block=insert_block, - settings=settings, fmt=fmt, compression=compression, - transport_settings=transport_settings) - - async def __aenter__(self) -> "AsyncClient": - if hasattr(self._impl, '_initialize'): - await self._impl._initialize() - return self + See BaseClient doc_string for this method. - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self.close() - - def __getattr__(self, name): - return getattr(self._impl, name) - - def __setattr__(self, name, value): - if name in ("_impl",) or "_impl" not in self.__dict__: - super().__setattr__(name, value) - return - if hasattr(self._impl, name): - setattr(self._impl, name, value) - else: - super().__setattr__(name, value) - - -class _LegacyAsyncWrapper: - """ - Legacy executor-based async wrapper (DEPRECATED). - - This wraps a sync HttpClient and runs all operations in a ThreadPoolExecutor. - Maintained for backward compatibility but may be removed in the future. - """ - - def __init__( - self, - client: Client, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, - ): - if isinstance(client, HttpClient): - client.headers["User-Agent"] = client.headers["User-Agent"].replace("mode:sync;", "mode:async;") - self.client = client - if executor_threads == 0: - executor_threads = min(32, (os.cpu_count() or 1) + 4) - if executor is NEW_THREAD_POOL_EXECUTOR: - self.new_executor = True - self.executor = ThreadPoolExecutor(max_workers=executor_threads) - else: - if executor_threads != 0: - logger.warning("executor_threads parameter is ignored when passing an executor object") - self.new_executor = False - self.executor = executor - - if not AIOHTTP_AVAILABLE: - logger.info( - "Using executor-based async (legacy mode). " - "For better performance with true native async, install: pip install clickhouse-connect[async]" - ) - - def set_client_setting(self, key, value): - self.client.set_client_setting(key=key, value=value) - - def get_client_setting(self, key) -> Optional[str]: - return self.client.get_client_setting(key=key) - - def set_access_token(self, access_token: str): - self.client.set_access_token(access_token) - - def min_version(self, version_str: str) -> bool: - return self.client.min_version(version_str) - - async def close(self): - self.client.close() - if self.new_executor: - await asyncio.to_thread(self.executor.shutdown, True) - - async def query(self, *args, **kwargs) -> QueryResult: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query(*args, **kwargs)) - - async def query_column_block_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_column_block_stream(*args, **kwargs)) - - async def query_row_block_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_row_block_stream(*args, **kwargs)) - - async def query_rows_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_rows_stream(*args, **kwargs)) - - async def raw_query(self, *args, **kwargs) -> bytes: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.raw_query(*args, **kwargs)) - - async def raw_stream(self, *args, **kwargs) -> io.IOBase: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.raw_stream(*args, **kwargs)) - - async def query_np(self, *args, **kwargs): - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_np(*args, **kwargs)) - - async def query_np_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_np_stream(*args, **kwargs)) - - async def query_df(self, *args, **kwargs): - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_df(*args, **kwargs)) - - async def query_df_arrow(self, *args, **kwargs): - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_df_arrow(*args, **kwargs)) - - async def query_df_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_df_stream(*args, **kwargs)) - - async def query_df_arrow_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_df_arrow_stream(*args, **kwargs)) + Uses true streaming via reverse bridge pattern: + - Sync producer (serializer) runs in executor, puts blocks in queue + - Async consumer (network) pulls from queue and yields to aiohttp + - Bounded queue provides backpressure to prevent memory bloat + """ + if context.empty: + logger.debug("No data included in insert, skipping") + return QuerySummary() - def create_query_context(self, *args, **kwargs) -> QueryContext: - return self.client.create_query_context(*args, **kwargs) + if context.compression is None: + context.compression = self.write_compression - async def query_arrow(self, *args, **kwargs): loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_arrow(*args, **kwargs)) - async def query_arrow_stream(self, *args, **kwargs) -> StreamContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.query_arrow_stream(*args, **kwargs)) + streaming_source = StreamingInsertSource( + transform=self._transform, context=context, loop=loop, maxsize=10 + ) - async def command(self, *args, **kwargs) -> Union[str, int, Sequence[str], QuerySummary]: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.command(*args, **kwargs)) + streaming_source.start_producer() - async def ping(self) -> bool: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.ping()) # pylint: disable=unnecessary-lambda + headers = {"Content-Type": "application/octet-stream"} + if context.compression: + headers["Content-Encoding"] = context.compression - async def insert(self, *args, **kwargs) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.insert(*args, **kwargs)) - - async def insert_df(self, *args, **kwargs) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.insert_df(*args, **kwargs)) + params = {} + if self.database: + params["database"] = self.database + params.update(self._validate_settings(context.settings)) + headers = dict_copy(headers, context.transport_settings) - async def insert_arrow(self, *args, **kwargs) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.insert_arrow(*args, **kwargs)) - - async def insert_df_arrow(self, *args, **kwargs) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.insert_df_arrow(*args, **kwargs)) - - async def create_insert_context(self, *args, **kwargs) -> InsertContext: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.create_insert_context(*args, **kwargs)) + try: + response = await self._raw_request( + streaming_source.async_generator(), params, headers=headers, server_wait=False + ) + logger.debug("Context insert response code: %d", response.status) + except Exception: + await streaming_source.close() + + if context.insert_exception: + ex = context.insert_exception + context.insert_exception = None + raise ex from None + raise + finally: + await streaming_source.close() + context.data = None + + return QuerySummary(self._summary(response)) + + async def insert_df( # type: ignore[override] + self, + table: Optional[str] = None, + df=None, + database: Optional[str] = None, + settings: Optional[Dict] = None, + column_names: Optional[Sequence[str]] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored + :param table: ClickHouse table + :param df: two-dimensional pandas dataframe + :param database: Optional ClickHouse database + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names + will be used + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + check_pandas() + self._add_integration_tag("pandas") + if context is None: + if column_names is None: + column_names = df.columns + elif len(column_names) != len(df.columns): + raise ProgrammingError("DataFrame column count does not match insert_columns") from None + return await self.insert( + table, + df, + column_names, + database, + column_types=column_types, + column_type_names=column_type_names, + settings=settings, + transport_settings=transport_settings, + context=context, + ) + + async def raw_insert( # type: ignore[override] + self, + table: Optional[str] = None, + column_names: Optional[Sequence[str]] = None, + insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, + settings: Optional[Dict] = None, + fmt: Optional[str] = None, + compression: Optional[str] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + See BaseClient doc_string for this method + """ + params = {} + headers = {"Content-Type": "application/octet-stream"} + if compression: + headers["Content-Encoding"] = compression + + if table: + cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" + fmt_str = fmt if fmt else self._write_format + query = f"INSERT INTO {table}{cols} FORMAT {fmt_str}" + if not compression and isinstance(insert_block, str): + insert_block = query + "\n" + insert_block + elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): + insert_block = (query + "\n").encode() + insert_block + else: + params["query"] = query + + if self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(insert_block, params, headers, server_wait=False) + logger.debug("Raw insert response code: %d", response.status) + return QuerySummary(self._summary(response)) + + def _add_integration_tag(self, name: str): + """ + Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. + """ + if not common.get_setting("send_integration_tags") or name in self._reported_libs: + return - async def data_insert(self, context: InsertContext) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.data_insert(context)) + try: + ver = "unknown" + try: + ver = dist_version(name) + except Exception: + try: + mod = import_module(name) + ver = getattr(mod, "__version__", "unknown") + except Exception: + pass + + product_info = f"{name}/{ver}" + + ua = self.headers.get("User-Agent", "") + start = ua.find("(") + if start == -1: + return + end = ua.find(")", start + 1) + if end == -1: + return + + details = ua[start + 1 : end].strip() + + if product_info in details: + self._reported_libs.add(name) + return + + new_details = f"{product_info}; {details}" if details else product_info + new_ua = f"{ua[: start + 1]}{new_details}{ua[end:]}" + self.headers["User-Agent"] = new_ua.strip() + if self._session: + self._session.headers["User-Agent"] = new_ua.strip() + + self._reported_libs.add(name) + logger.debug("Added '%s' to User-Agent", product_info) + + except Exception as e: + logger.debug("Problem adding '%s' to User-Agent: %s", name, e) + + async def _error_handler(self, response: aiohttp.ClientResponse, retried: bool = False): + """ + Handles HTTP errors. Tries to be robust and provide maximum context. + """ + try: + body = "" + try: + raw_body = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + decompressed = decompress_response(raw_body, encoding) + return common.format_error(decompressed.decode(errors="backslashreplace")).strip() + + body = await loop.run_in_executor(None, decompress_and_decode) + else: + loop = asyncio.get_running_loop() + body = await loop.run_in_executor( + None, + lambda: common.format_error(raw_body.decode(errors="backslashreplace")).strip() + ) + except Exception: + logger.warning("Failed to read error response body", exc_info=True) + + if self.show_clickhouse_errors: + err_code = response.headers.get(ex_header) + if err_code: + err_str = f"Received ClickHouse exception, code: {err_code}" + else: + err_str = f"HTTP driver received HTTP status {response.status}" + + if body: + err_str = f"{err_str}, server response: {body}" + else: + err_str = "The ClickHouse server returned an error" + + err_str = f"{err_str} (for url {self.url})" + + finally: + response.close() + + raise OperationalError(err_str) if retried else DatabaseError(err_str) from None + + async def _raw_request( + self, + data, + params, + headers=None, + files=None, + method="POST", + stream=False, + server_wait=True, + retries: int = 0, + ) -> aiohttp.ClientResponse: + if self._session is None: + raise ProgrammingError( + "Session not initialized. Use 'async with get_async_client(...)' or call 'await client._initialize()' first." + ) - async def raw_insert(self, *args, **kwargs) -> QuerySummary: - loop = asyncio.get_running_loop() - return await loop.run_in_executor(self.executor, lambda: self.client.raw_insert(*args, **kwargs)) + reset_seconds = common.get_setting("max_connection_age") + if reset_seconds: + now = time.time() + if self._last_pool_reset is None: + self._last_pool_reset = now + elif self._last_pool_reset < now - reset_seconds: + logger.debug("connection expiration - resetting connection pool") + await self.close_connections() + self._last_pool_reset = now + + final_params = dict_copy(self._client_settings, params) + if server_wait: + final_params.setdefault("wait_end_of_query", "1") + if self._send_progress: + final_params.setdefault("send_progress_in_http_headers", "1") + if self._progress_interval: + final_params.setdefault("http_headers_progress_interval_ms", self._progress_interval) + if self._autogenerate_query_id and "query_id" not in final_params: + final_params["query_id"] = str(uuid.uuid4()) + + req_headers = dict_copy(self.headers, headers) + if self.server_host_name: + req_headers["Host"] = self.server_host_name + query_session = final_params.get("session_id") + attempts = 0 + + # pylint: disable=too-many-nested-blocks + while True: + attempts += 1 + + if query_session: + if query_session == self._active_session: + raise ProgrammingError( + "Attempt to execute concurrent queries within the same session. " + "Please use a separate client instance per concurrent query." + ) + self._active_session = query_session + + try: + # Construct full URL (aiohttp doesn't have base_url) + url = f"{self.url}/" + request_kwargs = {"method": method, "url": url, "params": final_params, "headers": req_headers} + if hasattr(self, "_proxy_url") and self._proxy_url: + request_kwargs["proxy"] = self._proxy_url + if files: + # IMPORTANT: Must set content_type on text fields to force multipart/form-data encoding + # Without content_type, aiohttp uses application/x-www-form-urlencoded + form = aiohttp.FormData() + for field_name, field_value in files.items(): + if isinstance(field_value, tuple): + if field_value[0] is None: + form.add_field(field_name, str(field_value[1]), content_type='text/plain') + else: + filename = field_value[0] + file_data = field_value[1] + content_type = field_value[2] if len(field_value) > 2 else None + form.add_field(field_name, file_data, filename=filename, content_type=content_type) + else: + form.add_field(field_name, field_value, content_type='text/plain') + request_kwargs["data"] = form + elif isinstance(data, dict): + request_kwargs["data"] = data + else: + request_kwargs["data"] = data + + response = await self._session.request(**request_kwargs) + if 200 <= response.status < 300 and not response.headers.get(ex_header): + return response + + if response.status in (429, 503, 504): + if attempts > retries: + await self._error_handler(response, retried=True) + else: + logger.debug("Retrying request with status code %s (attempt %s/%s)", response.status, attempts, retries + 1) + await asyncio.sleep(0.1 * attempts) + response.close() + continue + await self._error_handler(response) + + except aiohttp.ServerConnectionError as e: + if "Connection reset" in str(e) or "Remote end closed" in str(e) or "Cannot connect" in str(e): + if attempts == 1: + logger.debug("Retrying after connection error from remote host") + continue + raise OperationalError(f"Network Error: {str(e)}") from e + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise OperationalError(f"Network Error: {str(e)}") from e + + finally: + if query_session: + self._active_session = None + + @staticmethod + def _summary(response: aiohttp.ClientResponse): + summary = {} + if "X-ClickHouse-Summary" in response.headers: + try: + summary = json.loads(response.headers["X-ClickHouse-Summary"]) + except json.JSONDecodeError: + pass + summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") + return summary diff --git a/examples/run_async.py b/examples/run_async.py index 069e3f90..7728bce8 100644 --- a/examples/run_async.py +++ b/examples/run_async.py @@ -1,51 +1,37 @@ #!/usr/bin/env python -u """ -This example will execute 10 queries in total, 2 concurrent queries at a time. -Each query will sleep for 2 seconds before returning. -Here's a sample output that shows that the queries are executed concurrently in batches of 2: -``` -Completed query 1, elapsed ms since start: 2002 -Completed query 0, elapsed ms since start: 2002 -Completed query 3, elapsed ms since start: 4004 -Completed query 2, elapsed ms since start: 4005 -Completed query 4, elapsed ms since start: 6006 -Completed query 5, elapsed ms since start: 6007 -Completed query 6, elapsed ms since start: 8009 -Completed query 7, elapsed ms since start: 8009 -Completed query 9, elapsed ms since start: 10011 -Completed query 8, elapsed ms since start: 10011 -``` +Demonstrates concurrent async queries using clickhouse-connect. + +Executes 10 queries with a concurrency limit of 2. Each query sleeps for 2 seconds, +so the total wall time is ~10 seconds rather than ~20. + +Sample output: + Completed query 1, elapsed: 2002ms + Completed query 0, elapsed: 2003ms + Completed query 3, elapsed: 4005ms + Completed query 2, elapsed: 4005ms + ... """ import asyncio -from datetime import datetime +import time import clickhouse_connect -QUERIES = 10 -SEMAPHORE = 2 - async def concurrent_queries(): - test_query = "SELECT sleep(2)" - client = await clickhouse_connect.get_async_client() - - start = datetime.now() - - async def semaphore_wrapper(sm: asyncio.Semaphore, num: int): - async with sm: - await client.query(query=test_query) - print(f"Completed query {num}, " - f"elapsed ms since start: {int((datetime.now() - start).total_seconds() * 1000)}") - - semaphore = asyncio.Semaphore(SEMAPHORE) - await asyncio.gather(*[semaphore_wrapper(semaphore, num) for num in range(QUERIES)]) - await client.close() + async with await clickhouse_connect.get_async_client() as client: + semaphore = asyncio.Semaphore(2) + start = time.monotonic() + async def run_query(num: int): + async with semaphore: + await client.query("SELECT sleep(2)") + elapsed = int((time.monotonic() - start) * 1000) + print(f"Completed query {num}, elapsed: {elapsed}ms") -async def main(): - await concurrent_queries() + await asyncio.gather(*(run_query(i) for i in range(10))) -asyncio.run(main()) +asyncio.run(concurrent_queries()) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index f3498e84..a681628a 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -17,7 +17,7 @@ from clickhouse_connect.driver.exceptions import OperationalError from clickhouse_connect.tools.testing import TableContext from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver import AsyncClient, Client, create_client +from clickhouse_connect.driver import Client, create_client from tests.helpers import PROJECT_ROOT_DIR @@ -324,15 +324,9 @@ def test_client_fixture(test_config: TestConfig, test_create_client: Callable) - sys.stderr.write('Successfully stopped docker compose') -@pytest_asyncio.fixture(scope='session', name='test_async_client') -async def test_async_client_fixture(test_client: Client) -> AsyncContextManager[AsyncClient]: - async with AsyncClient(client=test_client) as client: - yield client - - @pytest_asyncio.fixture(scope="function", loop_scope="function", name="test_native_async_client") async def test_native_async_client_fixture(test_config: TestConfig) -> AsyncContextManager: - """Function-scoped fixture for aiohttp async client""" + """Function-scoped async client fixture""" async with await get_async_client( host=test_config.host, port=test_config.port, @@ -340,7 +334,7 @@ async def test_native_async_client_fixture(test_config: TestConfig) -> AsyncCont password=test_config.password, database=test_config.test_database, compress=test_config.compress, - client_name="int_tests/aiohttp_async", + client_name="int_tests/native_async", ) as client: if client.min_version("22.8"): client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") diff --git a/tests/integration_tests/test_async_client.py b/tests/integration_tests/test_async_client.py deleted file mode 100644 index 64e5f215..00000000 --- a/tests/integration_tests/test_async_client.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -AsyncClient tests that verify that the wrapper for each method is working correctly. -""" - -from typing import Callable - -import numpy as np -import pandas as pd -import pytest - -from clickhouse_connect.driver.options import arrow, IS_PANDAS_2 # pylint: disable=no-name-in-module -from clickhouse_connect.driver.exceptions import ProgrammingError -from clickhouse_connect.driver import AsyncClient - - -@pytest.mark.asyncio -async def test_client_settings(test_async_client: AsyncClient): - key = 'prefer_column_name_to_alias' - value = '1' - test_async_client.set_client_setting(key, value) - assert test_async_client.get_client_setting(key) == value - - -@pytest.mark.asyncio -async def test_min_version(test_async_client: AsyncClient): - assert test_async_client.min_version('19') is True - assert test_async_client.min_version('22.4') is True - assert test_async_client.min_version('99999') is False - - -@pytest.mark.asyncio -async def test_query(test_async_client: AsyncClient): - result = await test_async_client.query('SELECT * FROM system.tables') - assert len(result.result_set) > 0 - assert result.row_count > 0 - assert result.first_item == next(result.named_results()) - - -stream_query = 'SELECT number, randomStringUTF8(50) FROM numbers(10000)' -stream_settings = {'max_block_size': 4000} - - -# pylint: disable=duplicate-code -@pytest.mark.asyncio -async def test_query_column_block_stream(test_async_client: AsyncClient): - block_stream = await test_async_client.query_column_block_stream(stream_query, settings=stream_settings) - total = 0 - block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - total += sum(block[0]) - assert total == 49995000 - assert block_count > 1 - - -# pylint: disable=duplicate-code -@pytest.mark.asyncio -async def test_query_row_block_stream(test_async_client: AsyncClient): - block_stream = await test_async_client.query_row_block_stream(stream_query, settings=stream_settings) - total = 0 - block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - for row in block: - total += row[0] - assert total == 49995000 - assert block_count > 1 - - -@pytest.mark.asyncio -async def test_query_rows_stream(test_async_client: AsyncClient): - row_stream = await test_async_client.query_rows_stream('SELECT number FROM numbers(10000)') - total = 0 - with row_stream: - for row in row_stream: - total += row[0] - assert total == 49995000 - - -@pytest.mark.asyncio -async def test_raw_query(test_async_client: AsyncClient): - result = await test_async_client.raw_query('SELECT 42') - assert result == b'42\n' - - -@pytest.mark.asyncio -async def test_raw_stream(test_async_client: AsyncClient): - stream = await test_async_client.raw_stream('SELECT 42') - result = b'' - with stream: - for chunk in stream: - result += chunk - assert result == b'42\n' - - -@pytest.mark.asyncio -async def test_query_np(test_async_client: AsyncClient): - result = await test_async_client.query_np('SELECT number FROM numbers(5)') - assert isinstance(result, np.ndarray) - assert list(result) == [[0], [1], [2], [3], [4]] - - -@pytest.mark.asyncio -async def test_query_np_stream(test_async_client: AsyncClient): - stream = await test_async_client.query_np_stream('SELECT number FROM numbers(5)') - result = np.array([]) - with stream: - for block in stream: - result = np.append(result, block) - assert list(result) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_df(test_async_client: AsyncClient): - result = await test_async_client.query_df('SELECT number FROM numbers(5)') - assert isinstance(result, pd.DataFrame) - assert list(result['number']) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_df_stream(test_async_client: AsyncClient): - stream = await test_async_client.query_df_stream('SELECT number FROM numbers(5)') - result = [] - with stream: - for block in stream: - result.append(list(block['number'])) - assert result == [[0, 1, 2, 3, 4]] - - -@pytest.mark.asyncio -async def test_create_query_context(test_async_client: AsyncClient): - query_context = test_async_client.create_query_context( - query='SELECT {k: Int32}', - parameters={'k': 42}, - column_oriented=True) - result = await test_async_client.query(context=query_context) - assert result.row_count == 1 - assert result.result_set == [[42]] - - -@pytest.mark.asyncio -async def test_query_arrow(test_async_client: AsyncClient): - if not arrow: - pytest.skip('PyArrow package not available') - result = await test_async_client.query_arrow('SELECT number FROM numbers(5)') - assert isinstance(result, arrow.Table) - assert list(result[0].to_pylist()) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_arrow_stream(test_async_client: AsyncClient): - if not arrow: - pytest.skip('PyArrow package not available') - stream = await test_async_client.query_arrow_stream('SELECT number FROM numbers(5)') - result = [] - with stream: - for block in stream: - result.append(block[0].to_pylist()) - assert result == [[0, 1, 2, 3, 4]] - - -@pytest.mark.asyncio -async def test_command(test_async_client: AsyncClient): - version = await test_async_client.command('SELECT version()') - assert int(version.split('.')[0]) >= 19 - - -@pytest.mark.asyncio -async def test_ping(test_async_client: AsyncClient): - assert await test_async_client.ping() is True - - -@pytest.mark.asyncio -async def test_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_insert', ['key UInt32', 'value String']) as ctx: - await test_async_client.insert(ctx.table, [[42, 'str_0'], [144, 'str_1']]) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_insert_df(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_insert_df', ['key UInt32', 'value String']) as ctx: - df = pd.DataFrame([[42, 'str_0'], [144, 'str_1']], columns=['key', 'value']) - await test_async_client.insert_df(ctx.table, df) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_insert_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip('PyArrow package not available') - with table_context('test_async_client_insert_arrow', ['key UInt32', 'value String']) as ctx: - data = arrow.Table.from_arrays([arrow.array([42, 144]), arrow.array(['str_0', 'str_1'])], names=['key', 'value']) - await test_async_client.insert_arrow(ctx.table, data) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_create_insert_context(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_create_insert_context', ['key UInt32', 'value String']) as ctx: - data = [[1, 'a'], [2, 'b']] - insert_context = await test_async_client.create_insert_context(table=ctx.table, data=data) - await test_async_client.insert(context=insert_context) - result = (await test_async_client.query(f'SELECT * FROM {ctx.table} ORDER BY key ASC')).result_columns - assert result == [[1, 2], ['a', 'b']] - - -@pytest.mark.asyncio -async def test_data_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_data_insert', ['key UInt32', 'value String']) as ctx: - df = pd.DataFrame([[42, 'str_0'], [144, 'str_1']], columns=['key', 'value']) - insert_context = await test_async_client.create_insert_context(ctx.table, df.columns) - insert_context.data = df - await test_async_client.data_insert(insert_context) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_raw_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_raw_insert', ['key UInt32', 'value String']) as ctx: - await test_async_client.raw_insert(table=ctx.table, - column_names=['key', 'value'], - insert_block='42,"foo"\n144,"bar"\n', - fmt='CSV') - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['foo', 'bar']] - - -@pytest.mark.asyncio -async def test_query_df_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip("PyArrow package not available") - - data = [[78, pd.NA, "a"], [51, 421, "b"]] - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - - with table_context( - "df_pyarrow_query_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = df.convert_dtypes(dtype_backend="pyarrow") - await test_async_client.insert_df(ctx.table, df) - result_df = await test_async_client.query_df_arrow(f"SELECT * FROM {ctx.table} ORDER BY i64") - for dt in list(result_df.dtypes): - assert isinstance(dt, pd.ArrowDtype) - else: - with pytest.raises(ProgrammingError): - result_df = await test_async_client.query_df_arrow(f"SELECT * FROM {ctx.table}") - - -@pytest.mark.asyncio -async def test_insert_df_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip("PyArrow package not available") - - data = [[78, pd.NA, "a"], [51, 421, "b"]] - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - - with table_context( - "df_pyarrow_insert_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = df.convert_dtypes(dtype_backend="pyarrow") - await test_async_client.insert_df_arrow(ctx.table, df) - res_df = await test_async_client.query(f"SELECT * from {ctx.table} ORDER BY i64") - assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] - else: - with pytest.raises(ProgrammingError, match="pandas 2.x"): - await test_async_client.insert_df_arrow(ctx.table, df) - - with table_context( - "df_pyarrow_insert_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - df["i64"] = df["i64"].astype(pd.ArrowDtype(arrow.int64())) - with pytest.raises(ProgrammingError, match="Non-Arrow columns found"): - await test_async_client.insert_df_arrow(ctx.table, df) - else: - with pytest.raises(ProgrammingError, match="pandas 2.x"): - await test_async_client.insert_df_arrow(ctx.table, df) From 8bb3d0a063e11a704043ec2f9db676e7bd78da18 Mon Sep 17 00:00:00 2001 From: Joe S Date: Thu, 26 Mar 2026 15:25:33 -0700 Subject: [PATCH 40/40] update changelog --- CHANGELOG.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bacd3a8..0772e2bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,5 @@ # ClickHouse Connect ChangeLog -### WARNING -- Breaking change for AsyncClient close() -The AsyncClient close() method is now async and should be called as an async function. - ### WARNING -- Python 3.8 EOL Python 3.8 was EOL on 2024-10-07. It is no longer tested, and versions after 2025-04-07 will not include Python 3.8 wheel distributions. As of version 0.8.15, wheels are not built for Python 3.8 AARCH64 versions due to @@ -23,6 +20,15 @@ The supported method of passing ClickHouse server settings is to prefix such arg ## UNRELEASED +### Breaking Changes +- Remove the legacy executor-based async client. The `AsyncClient(client=...)` constructor pattern, `executor_threads`, and `executor` parameters are no longer supported. Use `clickhouse_connect.get_async_client()` (or `create_async_client()`) which creates a native aiohttp-based async client directly. The `pool_mgr` parameter is also rejected on the async path. `aiohttp` remains an optional dependency, installed via `pip install clickhouse-connect[async]`. +- The internal `AiohttpAsyncClient` class has been renamed to `AsyncClient` and the module `clickhouse_connect.driver.aiohttp_client` has been removed. Import `AsyncClient` from `clickhouse_connect.driver` as before. + +### Improvements +- Lazy loading of optional dependencies (numpy, pandas, pyarrow, polars) now applies to the async client as well, matching the pattern established in 0.15.0 for the sync client. +- Clearer error message when attempting to use the async client without aiohttp installed. +- The `generic_args` parameter is now properly parsed on the async client creation path, matching the sync client behavior. + ## 0.15.0, 2026-03-26 ### Improvements