diff --git a/.docker/clickhouse/single_node/config.xml b/.docker/clickhouse/single_node/config.xml index 076ed6ce..d42d2637 100644 --- a/.docker/clickhouse/single_node/config.xml +++ b/.docker/clickhouse/single_node/config.xml @@ -21,6 +21,8 @@ 1 + 1 + system query_log
diff --git a/.docker/clickhouse/single_node_tls/config.xml b/.docker/clickhouse/single_node_tls/config.xml index 4ff6cc01..bc710317 100644 --- a/.docker/clickhouse/single_node_tls/config.xml +++ b/.docker/clickhouse/single_node_tls/config.xml @@ -21,6 +21,8 @@ 1 + 1 + /etc/clickhouse-server/certs/server.crt diff --git a/.github/workflows/clickhouse_ci.yml b/.github/workflows/clickhouse_ci.yml index 8200639e..effbd4cc 100644 --- a/.github/workflows/clickhouse_ci.yml +++ b/.github/workflows/clickhouse_ci.yml @@ -36,11 +36,11 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT_PROD }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests + run: pytest tests/integration_tests -n 4 - name: Run ClickHouse Container (HEAD) run: CLICKHOUSE_CONNECT_TEST_CH_VERSION=head docker compose up -d clickhouse - name: Run HEAD tests - run: pytest tests/integration_tests + run: pytest tests/integration_tests -n 4 - name: remove HEAD container run: docker compose down -v diff --git a/.github/workflows/on_push.yml b/.github/workflows/on_push.yml index 72092f00..58f6a9ff 100644 --- a/.github/workflows/on_push.yml +++ b/.github/workflows/on_push.yml @@ -135,7 +135,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' CLICKHOUSE_CONNECT_TEST_FUZZ: 50 SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests + run: pytest -n 4 tests pandas-1x-compat-test: runs-on: ubuntu-latest @@ -173,7 +173,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_TLS: 1 CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests/test_pandas_compat.py tests/integration_tests/test_pandas.py + run: pytest -n 4 tests/integration_tests/test_pandas_compat.py tests/integration_tests/test_pandas.py sqlalchemy-1x-compat-test: runs-on: ubuntu-latest @@ -211,7 +211,7 @@ jobs: CLICKHOUSE_CONNECT_TEST_TLS: 1 CLICKHOUSE_CONNECT_TEST_DOCKER: 'False' SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests/test_sqlalchemy + run: pytest -n 4 tests/integration_tests/test_sqlalchemy check-secret: runs-on: ubuntu-latest @@ -264,4 +264,4 @@ jobs: CLICKHOUSE_CONNECT_TEST_PASSWORD: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_PASSWORD_SMT_PROD }} CLICKHOUSE_CONNECT_TEST_JWT_SECRET: ${{ secrets.INTEGRATIONS_TEAM_TESTS_CLOUD_JWT_DESERT_VM_43 }} SQLALCHEMY_SILENCE_UBER_WARNING: 1 - run: pytest tests/integration_tests + run: pytest -n 4 tests/integration_tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 7aa4e9f8..0772e2bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,5 @@ # ClickHouse Connect ChangeLog -### WARNING -- Breaking change for AsyncClient close() -The AsyncClient close() method is now async and should be called as an async function. - ### WARNING -- Python 3.8 EOL Python 3.8 was EOL on 2024-10-07. It is no longer tested, and versions after 2025-04-07 will not include Python 3.8 wheel distributions. As of version 0.8.15, wheels are not built for Python 3.8 AARCH64 versions due to @@ -23,6 +20,15 @@ The supported method of passing ClickHouse server settings is to prefix such arg ## UNRELEASED +### Breaking Changes +- Remove the legacy executor-based async client. The `AsyncClient(client=...)` constructor pattern, `executor_threads`, and `executor` parameters are no longer supported. Use `clickhouse_connect.get_async_client()` (or `create_async_client()`) which creates a native aiohttp-based async client directly. The `pool_mgr` parameter is also rejected on the async path. `aiohttp` remains an optional dependency, installed via `pip install clickhouse-connect[async]`. +- The internal `AiohttpAsyncClient` class has been renamed to `AsyncClient` and the module `clickhouse_connect.driver.aiohttp_client` has been removed. Import `AsyncClient` from `clickhouse_connect.driver` as before. + +### Improvements +- Lazy loading of optional dependencies (numpy, pandas, pyarrow, polars) now applies to the async client as well, matching the pattern established in 0.15.0 for the sync client. +- Clearer error message when attempting to use the async client without aiohttp installed. +- The `generic_args` parameter is now properly parsed on the async client creation path, matching the sync client behavior. + ## 0.15.0, 2026-03-26 ### Improvements @@ -57,7 +63,6 @@ The supported method of passing ClickHouse server settings is to prefix such arg ### Deprecations - Pandas 1.x support is now deprecated and will be removed in v1.0.0. A `DeprecationWarning` is emitted at import time for pandas 1.x users. -- The current `AsyncClient` is a thread-pool wrapper around the sync client and now emits a `FutureWarning` on creation, pointing users to the fully native async client available as a prerelease: `pip install 'clickhouse-connect[async]==0.12.0rc1'`. This prerelease branch is based on 0.11.0 and is gathering feedback ahead of 1.0.0, where it will become the default async implementation. It is a drop-in replacement with the same API surface. ### Improvements - Added support for the `SAMPLE` clause in SQLAlchemy statements. Note: Due to a SQLAlchemy limitation, only one hint (SAMPLE or FINAL) can be applied per table; chaining both will silently ignore one. For now, this change enables use of sample(), but chaining with final() is not yet supported. Closes [#634](https://github.com/ClickHouse/clickhouse-connect/issues/634) @@ -84,6 +89,9 @@ are now serialized using their native ClickHouse types client-side (e.g. inserti - Recognize `UPDATE` as a command so lightweight updates work correctly via `client.query()` and SQLAlchemy. - SQLAlchemy: `GROUP BY` now renders label aliases instead of full expressions which avoids circular reference errors when an alias shadows a source column name in ClickHouse. +## 0.12.0rc1, 2026-02-11 +- Implement a native async client. Closes [#141](https://github.com/ClickHouse/clickhouse-connect/issues/141) + ## 0.11.0, 2026-02-10 ### Python 3.9 Deprecation diff --git a/README.md b/README.md index 7d0a5977..175b42d1 100644 --- a/README.md +++ b/README.md @@ -53,13 +53,15 @@ are not implemented. The dialect is best suited for SQLAlchemy Core usage and Su ### Asyncio Support -ClickHouse Connect provides an `AsyncClient` for use in `asyncio` environments. -See the [run_async example](./examples/run_async.py) for more details. +ClickHouse Connect provides native async support using aiohttp. To use the async client, +install the optional async dependency: -The current `AsyncClient` is a thread-pool executor wrapper around the synchronous client and is deprecated. -In 1.0.0 it will be replaced by a fully native async implementation. The API surface is the same, -with one difference: you will no longer be able to create a sync client first and pass it to the -`AsyncClient` constructor. Instead, use `clickhouse_connect.get_async_client()` directly. +``` +pip install clickhouse-connect[async] +``` + +Then create a client with `clickhouse_connect.get_async_client()`. See the +[run_async example](./examples/run_async.py) for more details. ### Complete Documentation diff --git a/clickhouse_connect/common.py b/clickhouse_connect/common.py index 04f43524..49a3646c 100644 --- a/clickhouse_connect/common.py +++ b/clickhouse_connect/common.py @@ -29,7 +29,7 @@ class CommonSetting: _common_settings: Dict[str, CommonSetting] = {} -def build_client_name(client_name: str) -> str: +def build_client_name(client_name: Optional[str]) -> str: product_name = get_setting('product_name') product_name = product_name.strip() + ' ' if product_name else '' client_name = client_name.strip() + ' ' if client_name else '' diff --git a/clickhouse_connect/driver/__init__.py b/clickhouse_connect/driver/__init__.py index aac2f85e..79caa705 100644 --- a/clickhouse_connect/driver/__init__.py +++ b/clickhouse_connect/driver/__init__.py @@ -1,20 +1,88 @@ -import asyncio -import warnings -from concurrent.futures import ThreadPoolExecutor +from __future__ import annotations + from inspect import signature -from typing import Optional, Union, Dict, Any +from typing import Optional, Union, Dict, Any, Tuple, TYPE_CHECKING from urllib.parse import urlparse, parse_qs -import clickhouse_connect.driver.ctypes +import clickhouse_connect.driver.ctypes # noqa: F401 -- side-effect import from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.common import dict_copy from clickhouse_connect.driver.exceptions import ProgrammingError from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.asyncclient import AsyncClient, DefaultThreadPoolExecutor, NEW_THREAD_POOL_EXECUTOR + +if TYPE_CHECKING: + from clickhouse_connect.driver.asyncclient import AsyncClient __all__ = ['Client', 'AsyncClient', 'create_client', 'create_async_client'] +def __getattr__(name): + if name == "AsyncClient": + try: + from clickhouse_connect.driver.asyncclient import AsyncClient # pylint: disable=import-outside-toplevel + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): # pylint: disable=no-member + raise ImportError( + "Async support requires aiohttp. Install with: pip install clickhouse-connect[async]" + ) from ex + raise + return AsyncClient + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def default_port(interface: str, secure: bool) -> int: + """Get default port for the given interface.""" + if interface.startswith("http"): + return 8443 if secure else 8123 + raise ValueError("Unrecognized ClickHouse interface") + + +def _parse_connection_params( + host: Optional[str], + username: Optional[str], + password: str, + port: int, + database: str, + interface: Optional[str], + secure: Union[bool, str], + dsn: Optional[str], + kwargs: Dict[str, Any] +) -> Tuple[str, Optional[str], str, int, str, str]: + """Parse and normalize connection parameters including DSN parsing.""" + if dsn: + parsed = urlparse(dsn) + username = username or parsed.username + password = password or parsed.password + host = host or parsed.hostname + port = port or parsed.port + if parsed.path and (not database or database == "__default__"): + database = parsed.path[1:].split("/")[0] + database = database or parsed.path + for k, v in parse_qs(parsed.query).items(): + kwargs[k] = v[0] + use_tls = str(secure).lower() == "true" or interface == "https" or (not interface and str(port) in ("443", "8443")) + if not host: + host = "localhost" + if not interface: + interface = "https" if use_tls else "http" + port = port or default_port(interface, use_tls) + if username is None and "user" in kwargs: + username = kwargs.pop("user") + if username is None and "user_name" in kwargs: + username = kwargs.pop("user_name") + if password and username is None: + username = "default" + if "compression" in kwargs and "compress" not in kwargs: + kwargs["compress"] = kwargs.pop("compression") + + return host, username, password, port, database, interface + + +def _validate_access_token(access_token: Optional[str], username: Optional[str], password: str) -> None: + """Validate that access_token and username/password are not both provided.""" + if access_token and (username or password != ""): + raise ProgrammingError("Cannot use both access_token and username/password") + + # pylint: disable=too-many-arguments,too-many-locals,too-many-branches def create_client(*, host: Optional[str] = None, @@ -94,33 +162,11 @@ def create_client(*, limits. Only available for query operations (not inserts). Default: False :return: ClickHouse Connect Client instance """ - if dsn: - parsed = urlparse(dsn) - username = username or parsed.username - password = password or parsed.password - host = host or parsed.hostname - port = port or parsed.port - if parsed.path and (not database or database == '__default__'): - database = parsed.path[1:].split('/')[0] - database = database or parsed.path - for k, v in parse_qs(parsed.query).items(): - kwargs[k] = v[0] - use_tls = str(secure).lower() == 'true' or interface == 'https' or (not interface and str(port) in ('443', '8443')) - if not host: - host = 'localhost' - if not interface: - interface = 'https' if use_tls else 'http' - port = port or default_port(interface, use_tls) - if access_token and (username or password != ''): - raise ProgrammingError('Cannot use both access_token and username/password') - if username is None and 'user' in kwargs: - username = kwargs.pop('user') - if username is None and 'user_name' in kwargs: - username = kwargs.pop('user_name') - if password and username is None: - username = 'default' - if 'compression' in kwargs and 'compress' not in kwargs: - kwargs['compress'] = kwargs.pop('compression') + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs + ) + _validate_access_token(access_token, username, password) + settings = settings or {} if interface.startswith('http'): if generic_args: @@ -140,16 +186,11 @@ def create_client(*, raise ProgrammingError(f'Unrecognized client type {interface}') -def default_port(interface: str, secure: bool): - if interface.startswith('http'): - return 8443 if secure else 8123 - raise ValueError('Unrecognized ClickHouse interface') - - async def create_async_client(*, host: Optional[str] = None, username: Optional[str] = None, password: str = '', + access_token: Optional[str] = None, database: str = '__default__', interface: Optional[str] = None, port: int = 0, @@ -157,11 +198,14 @@ async def create_async_client(*, dsn: Optional[str] = None, settings: Optional[Dict[str, Any]] = None, generic_args: Optional[Dict[str, Any]] = None, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR, + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, **kwargs) -> AsyncClient: """ The preferred method to get an async ClickHouse Connect Client instance. + Requires the async extra: pip install clickhouse-connect[async] + For sync version, see create_client. Unlike sync version, the 'autogenerate_session_id' setting by default is False. @@ -169,6 +213,7 @@ async def create_async_client(*, :param host: The hostname or IP address of the ClickHouse server. If not set, localhost will be used. :param username: The ClickHouse username. If not set, the default ClickHouse user will be used. :param password: The password for username. + :param access_token: JWT access token. :param database: The default database for the connection. If not set, ClickHouse Connect will use the default database for username. :param interface: Must be http or https. Defaults to http, or to https if port is set to 8443 or 443 @@ -180,12 +225,10 @@ async def create_async_client(*, :param settings: ClickHouse server settings to be used with the session/every request :param generic_args: Used internally to parse DBAPI connection strings into keyword arguments and ClickHouse settings. It is not recommended to use this parameter externally - :param executor_threads: 'max_worker' threads used by the client ThreadPoolExecutor. If not set, the default - of 4 + detected CPU cores will be used - :param executor: Optional `ThreadPoolExecutor` to use for async operations. If not set, a new `ThreadPoolExecutor` - will be created with the number of threads specified by `executor_threads`. If set to `None` it will use the - default executor of the event loop. - :param kwargs -- Recognized keyword arguments (used by the HTTP client), see below + :param connector_limit: Maximum number of allowable connections to the server + :param connector_limit_per_host: Maximum number of connections per host + :param keepalive_timeout: Time limit on idle keepalive connections + :param kwargs -- Recognized keyword arguments (used by the async HTTP client), see below :param compress: Enable compression for ClickHouse HTTP inserts and query results. True will select the preferred compression method (lz4). A str of 'lz4', 'zstd', 'brotli', or 'gzip' can be used to use a specific compression type @@ -194,7 +237,6 @@ async def create_async_client(*, :param send_receive_timeout: Read timeout in seconds for http connection :param client_name: client_name prepended to the HTTP User Agent header. Set this to track client queries in the ClickHouse system.query_log. - :param send_progress: Deprecated, has no effect. Previous functionality is now automatically determined :param verify: Verify the server certificate in secure/https mode :param ca_cert: If verify is True, the file path to Certificate Authority root to validate ClickHouse server certificate, in .pem format. Ignored if verify is False. This is not necessary if the ClickHouse server @@ -206,8 +248,6 @@ async def create_async_client(*, is not included the Client Certificate key file :param session_id ClickHouse session id. If not specified and the common setting 'autogenerate_session_id' is True, the client will generate a UUID1 session id - :param pool_mgr Optional urllib3 PoolManager for this client. Useful for creating separate connection - pools for multiple client endpoints for applications with many clients :param http_proxy http proxy address. Equivalent to setting the HTTP_PROXY environment variable :param https_proxy https proxy address. Equivalent to setting the HTTPS_PROXY environment variable :param server_host_name This is the server host name that will be checked against a TLS certificate for @@ -226,27 +266,49 @@ async def create_async_client(*, :param form_encode_query_params If True, query parameters will be sent as form-encoded data in the request body instead of as URL parameters. This is useful for queries with large parameter sets that might exceed URL length limits. Only available for query operations (not inserts). Default: False - :return: ClickHouse Connect Client instance + :return: ClickHouse Connect AsyncClient instance """ + try: + from clickhouse_connect.driver.asyncclient import AsyncClient as _AsyncClient # pylint: disable=import-outside-toplevel + except ModuleNotFoundError as ex: + if ex.name == "aiohttp" or (ex.name and ex.name.startswith("aiohttp.")): # pylint: disable=no-member + raise ImportError( + "Async support requires aiohttp. Install with: pip install clickhouse-connect[async]" + ) from ex + raise - warnings.warn( - "The current async client is a thread-pool wrapper around the sync client. " - "A fully native async client is available for testing as a prerelease: " - "pip install 'clickhouse-connect[async]==0.12.0rc1'. " - "This prerelease branch is based on 0.11.0 and is gathering feedback ahead of 1.0.0, " - "where it will become the default async implementation. It is a drop-in replacement " - "with the same API surface. The main line includes additional updates that the native " - "client will receive when merged into 1.0.0.", - FutureWarning, - stacklevel=2, + if "pool_mgr" in kwargs: + raise ProgrammingError( + "pool_mgr is not supported by the async client. " + "Use connector_limit and connector_limit_per_host to configure connection pooling." + ) + + host, username, password, port, database, interface = _parse_connection_params( + host, username, password, port, database, interface, secure, dsn, kwargs ) + _validate_access_token(access_token, username, password) + + settings = settings or {} + if generic_args: + client_params = signature(_AsyncClient).parameters + for name, value in generic_args.items(): + if name in client_params: + kwargs[name] = value + elif name == "compression": + if "compress" not in kwargs: + kwargs["compress"] = value + else: + if name.startswith("ch_"): + name = name[3:] + settings[name] = value - def _create_client(): - if 'autogenerate_session_id' not in kwargs: - kwargs['autogenerate_session_id'] = False - return create_client(host=host, username=username, password=password, database=database, interface=interface, - port=port, secure=secure, dsn=dsn, settings=settings, generic_args=generic_args, **kwargs) + if "autogenerate_session_id" not in kwargs: + kwargs["autogenerate_session_id"] = False - loop = asyncio.get_running_loop() - _client = await loop.run_in_executor(None, _create_client) - return AsyncClient(client=_client, executor_threads=executor_threads, executor=executor) + client = _AsyncClient(interface=interface, host=host, port=port, username=username, password=password, + database=database, access_token=access_token, + settings=settings, + connector_limit=connector_limit, connector_limit_per_host=connector_limit_per_host, + keepalive_timeout=keepalive_timeout, **kwargs) + await client._initialize() # pylint: disable=protected-access + return client diff --git a/clickhouse_connect/driver/asyncclient.py b/clickhouse_connect/driver/asyncclient.py index 3f046042..5fc11464 100644 --- a/clickhouse_connect/driver/asyncclient.py +++ b/clickhouse_connect/driver/asyncclient.py @@ -1,395 +1,1053 @@ +# pylint: disable=too-many-lines,duplicate-code,import-error + import asyncio +import gzip import io +import json import logging -import os -from concurrent.futures.thread import ThreadPoolExecutor +import re +import ssl +import sys +import time +import uuid +import pytz +import zlib +from base64 import b64encode from datetime import tzinfo -from typing import Literal, Optional, Union, Dict, Any, Sequence, Iterable, Generator, BinaryIO, TYPE_CHECKING +from importlib import import_module +from importlib.metadata import version as dist_version +from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Union -from clickhouse_connect.driver.client import Client -from clickhouse_connect.driver.query import TzMode -from clickhouse_connect.driver.common import StreamContext -from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver.external import ExternalData -from clickhouse_connect.driver.query import QueryContext, QueryResult -from clickhouse_connect.driver.summary import QuerySummary +import aiohttp +import lz4.frame +import zstandard + +from clickhouse_connect import common +from clickhouse_connect.datatypes import dynamic as dynamic_module from clickhouse_connect.datatypes.base import ClickHouseType +from clickhouse_connect.datatypes.registry import get_from_name +from clickhouse_connect.driver import httputil, tzutil +from clickhouse_connect.driver.binding import bind_query, quote_identifier +from clickhouse_connect.driver.client import Client, _apply_arrow_tz_policy +from clickhouse_connect.driver.common import StreamContext, coerce_bool, dict_copy +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.constants import CH_VERSION_WITH_PROTOCOL, PROTOCOL_VERSION_WITH_LOW_CARD +from clickhouse_connect.driver.ctypes import RespBuffCls +from clickhouse_connect.driver.exceptions import DatabaseError, DataError, OperationalError, ProgrammingError +from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.insert import InsertContext - -if TYPE_CHECKING: - import numpy - import pandas - import pyarrow +from clickhouse_connect.driver.models import ColumnDef, SettingDef +from clickhouse_connect.driver import options +from clickhouse_connect.driver.options import check_arrow, check_numpy, check_pandas, check_polars +from clickhouse_connect.driver.query import QueryContext, QueryResult, TzMode, TzSource, arrow_buffer +from clickhouse_connect.driver.summary import QuerySummary +from clickhouse_connect.driver.asyncqueue import AsyncSyncQueue, EOF_SENTINEL +from clickhouse_connect.driver.streaming import StreamingInsertSource +from clickhouse_connect.driver.transform import NativeTransform +from clickhouse_connect.driver.streaming import StreamingResponseSource, StreamingFileAdapter logger = logging.getLogger(__name__) +columns_only_re = re.compile(r"LIMIT 0\s*$", re.IGNORECASE) +ex_header = "X-ClickHouse-Exception-Code" +ex_tag_header = "X-ClickHouse-Exception-Tag" + +if "br" in available_compression: + import brotli +else: + brotli = None + +def decompress_response(data: bytes, encoding: Optional[str]) -> bytes: + """Decompress response data based on Content-Encoding header.""" + + if not encoding or encoding == "identity": + return data + + if encoding == "lz4": + lz4_decom = lz4.frame.LZ4FrameDecompressor() + return lz4_decom.decompress(data, len(data)) + if encoding == "zstd": + zstd_decom = zstandard.ZstdDecompressor() + return zstd_decom.stream_reader(io.BytesIO(data)).read() + if encoding == "br": + if brotli is not None: + return brotli.decompress(data) + raise OperationalError("Brotli compression requested but not installed.") + if encoding == "gzip": + return gzip.decompress(data) + if encoding == "deflate": + return zlib.decompress(data) + raise OperationalError(f"Unsupported compression type: '{encoding}'. Supported compression: {', '.join(available_compression)}") + + +class BytesSource: + """Wrapper to make bytes compatible with ResponseBuffer expectations.""" + + def __init__(self, data: bytes): + self.data = data + self.gen = self._make_generator() + + def _make_generator(self): + yield self.data + + def close(self): + """No-op close method for compatibility.""" + +# pylint: disable=invalid-overridden-method, too-many-instance-attributes, too-many-public-methods, broad-exception-caught +class AsyncClient(Client): + valid_transport_settings = {"database", "buffer_size", "session_id", + "compress", "decompress", "session_timeout", + "session_check", "query_id", "quota_key", + "wait_end_of_query", "client_protocol_version", + "role"} + optional_transport_settings = {"send_progress_in_http_headers", + "http_headers_progress_interval_ms", + "enable_http_compression"} + + # pylint: disable=too-many-arguments, too-many-positional-arguments, too-many-locals, too-many-branches, too-many-statements + def __init__( + self, + interface: str, + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + access_token: Optional[str] = None, + compress: Union[bool, str] = True, + connect_timeout: int = 10, + send_receive_timeout: int = 300, + client_name: Optional[str] = None, + verify: Union[bool, str] = True, + ca_cert: Optional[str] = None, + client_cert: Optional[str] = None, + client_cert_key: Optional[str] = None, + http_proxy: Optional[str] = None, + https_proxy: Optional[str] = None, + server_host_name: Optional[str] = None, + tls_mode: Optional[str] = None, + proxy_path: str = "", + connector_limit: int = 100, + connector_limit_per_host: int = 20, + keepalive_timeout: float = 30.0, + session_id: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + query_limit: int = 0, + query_retries: int = 2, + tz_source: Optional[TzSource] = None, + tz_mode: Optional[TzMode] = None, + apply_server_timezone: Optional[Union[str, bool]] = None, + utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, + show_clickhouse_errors: Optional[bool] = None, + autogenerate_session_id: Optional[bool] = None, + autogenerate_query_id: Optional[bool] = None, + form_encode_query_params: bool = False, + **kwargs, + ): + """ + Async HTTP Client using aiohttp. Initialization is handled via _initialize(). + """ + proxy_path = proxy_path.lstrip("/") + if proxy_path: + proxy_path = "/" + proxy_path + self.uri = f"{interface}://{host}:{port}{proxy_path}" + self.url = self.uri + self.form_encode_query_params = form_encode_query_params + self._rename_response_column = kwargs.get("rename_response_column") + self._initial_settings = settings + self.headers = {} + + if interface == "https": + if isinstance(verify, str) and verify.lower() == "proxy": + verify = True + tls_mode = tls_mode or "proxy" + + # Priority: access_token > mutual TLS > basic auth + if client_cert and (tls_mode is None or tls_mode == "mutual"): + if not username: + raise ProgrammingError("username parameter is required for Mutual TLS authentication") + self.headers["X-ClickHouse-User"] = username + self.headers["X-ClickHouse-SSL-Certificate-Auth"] = "on" + elif access_token: + self.headers["Authorization"] = f"Bearer {access_token}" + elif username and (not client_cert or tls_mode in ("strict", "proxy")): + credentials = b64encode(f"{username}:{password}".encode()).decode() + self.headers["Authorization"] = f"Basic {credentials}" + + self.headers["User-Agent"] = common.build_client_name(client_name) + # Prevent aiohttp from automatically requesting compressed responses + # We'll manually set Accept-Encoding when compression is desired + self.headers["Accept-Encoding"] = "identity" + self._send_receive_timeout = send_receive_timeout + + connect_timeout_val = float(connect_timeout) if connect_timeout is not None else None + send_receive_timeout_val = float(send_receive_timeout) if send_receive_timeout is not None else None + + self._timeout = aiohttp.ClientTimeout( + total=None, + connect=connect_timeout_val, + sock_connect=connect_timeout_val, + sock_read=send_receive_timeout_val, + ) + connector_limit_per_host = min(connector_limit_per_host, connector_limit) + + proxy_url = None + if http_proxy: + if not http_proxy.startswith("http://") and not http_proxy.startswith("https://"): + proxy_url = f"http://{http_proxy}" + else: + proxy_url = http_proxy + elif https_proxy: + if not https_proxy.startswith("http://") and not https_proxy.startswith("https://"): + proxy_url = f"http://{https_proxy}" + else: + proxy_url = https_proxy + else: + scheme = "https" if self.url.startswith("https://") else "http" + env_proxy = httputil.check_env_proxy(scheme, host, port) + if env_proxy: + if not env_proxy.startswith("http://") and not env_proxy.startswith("https://"): + proxy_url = f"http://{env_proxy}" + else: + proxy_url = env_proxy + + ssl_context = None + if interface == "https": + ssl_context = ssl.create_default_context() + ssl_verify = verify if isinstance(verify, bool) else coerce_bool(verify) + if not ssl_verify: + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + elif ca_cert: + ssl_context.load_verify_locations(ca_cert) + if client_cert: + ssl_context.load_cert_chain(client_cert, client_cert_key) + + self._ssl_context = ssl_context + self._proxy_url = proxy_url + self._connector_kwargs = { + "limit": connector_limit, + "limit_per_host": connector_limit_per_host, + "keepalive_timeout": keepalive_timeout, + "force_close": False, + "ssl": ssl_context, + } + # enable_cleanup_closed is only needed for Python < 3.12.7 or == 3.13.0 + # The underlying SSL connection leak was fixed in 3.12.7 and 3.13.1+ + # https://github.com/python/cpython/pull/118960 + if sys.version_info < (3, 12, 7) or sys.version_info[:3] == (3, 13, 0): + self._connector_kwargs["enable_cleanup_closed"] = True + + self._session = None + self._read_format = "Native" + self._write_format = "Native" + self._transform = NativeTransform() + self._client_settings = {} + self._initialized = False + self._reported_libs = set() + self._last_pool_reset = None + self.headers["User-Agent"] = self.headers["User-Agent"].replace("mode:sync;", "mode:async;") + + # Store aiohttp-specific params for deferred initialization + self._compress_param = compress + self._session_id_param = session_id + self._autogenerate_session_id_param = autogenerate_session_id + self._autogenerate_query_id = ( + common.get_setting("autogenerate_query_id") if autogenerate_query_id is None else autogenerate_query_id + ) + self._active_session = None + self._send_progress = None + self._progress_interval = None + + # Call parent init with autoconnect=False to set up config without blocking I/O + super().__init__( + database=database, + query_limit=query_limit, + uri=self.uri, + query_retries=query_retries, + server_host_name=server_host_name, + tz_source=tz_source, + tz_mode=tz_mode, + utc_tz_aware=utc_tz_aware, + apply_server_timezone=apply_server_timezone, + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=False + ) + + # pylint: disable=attribute-defined-outside-init + async def _initialize(self): + """ + Async equivalent of Client._init_common_settings. + Fetches server version, timezone, and settings. + """ + if not self._session: + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + if self._initialized: + return + + try: + tz_source = self._deferred_tz_source + + self.server_tz, self._dst_safe = pytz.UTC, True + row = await self.command("SELECT version(), timezone()", use_database=False) + self.server_version, server_tz_str = tuple(row) + try: + server_tz = pytz.timezone(server_tz_str) + server_tz, self._dst_safe = tzutil.normalize_timezone(server_tz) + if tz_source == "auto": + self._apply_server_tz = self._dst_safe + else: + self._apply_server_tz = tz_source == "server" + self.server_tz = server_tz + except pytz.exceptions.UnknownTimeZoneError: + logger.warning("Warning, server is using an unrecognized timezone %s, will use UTC default", server_tz_str) + + if not self._apply_server_tz and not tzutil.local_tz_dst_safe: + logger.warning("local timezone %s may return unexpected times due to Daylight Savings Time", tzutil.local_tz.tzname(None)) + + readonly = "readonly" + if not self.min_version("19.17"): + readonly = common.get_setting("readonly") + + server_settings = await self.query(f"SELECT name, value, {readonly} as readonly FROM system.settings LIMIT 10000") + self.server_settings = {row["name"]: SettingDef(**row) for row in server_settings.named_results()} + + if self.min_version(CH_VERSION_WITH_PROTOCOL) and common.get_setting("use_protocol_version"): + try: + test_data = await self.raw_query( + "SELECT 1 AS check", fmt="Native", settings={"client_protocol_version": PROTOCOL_VERSION_WITH_LOW_CARD} + ) + if test_data[8:16] == b"\x01\x01\x05check": + self.protocol_version = PROTOCOL_VERSION_WITH_LOW_CARD + except Exception: + pass + + cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") + if cancel_setting.is_writable and not cancel_setting.is_set and \ + "cancel_http_readonly_queries_on_client_close" not in (self._initial_settings or {}): + self._client_settings["cancel_http_readonly_queries_on_client_close"] = "1" + + if self._initial_settings: + for key, value in self._initial_settings.items(): + self.set_client_setting(key, value) + + compress = self._compress_param + if coerce_bool(compress): + compression = ",".join(available_compression) + self.write_compression = available_compression[0] + elif compress and compress not in ("False", "false", "0"): + if compress not in available_compression: + raise ProgrammingError(f"Unsupported compression method {compress}") + compression = compress + self.write_compression = compress + else: + compression = None + + comp_setting = self._setting_status("enable_http_compression") + self._send_comp_setting = not comp_setting.is_set and comp_setting.is_writable # pylint: disable=attribute-defined-outside-init + if comp_setting.is_set or comp_setting.is_writable: + self.compression = compression + + session_id = self._session_id_param + autogenerate_session_id = self._autogenerate_session_id_param + + if autogenerate_session_id is None: + autogenerate_session_id = common.get_setting("autogenerate_session_id") + + if session_id: + self.set_client_setting("session_id", session_id) + elif self.get_client_setting("session_id"): + pass + elif autogenerate_session_id: + self.set_client_setting("session_id", str(uuid.uuid4())) + + send_setting = self._setting_status("send_progress_in_http_headers") + self._send_progress = not send_setting.is_set and send_setting.is_writable + if (send_setting.is_set or send_setting.is_writable) and self._setting_status("http_headers_progress_interval_ms").is_writable: + self._progress_interval = str(min(120000, max(10000, (self._send_receive_timeout - 5) * 1000))) + + if self._setting_status("date_time_input_format").is_writable: + self.set_client_setting("date_time_input_format", "best_effort") + if ( + self._setting_status("allow_experimental_json_type").is_set + and self._setting_status("cast_string_to_dynamic_use_inference").is_writable + ): + self.set_client_setting("cast_string_to_dynamic_use_inference", "1") + if self.min_version("24.8") and not self.min_version("24.10"): + dynamic_module.json_serialization_format = 0 + + self._initialized = True + except Exception: + if self._session and not self._session.closed: + await self._session.close() + self._session = None + raise + + async def __aenter__(self): + """Async context manager entry.""" + if not self._initialized: + await self._initialize() + return self -class DefaultThreadPoolExecutor: - pass - + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.close() + return False + + async def close(self): # type: ignore[override] + if self._session: + await self._session.close() + + async def close_connections(self): # type: ignore[override] + """Close all pooled connections and recreate session""" + if self._session: + await self._session.close() + connector = aiohttp.TCPConnector(**self._connector_kwargs) + self._session = aiohttp.ClientSession( + connector=connector, + timeout=self._timeout, + headers=self.headers, + trust_env=False, + auto_decompress=False, + skip_auto_headers={"Accept-Encoding"}, + ) + + def set_client_setting(self, key, value): + str_value = self._validate_setting(key, value, common.get_setting("invalid_setting_action")) + if str_value is not None: + self._client_settings[key] = str_value + + def get_client_setting(self, key) -> Optional[str]: + return self._client_settings.get(key) + + def set_access_token(self, access_token: str): + auth_header = self.headers.get("Authorization") + if auth_header and not auth_header.startswith("Bearer"): + raise ProgrammingError("Cannot set access token when a different auth type is used") + self.headers["Authorization"] = f"Bearer {access_token}" + if self._session: + self._session.headers["Authorization"] = f"Bearer {access_token}" + + def _prep_query(self, context: QueryContext): + final_query = super()._prep_query(context) + if context.is_insert: + return final_query + fmt = f"\n FORMAT {self._read_format}" + if isinstance(final_query, bytes): + return final_query + fmt.encode() + return final_query + fmt + + async def _query_with_context(self, context: QueryContext) -> QueryResult: # type: ignore[override] + headers = {} + params = {} + if self.database: + params["database"] = self.database + if self.protocol_version: + params["client_protocol_version"] = self.protocol_version + context.block_info = True + params.update(self._validate_settings(context.settings)) + context.rename_response_column = self._rename_response_column + + if not context.is_insert and columns_only_re.search(context.uncommented_query): + fmt_json_query = f"{context.final_query}\n FORMAT JSON" + fields = {"query": fmt_json_query} + fields.update(context.bind_params) + + if self.form_encode_query_params: + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + response = await self._raw_request(None, params, headers, files=files, retries=self.query_retries) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = fmt_json_query + response = await self._raw_request(None, params, headers, files=context.external_data.form_data, retries=self.query_retries) + else: + params.update(context.bind_params) + response = await self._raw_request(fmt_json_query, params, headers, retries=self.query_retries) + + body = await response.read() + encoding = response.headers.get("Content-Encoding") + loop = asyncio.get_running_loop() + + def decompress_and_parse_json(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + return json.loads(decompressed_body) + + # Offload to executor + json_result = await loop.run_in_executor(None, decompress_and_parse_json) + + names: List[str] = [] + types: List[ClickHouseType] = [] + renamer = context.column_renamer + for col in json_result["meta"]: + name = col["name"] + if renamer is not None: + try: + name = renamer(name) + except Exception as e: + logger.debug("Failed to rename col '%s'. Skipping rename. Error: %s", name, e) + names.append(name) + types.append(get_from_name(col["type"])) + return QueryResult([], None, tuple(names), tuple(types)) + + if self.compression: + headers["Accept-Encoding"] = self.compression + if self._send_comp_setting: + params["enable_http_compression"] = "1" + + final_query = self._prep_query(context) + + files = None + data = None + + if self.form_encode_query_params: + fields = {"query": final_query} + fields.update(context.bind_params) + + files = {} + if context.external_data: + params.update(context.external_data.query_params) + files.update(context.external_data.form_data) + + for k, v in fields.items(): + files[k] = (None, str(v)) + elif context.external_data: + params.update(context.bind_params) + params.update(context.external_data.query_params) + params["query"] = final_query + files = context.external_data.form_data + else: + params.update(context.bind_params) + data = final_query + headers["Content-Type"] = "text/plain; charset=utf-8" -# Sentinel value to preserve default behavior and also allow passing `None` -NEW_THREAD_POOL_EXECUTOR = DefaultThreadPoolExecutor() + headers = dict_copy(headers, context.transport_settings) + response = await self._raw_request(data, params, headers, files=files, + server_wait=not context.streaming, + stream=True, retries=self.query_retries) + encoding = response.headers.get("Content-Encoding") + tz_header = response.headers.get("X-ClickHouse-Timezone") + exception_tag = response.headers.get(ex_tag_header) -# pylint: disable=too-many-public-methods,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,too-many-locals -class AsyncClient: - """ - AsyncClient is a wrapper around the ClickHouse Client object that allows for async calls to the ClickHouse server. - Internally, each of the methods that uses IO is wrapped in a call to EventLoop.run_in_executor. - """ + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + def parse_streaming(): + """Parse response from streaming queue (runs in executor).""" + # Wrap streaming source with ResponseBuffer. The streaming source provides a + # .gen property that yields decompressed chunks. + byte_source = RespBuffCls(streaming_source) + context.set_response_tz(self._check_tz_change(tz_header)) + result = self._transform.parse_response(byte_source, context) + + # For Pandas/Numpy, we must materialize in the executor because the resulting objects + # (DataFrame, Array) are fully in-memory structures. + # For standard queries, we return a lazy QueryResult. Accessing .result_set on the event loop + # will raise a ProgrammingError (deadlock check), encouraging usage of .rows_stream. + if not context.streaming: + if context.as_pandas and hasattr(result, 'df_result'): + _ = result.df_result + elif context.use_numpy and hasattr(result, 'np_result'): + _ = result.np_result + elif isinstance(result, QueryResult): + _ = result.result_set + + return result + + # Run parser in executor (pulls from queue, decompresses & parses) + try: + query_result = await loop.run_in_executor(None, parse_streaming) + except Exception: + await streaming_source.aclose() + raise + query_result.summary = self._summary(response) + + # Attach streaming_source to query_result.source to ensure it gets closed + # when the query result is closed (e.g. by StreamContext.__exit__) + query_result.source = streaming_source + + return query_result + + + # pylint: disable=arguments-differ + async def query( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + column_oriented: Optional[bool] = None, + use_numpy: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> QueryResult: + """ + Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. For + parameters, see the create_query_context method + :return: QueryResult -- data and metadata from response + """ + if query and query.lower().strip().startswith("select __connect_version__"): + return QueryResult( + [[f"ClickHouse Connect v.{common.version()} ⓒ ClickHouse Inc."]], None, ("connect_version",), (get_from_name("String"),) + ) + if not context: + context = self.create_query_context( + query=query, + parameters=parameters, + settings=settings, + query_formats=query_formats, + column_formats=column_formats, + encoding=encoding, + use_none=use_none, + column_oriented=column_oriented, + use_numpy=use_numpy, + max_str_len=max_str_len, + query_tz=query_tz, + column_tzs=column_tzs, + utc_tz_aware=utc_tz_aware, + external_data=external_data, + transport_settings=transport_settings, + tz_mode=tz_mode, + ) - def __init__(self, - *, - client: Client, - executor_threads: int = 0, - executor: Union[ThreadPoolExecutor, None, DefaultThreadPoolExecutor] = NEW_THREAD_POOL_EXECUTOR): - if isinstance(client, HttpClient): - client.headers['User-Agent'] = client.headers['User-Agent'].replace('mode:sync;', 'mode:async;') - self.client = client - if executor_threads == 0: - executor_threads = min(32, (os.cpu_count() or 1) + 4) # Mimic the default behavior - if executor is NEW_THREAD_POOL_EXECUTOR: - self.new_executor = True - self.executor = ThreadPoolExecutor(max_workers=executor_threads) - else: - if executor_threads != 0: - logger.warning('executor_threads parameter is ignored when passing an executor object') + if context.is_command: + response = await self.command( + query, + parameters=context.parameters, + settings=context.settings, + external_data=context.external_data, + transport_settings=context.transport_settings, + ) + if isinstance(response, QuerySummary): + return response.as_query_result() + return QueryResult([response] if isinstance(response, list) else [[response]]) - self.new_executor = False - self.executor = executor + return await self._query_with_context(context) - def set_client_setting(self, key: str, value: Any) -> None: + async def query_column_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: """ - Set a clickhouse setting for the client after initialization. If a setting is not recognized by ClickHouse, - or the setting is identified as "read_only", this call will either throw a Programming exception or attempt - to send the setting anyway based on the common setting 'invalid_setting_action'. - :param key: ClickHouse setting name - :param value: ClickHouse setting value + Async version of query_column_block_stream. + Returns a StreamContext that yields column-oriented blocks. """ - self.client.set_client_setting(key=key, value=value) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).column_block_stream - def get_client_setting(self, key: str) -> Optional[str]: + async def query_row_block_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: """ - :param key: The setting key - :return: The string value of the setting, if it exists, or None + Async version of query_row_block_stream. + Returns a StreamContext that yields row-oriented blocks. """ - return self.client.get_client_setting(key=key) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).row_block_stream - def set_access_token(self, access_token: str) -> None: + async def query_rows_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + context: Optional[QueryContext] = None, + query_tz: Optional[Union[str, tzinfo]] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: """ - Set the ClickHouse access token for the client - :param access_token: Access token string + Async version of query_rows_stream. + Returns a StreamContext that yields individual rows. """ - self.client.set_access_token(access_token) + return (await self._context_query(locals(), use_numpy=False, streaming=True)).rows_stream - def min_version(self, version_str: str) -> bool: + # pylint: disable=unused-argument + async def query_np( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ): + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True)).np_result + + async def query_np_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: + check_numpy() + self._add_integration_tag("numpy") + return (await self._context_query(locals(), use_numpy=True, streaming=True)).np_stream + + async def query_df( + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ): + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True)).df_result + + async def query_df_stream( # type: ignore[override] + self, + query: Optional[str] = None, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + query_formats: Optional[Dict[str, str]] = None, + column_formats: Optional[Dict[str, str]] = None, + encoding: Optional[str] = None, + use_none: Optional[bool] = None, + max_str_len: Optional[int] = None, + use_na_values: Optional[bool] = None, + query_tz: Optional[str] = None, + column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, + utc_tz_aware: Optional[bool] = None, + context: Optional[QueryContext] = None, + external_data: Optional[ExternalData] = None, + use_extended_dtypes: Optional[bool] = None, + transport_settings: Optional[Dict[str, str]] = None, + tz_mode: Optional[TzMode] = None, + ) -> StreamContext: + check_pandas() + self._add_integration_tag("pandas") + return (await self._context_query(locals(), use_numpy=True, as_pandas=True, streaming=True)).df_stream + + async def _context_query(self, lcls: dict, **overrides): # type: ignore[override] """ - Determine whether the connected server is at least the submitted version - For Altinity Stable versions like 22.8.15.25.altinitystable - the last condition in the first list comprehension expression is added - :param version_str: A version string consisting of up to 4 integers delimited by dots - :return: True if version_str is greater than the server_version, False if less than + Helper method to create query context and execute query. + Matches sync client pattern for consistency. """ - return self.client.min_version(version_str) + kwargs = lcls.copy() + kwargs.pop("self") + kwargs.update(overrides) + return await self._query_with_context(self.create_query_context(**kwargs)) - async def close(self) -> None: + async def command( # type: ignore[override] + self, + cmd, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + data: Optional[Union[str, bytes]] = None, + settings: Optional[Dict] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> Union[str, int, Sequence[str], QuerySummary]: """ - Subclass implementation to close the connection to the server/deallocate the client + See BaseClient doc_string for this method """ - self.client.close() + cmd, bind_params = bind_query(cmd, parameters, self.server_tz) + params = bind_params.copy() + headers = {} + payload = None + files = None + + if external_data: + if data: + raise ProgrammingError("Cannot combine command data with external data") from None + files = external_data.form_data + params.update(external_data.query_params) + elif isinstance(data, str): + headers["Content-Type"] = "text/plain; charset=utf-8" + payload = data.encode() + elif isinstance(data, bytes): + headers["Content-Type"] = "application/octet-stream" + payload = data + + if payload is None and not cmd: + raise ProgrammingError("Command sent without query or recognized data") from None + + if payload or files: + params["query"] = cmd + else: + payload = cmd + + if use_database and self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + method = "POST" if payload or files else "GET" + response = await self._raw_request(payload, params, headers, files=files, method=method, server_wait=False) + body = await response.read() + encoding = response.headers.get("Content-Encoding") + summary = self._summary(response) - if self.new_executor: - await asyncio.to_thread(self.executor.shutdown, True) + if not body: + return QuerySummary(summary) - async def query(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> QueryResult: + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + if encoding: + decompressed_body = decompress_response(body, encoding) + else: + decompressed_body = body + try: + result = decompressed_body.decode()[:-1].split("\t") + if len(result) == 1: + try: + return int(result[0]) + except ValueError: + return result[0] + return result + except UnicodeDecodeError: + return str(decompressed_body) + + return await loop.run_in_executor(None, decompress_and_decode) + + async def ping(self) -> bool: # type: ignore[override] + try: + url = f"{self.url}/ping" + timeout = aiohttp.ClientTimeout(total=3.0) + async with self._session.get(url, timeout=timeout) as response: + return 200 <= response.status < 300 + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.debug("ping failed", exc_info=True) + return False + + async def raw_query( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> bytes: """ - Main query method for SELECT, DESCRIBE and other SQL statements that return a result matrix. - For parameters, see the create_query_context method. - :return: QueryResult -- data and metadata from response + See BaseClient doc_string for this method """ + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) - def _query(): - return self.client.query(query=query, parameters=parameters, settings=settings, query_formats=query_formats, - column_formats=column_formats, encoding=encoding, use_none=use_none, - column_oriented=column_oriented, use_numpy=use_numpy, max_str_len=max_str_len, - context=context, query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + response = await self._raw_request(body, params, headers=headers, files=files, retries=self.query_retries) + response_data = await response.read() + encoding = response.headers.get("Content-Encoding") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query) - return result - - async def query_column_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None, - ) -> StreamContext: - """ - Variation of main query method that returns a stream of column oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns column oriented blocks - """ - - def _query_column_block_stream(): - return self.client.query_column_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + if encoding: + loop = asyncio.get_running_loop() + response_data = await loop.run_in_executor(None, decompress_response, response_data, encoding) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_column_block_stream) - return result - - async def query_row_block_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: - """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows - """ - - def _query_row_block_stream(): - return self.client.query_row_block_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + return response_data - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_row_block_stream) - return result - - async def query_rows_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - context: QueryContext = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: - """ - Variation of main query method that returns a stream of row oriented blocks. - For parameters, see the create_query_context method. - :return: StreamContext -- Iterable stream context that returns blocks of rows - """ - - def _query_rows_stream(): - return self.client.query_rows_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - external_data=external_data, transport_settings=transport_settings) + async def raw_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + fmt: Optional[str] = None, + use_database: bool = True, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> StreamContext: - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_rows_stream) - return result - - async def raw_query(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> bytes: - """ - Query method that simply returns the raw ClickHouse format bytes. - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: bytes representing raw ClickHouse return value based on format - """ + body, params, headers, files = self._prep_raw_query(query, parameters, settings, fmt, use_database, external_data) + if transport_settings: + headers = dict_copy(headers, transport_settings) - def _raw_query(): - return self.client.raw_query(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) + response = await self._raw_request( + body, params, headers=headers, files=files, stream=True, server_wait=False, retries=self.query_retries + ) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_query) - return result - - async def raw_stream(self, query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - fmt: str = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: - """ - Query method that returns the result as an io.IOBase iterator. - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param fmt: ClickHouse output format - :param use_database Send the database parameter to ClickHouse so the command will be executed in the client - database context - :param external_data External data to send with the query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result - """ + async def byte_iterator(): + async for chunk in response.content.iter_any(): + yield chunk - def _raw_stream(): - return self.client.raw_stream(query=query, parameters=parameters, settings=settings, fmt=fmt, - use_database=use_database, external_data=external_data, transport_settings=transport_settings) + return StreamContext(response, byte_iterator()) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_stream) - return result - - async def query_np(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'numpy.ndarray': - """ - Query method that returns the results as a numpy array. - For parameter values, see the create_query_context method. - :return: Numpy array representing the result set - """ - - def _query_np(): - return self.client.query_np(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, context=context, - external_data=external_data, transport_settings=transport_settings) + def _prep_raw_query(self, query, parameters, settings, fmt, use_database, external_data): + """ + Prepare raw query for execution. - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np) - return result - - async def query_np_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of numpy arrays. - For parameter values, see the create_query_context method. - :return: Generator that yield a numpy array per block representing the result set - """ - - def _query_np_stream(): - return self.client.query_np_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, max_str_len=max_str_len, - context=context, external_data=external_data, transport_settings=transport_settings) + Note: Unlike sync client which returns (body, params, fields), this async version + returns (body, params, headers, files) because aiohttp requires headers to be + configured before the request() call, while urllib3 can add them during request. + """ + if fmt: + query += f"\n FORMAT {fmt}" + + final_query, bind_params = bind_query(query, parameters, self.server_tz) + params = self._validate_settings(settings or {}) + if use_database and self.database: + params["database"] = self.database + + headers = {} + files = None + body = None + + if external_data and not self.form_encode_query_params and isinstance(final_query, bytes): + raise ProgrammingError("Binary query cannot be placed in URL when using External Data; enable form encoding.") + + if self.form_encode_query_params: + files = {} + files["query"] = (None, final_query if isinstance(final_query, str) else final_query.decode()) + for k, v in bind_params.items(): + files[k] = (None, str(v)) + + if external_data: + params.update(external_data.query_params) + files.update(external_data.form_data) + + body = None + elif external_data: + params.update(bind_params) + params["query"] = final_query + params.update(external_data.query_params) + files = external_data.form_data + body = None + else: + params.update(bind_params) + body = final_query.encode() if isinstance(final_query, str) else final_query - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_np_stream) - return result - - async def query_df(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> 'pandas.DataFrame': - """ - Query method that results the results as a pandas dataframe. - For parameter values, see the create_query_context method. - :return: Pandas dataframe representing the result set - """ - - def _query_df(): - return self.client.query_df(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, tz_mode=tz_mode, - utc_tz_aware=utc_tz_aware, context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + return body, params, headers, files - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df) - return result + async def insert( # type: ignore[override] + self, + table: Optional[str] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + column_names: Union[str, Iterable[str]] = "*", + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments + other than data are ignored + :param table: Target table + :param data: Sequence of sequences of Python data + :param column_names: Ordered list of column names or '*' if column types should be retrieved from the + ClickHouse table definition + :param database: Target database -- will use client default database if not specified. + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param column_oriented: If true the data is already "pivoted" in column form + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + if (context is None or context.empty) and data is None: + raise ProgrammingError("No data specified for insert") from None + if context is None: + context = await self.create_insert_context( + table, + column_names, + database, + column_types, + column_type_names, + column_oriented, + settings, + transport_settings=transport_settings, + ) + if data is not None: + if not context.empty: + raise ProgrammingError("Attempting to insert new data with non-empty insert context") from None + context.data = data + return await self.data_insert(context) - async def query_df_arrow( + async def query_arrow( self, query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, @@ -397,77 +1055,51 @@ async def query_df_arrow( use_strings: Optional[bool] = None, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas", - ) -> Union["pd.DataFrame", "pl.DataFrame"]: + ): """ - Query method using the ClickHouse Arrow format to return a DataFrame - with PyArrow dtype backend. This provides better performance and memory efficiency - compared to the standard query_df method, though fewer output formatting options. - + Query method using the ClickHouse Arrow format to return a PyArrow table :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: DataFrame (pandas or polars based on dataframe_library parameter) + :return: PyArrow.Table """ - - def _query_df_arrow(): - return self.client.query_df_arrow( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow) - return result - - async def query_df_stream(self, - query: Optional[str] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, str]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - max_str_len: Optional[int] = None, - use_na_values: Optional[bool] = None, - query_tz: Optional[str] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - context: QueryContext = None, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - tz_mode: Optional[TzMode] = None) -> StreamContext: - """ - Query method that returns the results as a StreamContext. - For parameter values, see the create_query_context method. - :return: Generator that yields a Pandas dataframe per block representing the result set - """ - - def _query_df_stream(): - return self.client.query_df_stream(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, - use_none=use_none, max_str_len=max_str_len, use_na_values=use_na_values, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, context=context, - external_data=external_data, use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_stream) - return result + def parse_arrow_stream(): + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) + table = reader.read_all() + return _apply_arrow_tz_policy(table, self.tz_mode) + + try: + return await loop.run_in_executor(None, parse_arrow_stream) + finally: + await streaming_source.aclose() - async def query_df_arrow_stream( + async def query_arrow_stream( # type: ignore[override] self, query: str, parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, @@ -475,11 +1107,9 @@ async def query_df_arrow_stream( use_strings: Optional[bool] = None, external_data: Optional[ExternalData] = None, transport_settings: Optional[Dict[str, str]] = None, - dataframe_library: str = "pandas" ) -> StreamContext: """ - Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. - Each DataFrame represents a block from the ClickHouse response. + Query method that returns the results as a stream of Arrow record batches. :param query: Query statement/format string :param parameters: Optional dictionary used to format the query @@ -487,274 +1117,269 @@ async def query_df_arrow_stream( :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") - :return: StreamContext that yields DataFrames (pandas or polars based on dataframe_library parameter) + :return: StreamContext that yields PyArrow RecordBatch objects asynchronously """ - - def _query_df_arrow_stream(): - return self.client.query_df_arrow_stream( - query=query, - parameters=parameters, - settings=settings, - use_strings=use_strings, - external_data=external_data, - transport_settings=transport_settings, - dataframe_library=dataframe_library - ) + check_arrow() + self._add_integration_tag("arrow") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_df_arrow_stream) - return result - - def create_query_context(self, - query: Optional[Union[str, bytes]] = None, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - query_formats: Optional[Dict[str, str]] = None, - column_formats: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, - encoding: Optional[str] = None, - use_none: Optional[bool] = None, - column_oriented: Optional[bool] = None, - use_numpy: Optional[bool] = False, - max_str_len: Optional[int] = 0, - context: Optional[QueryContext] = None, - query_tz: Optional[Union[str, tzinfo]] = None, - column_tzs: Optional[Dict[str, Union[str, tzinfo]]] = None, - tz_mode: Optional[TzMode] = None, - use_na_values: Optional[bool] = None, - streaming: bool = False, - as_pandas: bool = False, - external_data: Optional[ExternalData] = None, - use_extended_dtypes: Optional[bool] = None, - transport_settings: Optional[Dict[str, str]] = None, - utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None) -> QueryContext: - """ - Creates or updates a reusable QueryContext object - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param query_formats: See QueryContext __init__ docstring - :param column_formats: See QueryContext __init__ docstring - :param encoding: See QueryContext __init__ docstring - :param use_none: Use None for ClickHouse NULL instead of default values. Note that using None in Numpy - arrays will force the numpy array dtype to 'object', which is often inefficient. This effect also - will impact the performance of Pandas dataframes. - :param column_oriented: Deprecated. Controls orientation of the QueryResult result_set property - :param use_numpy: Return QueryResult columns as one-dimensional numpy arrays - :param max_str_len: Limit returned ClickHouse String values to this length, which allows a Numpy - structured array even with ClickHouse variable length String columns. If 0, Numpy arrays for - String columns will always be object arrays - :param context: An existing QueryContext to be updated with any provided parameter values - :param query_tz Either a string or a pytz tzinfo object. (Strings will be converted to tzinfo objects). - Values for any DateTime or DateTime64 column in the query will be converted to Python datetime.datetime - objects with the selected timezone - :param column_tzs A dictionary of column names to tzinfo objects (or strings that will be converted to - tzinfo objects). The timezone will be applied to datetime objects returned in the query - :param use_na_values: Deprecated alias for use_advanced_dtypes - :param as_pandas Return the result columns as pandas.Series objects - :param streaming Marker used to correctly configure streaming queries - :param external_data ClickHouse "external data" to send with query - :param use_extended_dtypes: Only relevant to Pandas Dataframe queries. Use Pandas "missing types", such as - pandas.NA and pandas.NaT for ClickHouse NULL values, as well as extended Pandas dtypes such as IntegerArray - and StringArray. Defaulted to True for query_df methods - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Reusable QueryContext - """ - - return self.client.create_query_context(query=query, parameters=parameters, settings=settings, - query_formats=query_formats, column_formats=column_formats, - encoding=encoding, use_none=use_none, - column_oriented=column_oriented, - use_numpy=use_numpy, max_str_len=max_str_len, context=context, - query_tz=query_tz, column_tzs=column_tzs, - tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, - use_na_values=use_na_values, - streaming=streaming, as_pandas=as_pandas, - external_data=external_data, - use_extended_dtypes=use_extended_dtypes, - transport_settings=transport_settings) - - async def query_arrow(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> 'pyarrow.Table': - """ - Query method using the ClickHouse Arrow format to return a PyArrow table - :param query: Query statement/format string - :param parameters: Optional dictionary used to format the query - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: PyArrow.Table - """ + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + + def parse_arrow_streaming(): + """Parse Arrow stream incrementally in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) + reader = options.arrow.ipc.open_stream(file_adapter) + + for batch in reader: + try: + batch = _apply_arrow_tz_policy(batch, self.tz_mode) + queue.sync_q.put(batch) + except RuntimeError: + return + + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() + + loop.run_in_executor(None, parse_arrow_streaming) + + async def arrow_batch_generator(): + """Async generator that yields record batches without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowStreamSource(streaming_source, queue), arrow_batch_generator()) - def _query_arrow(): - return self.client.query_arrow(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) + async def query_df_arrow( + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> Union["options.pd.DataFrame", "options.pl.DataFrame"]: + """ + Query method using the ClickHouse Arrow format to return a DataFrame + with PyArrow dtype backend. This provides better performance and memory efficiency + compared to the standard query_df method, though fewer output formatting options. - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow) - return result - - async def query_arrow_stream(self, - query: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - settings: Optional[Dict[str, Any]] = None, - use_strings: Optional[bool] = None, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> StreamContext: - """ - Query method that returns the results as a stream of Arrow tables :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) - :param external_data ClickHouse "external data" to send with query + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Generator that yields a PyArrow.Table for per block representing the result set + :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") + :return: DataFrame (pandas or polars based on dataframe_library parameter) """ + check_arrow() - def _query_arrow_stream(): - return self.client.query_arrow_stream(query=query, parameters=parameters, settings=settings, - use_strings=use_strings, external_data=external_data, - transport_settings=transport_settings) + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _query_arrow_stream) - return result - - async def command(self, - cmd: str, - parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, - data: Union[str, bytes] = None, - settings: Optional[Dict[str, Any]] = None, - use_database: bool = True, - external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> Union[str, int, Sequence[str], QuerySummary]: - """ - Client method that returns a single value instead of a result set - :param cmd: ClickHouse query/command as a python format string - :param parameters: Optional dictionary of key/values pairs to be formatted - :param data: Optional 'data' for the command (for INSERT INTO in particular) - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param use_database: Send the database parameter to ClickHouse so the command will be executed in the client - database context. Otherwise, no database will be specified with the command. This is useful for determining - the default user database - :param external_data ClickHouse "external data" to send with command/query - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: Decoded response from ClickHouse as either a string, int, or sequence of strings, or QuerySummary - if no data returned - """ + def converter(table: "options.arrow.Table") -> "options.pd.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) - def _command(): - return self.client.command(cmd=cmd, parameters=parameters, data=data, settings=settings, - use_database=use_database, external_data=external_data, - transport_settings=transport_settings) + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _command) - return result + def converter(table: "options.arrow.Table") -> "options.pl.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) - async def ping(self) -> bool: - """ - Validate the connection, does not throw an Exception (see debug logs) - :return: ClickHouse server is up and reachable - """ + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") - def _ping(): - return self.client.ping() + arrow_table = await self.query_arrow( + query=query, + parameters=parameters, + settings=settings, + use_strings=use_strings, + external_data=external_data, + transport_settings=transport_settings, + ) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _ping) - return result - - async def insert(self, - table: Optional[str] = None, - data: Sequence[Sequence[Any]] = None, - column_names: Union[str, Iterable[str]] = '*', - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + return converter(arrow_table) + + async def query_df_arrow_stream( # type: ignore[override] + self, + query: str, + parameters: Optional[Union[Sequence, Dict[str, Any]]] = None, + settings: Optional[Dict[str, Any]] = None, + use_strings: Optional[bool] = None, + external_data: Optional[ExternalData] = None, + transport_settings: Optional[Dict[str, str]] = None, + dataframe_library: str = "pandas", + ) -> StreamContext: """ - Method to insert multiple rows/data matrix of native Python objects. If context is specified arguments - other than data are ignored - :param table: Target table - :param data: Sequence of sequences of Python data - :param column_names: Ordered list of column names or '*' if column types should be retrieved from the - ClickHouse table definition - :param database: Target database -- will use client default database if not specified. - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param column_oriented: If true the data is already "pivoted" in column form + Query method that returns the results as a stream of DataFrames with PyArrow dtype backend. + Each DataFrame represents a record batch from the ClickHouse response. + + :param query: Query statement/format string + :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches + :param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary) + :param external_data: ClickHouse "external data" to send with query :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails + :param dataframe_library: Library to use for DataFrame creation ("pandas" or "polars") + :return: StreamContext that yields DataFrames asynchronously (pandas or polars based on dataframe_library parameter) """ + check_arrow() + if dataframe_library == "pandas": + check_pandas() + self._add_integration_tag("pandas") + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") - def _insert(): - return self.client.insert(table=table, data=data, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, context=context, - transport_settings=transport_settings) + def converter(table: "options.arrow.Table") -> "options.pd.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return table.to_pandas(types_mapper=options.pd.ArrowDtype, safe=False) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert) - return result - - async def insert_df(self, table: str = None, - df=None, - database: Optional[str] = None, - settings: Optional[Dict] = None, - column_names: Optional[Sequence[str]] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - context: InsertContext = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored - :param table: ClickHouse table - :param df: two-dimensional pandas dataframe - :param database: Optional ClickHouse database - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names - will be used - :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from - the server - :param column_type_names: ClickHouse column type names. If set then column data does not need to be - retrieved from the server - :param context: Optional reusable insert context to allow repeated inserts into the same table with - different data batches - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails - """ + elif dataframe_library == "polars": + check_polars() + self._add_integration_tag("polars") - def _insert_df(): - return self.client.insert_df(table=table, df=df, database=database, settings=settings, - column_names=column_names, - column_types=column_types, column_type_names=column_type_names, - context=context, transport_settings=transport_settings) + def converter(table: "options.arrow.Table") -> "options.pl.DataFrame": + table = _apply_arrow_tz_policy(table, self.tz_mode) + return options.pl.from_arrow(table) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df) - return result + else: + raise ValueError(f"dataframe_library must be 'pandas' or 'polars', got '{dataframe_library}'") + settings = self._update_arrow_settings(settings, use_strings) + + body, params, headers, files = self._prep_raw_query( + query, parameters, settings, fmt="ArrowStream", + use_database=True, external_data=external_data + ) + if transport_settings: + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request( + body, params, headers=headers, files=files, + stream=True, server_wait=False, retries=self.query_retries + ) + encoding = response.headers.get("Content-Encoding") + exception_tag = response.headers.get(ex_tag_header) - async def insert_arrow(self, table: str, - arrow_table, database: str = None, - settings: Optional[Dict] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: + loop = asyncio.get_running_loop() + streaming_source = StreamingResponseSource(response, encoding=encoding, exception_tag=exception_tag) + await streaming_source.start_producer(loop) + + queue = AsyncSyncQueue(maxsize=10) + + class _ArrowDFStreamSource: + def __init__(self, source, q): + self._source = source + self._queue = q + + async def aclose(self): + self._queue.shutdown() + await self._source.aclose() + + def close(self): + self._queue.shutdown() + self._source.close() + + def parse_and_convert_streaming(): + """Parse Arrow stream and convert to DataFrames in executor (off event loop).""" + try: + file_adapter = StreamingFileAdapter(streaming_source) + + # PyArrow reads incrementally from adapter (which pulls from queue) + reader = options.arrow.ipc.open_stream(file_adapter) + + for batch in reader: + try: + queue.sync_q.put(converter(batch)) + except RuntimeError: + return + + try: + queue.sync_q.put(EOF_SENTINEL) + except RuntimeError: + return + except Exception as e: + try: + queue.sync_q.put(e) + except Exception: + pass + finally: + queue.shutdown() + + loop.run_in_executor(None, parse_and_convert_streaming) + + async def df_generator(): + """Async generator that yields DataFrames without blocking event loop.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + raise item + yield item + + return StreamContext(_ArrowDFStreamSource(streaming_source, queue), df_generator()) + + async def insert_arrow( # type: ignore[override] + self, + table: str, + arrow_table, + database: Optional[str] = None, + settings: Optional[Dict] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: """ Insert a PyArrow table DataFrame into ClickHouse using raw Arrow format :param table: ClickHouse table @@ -762,21 +1387,20 @@ async def insert_arrow(self, table: str, :param database: Optional ClickHouse database :param settings: Optional dictionary of ClickHouse settings (key/string values) :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: QuerySummary with summary information, throws exception if insert fails """ - - def _insert_arrow(): - return self.client.insert_arrow(table=table, arrow_table=arrow_table, database=database, - settings=settings, transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_arrow) - return result - - async def insert_df_arrow( + check_arrow() + self._add_integration_tag("arrow") + full_table = table if "." in table or not database else f"{database}.{table}" + compression = self.write_compression if self.write_compression in ("zstd", "lz4") else None + column_names, insert_block = arrow_buffer(arrow_table, compression) + if hasattr(insert_block, "to_pybytes"): + insert_block = insert_block.to_pybytes() + return await self.raw_insert(full_table, column_names, insert_block, settings, "Arrow", transport_settings) + + async def insert_df_arrow( # type: ignore[override] self, table: str, - df: Union["pd.DataFrame", "pl.DataFrame"], + df: Union["options.pd.DataFrame", "options.pl.DataFrame"], database: Optional[str] = None, settings: Optional[Dict] = None, transport_settings: Optional[Dict[str, str]] = None, @@ -796,30 +1420,57 @@ async def insert_df_arrow( :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: QuerySummary with summary information, throws exception if insert fails """ + check_arrow() - def _insert_df_arrow(): - return self.client.insert_df_arrow( - table=table, - df=df, - database=database, - settings=settings, - transport_settings=transport_settings, - ) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _insert_df_arrow) - return result - - async def create_insert_context(self, - table: str, - column_names: Optional[Union[str, Sequence[str]]] = None, - database: Optional[str] = None, - column_types: Sequence[ClickHouseType] = None, - column_type_names: Sequence[str] = None, - column_oriented: bool = False, - settings: Optional[Dict[str, Any]] = None, - data: Optional[Sequence[Sequence[Any]]] = None, - transport_settings: Optional[Dict[str, str]] = None) -> InsertContext: + if options.pd is not None and isinstance(df, options.pd.DataFrame): + df_lib = "pandas" + elif options.pl is not None and isinstance(df, options.pl.DataFrame): + df_lib = "polars" + else: + if options.pd is None and options.pl is None: + raise ImportError("A DataFrame library (pandas or polars) must be installed to use insert_df_arrow.") + raise TypeError(f"df must be either a pandas DataFrame or polars DataFrame, got {type(df).__name__}") + + if df_lib == "pandas": + if not options.IS_PANDAS_2: + raise ProgrammingError("PyArrow-backed dtypes are only supported when using pandas 2.x.") + + non_arrow_cols = [col for col, dtype in df.dtypes.items() if not isinstance(dtype, options.pd.ArrowDtype)] + if non_arrow_cols: + raise ProgrammingError( + f"insert_df_arrow requires all columns to use PyArrow dtypes. Non-Arrow columns found: [{', '.join(non_arrow_cols)}]. " + ) + try: + arrow_table = options.arrow.Table.from_pandas(df, preserve_index=False) + except Exception as e: + raise DataError(f"Failed to convert pandas DataFrame to Arrow table: {e}") from e + else: + try: + arrow_table = df.to_arrow() + except Exception as e: + raise DataError(f"Failed to convert polars DataFrame to Arrow table: {e}") from e + + self._add_integration_tag(df_lib) + return await self.insert_arrow( + table=table, + arrow_table=arrow_table, + database=database, + settings=settings, + transport_settings=transport_settings, + ) + + async def create_insert_context( # type: ignore[override] + self, + table: str, + column_names: Optional[Union[str, Sequence[str]]] = None, + database: Optional[str] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + column_oriented: bool = False, + settings: Optional[Dict[str, Any]] = None, + data: Optional[Sequence[Sequence[Any]]] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> InsertContext: """ Builds a reusable insert context to hold state for a duration of an insert :param table: Target table @@ -836,60 +1487,393 @@ async def create_insert_context(self, :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) :return: Reusable insert context """ - - def _create_insert_context(): - return self.client.create_insert_context(table=table, column_names=column_names, database=database, - column_types=column_types, column_type_names=column_type_names, - column_oriented=column_oriented, settings=settings, data=data, - transport_settings=transport_settings) - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _create_insert_context) - return result - - async def data_insert(self, context: InsertContext) -> QuerySummary: + full_table = table + if "." not in table: + if database: + full_table = f"{quote_identifier(database)}.{quote_identifier(table)}" + else: + full_table = quote_identifier(table) + column_defs = [] + if column_types is None and column_type_names is None: + describe_result = await self.query(f"DESCRIBE TABLE {full_table}", settings=settings) + column_defs = [ + ColumnDef(**row) for row in describe_result.named_results() if row["default_type"] not in ("ALIAS", "MATERIALIZED") + ] + if column_names is None or isinstance(column_names, str) and column_names == "*": + column_names = [cd.name for cd in column_defs] + column_types = [cd.ch_type for cd in column_defs] + elif isinstance(column_names, str): + column_names = [column_names] + if len(column_names) == 0: + raise ValueError("Column names must be specified for insert") + if not column_types: + if column_type_names: + column_types = [get_from_name(name) for name in column_type_names] + else: + column_map = {d.name: d for d in column_defs} + try: + column_types = [column_map[name].ch_type for name in column_names] + except KeyError as ex: + raise ProgrammingError(f"Unrecognized column {ex} in table {table}") from None + if len(column_names) != len(column_types): + raise ProgrammingError("Column names do not match column types") from None + return InsertContext( + full_table, + column_names, + column_types, + column_oriented=column_oriented, + settings=settings, + transport_settings=transport_settings, + data=data, + ) + + async def data_insert(self, context: InsertContext) -> QuerySummary: # type: ignore[override] """ - Subclass implementation of the data insert - :context: InsertContext parameter object - :return: No return, throws an exception if the insert fails + See BaseClient doc_string for this method. + + Uses true streaming via reverse bridge pattern: + - Sync producer (serializer) runs in executor, puts blocks in queue + - Async consumer (network) pulls from queue and yields to aiohttp + - Bounded queue provides backpressure to prevent memory bloat """ + if context.empty: + logger.debug("No data included in insert, skipping") + return QuerySummary() - def _data_insert(): - return self.client.data_insert(context=context) + if context.compression is None: + context.compression = self.write_compression loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _data_insert) - return result - - async def raw_insert(self, table: str, - column_names: Optional[Sequence[str]] = None, - insert_block: Union[str, bytes, Generator[bytes, None, None], BinaryIO] = None, - settings: Optional[Dict] = None, - fmt: Optional[str] = None, - compression: Optional[str] = None, - transport_settings: Optional[Dict[str, str]] = None) -> QuerySummary: - """ - Insert data already formatted in a bytes object - :param table: Table name (whether qualified with the database name or not) - :param column_names: Sequence of column names - :param insert_block: Binary or string data already in a recognized ClickHouse format - :param settings: Optional dictionary of ClickHouse settings (key/string values) - :param compression: Recognized ClickHouse `Accept-Encoding` header compression value - :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :param fmt: Valid clickhouse format - """ - def _raw_insert(): - return self.client.raw_insert(table=table, column_names=column_names, insert_block=insert_block, - settings=settings, fmt=fmt, compression=compression, - transport_settings=transport_settings) + streaming_source = StreamingInsertSource( + transform=self._transform, context=context, loop=loop, maxsize=10 + ) - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(self.executor, _raw_insert) - return result + streaming_source.start_producer() - async def __aenter__(self) -> "AsyncClient": - return self + headers = {"Content-Type": "application/octet-stream"} + if context.compression: + headers["Content-Encoding"] = context.compression - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: - await self.close() + params = {} + if self.database: + params["database"] = self.database + params.update(self._validate_settings(context.settings)) + headers = dict_copy(headers, context.transport_settings) + + try: + response = await self._raw_request( + streaming_source.async_generator(), params, headers=headers, server_wait=False + ) + logger.debug("Context insert response code: %d", response.status) + except Exception: + await streaming_source.close() + + if context.insert_exception: + ex = context.insert_exception + context.insert_exception = None + raise ex from None + raise + finally: + await streaming_source.close() + context.data = None + + return QuerySummary(self._summary(response)) + + async def insert_df( # type: ignore[override] + self, + table: Optional[str] = None, + df=None, + database: Optional[str] = None, + settings: Optional[Dict] = None, + column_names: Optional[Sequence[str]] = None, + column_types: Optional[Sequence[ClickHouseType]] = None, + column_type_names: Optional[Sequence[str]] = None, + context: Optional[InsertContext] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + Insert a pandas DataFrame into ClickHouse. If context is specified arguments other than df are ignored + :param table: ClickHouse table + :param df: two-dimensional pandas dataframe + :param database: Optional ClickHouse database + :param settings: Optional dictionary of ClickHouse settings (key/string values) + :param column_names: An optional list of ClickHouse column names. If not set, the DataFrame column names + will be used + :param column_types: ClickHouse column types. If set then column data does not need to be retrieved from + the server + :param column_type_names: ClickHouse column type names. If set then column data does not need to be + retrieved from the server + :param context: Optional reusable insert context to allow repeated inserts into the same table with + different data batches + :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) + :return: QuerySummary with summary information, throws exception if insert fails + """ + check_pandas() + self._add_integration_tag("pandas") + if context is None: + if column_names is None: + column_names = df.columns + elif len(column_names) != len(df.columns): + raise ProgrammingError("DataFrame column count does not match insert_columns") from None + return await self.insert( + table, + df, + column_names, + database, + column_types=column_types, + column_type_names=column_type_names, + settings=settings, + transport_settings=transport_settings, + context=context, + ) + + async def raw_insert( # type: ignore[override] + self, + table: Optional[str] = None, + column_names: Optional[Sequence[str]] = None, + insert_block: Optional[Union[str, bytes, Generator[bytes, None, None], BinaryIO]] = None, + settings: Optional[Dict] = None, + fmt: Optional[str] = None, + compression: Optional[str] = None, + transport_settings: Optional[Dict[str, str]] = None, + ) -> QuerySummary: + """ + See BaseClient doc_string for this method + """ + params = {} + headers = {"Content-Type": "application/octet-stream"} + if compression: + headers["Content-Encoding"] = compression + + if table: + cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else "" + fmt_str = fmt if fmt else self._write_format + query = f"INSERT INTO {table}{cols} FORMAT {fmt_str}" + if not compression and isinstance(insert_block, str): + insert_block = query + "\n" + insert_block + elif not compression and isinstance(insert_block, (bytes, bytearray, BinaryIO)): + insert_block = (query + "\n").encode() + insert_block + else: + params["query"] = query + + if self.database: + params["database"] = self.database + params.update(self._validate_settings(settings or {})) + headers = dict_copy(headers, transport_settings) + + response = await self._raw_request(insert_block, params, headers, server_wait=False) + logger.debug("Raw insert response code: %d", response.status) + return QuerySummary(self._summary(response)) + + def _add_integration_tag(self, name: str): + """ + Dynamically adds a product (like pandas or sqlalchemy) to the User-Agent string details section. + """ + if not common.get_setting("send_integration_tags") or name in self._reported_libs: + return + + try: + ver = "unknown" + try: + ver = dist_version(name) + except Exception: + try: + mod = import_module(name) + ver = getattr(mod, "__version__", "unknown") + except Exception: + pass + + product_info = f"{name}/{ver}" + + ua = self.headers.get("User-Agent", "") + start = ua.find("(") + if start == -1: + return + end = ua.find(")", start + 1) + if end == -1: + return + + details = ua[start + 1 : end].strip() + + if product_info in details: + self._reported_libs.add(name) + return + + new_details = f"{product_info}; {details}" if details else product_info + new_ua = f"{ua[: start + 1]}{new_details}{ua[end:]}" + self.headers["User-Agent"] = new_ua.strip() + if self._session: + self._session.headers["User-Agent"] = new_ua.strip() + + self._reported_libs.add(name) + logger.debug("Added '%s' to User-Agent", product_info) + + except Exception as e: + logger.debug("Problem adding '%s' to User-Agent: %s", name, e) + + async def _error_handler(self, response: aiohttp.ClientResponse, retried: bool = False): + """ + Handles HTTP errors. Tries to be robust and provide maximum context. + """ + try: + body = "" + try: + raw_body = await response.read() + encoding = response.headers.get("Content-Encoding") + + if encoding: + loop = asyncio.get_running_loop() + + def decompress_and_decode(): + decompressed = decompress_response(raw_body, encoding) + return common.format_error(decompressed.decode(errors="backslashreplace")).strip() + + body = await loop.run_in_executor(None, decompress_and_decode) + else: + loop = asyncio.get_running_loop() + body = await loop.run_in_executor( + None, + lambda: common.format_error(raw_body.decode(errors="backslashreplace")).strip() + ) + except Exception: + logger.warning("Failed to read error response body", exc_info=True) + + if self.show_clickhouse_errors: + err_code = response.headers.get(ex_header) + if err_code: + err_str = f"Received ClickHouse exception, code: {err_code}" + else: + err_str = f"HTTP driver received HTTP status {response.status}" + + if body: + err_str = f"{err_str}, server response: {body}" + else: + err_str = "The ClickHouse server returned an error" + + err_str = f"{err_str} (for url {self.url})" + + finally: + response.close() + + raise OperationalError(err_str) if retried else DatabaseError(err_str) from None + + async def _raw_request( + self, + data, + params, + headers=None, + files=None, + method="POST", + stream=False, + server_wait=True, + retries: int = 0, + ) -> aiohttp.ClientResponse: + if self._session is None: + raise ProgrammingError( + "Session not initialized. Use 'async with get_async_client(...)' or call 'await client._initialize()' first." + ) + + reset_seconds = common.get_setting("max_connection_age") + if reset_seconds: + now = time.time() + if self._last_pool_reset is None: + self._last_pool_reset = now + elif self._last_pool_reset < now - reset_seconds: + logger.debug("connection expiration - resetting connection pool") + await self.close_connections() + self._last_pool_reset = now + + final_params = dict_copy(self._client_settings, params) + if server_wait: + final_params.setdefault("wait_end_of_query", "1") + if self._send_progress: + final_params.setdefault("send_progress_in_http_headers", "1") + if self._progress_interval: + final_params.setdefault("http_headers_progress_interval_ms", self._progress_interval) + if self._autogenerate_query_id and "query_id" not in final_params: + final_params["query_id"] = str(uuid.uuid4()) + + req_headers = dict_copy(self.headers, headers) + if self.server_host_name: + req_headers["Host"] = self.server_host_name + query_session = final_params.get("session_id") + attempts = 0 + + # pylint: disable=too-many-nested-blocks + while True: + attempts += 1 + + if query_session: + if query_session == self._active_session: + raise ProgrammingError( + "Attempt to execute concurrent queries within the same session. " + "Please use a separate client instance per concurrent query." + ) + self._active_session = query_session + + try: + # Construct full URL (aiohttp doesn't have base_url) + url = f"{self.url}/" + request_kwargs = {"method": method, "url": url, "params": final_params, "headers": req_headers} + if hasattr(self, "_proxy_url") and self._proxy_url: + request_kwargs["proxy"] = self._proxy_url + if files: + # IMPORTANT: Must set content_type on text fields to force multipart/form-data encoding + # Without content_type, aiohttp uses application/x-www-form-urlencoded + form = aiohttp.FormData() + for field_name, field_value in files.items(): + if isinstance(field_value, tuple): + if field_value[0] is None: + form.add_field(field_name, str(field_value[1]), content_type='text/plain') + else: + filename = field_value[0] + file_data = field_value[1] + content_type = field_value[2] if len(field_value) > 2 else None + form.add_field(field_name, file_data, filename=filename, content_type=content_type) + else: + form.add_field(field_name, field_value, content_type='text/plain') + request_kwargs["data"] = form + elif isinstance(data, dict): + request_kwargs["data"] = data + else: + request_kwargs["data"] = data + + response = await self._session.request(**request_kwargs) + if 200 <= response.status < 300 and not response.headers.get(ex_header): + return response + + if response.status in (429, 503, 504): + if attempts > retries: + await self._error_handler(response, retried=True) + else: + logger.debug("Retrying request with status code %s (attempt %s/%s)", response.status, attempts, retries + 1) + await asyncio.sleep(0.1 * attempts) + response.close() + continue + await self._error_handler(response) + + except aiohttp.ServerConnectionError as e: + if "Connection reset" in str(e) or "Remote end closed" in str(e) or "Cannot connect" in str(e): + if attempts == 1: + logger.debug("Retrying after connection error from remote host") + continue + raise OperationalError(f"Network Error: {str(e)}") from e + + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + raise OperationalError(f"Network Error: {str(e)}") from e + + finally: + if query_session: + self._active_session = None + + @staticmethod + def _summary(response: aiohttp.ClientResponse): + summary = {} + if "X-ClickHouse-Summary" in response.headers: + try: + summary = json.loads(response.headers["X-ClickHouse-Summary"]) + except json.JSONDecodeError: + pass + summary["query_id"] = response.headers.get("X-ClickHouse-Query-Id", "") + return summary diff --git a/clickhouse_connect/driver/asyncqueue.py b/clickhouse_connect/driver/asyncqueue.py new file mode 100644 index 00000000..e052a16f --- /dev/null +++ b/clickhouse_connect/driver/asyncqueue.py @@ -0,0 +1,195 @@ +import asyncio +import threading +from collections import deque +from typing import Deque, Generic, Optional, TypeVar + +from clickhouse_connect.driver.exceptions import ProgrammingError + +__all__ = ["AsyncSyncQueue", "Empty", "Full", "EOF_SENTINEL"] + +T = TypeVar("T") + +EOF_SENTINEL = object() + + +# pylint: disable=too-many-instance-attributes +class AsyncSyncQueue(Generic[T]): + """High-performance bridge between AsyncIO and Threading.""" + + def __init__(self, maxsize: int = 100): + self._maxsize = maxsize + self._queue: Deque[T] = deque() + self._shutdown = False + self._loop: Optional[asyncio.AbstractEventLoop] = None + + self._lock = threading.Lock() + + self._sync_not_empty = threading.Condition(self._lock) + self._sync_not_full = threading.Condition(self._lock) + + self._async_getters: Deque[asyncio.Future] = deque() + self._async_putters: Deque[asyncio.Future] = deque() + + self.sync_q = _SyncQueueInterface(self) + self.async_q = _AsyncQueueInterface(self) + + def _bind_loop(self): + """Lazy-bind to the running loop on first async access.""" + if self._loop is None: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + pass + + def _check_deadlock(self): + """Check if blocking would cause a deadlock on the event loop.""" + if self._loop is None: + return + + try: + current_loop = asyncio.get_running_loop() + if current_loop is self._loop: + raise ProgrammingError( + "Deadlock detected: Synchronous blocking operation called on event loop thread. " + "This usually happens when iterating a stream synchronously (e.g., 'for row in result') " + "instead of asynchronously ('async for row in result') inside an async function." + ) + except RuntimeError: + pass + + def _wakeup_async_waiter(self, waiter_queue: Deque[asyncio.Future]): + """Helper: Wake up the next async waiter in the queue safely.""" + while waiter_queue: + fut = waiter_queue.popleft() + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + break + + def shutdown(self): + """Terminates the queue. All readers will receive EOF_SENTINEL.""" + with self._lock: + self._shutdown = True + + self._sync_not_empty.notify_all() + self._sync_not_full.notify_all() + + if self._loop and not self._loop.is_closed(): + for fut in list(self._async_getters): + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + for fut in list(self._async_putters): + if not fut.done(): + self._loop.call_soon_threadsafe(fut.set_result, None) + self._async_getters.clear() + self._async_putters.clear() + + @property + def qsize(self) -> int: + with self._lock: + return len(self._queue) + + +# pylint: disable=protected-access +class _SyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + def get(self, block: bool = True, timeout: Optional[float] = None) -> T: + with self._p._lock: + while not self._p._queue and not self._p._shutdown: + if not block: + raise Empty() + + self._p._check_deadlock() + if not self._p._sync_not_empty.wait(timeout): + raise Empty() + + if not self._p._queue and self._p._shutdown: + return EOF_SENTINEL + + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + + return item + + def put(self, item: T, block: bool = True, timeout: Optional[float] = None) -> None: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + while self._p._maxsize > 0 and len(self._p._queue) >= self._p._maxsize: + if not block: + raise Full() + + self._p._check_deadlock() + if not self._p._sync_not_full.wait(timeout): + raise Full() + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + self._p._queue.append(item) + + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + + +class _AsyncQueueInterface(Generic[T]): + def __init__(self, parent: AsyncSyncQueue[T]): + self._p = parent + + async def get(self) -> T: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._queue: + item = self._p._queue.popleft() + self._p._sync_not_full.notify() + self._p._wakeup_async_waiter(self._p._async_putters) + return item + + if self._p._shutdown: + return EOF_SENTINEL + + fut = self._p._loop.create_future() + self._p._async_getters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_getters: + self._p._async_getters.remove(fut) + raise + + async def put(self, item: T) -> None: + self._p._bind_loop() + while True: + with self._p._lock: + if self._p._shutdown: + raise RuntimeError("Queue is shutdown") + + if self._p._maxsize <= 0 or len(self._p._queue) < self._p._maxsize: + self._p._queue.append(item) + self._p._sync_not_empty.notify() + self._p._wakeup_async_waiter(self._p._async_getters) + return + + fut = self._p._loop.create_future() + self._p._async_putters.append(fut) + + try: + await fut + except asyncio.CancelledError: + with self._p._lock: + if fut in self._p._async_putters: + self._p._async_putters.remove(fut) + raise + + +class Empty(Exception): + pass + + +class Full(Exception): + pass diff --git a/clickhouse_connect/driver/client.py b/clickhouse_connect/driver/client.py index 234f0e46..df6145b1 100644 --- a/clickhouse_connect/driver/client.py +++ b/clickhouse_connect/driver/client.py @@ -124,8 +124,8 @@ class Client(ABC): """ Base ClickHouse Connect client """ - compression: str = None - write_compression: str = None + compression: Optional[str] = None + write_compression: Optional[str] = None protocol_version = 0 valid_transport_settings = set() optional_transport_settings = set() @@ -190,7 +190,7 @@ def utc_tz_aware(self) -> Union[bool, Literal["schema"]]: return _TZ_MODE_TO_UTC_TZ_AWARE[self.tz_mode] def __init__(self, - database: str, + database: Optional[str], query_limit: int, uri: str, query_retries: int, @@ -199,7 +199,8 @@ def __init__(self, tz_mode: Optional[TzMode] = None, show_clickhouse_errors: Optional[bool] = None, utc_tz_aware: Optional[Union[bool, Literal["schema"]]] = None, - apply_server_timezone: Optional[Union[str, bool]] = None): + apply_server_timezone: Optional[Union[str, bool]] = None, + autoconnect: bool = True): """ Shared initialization of ClickHouse Connect client :param database: database name @@ -214,6 +215,8 @@ def __init__(self, for bare DateTime columns. :param utc_tz_aware: Deprecated. Use tz_mode instead. :param apply_server_timezone: Deprecated. Use tz_source instead. + :param autoconnect: If True, immediately connect to server and fetch settings. If False, + defer connection to _connect() method. Used by async clients to avoid blocking I/O in __init__. """ self.query_limit = coerce_int(query_limit) self.query_retries = coerce_int(query_retries) @@ -226,7 +229,17 @@ def __init__(self, self.tz_mode = _resolve_tz_mode(tz_mode, utc_tz_aware) resolved_tz_source = _resolve_tz_source(tz_source, apply_server_timezone) self._tz_source = resolved_tz_source - self._init_common_settings(resolved_tz_source) + + # Initialize attributes that will be set during connection + self.server_version = None + self.server_tz = pytz.UTC + self.server_settings = {} + + if autoconnect: + self._init_common_settings(resolved_tz_source) + else: + # Store for deferred async initialization + self._deferred_tz_source = resolved_tz_source def _init_common_settings(self, tz_source: TzSource): self.server_tz, self._dst_safe = pytz.UTC, True @@ -495,9 +508,9 @@ def raw_stream(self, query: str, fmt: str = None, use_database: bool = True, external_data: Optional[ExternalData] = None, - transport_settings: Optional[Dict[str, str]] = None) -> io.IOBase: + transport_settings: Optional[Dict[str, str]] = None) -> Union[io.IOBase, StreamContext]: """ - Query method that returns the result as an io.IOBase iterator + Query method that returns the result as a stream iterator. :param query: Query statement/format string :param parameters: Optional dictionary used to format the query :param settings: Optional dictionary of ClickHouse settings (key/string values) @@ -506,7 +519,7 @@ def raw_stream(self, query: str, database context. :param external_data: External data to send with the query. :param transport_settings: Optional dictionary of transport level settings (HTTP headers, etc.) - :return: io.IOBase stream/iterator for the result + :return: io.IOBase (sync) or StreamContext (async) - both support iteration over raw bytes """ # pylint: disable=duplicate-code,unused-argument diff --git a/clickhouse_connect/driver/common.py b/clickhouse_connect/driver/common.py index 15ff2bdf..820483c3 100644 --- a/clickhouse_connect/driver/common.py +++ b/clickhouse_connect/driver/common.py @@ -1,6 +1,7 @@ import array import struct import sys +import asyncio from typing import Any, Sequence, MutableSequence, Dict, Optional, Union, Generator, Callable @@ -190,7 +191,8 @@ def __eq__(self, other): class StreamContext: """ Wraps a generator and its "source" in a Context. This ensures that the source will be "closed" even if the - generator is not fully consumed or there is an exception during consumption + generator is not fully consumed or there is an exception during consumption. Supports both synchronous and + asynchronous usage. """ __slots__ = 'source', 'gen', '_in_context' @@ -218,6 +220,58 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.source.close() self.gen = None + def __aiter__(self): + return self + + async def __anext__(self): + if not self._in_context: + raise ProgrammingError("Stream should be used within a context") + try: + if hasattr(self.gen, "__anext__"): + return await self.gen.__anext__() + + def _next_wrapper(): + try: + return True, self.gen.__next__() + except StopIteration: + return False, None + + loop = asyncio.get_running_loop() + has_value, value = await loop.run_in_executor(None, _next_wrapper) + if not has_value: + raise StopAsyncIteration from None + return value + except (StopAsyncIteration, StopIteration): + raise StopAsyncIteration from None + except Exception as ex: + if not isinstance(ex, StreamClosedError): + self._in_context = False + if hasattr(self.source, "close"): + if hasattr(self.source.close, "__await__"): + await self.source.close() + else: + self.source.close() + self.gen = None + raise ex + + async def __aenter__(self): + if not self.gen: + raise StreamClosedError + self._in_context = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._in_context = False + if hasattr(self.source, 'aclose'): + await self.source.aclose() + elif hasattr(self.source, 'close'): + if hasattr(self.source.close, '__await__'): + await self.source.close() + else: + self.source.close() + self.gen = None + + # pylint: disable=too-many-return-statements def get_rename_method(method: Optional[str]) -> Optional[Callable[[str], str]]: def _to_camel(s: str) -> str: diff --git a/clickhouse_connect/driver/httpclient.py b/clickhouse_connect/driver/httpclient.py index f456e3c9..55cf8c66 100644 --- a/clickhouse_connect/driver/httpclient.py +++ b/clickhouse_connect/driver/httpclient.py @@ -190,7 +190,8 @@ def __init__(self, tz_mode=tz_mode, utc_tz_aware=utc_tz_aware, apply_server_timezone=apply_server_timezone, - show_clickhouse_errors=show_clickhouse_errors) + show_clickhouse_errors=show_clickhouse_errors, + autoconnect=True) self.params = dict_copy(self.params, self._validate_settings(ch_settings)) cancel_setting = self._setting_status("cancel_http_readonly_queries_on_client_close") if cancel_setting.is_writable and not cancel_setting.is_set and \ @@ -449,42 +450,31 @@ def _error_handler(self, response: HTTPResponse, retried: bool = False) -> None: """ try: body = "" - # Always try to read the response body for context. try: - # get_response_data reads body and decodes it for the error message raw_body = get_response_data(response) body = common.format_error( raw_body.decode(errors="backslashreplace") ).strip() except Exception: # pylint: disable=broad-except - # If we can't read or decode the body, we'll proceed without it logger.warning("Failed to read error response body", exc_info=True) - # Build the error message if self.show_clickhouse_errors: err_code = response.headers.get(ex_header) if err_code: - # Prioritize the specific ClickHouse exception code if it exists. err_str = f"Received ClickHouse exception, code: {err_code}" else: - # Otherwise, just use the generic HTTP status err_str = f"HTTP driver received HTTP status {response.status}" if body: - # Always append the body if it exists err_str = f"{err_str}, server response: {body}" else: - # Simple message for when detailed errors are disabled err_str = "The ClickHouse server returned an error" - # Add the URL for additional context err_str = f"{err_str} (for url {self.url})" finally: - # Ensure closed response to prevent resource leaks response.close() - # Raise the appropriate exception class raise OperationalError(err_str) if retried else DatabaseError(err_str) from None def _raw_request(self, diff --git a/clickhouse_connect/driver/query.py b/clickhouse_connect/driver/query.py index 5ce5f8bc..6a0f7d41 100644 --- a/clickhouse_connect/driver/query.py +++ b/clickhouse_connect/driver/query.py @@ -431,12 +431,20 @@ def result_set(self) -> Matrix: @property def result_columns(self) -> Matrix: if self._result_columns is None: - result = [[] for _ in range(len(self.column_names))] - with self.column_block_stream as stream: - for block in stream: - for base, added in zip(result, block): - base.extend(added) - self._result_columns = result + # If rows are already materialized and stream is closed, transpose from rows + # This happens when async client eagerly materializes result_rows + if self._result_rows is not None and self._block_gen is None: + if self._result_rows: + self._result_columns = list(map(list, zip(*self._result_rows))) + else: + self._result_columns = [[] for _ in range(len(self.column_names))] + else: + result = [[] for _ in range(len(self.column_names))] + with self.column_block_stream as stream: + for block in stream: + for base, added in zip(result, block): + base.extend(added) + self._result_columns = result return self._result_columns @property diff --git a/clickhouse_connect/driver/streaming.py b/clickhouse_connect/driver/streaming.py new file mode 100644 index 00000000..476083e7 --- /dev/null +++ b/clickhouse_connect/driver/streaming.py @@ -0,0 +1,319 @@ +# pylint: disable=import-error + +import asyncio +import logging +import threading +import zlib +from typing import Iterator, Optional + +import lz4.frame +import zstandard + +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue +from clickhouse_connect.driver.compression import available_compression +from clickhouse_connect.driver.exceptions import OperationalError +from clickhouse_connect.driver.types import Closable + +logger = logging.getLogger(__name__) + +__all__ = ["StreamingResponseSource", "StreamingFileAdapter", "StreamingInsertSource"] + +if "br" in available_compression: + import brotli +else: + brotli = None + + +# pylint: disable=too-many-instance-attributes, broad-exception-caught +class StreamingResponseSource(Closable): + """Streaming source that feeds chunks from async producer to sync consumer.""" + + READ_BUFFER_SIZE = 1024 * 1024 + + def __init__(self, response, encoding: Optional[str] = None, exception_tag: Optional[str] = None): + self.response = response + self.encoding = encoding + self.exception_tag = exception_tag + + # maxsize=10 means max ~10 socket reads buffered + self.queue = AsyncSyncQueue(maxsize=10) + + self._decompressor = None + self._decompressor_initialized = False + + # Multiple accesses to .gen must return the same generator, not create new ones + self._gen_cache = None + + self._producer_task = None + self._producer_started = threading.Event() + self._producer_error: Optional[Exception] = None + self._producer_completed = False + + async def start_producer(self, loop: asyncio.AbstractEventLoop): + """Start the async producer task. + Must be called from the event loop thread before consuming. + """ + + async def producer(): + """Async producer: reads chunks from response, feeds queue.""" + data_sent = False + try: + while True: + chunk = await self.response.content.read(self.READ_BUFFER_SIZE) + if not chunk: + break + data_sent = True + await self.queue.async_q.put(chunk) + + await self.queue.async_q.put(EOF_SENTINEL) + self._producer_completed = True + + except Exception as e: + logger.error("Producer error while streaming response: %s", e, exc_info=True) + if not data_sent: + e = OperationalError("Failed to read response data from server") + self._producer_error = e + + try: + await self.queue.async_q.put(e) + except RuntimeError: + pass + + finally: + self.queue.shutdown() + + self._producer_task = loop.create_task(producer()) + self._producer_started.set() + + @property + def gen(self) -> Iterator[bytes]: + """Generator that yields decompressed chunks. + + CRITICAL: Returns cached generator to prevent multiple generators + from competing to read from the same queue. + """ + if self._gen_cache is not None: + return self._gen_cache + + self._gen_cache = self._create_generator() + return self._gen_cache + + # pylint: disable=too-many-branches + def _create_generator(self) -> Iterator[bytes]: + """Creates the actual generator function.""" + if not self._producer_started.wait(timeout=5.0): + raise RuntimeError("Producer failed to start within timeout") + + if self.encoding and not self._decompressor_initialized: + self._decompressor_initialized = True + try: + self._decompressor = self._create_decompressor(self.encoding) + except Exception as e: + logger.error("Failed to create decompressor for %s: %s", self.encoding, e) + raise + + # pylint: disable=too-many-nested-blocks + while True: + chunk = self.queue.sync_q.get() + + if chunk is EOF_SENTINEL: + if self._decompressor: + try: + if hasattr(self._decompressor, "flush"): + final = self._decompressor.flush() + if final: + yield final + except Exception as e: + logger.error("Error flushing decompressor: %s", e, exc_info=True) + raise + break + + if isinstance(chunk, Exception): + raise chunk + + if self._decompressor: + try: + if hasattr(self._decompressor, "decompress"): + decompressed = self._decompressor.decompress(chunk) + else: + decompressed = self._decompressor.process(chunk) + if decompressed: + yield decompressed + except Exception as e: + logger.error("Decompression error: %s", e, exc_info=True) + raise + else: + yield chunk + + @staticmethod + def _create_decompressor(encoding: str): + """Create incremental decompressor for encoding.""" + if encoding == "gzip": + return zlib.decompressobj(16 + zlib.MAX_WBITS) + + if encoding == "deflate": + return zlib.decompressobj() + + if encoding == "br": + if brotli is not None: + return brotli.Decompressor() + raise ImportError("brotli compression requires 'brotli' package. Install with: pip install brotli") + + if encoding == "zstd": + return zstandard.ZstdDecompressor().decompressobj() + + if encoding == "lz4": + return lz4.frame.LZ4FrameDecompressor() + + raise ValueError(f"Unsupported compression encoding: {encoding}") + + async def aclose(self): + """Async cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + try: + await self._producer_task + except asyncio.CancelledError: + pass + except Exception: + pass + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + await asyncio.sleep(0.05) + + def close(self): + """Synchronous cleanup resources""" + self.queue.shutdown() + + if self._producer_task and not self._producer_task.done(): + self._producer_task.cancel() + + if self.response and not self.response.closed: + if not self._producer_completed: + self.response.close() + + +class StreamingFileAdapter: + """File-like adapter for PyArrow streaming.""" + + def __init__(self, streaming_source): + self.streaming_source = streaming_source + self.gen = streaming_source.gen + self.buffer = b"" + self.closed = False + self.eof = False + + def read(self, size: int = -1) -> bytes: + """Read up to size bytes from stream""" + if self.closed or self.eof: + return b"" + + if size != -1 and len(self.buffer) >= size: + result = self.buffer[:size] + self.buffer = self.buffer[size:] + return result + + chunks = [self.buffer] if self.buffer else [] + current_len = len(self.buffer) + self.buffer = b"" + + while (size == -1 or current_len < size) and not self.eof: + try: + chunk = next(self.gen) + if chunk: + chunks.append(chunk) + current_len += len(chunk) + else: + self.eof = True + break + except StopIteration: + self.eof = True + break + + full_data = b"".join(chunks) + + if size == -1 or len(full_data) <= size: + return full_data + + result = full_data[:size] + self.buffer = full_data[size:] + return result + + def close(self): + self.closed = True + + +class StreamingInsertSource: + """Streaming source for async inserts (reverse bridge)""" + + def __init__(self, transform, context, loop: asyncio.AbstractEventLoop, maxsize: int = 10): + self.transform = transform + self.context = context + self.loop = loop + self.queue = AsyncSyncQueue(maxsize=maxsize) + self._producer_future = None + self._started = False + + def start_producer(self): + if self._started: + raise RuntimeError("Producer already started") + self._started = True + + def producer(): + try: + for block in self.transform.build_insert(self.context): + self.queue.sync_q.put(block) + + self.queue.sync_q.put(EOF_SENTINEL) + + except Exception as e: + logger.error("Insert producer error: %s", e, exc_info=True) + try: + self.queue.sync_q.put(e) + except Exception: + pass + finally: + self.queue.shutdown() + + self._producer_future = self.loop.run_in_executor(None, producer) + + async def async_generator(self): + """Async generator that yields blocks for aiohttp streaming.""" + if not self._started: + raise RuntimeError("Producer not started, call start_producer() first") + + try: + while True: + chunk = await self.queue.async_q.get() + + if chunk is EOF_SENTINEL: + break + + if isinstance(chunk, Exception): + raise chunk + + yield chunk + + except Exception as e: + logger.error("Insert consumer error: %s", e, exc_info=True) + raise + finally: + if self._producer_future and not self._producer_future.done(): + try: + await self._producer_future + except Exception: + pass + + async def close(self): + self.queue.shutdown() + if self._producer_future and not self._producer_future.done(): + try: + await asyncio.wait_for(self._producer_future, timeout=1.0) + except asyncio.TimeoutError: + logger.warning("Insert producer did not finish within timeout") + except Exception: + pass diff --git a/clickhouse_connect/driver/tools.py b/clickhouse_connect/driver/tools.py index 42480858..d26e55a1 100644 --- a/clickhouse_connect/driver/tools.py +++ b/clickhouse_connect/driver/tools.py @@ -1,3 +1,4 @@ +import asyncio from typing import Optional, Sequence, Dict, Any from clickhouse_connect.driver import Client @@ -31,3 +32,38 @@ def insert_file(client: Client, fmt=fmt, settings=settings, compression=compression) + + +async def insert_file_async(client, + table: str, + file_path: str, + fmt: Optional[str] = None, + column_names: Optional[Sequence[str]] = None, + database: Optional[str] = None, + settings: Optional[Dict[str, Any]] = None, + compression: Optional[str] = None) -> QuerySummary: + + if not database and table[0] not in ('`', "'") and table.find('.') > 0: + full_table = table + elif database: + full_table = f'{quote_identifier(database)}.{quote_identifier(table)}' + else: + full_table = quote_identifier(table) + if not fmt: + fmt = 'CSV' if column_names else 'CSVWithNames' + if compression is None: + if file_path.endswith('.gzip') or file_path.endswith('.gz'): + compression = 'gzip' + + def read_file(): + with open(file_path, 'rb') as file: + return file.read() + + file_data = await asyncio.to_thread(read_file) + + return await client.raw_insert(full_table, + column_names=column_names, + insert_block=file_data, + fmt=fmt, + settings=settings, + compression=compression) diff --git a/clickhouse_connect/driver/transform.py b/clickhouse_connect/driver/transform.py index ce4e9b0a..51bc0c0b 100644 --- a/clickhouse_connect/driver/transform.py +++ b/clickhouse_connect/driver/transform.py @@ -71,6 +71,20 @@ def get_block(): if not error_msg: error_msg = extract_error_message(source.last_message) raise StreamFailureError(error_msg) from None + raise StreamFailureError("Stream ended unexpectedly (connection closed by server)") from ex + + # Handle async streaming errors (ClientPayloadError from aiohttp) + if ex.__class__.__name__ == "ClientPayloadError": + if source.last_message: + error_msg = None + exception_tag = getattr(source, "exception_tag", None) + if exception_tag: + error_msg = extract_exception_with_tag(source.last_message, exception_tag) + if not error_msg: + error_msg = extract_error_message(source.last_message) + raise StreamFailureError(error_msg) from None + raise StreamFailureError("Stream failed during read (connection closed by server)") from ex + raise block_num += 1 return result_block diff --git a/examples/run_async.py b/examples/run_async.py index 069e3f90..7728bce8 100644 --- a/examples/run_async.py +++ b/examples/run_async.py @@ -1,51 +1,37 @@ #!/usr/bin/env python -u """ -This example will execute 10 queries in total, 2 concurrent queries at a time. -Each query will sleep for 2 seconds before returning. -Here's a sample output that shows that the queries are executed concurrently in batches of 2: -``` -Completed query 1, elapsed ms since start: 2002 -Completed query 0, elapsed ms since start: 2002 -Completed query 3, elapsed ms since start: 4004 -Completed query 2, elapsed ms since start: 4005 -Completed query 4, elapsed ms since start: 6006 -Completed query 5, elapsed ms since start: 6007 -Completed query 6, elapsed ms since start: 8009 -Completed query 7, elapsed ms since start: 8009 -Completed query 9, elapsed ms since start: 10011 -Completed query 8, elapsed ms since start: 10011 -``` +Demonstrates concurrent async queries using clickhouse-connect. + +Executes 10 queries with a concurrency limit of 2. Each query sleeps for 2 seconds, +so the total wall time is ~10 seconds rather than ~20. + +Sample output: + Completed query 1, elapsed: 2002ms + Completed query 0, elapsed: 2003ms + Completed query 3, elapsed: 4005ms + Completed query 2, elapsed: 4005ms + ... """ import asyncio -from datetime import datetime +import time import clickhouse_connect -QUERIES = 10 -SEMAPHORE = 2 - async def concurrent_queries(): - test_query = "SELECT sleep(2)" - client = await clickhouse_connect.get_async_client() - - start = datetime.now() - - async def semaphore_wrapper(sm: asyncio.Semaphore, num: int): - async with sm: - await client.query(query=test_query) - print(f"Completed query {num}, " - f"elapsed ms since start: {int((datetime.now() - start).total_seconds() * 1000)}") - - semaphore = asyncio.Semaphore(SEMAPHORE) - await asyncio.gather(*[semaphore_wrapper(semaphore, num) for num in range(QUERIES)]) - await client.close() + async with await clickhouse_connect.get_async_client() as client: + semaphore = asyncio.Semaphore(2) + start = time.monotonic() + async def run_query(num: int): + async with semaphore: + await client.query("SELECT sleep(2)") + elapsed = int((time.monotonic() - start) * 1000) + print(f"Completed query {num}, elapsed: {elapsed}ms") -async def main(): - await concurrent_queries() + await asyncio.gather(*(run_query(i) for i in range(10))) -asyncio.run(main()) +asyncio.run(concurrent_queries()) diff --git a/pyproject.toml b/pyproject.toml index 40b152e2..7854c502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,4 @@ log_cli = true log_cli_level = "INFO" env_files = ["test.env"] asyncio_default_fixture_loop_scope = "session" +addopts = "-n 4" diff --git a/setup.py b/setup.py index 7fe65086..265f5d65 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ def run_setup(try_c: bool = True): 'arrow': ['pyarrow>=22.0; python_version>="3.14"', 'pyarrow; python_version<"3.14"'], 'orjson': ['orjson'], 'tzlocal': ['tzlocal>=4.0'], + 'async': ['aiohttp>=3.8.0'], }, tests_require=['pytest'], entry_points={ diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 12598689..a681628a 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -1,3 +1,6 @@ +# pylint: disable=duplicate-code + +import asyncio import sys import os import random @@ -8,12 +11,13 @@ import pytest_asyncio from pytest import fixture +from clickhouse_connect import get_async_client from clickhouse_connect import common from clickhouse_connect.driver.common import coerce_bool from clickhouse_connect.driver.exceptions import OperationalError from clickhouse_connect.tools.testing import TableContext from clickhouse_connect.driver.httpclient import HttpClient -from clickhouse_connect.driver import AsyncClient, Client, create_client +from clickhouse_connect.driver import Client, create_client from tests.helpers import PROJECT_ROOT_DIR @@ -35,6 +39,28 @@ class TestException(BaseException): pass +def make_client_config(test_config: TestConfig, **kwargs): + """Helper to build client config dict from test_config with optional overrides.""" + settings = kwargs.pop("settings", {}).copy() + if test_config.insert_quorum: + settings["insert_quorum"] = test_config.insert_quorum + elif test_config.cloud: + settings["select_sequential_consistency"] = 1 + + return { + "host": test_config.host, + "port": test_config.port, + "username": test_config.username, + "password": test_config.password, + "database": test_config.test_database, + "compress": test_config.compress, + "settings": settings, + **kwargs, + } + + +# pylint: disable=redefined-outer-name + @fixture(scope='session', autouse=True, name='test_config') def test_config_fixture() -> Iterator[TestConfig]: common.set_setting('max_connection_age', 15) # Make sure resetting connections doesn't break stuff @@ -92,6 +118,162 @@ def test_table_engine_fixture() -> Iterator[str]: yield 'MergeTree' +@fixture(scope="module") +def shared_loop(): + """Shared event loop for running async clients in sync test context.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@fixture(params=["sync", "async"]) +def client_mode(request): + return request.param + + +@fixture +def call(client_mode, shared_loop): + """Wrapper to call functions in appropriate sync/async context.""" + if client_mode == "sync": + return lambda fn, *args, **kwargs: fn(*args, **kwargs) + return lambda fn, *args, **kwargs: shared_loop.run_until_complete(fn(*args, **kwargs)) + + +@fixture +def consume_stream(client_mode, call): + """Fixture to consume a stream in either sync or async mode.""" + + def _consume(stream, callback=None): + if client_mode == "sync": + with stream: + for item in stream: + if callback: + callback(item) + else: + + async def runner(): + async with stream: + async for item in stream: + if callback: + callback(item) + + call(runner) + + return _consume + + +@fixture +def client_factory(client_mode, test_config, shared_loop): + """Factory for creating clients with custom configuration in tests.""" + clients = [] + + def factory(**kwargs): + config = { + "host": test_config.host, + "port": test_config.port, + "username": test_config.username, + "password": test_config.password, + "database": test_config.test_database, + "compress": test_config.compress, + **kwargs, + } + + if client_mode == "sync": + client = create_client(**config) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", 1) + if client.min_version("24.8") and (client.min_version("24.12") or not test_config.cloud): + client.set_client_setting("allow_experimental_json_type", 1) + client.set_client_setting("allow_experimental_dynamic_type", 1) + client.set_client_setting("allow_experimental_variant_type", 1) + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", test_config.insert_quorum) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", 1) + else: + client = shared_loop.run_until_complete(get_async_client(**config)) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") + + clients.append(client) + return client + + yield factory + + for client in clients: + try: + if client_mode == "sync": + client.close() + else: + shared_loop.run_until_complete(client.close()) + except Exception: # pylint: disable=broad-exception-caught + pass + + +@fixture +def param_client(client_mode, test_config, shared_loop): + """Provides client based on client_mode parameter.""" + if client_mode == "sync": + client = create_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + settings={"allow_suspicious_low_cardinality_types": 1}, + client_name="int_tests/param_sync", + ) + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", 1) + if client.min_version("24.8") and (client.min_version("24.12") or not test_config.cloud): + client.set_client_setting("allow_experimental_json_type", 1) + client.set_client_setting("allow_experimental_dynamic_type", 1) + client.set_client_setting("allow_experimental_variant_type", 1) + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", test_config.insert_quorum) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", 1) + + yield client + client.close() + else: + client = shared_loop.run_until_complete( + get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + settings={"allow_suspicious_low_cardinality_types": 1}, + client_name="int_tests/param_async", + ) + ) + + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") + + yield client + shared_loop.run_until_complete(client.close()) + + # pylint: disable=too-many-branches @fixture(scope='session', autouse=True, name='test_client') def test_client_fixture(test_config: TestConfig, test_create_client: Callable) -> Iterator[Client]: @@ -142,9 +324,29 @@ def test_client_fixture(test_config: TestConfig, test_create_client: Callable) - sys.stderr.write('Successfully stopped docker compose') -@pytest_asyncio.fixture(scope='session', autouse=True, name='test_async_client') -async def test_async_client_fixture(test_client: Client) -> AsyncContextManager[AsyncClient]: - async with AsyncClient(client=test_client) as client: +@pytest_asyncio.fixture(scope="function", loop_scope="function", name="test_native_async_client") +async def test_native_async_client_fixture(test_config: TestConfig) -> AsyncContextManager: + """Function-scoped async client fixture""" + async with await get_async_client( + host=test_config.host, + port=test_config.port, + username=test_config.username, + password=test_config.password, + database=test_config.test_database, + compress=test_config.compress, + client_name="int_tests/native_async", + ) as client: + if client.min_version("22.8"): + client.set_client_setting("database_replicated_enforce_synchronous_settings", "1") + if client.min_version("24.8"): + client.set_client_setting("allow_experimental_json_type", "1") + client.set_client_setting("allow_experimental_dynamic_type", "1") + client.set_client_setting("allow_experimental_variant_type", "1") + if test_config.insert_quorum: + client.set_client_setting("insert_quorum", str(test_config.insert_quorum)) + elif test_config.cloud: + client.set_client_setting("select_sequential_consistency", "1") + yield client diff --git a/tests/integration_tests/test_arrow.py b/tests/integration_tests/test_arrow.py index bf051b6e..35706963 100644 --- a/tests/integration_tests/test_arrow.py +++ b/tests/integration_tests/test_arrow.py @@ -8,18 +8,18 @@ from clickhouse_connect.driver.options import arrow # pylint: disable=no-name-in-module -def test_arrow(test_client: Client, table_context: Callable): +def test_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_insert', ['animal String', 'legs Int64']): n_legs = arrow.array([2, 4, 5, 100] * 50) animals = arrow.array(['Flamingo', 'Horse', 'Brittle stars', 'Centipede'] * 50) names = ['legs', 'animal'] insert_table = arrow.Table.from_arrays([n_legs, animals], names=names) - test_client.insert_arrow('test_arrow_insert', insert_table) - result_table = test_client.query_arrow('SELECT * FROM test_arrow_insert', use_strings=False) + call(param_client.insert_arrow, 'test_arrow_insert', insert_table) + result_table = call(param_client.query_arrow, 'SELECT * FROM test_arrow_insert', use_strings=False) arrow_schema = result_table.schema assert arrow_schema.field(0).name == 'animal' assert arrow_schema.field(0).type == arrow.binary() @@ -29,7 +29,7 @@ def test_arrow(test_client: Client, table_context: Callable): assert arrow.compute.sum(result_table['legs']).as_py() == 5550 assert len(result_table.columns) == 2 - arrow_table = test_client.query_arrow('SELECT number from system.numbers LIMIT 500', + arrow_table = call(param_client.query_arrow, 'SELECT number from system.numbers LIMIT 500', settings={'max_block_size': 50}) arrow_schema = arrow_table.schema assert arrow_schema.field(0).name == 'number' @@ -37,21 +37,26 @@ def test_arrow(test_client: Client, table_context: Callable): assert arrow_table.num_rows == 500 -def test_arrow_stream(test_client: Client, table_context: Callable): +def test_arrow_stream(param_client: Client, call, table_context, consume_stream): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_insert', ['counter Int64', 'letter String']): counter = arrow.array(range(1000000)) alphabet = string.ascii_lowercase letter = arrow.array([alphabet[x % 26] for x in range(1000000)]) names = ['counter', 'letter'] insert_table = arrow.Table.from_arrays([counter, letter], names=names) - test_client.insert_arrow('test_arrow_insert', insert_table) - stream = test_client.query_arrow_stream('SELECT * FROM test_arrow_insert', use_strings=True) - with stream: - result_tables = list(stream) + call(param_client.insert_arrow, 'test_arrow_insert', insert_table) + stream = call(param_client.query_arrow_stream, 'SELECT * FROM test_arrow_insert', use_strings=True) + result_tables = [] + + def process(table): + result_tables.append(table) + + consume_stream(stream, process) + # Hopefully we made the table long enough we got multiple tables in the query assert len(result_tables) > 1 total_rows = 0 @@ -67,20 +72,20 @@ def test_arrow_stream(test_client: Client, table_context: Callable): assert total_rows == 1000000 -def test_arrow_map(test_client: Client, table_context: Callable): +def test_arrow_map(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') with table_context('test_arrow_map', ['trade_date Date, code String', 'kdj Map(String, Float32)', 'update_time DateTime DEFAULT now()']): data = [[date(2023, 10, 15), 'C1', {'k': 2.5, 'd': 0, 'j': 0}], [date(2023, 10, 16), 'C2', {'k': 3.5, 'd': 0, 'j': -.372}]] - test_client.insert('test_arrow_map', data, column_names=('trade_date', 'code', 'kdj'), + call(param_client.insert, 'test_arrow_map', data, column_names=('trade_date', 'code', 'kdj'), settings={'insert_deduplication_token': '10381'}) - arrow_table = test_client.query_arrow('SELECT * FROM test_arrow_map ORDER BY trade_date', + arrow_table = call(param_client.query_arrow, 'SELECT * FROM test_arrow_map ORDER BY trade_date', use_strings=True) assert isinstance(arrow_table.schema, arrow.Schema) - test_client.insert_arrow('test_arrow_map', arrow_table, settings={'insert_deduplication_token': '10382'}) - assert 4 == test_client.command('SELECT count() FROM test_arrow_map') + call(param_client.insert_arrow, 'test_arrow_map', arrow_table, settings={'insert_deduplication_token': '10382'}) + assert 4 == call(param_client.command, 'SELECT count() FROM test_arrow_map') diff --git a/tests/integration_tests/test_async_client.py b/tests/integration_tests/test_async_client.py deleted file mode 100644 index 64e5f215..00000000 --- a/tests/integration_tests/test_async_client.py +++ /dev/null @@ -1,302 +0,0 @@ -""" -AsyncClient tests that verify that the wrapper for each method is working correctly. -""" - -from typing import Callable - -import numpy as np -import pandas as pd -import pytest - -from clickhouse_connect.driver.options import arrow, IS_PANDAS_2 # pylint: disable=no-name-in-module -from clickhouse_connect.driver.exceptions import ProgrammingError -from clickhouse_connect.driver import AsyncClient - - -@pytest.mark.asyncio -async def test_client_settings(test_async_client: AsyncClient): - key = 'prefer_column_name_to_alias' - value = '1' - test_async_client.set_client_setting(key, value) - assert test_async_client.get_client_setting(key) == value - - -@pytest.mark.asyncio -async def test_min_version(test_async_client: AsyncClient): - assert test_async_client.min_version('19') is True - assert test_async_client.min_version('22.4') is True - assert test_async_client.min_version('99999') is False - - -@pytest.mark.asyncio -async def test_query(test_async_client: AsyncClient): - result = await test_async_client.query('SELECT * FROM system.tables') - assert len(result.result_set) > 0 - assert result.row_count > 0 - assert result.first_item == next(result.named_results()) - - -stream_query = 'SELECT number, randomStringUTF8(50) FROM numbers(10000)' -stream_settings = {'max_block_size': 4000} - - -# pylint: disable=duplicate-code -@pytest.mark.asyncio -async def test_query_column_block_stream(test_async_client: AsyncClient): - block_stream = await test_async_client.query_column_block_stream(stream_query, settings=stream_settings) - total = 0 - block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - total += sum(block[0]) - assert total == 49995000 - assert block_count > 1 - - -# pylint: disable=duplicate-code -@pytest.mark.asyncio -async def test_query_row_block_stream(test_async_client: AsyncClient): - block_stream = await test_async_client.query_row_block_stream(stream_query, settings=stream_settings) - total = 0 - block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - for row in block: - total += row[0] - assert total == 49995000 - assert block_count > 1 - - -@pytest.mark.asyncio -async def test_query_rows_stream(test_async_client: AsyncClient): - row_stream = await test_async_client.query_rows_stream('SELECT number FROM numbers(10000)') - total = 0 - with row_stream: - for row in row_stream: - total += row[0] - assert total == 49995000 - - -@pytest.mark.asyncio -async def test_raw_query(test_async_client: AsyncClient): - result = await test_async_client.raw_query('SELECT 42') - assert result == b'42\n' - - -@pytest.mark.asyncio -async def test_raw_stream(test_async_client: AsyncClient): - stream = await test_async_client.raw_stream('SELECT 42') - result = b'' - with stream: - for chunk in stream: - result += chunk - assert result == b'42\n' - - -@pytest.mark.asyncio -async def test_query_np(test_async_client: AsyncClient): - result = await test_async_client.query_np('SELECT number FROM numbers(5)') - assert isinstance(result, np.ndarray) - assert list(result) == [[0], [1], [2], [3], [4]] - - -@pytest.mark.asyncio -async def test_query_np_stream(test_async_client: AsyncClient): - stream = await test_async_client.query_np_stream('SELECT number FROM numbers(5)') - result = np.array([]) - with stream: - for block in stream: - result = np.append(result, block) - assert list(result) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_df(test_async_client: AsyncClient): - result = await test_async_client.query_df('SELECT number FROM numbers(5)') - assert isinstance(result, pd.DataFrame) - assert list(result['number']) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_df_stream(test_async_client: AsyncClient): - stream = await test_async_client.query_df_stream('SELECT number FROM numbers(5)') - result = [] - with stream: - for block in stream: - result.append(list(block['number'])) - assert result == [[0, 1, 2, 3, 4]] - - -@pytest.mark.asyncio -async def test_create_query_context(test_async_client: AsyncClient): - query_context = test_async_client.create_query_context( - query='SELECT {k: Int32}', - parameters={'k': 42}, - column_oriented=True) - result = await test_async_client.query(context=query_context) - assert result.row_count == 1 - assert result.result_set == [[42]] - - -@pytest.mark.asyncio -async def test_query_arrow(test_async_client: AsyncClient): - if not arrow: - pytest.skip('PyArrow package not available') - result = await test_async_client.query_arrow('SELECT number FROM numbers(5)') - assert isinstance(result, arrow.Table) - assert list(result[0].to_pylist()) == [0, 1, 2, 3, 4] - - -@pytest.mark.asyncio -async def test_query_arrow_stream(test_async_client: AsyncClient): - if not arrow: - pytest.skip('PyArrow package not available') - stream = await test_async_client.query_arrow_stream('SELECT number FROM numbers(5)') - result = [] - with stream: - for block in stream: - result.append(block[0].to_pylist()) - assert result == [[0, 1, 2, 3, 4]] - - -@pytest.mark.asyncio -async def test_command(test_async_client: AsyncClient): - version = await test_async_client.command('SELECT version()') - assert int(version.split('.')[0]) >= 19 - - -@pytest.mark.asyncio -async def test_ping(test_async_client: AsyncClient): - assert await test_async_client.ping() is True - - -@pytest.mark.asyncio -async def test_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_insert', ['key UInt32', 'value String']) as ctx: - await test_async_client.insert(ctx.table, [[42, 'str_0'], [144, 'str_1']]) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_insert_df(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_insert_df', ['key UInt32', 'value String']) as ctx: - df = pd.DataFrame([[42, 'str_0'], [144, 'str_1']], columns=['key', 'value']) - await test_async_client.insert_df(ctx.table, df) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_insert_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip('PyArrow package not available') - with table_context('test_async_client_insert_arrow', ['key UInt32', 'value String']) as ctx: - data = arrow.Table.from_arrays([arrow.array([42, 144]), arrow.array(['str_0', 'str_1'])], names=['key', 'value']) - await test_async_client.insert_arrow(ctx.table, data) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_create_insert_context(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_create_insert_context', ['key UInt32', 'value String']) as ctx: - data = [[1, 'a'], [2, 'b']] - insert_context = await test_async_client.create_insert_context(table=ctx.table, data=data) - await test_async_client.insert(context=insert_context) - result = (await test_async_client.query(f'SELECT * FROM {ctx.table} ORDER BY key ASC')).result_columns - assert result == [[1, 2], ['a', 'b']] - - -@pytest.mark.asyncio -async def test_data_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_data_insert', ['key UInt32', 'value String']) as ctx: - df = pd.DataFrame([[42, 'str_0'], [144, 'str_1']], columns=['key', 'value']) - insert_context = await test_async_client.create_insert_context(ctx.table, df.columns) - insert_context.data = df - await test_async_client.data_insert(insert_context) - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['str_0', 'str_1']] - - -@pytest.mark.asyncio -async def test_raw_insert(test_async_client: AsyncClient, table_context: Callable): - with table_context('test_async_client_raw_insert', ['key UInt32', 'value String']) as ctx: - await test_async_client.raw_insert(table=ctx.table, - column_names=['key', 'value'], - insert_block='42,"foo"\n144,"bar"\n', - fmt='CSV') - result_set = (await test_async_client.query(f"SELECT * FROM {ctx.table} ORDER BY key ASC")).result_columns - assert result_set == [[42, 144], ['foo', 'bar']] - - -@pytest.mark.asyncio -async def test_query_df_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip("PyArrow package not available") - - data = [[78, pd.NA, "a"], [51, 421, "b"]] - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - - with table_context( - "df_pyarrow_query_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = df.convert_dtypes(dtype_backend="pyarrow") - await test_async_client.insert_df(ctx.table, df) - result_df = await test_async_client.query_df_arrow(f"SELECT * FROM {ctx.table} ORDER BY i64") - for dt in list(result_df.dtypes): - assert isinstance(dt, pd.ArrowDtype) - else: - with pytest.raises(ProgrammingError): - result_df = await test_async_client.query_df_arrow(f"SELECT * FROM {ctx.table}") - - -@pytest.mark.asyncio -async def test_insert_df_arrow(test_async_client: AsyncClient, table_context: Callable): - if not arrow: - pytest.skip("PyArrow package not available") - - data = [[78, pd.NA, "a"], [51, 421, "b"]] - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - - with table_context( - "df_pyarrow_insert_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = df.convert_dtypes(dtype_backend="pyarrow") - await test_async_client.insert_df_arrow(ctx.table, df) - res_df = await test_async_client.query(f"SELECT * from {ctx.table} ORDER BY i64") - assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] - else: - with pytest.raises(ProgrammingError, match="pandas 2.x"): - await test_async_client.insert_df_arrow(ctx.table, df) - - with table_context( - "df_pyarrow_insert_test", - [ - "i64 Int64", - "ni64 Nullable(Int64)", - "str String", - ], - ) as ctx: - if IS_PANDAS_2: - df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) - df["i64"] = df["i64"].astype(pd.ArrowDtype(arrow.int64())) - with pytest.raises(ProgrammingError, match="Non-Arrow columns found"): - await test_async_client.insert_df_arrow(ctx.table, df) - else: - with pytest.raises(ProgrammingError, match="pandas 2.x"): - await test_async_client.insert_df_arrow(ctx.table, df) diff --git a/tests/integration_tests/test_async_features.py b/tests/integration_tests/test_async_features.py new file mode 100644 index 00000000..86b8ee03 --- /dev/null +++ b/tests/integration_tests/test_async_features.py @@ -0,0 +1,243 @@ +import asyncio +import time +from typing import Callable + +import pytest + +from clickhouse_connect import get_async_client +from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError, ProgrammingError +from tests.integration_tests.conftest import make_client_config + +# pylint: disable=protected-access + + +@pytest.mark.asyncio +async def test_concurrent_queries(test_config): + """Verify multiple queries execute concurrently (not sequentially).""" + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: + queries = [client.query(f"SELECT {i}, sleep(0.1)") for i in range(10)] + + start = time.time() + results = await asyncio.gather(*queries) + elapsed = time.time() - start + + assert elapsed < 0.5, f"Took {elapsed}s, queries appear to run sequentially" + assert len(results) == 10 + + for i, result in enumerate(results): + assert result.row_count == 1 + first_row = result.result_rows[0] + assert first_row[0] == i + + +@pytest.mark.asyncio +async def test_stream_cancellation(test_config): + """Test that early exit from async iteration doesn't leak resources.""" + async with await get_async_client(**make_client_config(test_config)) as client: + stream = await client.query_rows_stream("SELECT number FROM numbers(100000)", settings={"max_block_size": 1000}) + + count = 0 + async with stream: + async for _ in stream: + count += 1 + if count >= 10: + break + + assert count == 10 + + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + +@pytest.mark.asyncio +async def test_concurrent_streams(test_config): + """Verify multiple streams can run in parallel.""" + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: + + async def consume_stream(stream_id: int): + stream = await client.query_rows_stream( + f"SELECT number FROM numbers(1000) WHERE number % 3 = {stream_id}", settings={"max_block_size": 100} + ) + total = 0 + async with stream: + async for row in stream: + total += row[0] + return total + + start = time.time() + results = await asyncio.gather(consume_stream(0), consume_stream(1), consume_stream(2)) + elapsed = time.time() - start + + assert len(results) == 3 + assert all(r > 0 for r in results) + assert elapsed < 5.0 + + +@pytest.mark.asyncio +async def test_context_manager_cleanup(test_config): + """Test proper resource cleanup on context manager exit.""" + client = await get_async_client(**make_client_config(test_config)) + + assert client._initialized is True + assert client._session is not None + + async with client: + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + assert client._session is None or client._session.closed + + with pytest.raises((RuntimeError, OperationalError)): + await client.query("SELECT 1") + + +@pytest.mark.asyncio +async def test_session_concurrency_protection(test_config): + """Test that concurrent queries in the same session are blocked.""" + async with await get_async_client(**make_client_config(test_config, session_id="test_concurrent_session")) as client: + + async def long_query(): + return await client.query("SELECT sleep(0.5), 1") + + async def quick_query(): + await asyncio.sleep(0.1) + return await client.query("SELECT 1") + + # This can raise either: + # - ProgrammingError (client-side detection - best effort) + # - DatabaseError code 373 (server-side SESSION_IS_LOCKED - when client check is too slow) + # Both are valid ways to detect the concurrent session violation. + with pytest.raises((ProgrammingError, DatabaseError)) as exc_info: + await asyncio.gather(long_query(), quick_query()) + + # Verify it's the right kind of error (concurrent session access) + error_msg = str(exc_info.value).lower() + assert ("concurrent" in error_msg or "session" in error_msg or "locked" in error_msg), \ + f"Expected session concurrency error, got: {exc_info.value}" + + +@pytest.mark.asyncio +async def test_timeout_handling(test_config): + """Test that async timeout exceptions propagate correctly.""" + async with await get_async_client( + **make_client_config(test_config), + send_receive_timeout=1, # 1 second timeout + autogenerate_session_id=False, # No session to avoid session locking after timeout + ) as client: + # This query should timeout (sleep 2 seconds with 1 second timeout) + with pytest.raises((asyncio.TimeoutError, OperationalError)): + await client.query("SELECT sleep(2)") + + # Client should remain functional after timeout + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + +@pytest.mark.asyncio +async def test_connection_pool_reuse(test_config): + """Verify connection pooling works correctly under load.""" + async with await get_async_client( + **make_client_config(test_config), + connector_limit=10, # Limit pool size + connector_limit_per_host=5, + autogenerate_session_id=False, + ) as client: + # Run more queries in parallel than pool size + queries = [client.query(f"SELECT {i}") for i in range(50)] + + start = time.time() + results = await asyncio.gather(*queries) + elapsed = time.time() - start + + assert len(results) == 50 + for i, result in enumerate(results): + assert result.result_rows[0][0] == i + + assert elapsed < 10.0 + + +@pytest.mark.asyncio +async def test_concurrent_inserts(test_config, table_context: Callable): + """Test multiple inserts can run in parallel.""" + with table_context("test_concurrent_inserts", ["id UInt32", "value String"]) as ctx: + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: + + async def insert_batch(start_id: int, count: int): + data = [[start_id + i, f"value_{start_id + i}"] for i in range(count)] + await client.insert(ctx.table, data, settings={"wait_for_async_insert": 1}) + + await asyncio.gather( + insert_batch(0, 10), + insert_batch(100, 10), + insert_batch(200, 10), + insert_batch(300, 10), + insert_batch(400, 10), + ) + + result = await client.query(f"SELECT count() FROM {ctx.table}") + assert result.result_rows[0][0] == 50 + + +@pytest.mark.asyncio +async def test_error_isolation(test_config): + """Test that one failing query doesn't break other concurrent queries.""" + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: + + async def good_query(n: int): + return await client.query(f"SELECT {n}") + + async def bad_query(): + return await client.query("SELECT invalid_syntax_here!!!") + + results = await asyncio.gather(good_query(1), bad_query(), good_query(2), bad_query(), good_query(3), return_exceptions=True) + + assert results[0].result_rows[0][0] == 1 + assert results[2].result_rows[0][0] == 2 + assert results[4].result_rows[0][0] == 3 + + assert isinstance(results[1], Exception) + assert isinstance(results[3], Exception) + + +@pytest.mark.asyncio +async def test_streaming_early_termination(test_config): + """Verify streaming can be terminated early without issues.""" + async with await get_async_client(**make_client_config(test_config, autogenerate_session_id=False)) as client: + stream = await client.query_rows_stream("SELECT number, repeat('x', 10000) FROM numbers(100000)", settings={"max_block_size": 1000}) + + count = 0 + async with stream: + async for _ in stream: + count += 1 + if count >= 1000: + break # Early termination + + assert count == 1000 + + # Client should still be functional after early termination + result = await client.query("SELECT 1") + assert result.result_rows[0][0] == 1 + + stream2 = await client.query_rows_stream("SELECT number FROM numbers(100)", settings={"max_block_size": 10}) + + count2 = 0 + async with stream2: + async for _ in stream2: + count2 += 1 + + assert count2 == 100 + + +@pytest.mark.asyncio +async def test_regular_query_streams_then_materializes(test_config): + """Verify regular query() uses streaming internally but materializes result.""" + async with await get_async_client(**make_client_config(test_config)) as client: + result = await client.query("SELECT number FROM numbers(10000)") + + assert len(result.result_rows) == 10000 + assert result.result_rows[0][0] == 0 + assert result.result_rows[-1][0] == 9999 + + expected_numbers = list(range(10000)) + actual_numbers = [row[0] for row in result.result_rows] + assert actual_numbers == expected_numbers diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index eb24a6b5..df3e5e15 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -29,75 +29,96 @@ def _is_valid_uuid_v4(id_string: str) -> bool: return False -def test_ping(test_client: Client): - assert test_client.ping() is True +def test_ping(param_client, call): + assert call(param_client.ping) is True -def test_query(test_client: Client): - result = test_client.query('SELECT * FROM system.tables') +def test_query(param_client, call): + result = call(param_client.query, 'SELECT * FROM system.tables') assert len(result.result_set) > 0 assert result.row_count > 0 assert result.first_item == next(result.named_results()) -def test_command(test_client: Client): - version = test_client.command('SELECT version()') +def test_command(param_client, call): + version = call(param_client.command, 'SELECT version()') assert int(version.split('.')[0]) >= 19 -def test_client_name(test_client: Client): - user_agent = test_client.headers['User-Agent'] - assert 'test' in user_agent +def test_client_name(param_client, client_mode): + user_agent = param_client.headers['User-Agent'] + assert 'test' in user_agent or 'param' in user_agent assert 'py/' in user_agent + assert f"mode:{client_mode}" in user_agent -def test_transport_settings(test_client: Client): - result = test_client.query('SELECT name,database FROM system.tables', +def test_transport_settings(param_client, call): + result = call(param_client.query, 'SELECT name,database FROM system.tables', transport_settings={'X-Workload': 'ONLINE'}) assert result.column_names == ('name', 'database') assert len(result.result_set) > 0 -def test_none_database(test_client: Client): - old_db = test_client.database - test_db = test_client.command('select currentDatabase()') +def test_none_database(param_client, call): + old_db = param_client.database + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == old_db try: - test_client.database = None - test_client.query('SELECT * FROM system.tables') - test_db = test_client.command('select currentDatabase()') + param_client.database = None + call(param_client.query, 'SELECT * FROM system.tables') + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == 'default' - test_client.database = old_db - test_db = test_client.command('select currentDatabase()') + param_client.database = old_db + test_db = call(param_client.command, 'select currentDatabase()') assert test_db == old_db finally: - test_client.database = old_db + param_client.database = old_db -def test_session_params(test_config: TestConfig): +def test_session_params(test_config: TestConfig, client_factory, call): session_id = 'TEST_SESSION_ID_' + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password) - result = client.query('SELECT number FROM system.numbers LIMIT 5', - settings={'query_id': 'test_session_params'}).result_set + client = client_factory(session_id=session_id) + result = call(client.query, 'SELECT number FROM system.numbers LIMIT 5', + settings={'query_id': 'test_session_params'}).result_set assert len(result) == 5 if client.min_version('21'): if test_config.host != 'localhost': return # By default, the session log isn't enabled, so we only validate in environments we control - sleep(10) # Allow the log entries to flush to tables - result = client.query( - f"SELECT session_id, user FROM system.session_log WHERE session_id = '{session_id}' AND " + - 'event_time > now() - 30').result_set - assert result[0] == (session_id, test_config.username) - result = client.query( - "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + - 'event_time > now() - 30').result_set - assert result[0] == ('test_session_params', test_config.username) + + def check_session_in_log(): + max_retries = 100 + for _ in range(max_retries): + result = call(client.query, + f"SELECT session_id, user FROM system.session_log WHERE session_id = '{session_id}' AND " + + 'event_time > now() - 30').result_set + + if len(result) > 0: + assert result[0] == (session_id, test_config.username) + return + + sleep(0.1) + + pytest.fail(f"session_id '{session_id}' did not appear in system.session_log after {max_retries * 0.1}s") + + def check_query_in_log(): + max_retries = 100 + for _ in range(max_retries): + result = call(client.query, + "SELECT query_id, user FROM system.query_log WHERE query_id = 'test_session_params' AND " + + 'event_time > now() - 30').result_set + + if len(result) > 0: + assert result[0] == ('test_session_params', test_config.username) + return + + sleep(0.1) + + pytest.fail(f"query_id 'test_session_params' did not appear in system.query_log after {max_retries * 0.1}s") + + # Check both logs with smart retry logic + check_session_in_log() + check_query_in_log() def test_dsn_config(test_config: TestConfig): @@ -105,31 +126,33 @@ def test_dsn_config(test_config: TestConfig): dsn = (f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:{test_config.port}' + f'/{test_config.test_database}?session_id={session_id}&show_clickhouse_errors=false') client = create_client(dsn=dsn) - assert client.get_client_setting('session_id') == session_id - count = client.command('SELECT count() from system.tables') - assert client.database == test_config.test_database - assert count > 0 try: - client.query('SELECT nothing') - except DatabaseError as ex: - assert 'returned an error' in str(ex) - client.close() + assert client.get_client_setting('session_id') == session_id + count = client.command('SELECT count() from system.tables') + assert client.database == test_config.test_database + assert count > 0 + try: + client.query('SELECT nothing') + except DatabaseError as ex: + assert 'returned an error' in str(ex) + finally: + client.close() -def test_no_columns_and_types_when_no_results(test_client: Client): +def test_no_columns_and_types_when_no_results(param_client, call): """ In case of no results, the column names and types are not returned when FORMAT Native is set. This may cause a lot of confusion. Read more: https://github.com/ClickHouse/clickhouse-connect/issues/257 """ - result = test_client.query('SELECT name, database, NOW() as dt FROM system.tables WHERE FALSE') + result = call(param_client.query, 'SELECT name, database, NOW() as dt FROM system.tables WHERE FALSE') assert result.column_names == () assert result.column_types == () assert result.result_set == [] -def test_get_columns_only(test_client: Client): - result = test_client.query('SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') +def test_get_columns_only(param_client, call): + result = call(param_client.query, 'SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') assert result.column_names == ('name', 'database', 'dt') assert len(result.column_types) == 3 assert isinstance(result.column_types[0], datatypes.string.String) @@ -137,28 +160,28 @@ def test_get_columns_only(test_client: Client): assert isinstance(result.column_types[2], datatypes.temporal.DateTime) assert len(result.result_set) == 0 - test_client.query('CREATE TABLE IF NOT EXISTS test_zero_insert (v Int8) ENGINE MergeTree() ORDER BY tuple()') - test_client.query('INSERT INTO test_zero_insert SELECT 1 LIMIT 0') + call(param_client.query, 'CREATE TABLE IF NOT EXISTS test_zero_insert (v Int8) ENGINE MergeTree() ORDER BY tuple()') + call(param_client.query, 'INSERT INTO test_zero_insert SELECT 1 LIMIT 0') -def test_no_limit(test_client: Client): - old_limit = test_client.query_limit - test_client.limit = 0 - result = test_client.query('SELECT name FROM system.databases') +def test_no_limit(param_client, call): + old_limit = param_client.query_limit + param_client.limit = 0 + result = call(param_client.query, 'SELECT name FROM system.databases') assert len(result.result_set) > 0 - test_client.limit = old_limit + param_client.limit = old_limit -def test_multiline_query(test_client: Client): - result = test_client.query(""" +def test_multiline_query(param_client, call): + result = call(param_client.query, """ SELECT * FROM system.tables """) assert len(result.result_set) > 0 -def test_query_with_inline_comment(test_client: Client): - result = test_client.query(""" +def test_query_with_inline_comment(param_client, call): + result = call(param_client.query, """ SELECT * -- This is just a comment FROM system.tables LIMIT 77 @@ -167,8 +190,8 @@ def test_query_with_inline_comment(test_client: Client): assert len(result.result_set) > 0 -def test_query_with_comment(test_client: Client): - result = test_client.query(""" +def test_query_with_comment(param_client, call): + result = call(param_client.query, """ SELECT * /* This is: a multiline comment */ @@ -177,14 +200,14 @@ def test_query_with_comment(test_client: Client): assert len(result.result_set) > 0 -def test_insert_csv_format(test_client: Client, test_table_engine: str): - test_client.command('DROP TABLE IF EXISTS test_csv') - test_client.command( +def test_insert_csv_format(param_client, call, test_table_engine: str): + call(param_client.command, 'DROP TABLE IF EXISTS test_csv') + call(param_client.command, 'CREATE TABLE test_csv ("key" String, "val1" Int32, "val2" Int32) ' + f'ENGINE {test_table_engine} ORDER BY tuple()') sql = f'INSERT INTO test_csv ("key", "val1", "val2") FORMAT CSV {CSV_CONTENT}' - test_client.command(sql) - result = test_client.query('SELECT * from test_csv') + call(param_client.command, sql) + result = call(param_client.query, 'SELECT * from test_csv') def compare_rows(row_1, row_2): return all(c1 == c2 for c1, c2 in zip(row_1, row_2)) @@ -194,35 +217,35 @@ def compare_rows(row_1, row_2): assert compare_rows(result.result_set[4], ['hij', 1, 0]) -def test_non_latin_query(test_client: Client): - result = test_client.query("SELECT database, name FROM system.tables WHERE engine_full IN ('空')") +def test_non_latin_query(param_client, call): + result = call(param_client.query, "SELECT database, name FROM system.tables WHERE engine_full IN ('空')") assert len(result.result_set) == 0 -def test_error_decode(test_client: Client): +def test_error_decode(param_client, call): try: - test_client.query("SELECT database, name FROM system.tables WHERE has_own_data = '空'") + call(param_client.query, "SELECT database, name FROM system.tables WHERE has_own_data = '空'") except DatabaseError as ex: assert '空' in str(ex) -def test_command_as_query(test_client: Client): +def test_command_as_query(param_client, call): # Test that non-SELECT and non-INSERT statements are treated as commands and # just return the QueryResult metadata - result = test_client.query("SET count_distinct_implementation = 'uniq'") + result = call(param_client.query, "SET count_distinct_implementation = 'uniq'") assert 'query_id' in result.first_item -def test_show_create(test_client: Client): - if not test_client.min_version('21'): - pytest.skip(f'Not supported server version {test_client.server_version}') - result = test_client.query('SHOW CREATE TABLE system.tables') +def test_show_create(param_client, call): + if not param_client.min_version('21'): + pytest.skip(f'Not supported server version {param_client.server_version}') + result = call(param_client.query, 'SHOW CREATE TABLE system.tables') result.close() assert 'statement' in result.column_names -def test_empty_result(test_client: Client): - assert len(test_client.query("SELECT * FROM system.tables WHERE name = '_NOT_A THING'").result_rows) == 0 +def test_empty_result(param_client, call): + assert len(call(param_client.query, "SELECT * FROM system.tables WHERE name = '_NOT_A THING'").result_rows) == 0 def test_temporary_tables(test_client: Client, test_config: TestConfig): @@ -250,29 +273,29 @@ def test_temporary_tables(test_client: Client, test_config: TestConfig): test_client.command('DROP TABLE IF EXISTS temp_test_table', settings=session_settings) -def test_str_as_bytes(test_client: Client, table_context: Callable): +def test_str_as_bytes(param_client, call, table_context: Callable): with table_context('test_insert_bytes', ['key UInt32', 'byte_str String', 'n_byte_str Nullable(String)']): - test_client.insert('test_insert_bytes', [[0, 'str_0', 'n_str_0'], [1, 'str_1', 'n_str_0']]) - test_client.insert('test_insert_bytes', [[2, 'str_2'.encode('ascii'), 'n_str_2'.encode()], + call(param_client.insert, 'test_insert_bytes', [[0, 'str_0', 'n_str_0'], [1, 'str_1', 'n_str_0']]) + call(param_client.insert, 'test_insert_bytes', [[2, 'str_2'.encode('ascii'), 'n_str_2'.encode()], [3, b'str_3', b'str_3'], [4, bytearray([5, 120, 24]), bytes([16, 48, 52])], [5, b'', None] ]) - result_set = test_client.query('SELECT * FROM test_insert_bytes ORDER BY key').result_columns + result_set = call(param_client.query, 'SELECT * FROM test_insert_bytes ORDER BY key').result_columns assert result_set[1][0] == 'str_0' assert result_set[1][3] == 'str_3' assert result_set[2][5] is None assert result_set[1][4].encode() == b'\x05\x78\x18' - result_set = test_client.query('SELECT * FROM test_insert_bytes ORDER BY key', + result_set = call(param_client.query, 'SELECT * FROM test_insert_bytes ORDER BY key', query_formats={'String': 'bytes'}).result_columns assert result_set[1][0] == b'str_0' assert result_set[1][4] == b'\x05\x78\x18' assert result_set[2][4] == b'\x10\x30\x34' -def test_embedded_binary(test_client: Client): +def test_embedded_binary(param_client, call): binary_params = {'$xx$': 'col1,col2\n100,700'.encode()} - result = test_client.raw_query( + result = call(param_client.raw_query, 'SELECT col2, col1 FROM format(CSVWithNames, $xx$)', parameters=binary_params) assert result == b'700\t100\n' @@ -280,83 +303,45 @@ def test_embedded_binary(test_client: Client): with open(movies_file, 'rb') as f: # read bytes data = f.read() binary_params = {'$parquet$': data} - result = test_client.query( + result = call(param_client.query, 'SELECT movie, rating FROM format(Parquet, $parquet$) ORDER BY movie', parameters=binary_params) assert result.first_item['movie'] == '12 Angry Men' binary_params = {'$mult$': 'foobar'.encode()} - result = test_client.query("SELECT $mult$ as m1, $mult$ as m2 WHERE m1 = 'foobar'", parameters=binary_params) + result = call(param_client.query, "SELECT $mult$ as m1, $mult$ as m2 WHERE m1 = 'foobar'", parameters=binary_params) assert result.first_item['m2'] == 'foobar' -def test_column_rename_setting_none(test_config: TestConfig): +def test_column_rename_setting_none(client_factory, call): sql = "SELECT 1 as `a.b.c d_e`" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory() + names = call(client.query, sql).column_names assert names[0] == "a.b.c d_e" -def test_column_rename_limit_0_path(test_config: TestConfig): +def test_column_rename_limit_0_path(client_factory, call): + """Test column renaming with LIMIT 0 query (no data returned).""" sql = "SELECT 1 as `a.b.c d_e` LIMIT 0" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="to_camelcase_without_prefix", - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory(rename_response_column="to_camelcase_without_prefix") + names = call(client.query, sql).column_names assert names[0] == "cDE" -def test_column_rename_data_path(test_config: TestConfig): +def test_column_rename_data_path(client_factory, call): + """Test column renaming with data returned.""" sql = "SELECT 1 as `a.b.c d_e`" - session_id = "TEST_SESSION_ID_" + test_config.test_database - client = create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="to_camelcase_without_prefix", - ) - names = client.query( - sql, - ).column_names - client.close() + client = client_factory(rename_response_column="to_camelcase_without_prefix") + names = call(client.query, sql).column_names assert names[0] == "cDE" -def test_column_rename_with_bad_option(test_config: TestConfig): - session_id = "TEST_SESSION_ID_" + test_config.test_database - +def test_column_rename_with_bad_option(client_factory): + """Test that invalid rename option raises ValueError.""" with pytest.raises(ValueError, match="Invalid option"): - create_client( - session_id=session_id, - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - rename_response_column="not_an_option", - ) + client_factory(rename_response_column="not_an_option") -def test_role_setting_works(test_client: Client, test_config: TestConfig): +def test_role_setting_works(param_client: Client, test_config: TestConfig, client_factory: Callable, call): if test_config.cloud: pytest.skip("Skipping role test in cloud mode - cannot create custom users") @@ -364,82 +349,73 @@ def test_role_setting_works(test_client: Client, test_config: TestConfig): user_limited = 'limit_rows_user' user_password = 'R7m!pZt9qL#x' - test_client.command(f'CREATE ROLE IF NOT EXISTS {role_limited}') - test_client.command(f'CREATE USER IF NOT EXISTS {user_limited} IDENTIFIED BY \'{user_password}\'') - test_client.command(f'GRANT SELECT ON system.numbers TO {user_limited}') - test_client.command(f'GRANT {role_limited} TO {user_limited}') - test_client.command(f'SET DEFAULT ROLE NONE TO {user_limited}') + call(param_client.command, f'CREATE ROLE IF NOT EXISTS {role_limited}') + call(param_client.command, f'CREATE USER IF NOT EXISTS {user_limited} IDENTIFIED BY \'{user_password}\'') + call(param_client.command, f'GRANT SELECT ON system.numbers TO {user_limited}') + call(param_client.command, f'GRANT {role_limited} TO {user_limited}') + call(param_client.command, f'SET DEFAULT ROLE NONE TO {user_limited}') - client = create_client( - host=test_client.server_host_name, - port=test_client.url.rsplit(':', 1)[-1].split('/')[0], + client = client_factory( + host=test_config.host, + port=test_config.port, username=user_limited, password=user_password, ) # the default should not have the role - res = client.query('SELECT currentRoles()') + res = call(client.query, 'SELECT currentRoles()') assert res.result_rows == [([],)] # passing it as a per-query setting should work - res = client.query('SELECT currentRoles()', settings={'role': role_limited}) + res = call(client.query, 'SELECT currentRoles()', settings={'role': role_limited}) assert res.result_rows == [([role_limited],)] # passing it as a per-client setting should work - role_client = create_client( - host=test_client.server_host_name, - port=test_client.url.rsplit(':', 1)[-1].split('/')[0], + role_client = client_factory( + host=test_config.host, + port=test_config.port, username=user_limited, password=user_password, settings={'role': role_limited}, ) - res = role_client.query('SELECT currentRoles()') + res = call(role_client.query, 'SELECT currentRoles()') assert res.result_rows == [([role_limited],)] -def test_query_id_autogeneration(test_client: Client, test_table_engine: str): +def test_query_id_autogeneration(param_client: Client, test_table_engine: str, call): """Test that query_id is auto-generated for query(), command(), and insert() methods""" - result = test_client.query("SELECT 1") + result = call(param_client.query, "SELECT 1") assert _is_valid_uuid_v4(result.query_id) - summary = test_client.command("DROP TABLE IF EXISTS test_query_id_nonexistent") + summary = call(param_client.command, "DROP TABLE IF EXISTS test_query_id_nonexistent") assert _is_valid_uuid_v4(summary.query_id()) - test_client.command("DROP TABLE IF EXISTS test_query_id_insert") - test_client.command(f"CREATE TABLE test_query_id_insert (id UInt32) ENGINE {test_table_engine} ORDER BY id") - summary = test_client.insert("test_query_id_insert", [[1], [2], [3]], column_names=["id"]) + call(param_client.command, "DROP TABLE IF EXISTS test_query_id_insert") + call(param_client.command, f"CREATE TABLE test_query_id_insert (id UInt32) ENGINE {test_table_engine} ORDER BY id") + summary = call(param_client.insert, "test_query_id_insert", [[1], [2], [3]], column_names=["id"]) assert _is_valid_uuid_v4(summary.query_id()) - test_client.command("DROP TABLE test_query_id_insert") + call(param_client.command, "DROP TABLE test_query_id_insert") -def test_query_id_manual_override(test_client: Client): +def test_query_id_manual_override(param_client: Client, call): """Test that manually specified query_id is respected and not overwritten""" manual_query_id = "test_manual_query_id_override" - result = test_client.query("SELECT 1", settings={"query_id": manual_query_id}) + result = call(param_client.query, "SELECT 1", settings={"query_id": manual_query_id}) assert result.query_id == manual_query_id # pylint: disable=protected-access -def test_query_id_disabled(test_config: TestConfig): +def test_query_id_disabled(client_factory, call): """Test that autogenerate_query_id=False works correctly""" - client_no_autogen = create_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - autogenerate_query_id=False, - ) - + client_no_autogen = client_factory(autogenerate_query_id=False) assert client_no_autogen._autogenerate_query_id is False # Even with autogen disabled, server generates a query_id - result = client_no_autogen.query("SELECT 1") + result = call(client_no_autogen.query, "SELECT 1") assert _is_valid_uuid_v4(result.query_id) - client_no_autogen.close() - -def test_query_id_in_query_logs(test_client: Client, test_config: TestConfig): +def test_query_id_in_query_logs(param_client: Client, test_config: TestConfig, call): """Test that query_id appears in ClickHouse's system.query_log for observability""" if test_config.cloud: pytest.skip("Skipping query_log test in cloud environment") @@ -447,7 +423,8 @@ def test_query_id_in_query_logs(test_client: Client, test_config: TestConfig): def check_in_logs(test_query_id): max_retries = 30 for _ in range(max_retries): - log_result = test_client.query( + log_result = call( + param_client.query, "SELECT query_id FROM system.query_log WHERE query_id = {query_id:String} AND event_time > now() - 30 LIMIT 1", parameters={"query_id": test_query_id} ) @@ -462,11 +439,56 @@ def check_in_logs(test_query_id): pytest.fail(f"query_id '{test_query_id}' did not appear in system.query_log after {max_retries * 0.1}s") # Manual override check - test_query_id_manual = "test_query_id_in_logs" - test_client.query("SELECT 1 as num", settings={"query_id": test_query_id_manual}) + test_query_id_manual = f"test_query_id_in_logs_{uuid.uuid4()}" + call(param_client.query, "SELECT 1 as num", settings={"query_id": test_query_id_manual}) check_in_logs(test_query_id_manual) # Autogen check - result = test_client.query("SELECT 2 as num") + result = call(param_client.query, "SELECT 2 as num") test_query_id_auto = result.query_id check_in_logs(test_query_id_auto) + + +def test_compression_enabled(client_factory, call, table_context): + """Test that compression works when enabled.""" + client = client_factory(compress=True) + + assert client.compression is not None + assert client.write_compression is not None + + with table_context("test_compression", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}"] for i in range(100)] + call(client.insert, "test_compression", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_compression") + assert result.result_rows[0][0] == 100 + + +def test_compression_disabled(client_factory, call, table_context): + """Test that compression can be explicitly disabled.""" + client = client_factory(compress=False) + + assert client.compression is None + assert client.write_compression is None + + with table_context("test_no_compression", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}"] for i in range(100)] + call(client.insert, "test_no_compression", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_no_compression") + assert result.result_rows[0][0] == 100 + + +def test_compression_gzip(client_factory, call, table_context): + """Test that gzip compression works.""" + client = client_factory(compress="gzip") + + assert client.compression == "gzip" + assert client.write_compression == "gzip" + + with table_context("test_gzip", ["id", "data"], ["UInt32", "String"]): + data = [[i, f"data_{i}" * 10] for i in range(50)] + call(client.insert, "test_gzip", data) + + result = call(client.query, "SELECT COUNT(*) FROM test_gzip") + assert result.result_rows[0][0] == 50 diff --git a/tests/integration_tests/test_contexts.py b/tests/integration_tests/test_contexts.py index 2accb436..c25332e0 100644 --- a/tests/integration_tests/test_contexts.py +++ b/tests/integration_tests/test_contexts.py @@ -4,39 +4,39 @@ from clickhouse_connect.driver import Client -def test_contexts(test_client: Client, table_context: Callable): +def test_contexts(param_client: Client, call, table_context: Callable): with table_context('test_contexts', ['key Int32', 'value1 String', 'value2 String']) as ctx: data = [[1, 'v1', 'v2'], [2, 'v3', 'v4']] - insert_context = test_client.create_insert_context(table=ctx.table, data=data) - test_client.insert(context=insert_context) - query_context = test_client.create_query_context( + insert_context = call(param_client.create_insert_context, table=ctx.table, data=data) + call(param_client.insert, context=insert_context) + query_context = param_client.create_query_context( query=f'SELECT value1, value2 FROM {ctx.table} WHERE key = {{k:Int32}}', parameters={'k': 2}, column_oriented=True) - result = test_client.query(context=query_context) + result = call(param_client.query, context=query_context) assert result.result_set[1][0] == 'v4' query_context.set_parameter('k', 1) - result = test_client.query(context=query_context) + result = call(param_client.query, context=query_context) assert result.row_count == 1 assert result.result_set[1][0] data = [[1, 'v5', 'v6'], [2, 'v7', 'v8']] - test_client.insert(data=data, context=insert_context) - result = test_client.query(context=query_context) + call(param_client.insert, data=data, context=insert_context) + result = call(param_client.query, context=query_context) assert result.row_count == 2 insert_context.data = [[5, 'v5', 'v6'], [7, 'v7', 'v8']] - test_client.insert(context=insert_context) - assert test_client.command(f'SELECT count() FROM {ctx.table}') == 6 + call(param_client.insert, context=insert_context) + assert call(param_client.command, f'SELECT count() FROM {ctx.table}') == 6 -def test_insert_context_data_cleared_on_failure(test_client: Client, table_context: Callable): +def test_insert_context_data_cleared_on_failure(param_client: Client, call, table_context: Callable): with table_context('test_contexts', ['key Int32', 'value1 String', 'value2 String']) as ctx: data = [[1, "v1", "v2"], [2, "v3", "v4"]] - insert_context = test_client.create_insert_context(table=ctx.table, data=data) + insert_context = call(param_client.create_insert_context, table=ctx.table, data=data) insert_context.table = f"{ctx.table}__does_not_exist" with pytest.raises(Exception): - test_client.insert(context=insert_context) + call(param_client.insert, context=insert_context) assert insert_context.data is None diff --git a/tests/integration_tests/test_dynamic.py b/tests/integration_tests/test_dynamic.py index e3d65880..7c34f6a5 100644 --- a/tests/integration_tests/test_dynamic.py +++ b/tests/integration_tests/test_dynamic.py @@ -14,18 +14,18 @@ from tests.integration_tests.conftest import TestConfig -def type_available(test_client: Client, data_type: str): - if test_client.get_client_setting(f'allow_experimental_{data_type}_type') is None: +def type_available(param_client: Client, data_type: str): + if param_client.get_client_setting(f'allow_experimental_{data_type}_type') is None: return - setting_def = test_client.server_settings.get(f'allow_experimental_{data_type}_type', None) + setting_def = param_client.server_settings.get(f'allow_experimental_{data_type}_type', None) if setting_def is not None and setting_def.value == '1': return - pytest.skip(f'New {data_type.upper()} type not available in this version: {test_client.server_version}') + pytest.skip(f'New {data_type.upper()} type not available in this version: {param_client.server_version}') -def test_variant(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('basic_variants', [ 'key Int32', 'v1 Variant(UInt64, String, Array(UInt64), UUID)', @@ -35,8 +35,8 @@ def test_variant(test_client: Client, table_context: Callable): [3, UUID('bef56f14-0870-4f82-a35e-9a47eff45a5b'), decimal.Decimal('777.25')], [4, [120, 250], IPv4Address('243.12.55.44')] ] - test_client.insert('basic_variants', data) - result = test_client.query('SELECT * FROM basic_variants ORDER BY key').result_set + call(param_client.insert, 'basic_variants', data) + result = call(param_client.query, 'SELECT * FROM basic_variants ORDER BY key').result_set assert result[0][1] == 58322 assert result[1][1] == 'a string' assert result[2][1] == UUID('bef56f14-0870-4f82-a35e-9a47eff45a5b') @@ -45,8 +45,8 @@ def test_variant(test_client: Client, table_context: Callable): assert result[3][2] == IPv4Address('243.12.55.44') -def test_nested_variant(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_nested_variant(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('nested_variants', [ 'key Int32', 'm1 Map(String, Variant(String, UInt128, Bool))', @@ -64,8 +64,8 @@ def test_nested_variant(test_client: Client, table_context: Callable): (), ] ] - test_client.insert('nested_variants', data) - result = test_client.query('SELECT * FROM nested_variants ORDER BY key').result_set + call(param_client.insert, 'nested_variants', data) + result = call(param_client.query, 'SELECT * FROM nested_variants ORDER BY key').result_set assert result[0][1]['k1'] == 'string1' assert result[0][1]['k2'] == 34782477743 assert result[0][2] == (-40, True) @@ -73,117 +73,117 @@ def test_nested_variant(test_client: Client, table_context: Callable): assert result[1][1]['k3'] == 100 -def test_variant_bool_int_ordering(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_bool_int_ordering(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_bool_int', [ 'key Int32', 'v1 Variant(Bool, Int32)']): - data = [[1, True], [2, 42], [3, False], [4, -7]] - test_client.insert('variant_bool_int', data) - result = test_client.query('SELECT * FROM variant_bool_int ORDER BY key').result_set + data = [[1, True], [2, 57], [3, False], [4, -7]] + call(param_client.insert, 'variant_bool_int', data) + result = call(param_client.query, 'SELECT * FROM variant_bool_int ORDER BY key').result_set assert result[0][1] is True - assert result[1][1] == 42 + assert result[1][1] == 57 assert result[2][1] is False assert result[3][1] == -7 -def test_variant_no_string_error(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_no_string_error(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_no_string', [ 'key Int32', 'v1 Variant(Int64, Float64)']): with pytest.raises(DataError): - test_client.insert('variant_no_string', [[1, 'hello']]) + call(param_client.insert, 'variant_no_string', [[1, 'hello']]) -def test_variant_ambiguous_arrays(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_ambiguous_arrays(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_arrays', [ 'key Int32', 'v1 Variant(Array(UInt32), Array(String))']): data = [[1, typed_variant([1, 2, 3], 'Array(UInt32)')], [2, typed_variant(['a', 'b'], 'Array(String)')]] - test_client.insert('variant_arrays', data) - result = test_client.query('SELECT * FROM variant_arrays ORDER BY key').result_set + call(param_client.insert, 'variant_arrays', data) + result = call(param_client.query, 'SELECT * FROM variant_arrays ORDER BY key').result_set assert result[0][1] == [1, 2, 3] assert result[1][1] == ['a', 'b'] -def test_variant_empty_array_fallback(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_empty_array_fallback(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_empty_array', [ 'key Int32', 'v1 Variant(Array(UInt32), Array(String))']): data = [[1, typed_variant([], 'Array(UInt32)')]] - test_client.insert('variant_empty_array', data) - result = test_client.query('SELECT * FROM variant_empty_array ORDER BY key').result_set + call(param_client.insert, 'variant_empty_array', data) + result = call(param_client.query, 'SELECT * FROM variant_empty_array ORDER BY key').result_set assert result[0][1] == [] -def test_variant_uuid_dispatch(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_uuid_dispatch(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_uuid', [ 'key Int32', 'v1 Variant(UUID, String)']): test_uuid = UUID('bef56f14-0870-4f82-a35e-9a47eff45a5b') data = [[1, test_uuid], [2, 'just a string']] - test_client.insert('variant_uuid', data) - result = test_client.query('SELECT * FROM variant_uuid ORDER BY key').result_set + call(param_client.insert, 'variant_uuid', data) + result = call(param_client.query, 'SELECT * FROM variant_uuid ORDER BY key').result_set assert result[0][1] == test_uuid assert result[1][1] == 'just a string' -def test_variant_no_implicit_coercion(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_no_implicit_coercion(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_no_coerce', [ 'key Int32', 'v1 Variant(Int32, String)']): test_uuid = UUID('bef56f14-0870-4f82-a35e-9a47eff45a5b') with pytest.raises(DataError): - test_client.insert('variant_no_coerce', [[1, test_uuid]]) + call(param_client.insert, 'variant_no_coerce', [[1, test_uuid]]) -def test_variant_all_null(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_all_null(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_all_null', [ 'key Int32', 'v1 Variant(Int32, String)']): data = [[1, None], [2, None], [3, None]] - test_client.insert('variant_all_null', data) - result = test_client.query('SELECT * FROM variant_all_null ORDER BY key').result_set + call(param_client.insert, 'variant_all_null', data) + result = call(param_client.query, 'SELECT * FROM variant_all_null ORDER BY key').result_set assert result[0][1] is None assert result[1][1] is None assert result[2][1] is None -def test_variant_leading_nulls(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_leading_nulls(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_leading_nulls', [ 'key Int32', 'v1 Variant(Int32, String)']): - data = [[1, None], [2, None], [3, 42], [4, 'hello'], [5, None]] - test_client.insert('variant_leading_nulls', data) - result = test_client.query('SELECT * FROM variant_leading_nulls ORDER BY key').result_set + data = [[1, None], [2, None], [3, 57], [4, 'hello'], [5, None]] + call(param_client.insert, 'variant_leading_nulls', data) + result = call(param_client.query, 'SELECT * FROM variant_leading_nulls ORDER BY key').result_set assert result[0][1] is None assert result[1][1] is None - assert result[2][1] == 42 + assert result[2][1] == 57 assert result[3][1] == 'hello' assert result[4][1] is None -def test_dynamic_nested(test_client: Client, table_context: Callable): - type_available(test_client, 'dynamic') +def test_dynamic_nested(param_client: Client, call, table_context: Callable): + type_available(param_client, 'dynamic') with table_context('nested_dynamics', [ 'm2 Map(String, Dynamic)' ], order_by='()'): data = [({'k4': 'string8', 'k5': 5000},)] - test_client.insert('nested_dynamics', data) - result = test_client.query('SELECT * FROM nested_dynamics').result_set + call(param_client.insert, 'nested_dynamics', data) + result = call(param_client.query, 'SELECT * FROM nested_dynamics').result_set assert result[0][0]['k5'] == '5000' -def test_dynamic(test_client: Client, table_context: Callable): - type_available(test_client, 'dynamic') +def test_dynamic(param_client: Client, call, table_context: Callable): + type_available(param_client, 'dynamic') with table_context('basic_dynamic', [ 'key UInt64', 'v1 Dynamic', @@ -193,25 +193,25 @@ def test_dynamic(test_client: Client, table_context: Callable): [2, 'a string', 55.2], [4, [120, 250], 577.22] ] - test_client.insert('basic_dynamic', data) - result = test_client.query('SELECT * FROM basic_dynamic ORDER BY key').result_set + call(param_client.insert, 'basic_dynamic', data) + result = call(param_client.query, 'SELECT * FROM basic_dynamic ORDER BY key').result_set assert result[2][1] == 'bef56f14-0870-4f82-a35e-9a47eff45a5b' assert result[3][1] == '[120, 250]' assert result[2][2] == '777.25' -def test_dynamic_shared_variant_unsupported_types(test_client: Client, table_context: Callable): - type_available(test_client, "dynamic") +def test_dynamic_shared_variant_unsupported_types(param_client: Client, call, table_context: Callable): + type_available(param_client, "dynamic") with table_context("dynamic_shared_variant", [ "id UInt8", "d Dynamic(max_types=0)", ]): - test_client.command("INSERT INTO dynamic_shared_variant SELECT 1, toDate('2024-01-02')") - test_client.command("INSERT INTO dynamic_shared_variant SELECT 2, toDateTime('2024-01-02 03:04:05')") - test_client.command("INSERT INTO dynamic_shared_variant SELECT 3, [1, 2, 3]") - test_client.command("INSERT INTO dynamic_shared_variant SELECT 4, 'hello'") + call(param_client.command, "INSERT INTO dynamic_shared_variant SELECT 1, toDate('2024-01-02')") + call(param_client.command, "INSERT INTO dynamic_shared_variant SELECT 2, toDateTime('2024-01-02 03:04:05')") + call(param_client.command, "INSERT INTO dynamic_shared_variant SELECT 3, [1, 2, 3]") + call(param_client.command, "INSERT INTO dynamic_shared_variant SELECT 4, 'hello'") - result = test_client.query("SELECT * FROM dynamic_shared_variant ORDER BY id").result_set + result = call(param_client.query, "SELECT * FROM dynamic_shared_variant ORDER BY id").result_set assert result[0][1] == b"\x0f\x0cM" assert result[1][1] == b"\x11%}\x93e" @@ -219,8 +219,8 @@ def test_dynamic_shared_variant_unsupported_types(test_client: Client, table_con assert result[3][1] == "hello" -def test_basic_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') +def test_basic_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') with table_context('new_json_basic', [ 'key Int32', 'value JSON', @@ -230,12 +230,12 @@ def test_basic_json(test_client: Client, table_context: Callable): jv1 = {'key1': 337, 'value.2': 'vvvv', 'HKD@spéçiäl': 'Special K', 'blank': 'not_really_blank'} njv2 = {'nk1': -302, 'nk2': {'sub1': 372, 'sub2': 'a string'}} njv3 = {'nk1': 5832.44, 'nk2': {'sub1': 47788382, 'sub2': 'sub2val', 'sub3': 'sub3str', 'space key': 'spacey'}} - test_client.insert('new_json_basic', [ + call(param_client.insert, 'new_json_basic', [ [5, jv1, None], [20, None, njv2], [25, jv3, njv3]]) - result = test_client.query('SELECT * FROM new_json_basic ORDER BY key').result_set + result = call(param_client.query, 'SELECT * FROM new_json_basic ORDER BY key').result_set json1 = result[0][1] assert json1['HKD@spéçiäl'] == 'Special K' assert 'key3' not in json1 @@ -252,23 +252,25 @@ def test_basic_json(test_client: Client, table_context: Callable): assert null_json3['nk2']['space key'] == 'spacey' set_write_format('JSON', 'string') - test_client.insert('new_json_basic', [[999, '{"key4": 283, "value.2": "str_value"}', '{"nk1":53}']]) - result = test_client.query('SELECT value.key4, null_value.nk1 FROM new_json_basic ORDER BY key').result_set + call(param_client.insert, 'new_json_basic', [[999, '{"key4": 283, "value.2": "str_value"}', '{"nk1":53}']]) + result = call(param_client.query, 'SELECT value.key4, null_value.nk1 FROM new_json_basic ORDER BY key').result_set assert result[3][0] == 283 assert result[3][1] == 53 -def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callable): - type_available(test_client, "json") - if test_client.server_settings.get("json_type_escape_dots_in_keys") is None: +def test_json_escaped_dots_roundtrip(param_client: Client, call, table_context: Callable): + type_available(param_client, "json") + if param_client.server_settings.get("json_type_escape_dots_in_keys") is None: pytest.skip("json_type_escape_dots_in_keys setting unavailable on this server version") # with escaping enabled dots are preserved in keys - test_client.command("SET json_type_escape_dots_in_keys=1") + if not param_client.get_client_setting("session_id"): + param_client.set_client_setting("session_id", str(UUID(int=0))) + call(param_client.command, "SET json_type_escape_dots_in_keys=1") with table_context("json_dots_escape", ["value JSON"], order_by="()"): payload = {"a.b": 123, "c": {"d.e": 456}} - test_client.insert("json_dots_escape", [[payload]]) - result = test_client.query("SELECT value FROM json_dots_escape").result_set + call(param_client.insert, "json_dots_escape", [[payload]]) + result = call(param_client.query, "SELECT value FROM json_dots_escape").result_set returned = result[0][0] assert "a.b" in returned @@ -278,11 +280,11 @@ def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callabl assert returned["c"]["d.e"] == 456 # with escaping disabled dots create nested structure - test_client.command("SET json_type_escape_dots_in_keys=0") + call(param_client.command, "SET json_type_escape_dots_in_keys=0") with table_context("json_dots_no_escape", ["value JSON"], order_by="()"): payload = {"a.b": 789} - test_client.insert("json_dots_no_escape", [[payload]]) - result = test_client.query("SELECT value FROM json_dots_no_escape").result_set + call(param_client.insert, "json_dots_no_escape", [[payload]]) + result = call(param_client.query, "SELECT value FROM json_dots_no_escape").result_set returned = result[0][0] assert "a" in returned @@ -290,22 +292,22 @@ def test_json_escaped_dots_roundtrip(test_client: Client, table_context: Callabl assert returned["a"]["b"] == 789 -def test_typed_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') +def test_typed_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') with table_context('new_json_typed', [ 'key Int32', 'value JSON(max_dynamic_paths=150, `a.b` DateTime64(3), SKIP a.c)' ]): v1 = '{"a":{"b":"2020-10-15T10:15:44.877", "c":"skip_me"}}' - test_client.insert('new_json_typed', [[1, v1]]) - result = test_client.query('SELECT * FROM new_json_typed ORDER BY key') + call(param_client.insert, 'new_json_typed', [[1, v1]]) + result = call(param_client.query, 'SELECT * FROM new_json_typed ORDER BY key') json1 = result.result_set[0][1] assert json1['a']['b'] == datetime.datetime(2020, 10, 15, 10, 15, 44, 877000) -def test_nullable_json(test_client: Client, table_context: Callable): - if not test_client.min_version('25.2'): - pytest.skip(f'Nullable(JSON) type not available in this version: {test_client.server_version}') +def test_nullable_json(param_client: Client, call, table_context: Callable): + if not param_client.min_version('25.2'): + pytest.skip(f'Nullable(JSON) type not available in this version: {param_client.server_version}') with table_context("nullable_json", [ "key Int32", "value_1 Nullable(JSON)", @@ -314,8 +316,8 @@ def test_nullable_json(test_client: Client, table_context: Callable): ]): v1 = {"item_a": 5, "item_b": 10} - test_client.insert("nullable_json", [[1, v1, json.dumps(v1), None], [2, v1, None, None]]) - result = test_client.query('SELECT * FROM nullable_json ORDER BY key') + call(param_client.insert, "nullable_json", [[1, v1, json.dumps(v1), None], [2, v1, None, None]]) + result = call(param_client.query, 'SELECT * FROM nullable_json ORDER BY key') assert result.result_set[0][1] == v1 assert result.result_set[1][1] == v1 assert result.result_set[0][2] == v1 @@ -324,26 +326,26 @@ def test_nullable_json(test_client: Client, table_context: Callable): assert result.result_set[1][3] is None -def test_complex_json(test_client: Client, table_context: Callable): - type_available(test_client, 'json') - if not test_client.min_version('24.10'): +def test_complex_json(param_client: Client, call, table_context: Callable): + type_available(param_client, 'json') + if not param_client.min_version('24.10'): pytest.skip('Complex JSON broken before 24.10') with table_context('new_json_complex', [ 'key Int32', 'value Tuple(t JSON)' ]): data = [[100, ({'a': 'qwe123', 'b': 'main', 'c': None},)]] - test_client.insert('new_json_complex', data) - result = test_client.query('SELECT * FROM new_json_complex ORDER BY key') + call(param_client.insert, 'new_json_complex', data) + result = call(param_client.query, 'SELECT * FROM new_json_complex ORDER BY key') json1 = result.result_set[0][1] assert json1['t']['a'] == 'qwe123' -def test_json_str_time(test_client: Client, test_config: TestConfig): +def test_json_str_time(param_client: Client, call, test_config: TestConfig): - if not test_client.min_version('25.1') or test_config.cloud: + if not param_client.min_version('25.1') or test_config.cloud: pytest.skip('JSON string/numbers bug before 25.1, skipping') - result = test_client.query("SELECT '{\"timerange\": \"2025-01-01T00:00:00+0000\"}'::JSON").result_set + result = call(param_client.query, "SELECT '{\"timerange\": \"2025-01-01T00:00:00+0000\"}'::JSON").result_set assert result[0][0]['timerange'] == datetime.datetime(2025, 1, 1) # The following query is broken -- looks like something to do with Nullable(String) in the Tuple @@ -351,32 +353,32 @@ def test_json_str_time(test_client: Client, test_config: TestConfig): # settings={'input_format_json_read_numbers_as_strings': 0}).result_set -def test_typed_variant_ambiguous_scalars(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_typed_variant_ambiguous_scalars(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_ambig_scalars', [ 'key Int32', 'v1 Variant(Int64, Float64)']): - data = [[1, typed_variant(42, 'Int64')], - [2, typed_variant(42, 'Float64')], + data = [[1, typed_variant(57, 'Int64')], + [2, typed_variant(57, 'Float64')], [3, None]] - test_client.insert('variant_ambig_scalars', data) - result = test_client.query('SELECT * FROM variant_ambig_scalars ORDER BY key').result_set - assert result[0][1] == 42 - assert result[1][1] == 42 + call(param_client.insert, 'variant_ambig_scalars', data) + result = call(param_client.query, 'SELECT * FROM variant_ambig_scalars ORDER BY key').result_set + assert result[0][1] == 57 + assert result[1][1] == 57 assert result[2][1] is None -def test_typed_variant_mixed_with_inference(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_typed_variant_mixed_with_inference(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_mixed_infer', [ 'key Int32', 'v1 Variant(Int64, String, Float64)']): - data = [[1, typed_variant(42, 'Int64')], + data = [[1, typed_variant(57, 'Int64')], [2, 'hello'], [3, 3.14]] - test_client.insert('variant_mixed_infer', data) - result = test_client.query('SELECT * FROM variant_mixed_infer ORDER BY key').result_set - assert result[0][1] == 42 + call(param_client.insert, 'variant_mixed_infer', data) + result = call(param_client.query, 'SELECT * FROM variant_mixed_infer ORDER BY key').result_set + assert result[0][1] == 57 assert result[1][1] == 'hello' assert result[2][1] == 3.14 @@ -390,48 +392,48 @@ def test_typed_variant_validation(): typed_variant(None, 'Int32') -def test_typed_variant_member_mismatch(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_typed_variant_member_mismatch(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_mismatch', [ 'key Int32', 'v1 Variant(Int32, String)']): with pytest.raises(DataError): - test_client.insert('variant_mismatch', [[1, typed_variant(42, 'Float64')]]) + call(param_client.insert, 'variant_mismatch', [[1, typed_variant(57, 'Float64')]]) -def test_variant_in_tuple(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_variant_in_tuple(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_in_tuple', [ 'key Int32', 't1 Tuple(Int64, Variant(Bool, String, Int32))']): data = [[1, (-40, True)], [2, (340283, 'str')], [3, (0, 55)]] - test_client.insert('variant_in_tuple', data) - result = test_client.query('SELECT * FROM variant_in_tuple ORDER BY key').result_set + call(param_client.insert, 'variant_in_tuple', data) + result = call(param_client.query, 'SELECT * FROM variant_in_tuple ORDER BY key').result_set assert result[0][1] == (-40, True) assert result[1][1] == (340283, 'str') assert result[2][1] == (0, 55) -def test_typed_variant_name_normalization(test_client: Client, table_context: Callable): - type_available(test_client, 'variant') +def test_typed_variant_name_normalization(param_client: Client, call, table_context: Callable): + type_available(param_client, 'variant') with table_context('variant_norm', [ 'key Int32', 'v1 Variant(Decimal(10, 2), String)']): data = [[1, typed_variant(decimal.Decimal('1.50'), 'Decimal(10,2)')]] - test_client.insert('variant_norm', data) - result = test_client.query('SELECT * FROM variant_norm ORDER BY key').result_set + call(param_client.insert, 'variant_norm', data) + result = call(param_client.query, 'SELECT * FROM variant_norm ORDER BY key').result_set assert result[0][1] == decimal.Decimal('1.50') -def test_json_with_many_paths(test_client: Client, table_context: Callable): +def test_json_with_many_paths(param_client: Client, call, table_context: Callable): """Test JSON with many dynamic paths to exercise the shared data structure.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_many_paths", ["id Int32", "data JSON(max_dynamic_paths=5)"]): large_json = {f"key_{i}": f"value_{i}" for i in range(20)} - test_client.insert("json_many_paths", [[1, large_json]]) - result = test_client.query("SELECT * FROM json_many_paths").result_set + call(param_client.insert, "json_many_paths", [[1, large_json]]) + result = call(param_client.query, "SELECT * FROM json_many_paths").result_set assert result[0][0] == 1 returned_json = result[0][1] @@ -442,9 +444,9 @@ def test_json_with_many_paths(test_client: Client, table_context: Callable): assert returned_json[f"key_{i}"] == f"value_{i}" -def test_json_with_long_values(test_client: Client, table_context: Callable): +def test_json_with_long_values(param_client: Client, call, table_context: Callable): """Test JSON shared data with long string values (>127 chars) to verify VarInt decoding.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_long_values", ["id Int32", "data JSON(max_dynamic_paths=2)"]): short_val = "a" * 10 medium_val = "b" * 150 @@ -455,8 +457,8 @@ def test_json_with_long_values(test_client: Client, table_context: Callable): "key_1": medium_val, "key_2": long_val, } - test_client.insert("json_long_values", [[1, test_json]]) - result = test_client.query("SELECT * FROM json_long_values").result_set + call(param_client.insert, "json_long_values", [[1, test_json]]) + result = call(param_client.query, "SELECT * FROM json_long_values").result_set assert result[0][0] == 1 returned_json = result[0][1] @@ -466,9 +468,9 @@ def test_json_with_long_values(test_client: Client, table_context: Callable): assert returned_json["key_2"] == long_val -def test_json_shared_data_primitive_types(test_client: Client, table_context: Callable): +def test_json_shared_data_primitive_types(param_client: Client, call, table_context: Callable): """Tests round-trip of integers, floats, booleans, strings, and NULL in shared data.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_primitive_types", ["id Int32", "data JSON(max_dynamic_paths=2)"]): test_data = { @@ -494,8 +496,8 @@ def test_json_shared_data_primitive_types(test_client: Client, table_context: Ca "negative_int": -1, } - test_client.insert("json_primitive_types", [[1, test_data]]) - result = test_client.query("SELECT * FROM json_primitive_types").result_set + call(param_client.insert, "json_primitive_types", [[1, test_data]]) + result = call(param_client.query, "SELECT * FROM json_primitive_types").result_set assert result[0][0] == 1 returned = result[0][1] @@ -523,20 +525,20 @@ def test_json_shared_data_primitive_types(test_client: Client, table_context: Ca assert returned["negative_int"] == test_data["negative_int"] -def test_json_shared_data_multiple_rows(test_client: Client, table_context: Callable): +def test_json_shared_data_multiple_rows(param_client: Client, call, table_context: Callable): """Test JSON shared data with multiple rows to ensure consistent decoding.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_multirow", ["id Int32", "data JSON(max_dynamic_paths=2)"]): test_data = [ {"a": "string_val", "b": 100, "c": 3.14, "d": True, "e": "more"}, - {"a": 42, "b": "different", "c": False, "d": 2.718, "e": -999}, + {"a": 57, "b": "different", "c": False, "d": 2.718, "e": -999}, {"a": 0, "b": 0.0, "c": "", "d": None, "e": False}, ] rows = [[i + 1, data] for i, data in enumerate(test_data)] - test_client.insert("json_multirow", rows) - result = test_client.query("SELECT * FROM json_multirow ORDER BY id").result_set + call(param_client.insert, "json_multirow", rows) + result = call(param_client.query, "SELECT * FROM json_multirow ORDER BY id").result_set # Row 1 assert result[0][0] == 1 @@ -566,13 +568,13 @@ def test_json_shared_data_multiple_rows(test_client: Client, table_context: Call assert row3["e"] is test_data[2]["e"] # Query column with nulls via dot notation - result_w_nulls = test_client.query("SELECT data.d FROM json_multirow ORDER BY id").result_set + result_w_nulls = call(param_client.query, "SELECT data.d FROM json_multirow ORDER BY id").result_set assert [result[0] for result in result_w_nulls] == [item["d"] for item in test_data] -def test_json_shared_data_nested_keys(test_client: Client, table_context: Callable): +def test_json_shared_data_nested_keys(param_client: Client, call, table_context: Callable): """Test that dotted keys in shared data are properly nested into dicts.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_shared_nested", ["id Int32", "data JSON(max_dynamic_paths=2)"]): test_data = { @@ -583,8 +585,8 @@ def test_json_shared_data_nested_keys(test_client: Client, table_context: Callab "nested.sibling": "sibling_value", "flat_overflow": "flat_value", } - test_client.insert("json_shared_nested", [[1, test_data]]) - result = test_client.query("SELECT * FROM json_shared_nested").result_set + call(param_client.insert, "json_shared_nested", [[1, test_data]]) + result = call(param_client.query, "SELECT * FROM json_shared_nested").result_set returned = result[0][1] assert isinstance(returned, dict) @@ -594,9 +596,9 @@ def test_json_shared_data_nested_keys(test_client: Client, table_context: Callab assert returned["flat_overflow"] == "flat_value" -def test_json_dynamic_variant_decoding(test_client: Client, table_context: Callable): +def test_json_dynamic_variant_decoding(param_client: Client, call, table_context: Callable): """Test with one nested JSON path changing type across rows.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_dyn_variant", ["id Int32", "attributes JSON"]): rows = [ [1, {"agent.value": "computer-use", "meta.region": "us-west-2"}], @@ -604,8 +606,8 @@ def test_json_dynamic_variant_decoding(test_client: Client, table_context: Calla [3, {"agent.value": True, "meta.region": "eu-central-1"}], [4, {"agent.value": 2.81, "meta.region": "ap-southeast-1"}], ] - test_client.insert("json_dyn_variant", rows) - result = test_client.query("SELECT * FROM json_dyn_variant ORDER BY id").result_set + call(param_client.insert, "json_dyn_variant", rows) + result = call(param_client.query, "SELECT * FROM json_dyn_variant ORDER BY id").result_set assert result[0][1]["agent"]["value"] == "computer-use" assert result[1][1]["agent"]["value"] == 2164330 @@ -617,9 +619,9 @@ def test_json_dynamic_variant_decoding(test_client: Client, table_context: Calla assert result[3][1]["meta"]["region"] == "ap-southeast-1" -def test_json_dynamic_variant_multiple_rows(test_client: Client, table_context: Callable): +def test_json_dynamic_variant_multiple_rows(param_client: Client, call, table_context: Callable): """Test with top-level JSON keys changing type across rows.""" - type_available(test_client, "json") + type_available(param_client, "json") with table_context("json_dyn_multi", ["id Int32", "data JSON"]): rows = [ [1, {"value": 95, "status": "user_1"}], @@ -627,8 +629,8 @@ def test_json_dynamic_variant_multiple_rows(test_client: Client, table_context: [3, {"value": True, "status": 4.5}], [4, {"value": 0.0, "status": -10}], ] - test_client.insert("json_dyn_multi", rows) - result = test_client.query("SELECT * FROM json_dyn_multi ORDER BY id").result_set + call(param_client.insert, "json_dyn_multi", rows) + result = call(param_client.query, "SELECT * FROM json_dyn_multi ORDER BY id").result_set r1 = result[0][1] assert r1["value"] == 95 diff --git a/tests/integration_tests/test_error_handling.py b/tests/integration_tests/test_error_handling.py index 262e7fa1..2509ca11 100644 --- a/tests/integration_tests/test_error_handling.py +++ b/tests/integration_tests/test_error_handling.py @@ -1,95 +1,68 @@ import logging import pytest -from clickhouse_connect import create_client from clickhouse_connect.driver.exceptions import DatabaseError, OperationalError from tests.integration_tests.conftest import TestConfig -# pylint: disable=attribute-defined-outside-init - -class TestErrorHandling: - """Tests for error handling in the ClickHouse Connect client""" - - @pytest.fixture(autouse=True) - def setup(self, test_config: TestConfig): - self.config = test_config - - def test_wrong_port_error_message(self): - """ - Test that connecting to the wrong port properly propagates - the error message from ClickHouse. - """ - if self.config.cloud: - pytest.skip("Skipping wrong port test in cloud environ.") - wrong_port = 9000 - - with pytest.raises((DatabaseError, OperationalError)) as excinfo: - create_client( - host=self.config.host, - port=wrong_port, - username=self.config.username, - password=self.config.password, - ) +def test_wrong_port_error_message(client_factory, test_config: TestConfig): + """ + Test that connecting to the wrong port properly propagates + the error message from ClickHouse. + """ + if test_config.cloud: + pytest.skip("Skipping wrong port test in cloud environ.") + wrong_port = 9000 + + with pytest.raises((DatabaseError, OperationalError)) as excinfo: + client_factory(port=wrong_port) + + error_message = str(excinfo.value) + assert ( + f"Port {wrong_port} is for clickhouse-client program" in error_message + or "You must use port 8123 for HTTP" in error_message + ) + +def test_connection_refused_error(client_factory, test_config: TestConfig, caplog): + """ + Test that connecting to a port where nothing is listening + produces a clear error message. + """ + if test_config.cloud: + pytest.skip("Skipping connection refused test in cloud environ.") + # Suppress urllib3 and aiohttp connection warnings + urllib3_logger = logging.getLogger("urllib3.connectionpool") + original_urllib3_level = urllib3_logger.level + urllib3_logger.setLevel(logging.CRITICAL) + + # Swallow logging messages to prevent polluting pytest output + caplog.set_level(logging.CRITICAL) + + try: + # Use a port that shouldn't have anything listening + unused_port = 45678 + + # Try connecting to an unused port - should fail with connection refused + with pytest.raises(OperationalError) as excinfo: + client_factory(port=unused_port) error_message = str(excinfo.value) assert ( - f"Port {wrong_port} is for clickhouse-client program" in error_message - or "You must use port 8123 for HTTP" in error_message + "Connection refused" in error_message + or "Failed to establish a new connection" in error_message + or "Cannot connect to host" in error_message + or "Connection aborted" in error_message # Port occasionally occupied in CI, apparently ) + finally: + # Restore the original logging level + urllib3_logger.setLevel(original_urllib3_level) - def test_connection_refused_error(self, caplog): - """ - Test that connecting to a port where nothing is listening - produces a clear error message. - """ - if self.config.cloud: - pytest.skip("Skipping connection refused test in cloud environ.") - # Suppress urllib3 connection pool warnings - urllib3_logger = logging.getLogger("urllib3.connectionpool") - original_level = urllib3_logger.level - urllib3_logger.setLevel(logging.CRITICAL) - - # Swallow logging messages to prevent polluting pytest output - caplog.set_level(logging.CRITICAL) - - try: - # Use a port that shouldn't have anything listening - unused_port = 45678 - - # Try connecting to an unused port - should fail with connection refused - with pytest.raises(OperationalError) as excinfo: - create_client( - host=self.config.host, - port=unused_port, - username=self.config.username, - password=self.config.password, - ) - - error_message = str(excinfo.value) - assert ( - "Connection refused" in error_message - or "Failed to establish a new connection" in error_message - ) - finally: - # Restore the original logging level - urllib3_logger.setLevel(original_level) - - def test_successful_connection(self): - """ - Verify that connecting to the correct port works properly. - This serves as a sanity check that the test environment is configured correctly. - """ - # Connect to the correct HTTP port - client = create_client( - host=self.config.host, - port=self.config.port, # Use the port from test config - username=self.config.username, - password=self.config.password, - ) - # Simple query to verify connection works - result = client.command("SELECT 1") - assert result == 1 +def test_successful_connection(client_factory, call): + """Verify that connecting to the correct port works properly.""" + # Connect to the correct HTTP port (uses defaults from test_config) + client = client_factory() - client.close() + # Simple query to verify connection works + result = call(client.command, "SELECT 1") + assert result == 1 diff --git a/tests/integration_tests/test_external_data.py b/tests/integration_tests/test_external_data.py index d43fca69..4f16069f 100644 --- a/tests/integration_tests/test_external_data.py +++ b/tests/integration_tests/test_external_data.py @@ -2,7 +2,6 @@ import pytest -from clickhouse_connect import get_client from clickhouse_connect.driver import Client from clickhouse_connect.driver.external import ExternalData from clickhouse_connect.driver.options import arrow # pylint: disable=no-name-in-module @@ -11,40 +10,40 @@ ext_settings = {'input_format_allow_errors_num': 10, 'input_format_allow_errors_ratio': .2} -def test_external_simple(test_client: Client): +def test_external_simple(param_client: Client, call): data_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(data_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) - result = test_client.query('SELECT * FROM movies ORDER BY movie', + result = call(param_client.query, 'SELECT * FROM movies ORDER BY movie', external_data=data, settings=ext_settings).result_rows assert result[0][0] == '12 Angry Men' -def test_external_arrow(test_client: Client): +def test_external_arrow(param_client: Client, call): if not arrow: pytest.skip('PyArrow package not available') - if not test_client.min_version('21'): - pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}') + if not param_client.min_version('21'): + pytest.skip(f'PyArrow is not supported in this server version {param_client.server_version}') data_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(data_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) - result = test_client.query_arrow('SELECT * FROM movies ORDER BY movie', + result = call(param_client.query_arrow, 'SELECT * FROM movies ORDER BY movie', external_data=data, settings=ext_settings) assert str(result[0][0]) == '12 Angry Men' -def test_external_multiple(test_client: Client): +def test_external_multiple(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(movies_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) actors_file = f'{Path(__file__).parent}/actors.csv' data.add_file(actors_file, fmt='CSV', types='String,UInt16,String') - result = test_client.query('SELECT * FROM actors;', external_data=data, settings={ + result = call(param_client.query, 'SELECT * FROM actors;', external_data=data, settings={ 'input_format_allow_errors_num': 10, 'input_format_allow_errors_ratio': .2}).result_rows assert result[1][1] == 1940 - result = test_client.query( + result = call(param_client.query, 'SELECT _1, movie FROM actors INNER JOIN movies ON actors._3 = movies.movie AND actors._2 = 1940', external_data=data, settings=ext_settings).result_rows @@ -52,84 +51,77 @@ def test_external_multiple(test_client: Client): assert result[0][1] == 'Scarface' -def test_external_parquet(test_config: TestConfig, test_client: Client): +def test_external_parquet(test_config: TestConfig, param_client: Client, call): if test_config.cloud: pytest.skip('External data join not working in SMT, skipping') movies_file = f'{Path(__file__).parent}/movies.parquet' - test_client.command('DROP TABLE IF EXISTS movies') - test_client.command('DROP TABLE IF EXISTS num') - test_client.command(""" + call(param_client.command, 'DROP TABLE IF EXISTS movies') + call(param_client.command, 'DROP TABLE IF EXISTS num') + call(param_client.command, """ CREATE TABLE IF NOT EXISTS num (number UInt64, t String) ENGINE = MergeTree ORDER BY number""") - test_client.command(""" + call(param_client.command, """ INSERT INTO num SELECT number, concat(toString(number), 'x') as t FROM numbers(2500) WHERE (number > 1950) AND (number < 2025) """) data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.query( + result = call(param_client.query, "SELECT * FROM movies INNER JOIN num ON movies.year = number AND t = '2000x' ORDER BY movie", settings={'output_format_parquet_string_as_string': 1}, external_data=data).result_rows assert len(result) == 5 assert result[2][0] == 'Memento' - test_client.command('DROP TABLE num') + call(param_client.command, 'DROP TABLE num') -def test_external_binary(test_client: Client): - actors = 'Robert Redford\t1936\tThe Sting\nAl Pacino\t1940\tScarface'.encode() +def test_external_binary(param_client: Client, call): + actors = b'Robert Redford\t1936\tThe Sting\nAl Pacino\t1940\tScarface' data = ExternalData(file_name='actors.csv', data=actors, structure='name String, birth_year UInt16, movie String') - result = test_client.query('SELECT * FROM actors ORDER BY birth_year DESC', external_data=data).result_rows + result = call(param_client.query, 'SELECT * FROM actors ORDER BY birth_year DESC', external_data=data).result_rows assert len(result) == 2 assert result[1][2] == 'The Sting' -def test_external_empty_binary(test_client: Client): +def test_external_empty_binary(param_client: Client, call): data = ExternalData(file_name='empty.csv', data=b'', structure='name String') - result = test_client.query('SELECT * FROM empty', external_data=data).result_rows + result = call(param_client.query, 'SELECT * FROM empty', external_data=data).result_rows assert len(result) == 0 -def test_external_raw(test_client: Client): +def test_external_raw(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.parquet' data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.raw_query('SELECT avg(rating) FROM movies', external_data=data) + result = call(param_client.raw_query, 'SELECT avg(rating) FROM movies', external_data=data) assert '8.25' == result.decode()[0:4] -def test_external_command(test_client: Client): +def test_external_command(param_client: Client, call): movies_file = f'{Path(__file__).parent}/movies.parquet' data = ExternalData(movies_file, fmt='Parquet', structure=['movie String', 'year UInt16', 'rating Float64']) - result = test_client.command('SELECT avg(rating) FROM movies', external_data=data) + result = call(param_client.command, 'SELECT avg(rating) FROM movies', external_data=data) assert '8.25' == result[0:4] - test_client.command('DROP TABLE IF EXISTS movies_ext') - if test_client.min_version('22.8'): - query_result = test_client.query('CREATE TABLE movies_ext ENGINE MergeTree() ORDER BY tuple() EMPTY ' + + call(param_client.command, 'DROP TABLE IF EXISTS movies_ext') + if param_client.min_version('22.8'): + query_result = call(param_client.query, 'CREATE TABLE movies_ext ENGINE MergeTree() ORDER BY tuple() EMPTY ' + 'AS SELECT * FROM movies', external_data=data) assert 'query_id' in query_result.first_item - test_client.raw_query('INSERT INTO movies_ext SELECT * FROM movies', external_data=data) - assert 250 == test_client.command('SELECT COUNT() FROM movies_ext') - - -def test_external_with_form_encode(test_config: TestConfig): - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + call(param_client.raw_query, 'INSERT INTO movies_ext SELECT * FROM movies', external_data=data) + assert 250 == call(param_client.command, 'SELECT COUNT() FROM movies_ext') + + +def test_external_with_form_encode(client_factory, call): + form_client = client_factory(form_encode_query_params=True) movies_file = f'{Path(__file__).parent}/movies.csv' data = ExternalData(movies_file, fmt='CSVWithNames', structure=['movie String', 'year UInt16', 'rating Decimal32(3)']) # Test with parameters in the query - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM movies WHERE year > {year:UInt16} ORDER BY rating DESC LIMIT {limit:UInt32}', parameters={'year': 1990, 'limit': 5}, external_data=data, @@ -144,7 +136,7 @@ def test_external_with_form_encode(test_config: TestConfig): assert ratings == sorted(ratings, reverse=True) # Test raw query with external data and form encoding - raw_result = form_client.raw_query( + raw_result = call(form_client.raw_query, 'SELECT COUNT() FROM movies WHERE rating > {min_rating:Decimal32(3)}', parameters={'min_rating': 8.0}, external_data=data, diff --git a/tests/integration_tests/test_form_encode_query.py b/tests/integration_tests/test_form_encode_query.py index 907c1797..3b80d70f 100644 --- a/tests/integration_tests/test_form_encode_query.py +++ b/tests/integration_tests/test_form_encode_query.py @@ -1,35 +1,26 @@ from typing import Callable -from clickhouse_connect import get_client from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig -def test_form_encode_query_basic(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_query_basic(client_factory, call, table_context: Callable): """Test that form_encode_query sends parameters as form data""" - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_form_encode', ['id UInt32', 'name String', 'value Float64']): - test_client.insert('test_form_encode', + call(form_client.insert, 'test_form_encode', [[1, 'test1', 10.5], [2, 'test2', 20.3], [3, 'test3', 30.7]]) - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_encode WHERE id = {id:UInt32}', parameters={'id': 2} ) assert result.row_count == 1 assert result.first_row[1] == 'test2' - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_encode WHERE name = {name:String} AND value > {val:Float64}', parameters={'name': 'test3', 'val': 25.0} ) @@ -37,31 +28,24 @@ def test_form_encode_query_basic(test_client: Client, test_config: TestConfig, t assert result.first_row[0] == 3 -def test_form_encode_with_arrays(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_with_arrays(client_factory, call, table_context: Callable): """Test form_encode_query with array parameters""" - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_form_arrays', ['id UInt32', 'tags Array(String)']): - test_client.insert('test_form_arrays', + call(form_client.insert, 'test_form_arrays', [[1, ['tag1', 'tag2']], [2, ['tag2', 'tag3']], [3, ['tag1', 'tag3']]]) - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_arrays WHERE has(tags, {tag:String})', parameters={'tag': 'tag3'} ) assert result.row_count == 2 ids = [1, 3] - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_form_arrays WHERE id IN {ids:Array(UInt32)}', parameters={'ids': ids} ) @@ -69,18 +53,11 @@ def test_form_encode_with_arrays(test_client: Client, test_config: TestConfig, t assert sorted([row[0] for row in result.result_rows]) == [1, 3] -def test_form_encode_raw_query(test_config: TestConfig): +def test_form_encode_raw_query(client_factory, call): """Test form_encode_query with raw_query method""" - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) - result = form_client.raw_query( + result = call(form_client.raw_query, 'SELECT {a:Int32} + {b:Int32} as sum', parameters={'a': 10, 'b': 20} ) @@ -88,84 +65,56 @@ def test_form_encode_raw_query(test_config: TestConfig): assert b'30' in result -def test_form_encode_vs_regular(test_client: Client, test_config: TestConfig, table_context: Callable): +def test_form_encode_vs_regular(client_factory, param_client: Client, call, table_context: Callable): """Verify that form_encode_query produces same results as regular parameter handling""" - regular_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=False - ) + regular_client = client_factory(form_encode_query_params=False) - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) with table_context('test_comparison', ['id UInt32', 'text String', 'score Float64']): - test_client.insert('test_comparison', + call(param_client.insert, 'test_comparison', [[i, f'text_{i}', i * 1.5] for i in range(1, 11)]) query = 'SELECT * FROM test_comparison WHERE id > {min_id:UInt32} AND score < {max_score:Float64} ORDER BY id' params = {'min_id': 3, 'max_score': 12.0} - regular_result = regular_client.query(query, parameters=params) - form_result = form_client.query(query, parameters=params) + regular_result = call(regular_client.query, query, parameters=params) + form_result = call(form_client.query, query, parameters=params) assert regular_result.result_rows == form_result.result_rows assert regular_result.row_count == form_result.row_count -def test_form_encode_nullable_params(test_config: TestConfig): +def test_form_encode_nullable_params(client_factory, call): """Test form_encode_query with nullable parameters""" - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) - result = form_client.query( + result = call(form_client.query, 'SELECT {val:Nullable(String)} IS NULL as is_null', parameters={'val': None} ) assert result.first_row[0] == 1 - result = form_client.query( + result = call(form_client.query, 'SELECT {val:Nullable(String)} as value', parameters={'val': 'test_value'} ) assert result.first_row[0] == 'test_value' -def test_form_encode_schema_probe_query(test_config: TestConfig, table_context: Callable): +def test_form_encode_schema_probe_query(client_factory, call, table_context: Callable): """Test that schema-probe queries (LIMIT 0) work correctly with form_encode_query_params""" - form_client = get_client( - host=test_config.host, - port=test_config.port, - username=test_config.username, - password=test_config.password, - database=test_config.test_database, - form_encode_query_params=True - ) + form_client = client_factory(form_encode_query_params=True) # Test with a simple LIMIT 0 query - result = form_client.query('SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') + result = call(form_client.query, 'SELECT name, database, NOW() as dt FROM system.tables LIMIT 0') assert result.column_names == ('name', 'database', 'dt') assert len(result.column_types) == 3 assert len(result.result_set) == 0 # Test with LIMIT 0 and parameters with table_context('test_schema_probe', ['id UInt32', 'name String', 'value Float64']): - result = form_client.query( + result = call(form_client.query, 'SELECT * FROM test_schema_probe WHERE id = {id:UInt32} LIMIT 0', parameters={'id': 1} ) @@ -174,7 +123,7 @@ def test_form_encode_schema_probe_query(test_config: TestConfig, table_context: assert len(result.result_set) == 0 # Test with complex query and parameters - result = form_client.query( + result = call(form_client.query, 'SELECT id, name, value * {multiplier:Float64} as adjusted_value ' 'FROM test_schema_probe ' 'WHERE name = {filter_name:String} LIMIT 0', diff --git a/tests/integration_tests/test_formats.py b/tests/integration_tests/test_formats.py index 2e9876c7..0a1895d4 100644 --- a/tests/integration_tests/test_formats.py +++ b/tests/integration_tests/test_formats.py @@ -1,15 +1,11 @@ -from clickhouse_connect.driver import Client, ProgrammingError +from clickhouse_connect.driver import Client -def test_uint64_format(test_client: Client): +def test_uint64_format(param_client: Client, call): # Default should be unsigned - result = test_client.query('SELECT toUInt64(9523372036854775807) as value') + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value") assert result.result_set[0][0] == 9523372036854775807 - result = test_client.query('SELECT toUInt64(9523372036854775807) as value', query_formats={'UInt64': 'signed'}) + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value", query_formats={"UInt64": "signed"}) assert result.result_set[0][0] == -8923372036854775809 - result = test_client.query('SELECT toUInt64(9523372036854775807) as value', query_formats={'UInt64': 'native'}) + result = call(param_client.query, "SELECT toUInt64(9523372036854775807) as value", query_formats={"UInt64": "native"}) assert result.result_set[0][0] == 9523372036854775807 - try: - test_client.query('SELECT toUInt64(9523372036854775807) as signed', query_formats={'UInt64': 'huh'}) - except ProgrammingError: - pass diff --git a/tests/integration_tests/test_geometric.py b/tests/integration_tests/test_geometric.py index 5e667c78..a6f576ed 100644 --- a/tests/integration_tests/test_geometric.py +++ b/tests/integration_tests/test_geometric.py @@ -3,32 +3,32 @@ from clickhouse_connect.driver import Client -def test_point_column(test_client: Client, table_context: Callable): +def test_point_column(param_client: Client, call, table_context: Callable): with table_context('point_column_test', ['key Int32', 'point Point']): data = [[1, (3.55, 3.55)], [2, (4.55, 4.55)]] - test_client.insert('point_column_test', data) + call(param_client.insert, 'point_column_test', data) - query_result = test_client.query('SELECT * FROM point_column_test ORDER BY key').result_rows + query_result = call(param_client.query, 'SELECT * FROM point_column_test ORDER BY key').result_rows assert len(query_result) == 2 assert query_result[0] == (1, (3.55, 3.55)) assert query_result[1] == (2, (4.55, 4.55)) -def test_ring_column(test_client: Client, table_context: Callable): +def test_ring_column(param_client: Client, call, table_context: Callable): with table_context('ring_column_test', ['key Int32', 'ring Ring']): data = [[1, [(5.522, 58.472),(3.55, 3.55)]], [2, [(4.55, 4.55)]]] - test_client.insert('ring_column_test', data) + call(param_client.insert, 'ring_column_test', data) - query_result = test_client.query('SELECT * FROM ring_column_test ORDER BY key').result_rows + query_result = call(param_client.query, 'SELECT * FROM ring_column_test ORDER BY key').result_rows assert len(query_result) == 2 assert query_result[0] == (1, [(5.522, 58.472),(3.55, 3.55)]) assert query_result[1] == (2, [(4.55, 4.55)]) -def test_polygon_column(test_client: Client, table_context: Callable): +def test_polygon_column(param_client: Client, call, table_context: Callable): with table_context('polygon_column_test', ['key Int32', 'polygon Polygon']): - res = test_client.query("SELECT readWKTPolygon('POLYGON ((-64.8 32.3, -65.5 18.3, -80.3 25.2, -64.8 32.3))') as polygon") + res = call(param_client.query, "SELECT readWKTPolygon('POLYGON ((-64.8 32.3, -65.5 18.3, -80.3 25.2, -64.8 32.3))') as polygon") pg = res.first_row[0] - test_client.insert('polygon_column_test', [(1, pg), (4, pg)]) - query_result = test_client.query('SELECT key, polygon FROM polygon_column_test WHERE key = 4') + call(param_client.insert, 'polygon_column_test', [(1, pg), (4, pg)]) + query_result = call(param_client.query, 'SELECT key, polygon FROM polygon_column_test WHERE key = 4') assert query_result.first_row[1] == pg diff --git a/tests/integration_tests/test_inserts.py b/tests/integration_tests/test_inserts.py index 123cf8d5..5adfa97c 100644 --- a/tests/integration_tests/test_inserts.py +++ b/tests/integration_tests/test_inserts.py @@ -1,91 +1,91 @@ from decimal import Decimal from typing import Callable +import pytest + from clickhouse_connect.driver.client import Client from clickhouse_connect.driver.exceptions import DataError -def test_insert(test_client: Client, test_table_engine: str): - if test_client.min_version('19'): - test_client.command('DROP TABLE IF EXISTS test_system_insert') +def test_insert(param_client: Client, call, test_table_engine: str): + if param_client.min_version('19'): + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert') else: - test_client.command('DROP TABLE IF EXISTS test_system_insert SYNC') - test_client.command(f'CREATE TABLE test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name') - tables_result = test_client.query('SELECT * from system.tables') - test_client.insert(table='test_system_insert', column_names='*', data=tables_result.result_set) - copy_result = test_client.command('SELECT count() from test_system_insert') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert SYNC') + call(param_client.command, f'CREATE TABLE test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name') + tables_result = call(param_client.query, 'SELECT * from system.tables') + call(param_client.insert, table='test_system_insert', column_names='*', data=tables_result.result_set) + copy_result = call(param_client.command, 'SELECT count() from test_system_insert') assert tables_result.row_count == copy_result - test_client.command('DROP TABLE IF EXISTS test_system_insert') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert') -def test_decimal_conv(test_client: Client, table_context: Callable): +def test_decimal_conv(param_client: Client, call, table_context: Callable): with table_context('test_num_conv', ['col1 UInt64', 'col2 Int32', 'f1 Float64']): data = [[Decimal(5), Decimal(-182), Decimal(55.2)], [Decimal(57238478234), Decimal(77), Decimal(-29.5773)]] - test_client.insert('test_num_conv', data) - result = test_client.query('SELECT * FROM test_num_conv').result_set + call(param_client.insert, 'test_num_conv', data) + result = call(param_client.query, 'SELECT * FROM test_num_conv').result_set assert result == [(5, -182, 55.2), (57238478234, 77, -29.5773)] -def test_float_decimal_conv(test_client: Client, table_context: Callable): +def test_float_decimal_conv(param_client: Client, call, table_context: Callable): with table_context('test_float_to_dec_conv', ['col1 Decimal32(6)','col2 Decimal32(6)', 'col3 Decimal128(6)', 'col4 Decimal128(6)']): data = [[0.492917, 0.49291700, 0.492917, 0.49291700]] - test_client.insert('test_float_to_dec_conv', data) - result = test_client.query('SELECT * FROM test_float_to_dec_conv').result_set + call(param_client.insert, 'test_float_to_dec_conv', data) + result = call(param_client.query, 'SELECT * FROM test_float_to_dec_conv').result_set assert result == [(Decimal("0.492917"), Decimal("0.492917"), Decimal("0.492917"), Decimal("0.492917"))] -def test_bad_data_insert(test_client: Client, table_context: Callable): +def test_bad_data_insert(param_client: Client, call, table_context: Callable): with table_context('test_bad_insert', ['key Int32', 'float_col Float64']): data = [[1, 3.22], [2, 'nope']] - try: - test_client.insert('test_bad_insert', data) - except DataError as ex: - assert 'array' in str(ex) + with pytest.raises(DataError, match="array"): + call(param_client.insert, 'test_bad_insert', data) -def test_bad_strings(test_client: Client, table_context: Callable): +def test_bad_strings(param_client: Client, call, table_context: Callable): with table_context('test_bad_strings', 'key Int32, fs FixedString(6), nsf Nullable(FixedString(4))'): try: - test_client.insert('test_bad_strings', [[1, b'\x0535', None]]) + call(param_client.insert, 'test_bad_strings', [[1, b'\x0535', None]]) except DataError as ex: assert 'match' in str(ex) try: - test_client.insert('test_bad_strings', [[1, b'\x0535abc', '😀🙃']]) + call(param_client.insert, 'test_bad_strings', [[1, b'\x0535abc', '😀🙃']]) except DataError as ex: assert 'encoded' in str(ex) -def test_low_card_dictionary_size(test_client: Client, table_context: Callable): +def test_low_card_dictionary_size(param_client: Client, call, table_context: Callable): with table_context('test_low_card_dict', 'key Int32, lc LowCardinality(String)', settings={'index_granularity': 65536 }): data = [[x, str(x)] for x in range(30000)] - test_client.insert('test_low_card_dict', data) - assert 30000 == test_client.command('SELECT count() FROM test_low_card_dict') + call(param_client.insert, 'test_low_card_dict', data) + assert 30000 == call(param_client.command, 'SELECT count() FROM test_low_card_dict') -def test_column_names_spaces(test_client: Client, table_context: Callable): +def test_column_names_spaces(param_client: Client, call, table_context: Callable): with table_context('test_column_spaces', columns=['key 1', 'value 1'], column_types=['Int32', 'String']): data = [[1, 'str 1'], [2, 'str 2']] - test_client.insert('test_column_spaces', data) - result = test_client.query('SELECT * FROM test_column_spaces').result_rows + call(param_client.insert, 'test_column_spaces', data) + result = call(param_client.query, 'SELECT * FROM test_column_spaces').result_rows assert result[0][0] == 1 assert result[1][1] == 'str 2' -def test_numeric_conversion(test_client: Client, table_context: Callable): +def test_numeric_conversion(param_client: Client, call, table_context: Callable): with table_context('test_numeric_convert', columns=['key Int32', 'n_int Nullable(UInt64)', 'n_flt Nullable(Float64)']): data = [[1, None, None], [2, '2', '5.32']] - test_client.insert('test_numeric_convert', data) - result = test_client.query('SELECT * FROM test_numeric_convert').result_rows + call(param_client.insert, 'test_numeric_convert', data) + result = call(param_client.query, 'SELECT * FROM test_numeric_convert').result_rows assert result[1][1] == 2 assert result[1][2] == float('5.32') - test_client.command('TRUNCATE TABLE test_numeric_convert') + call(param_client.command, 'TRUNCATE TABLE test_numeric_convert') data = [[0, '55', '532.48'], [1, None, None], [2, '2', '5.32']] - test_client.insert('test_numeric_convert', data) - result = test_client.query('SELECT * FROM test_numeric_convert').result_rows + call(param_client.insert, 'test_numeric_convert', data) + result = call(param_client.query, 'SELECT * FROM test_numeric_convert').result_rows assert result[0][1] == 55 assert result[0][2] == 532.48 assert result[1][1] is None diff --git a/tests/integration_tests/test_jwt_auth.py b/tests/integration_tests/test_jwt_auth.py index 07cc5867..9e2e69ad 100644 --- a/tests/integration_tests/test_jwt_auth.py +++ b/tests/integration_tests/test_jwt_auth.py @@ -7,6 +7,7 @@ pytest.skip('JWT tests are not yet configured', allow_module_level=True) + def test_jwt_auth_sync_client(test_config: TestConfig): if not test_config.cloud: pytest.skip('Skipping JWT test in non-Cloud mode') diff --git a/tests/integration_tests/test_mid_stream_exception.py b/tests/integration_tests/test_mid_stream_exception.py index 174f9d3a..848716bf 100644 --- a/tests/integration_tests/test_mid_stream_exception.py +++ b/tests/integration_tests/test_mid_stream_exception.py @@ -29,3 +29,34 @@ def test_mid_stream_exception_streaming(test_client: Client): error_msg = str(exc_info.value) assert "Value passed to 'throwIf' function is non-zero" in error_msg assert test_client.command("SELECT 1") == 1 + + +@pytest.mark.asyncio +async def test_mid_stream_exception_async(test_native_async_client): + """Test that mid-stream exceptions are properly detected and raised (async).""" + query = "SELECT sleepEachRow(0.01), throwIf(number=100) FROM numbers(200)" + + with pytest.raises(StreamFailureError) as exc_info: + result = await test_native_async_client.query(query, settings={"max_block_size": 1, "wait_end_of_query": 0}) + _ = result.result_set + + error_msg = str(exc_info.value) + assert "Value passed to 'throwIf' function is non-zero" in error_msg + assert await test_native_async_client.command("SELECT 1") == 1 + + +@pytest.mark.asyncio +async def test_mid_stream_exception_streaming_async(test_native_async_client): + """Test that mid-stream exceptions are properly detected in streaming mode (async).""" + query = "SELECT sleepEachRow(0.01), throwIf(number=100) FROM numbers(200)" + + with pytest.raises(StreamFailureError) as exc_info: + async with await test_native_async_client.query_rows_stream( + query, settings={"max_block_size": 1, "wait_end_of_query": 0} + ) as stream: + async for _ in stream: + pass + + error_msg = str(exc_info.value) + assert "Value passed to 'throwIf' function is non-zero" in error_msg + assert await test_native_async_client.command("SELECT 1") == 1 diff --git a/tests/integration_tests/test_multithreading.py b/tests/integration_tests/test_multithreading.py index 1b01eaf4..f18ccbc0 100644 --- a/tests/integration_tests/test_multithreading.py +++ b/tests/integration_tests/test_multithreading.py @@ -1,29 +1,155 @@ +import asyncio import threading +import uuid import pytest -from clickhouse_connect.driver import Client +from clickhouse_connect import create_client, get_async_client from clickhouse_connect.driver.exceptions import ProgrammingError -from tests.integration_tests.conftest import TestConfig +from tests.integration_tests.conftest import TestConfig, make_client_config -def test_threading_error(test_config: TestConfig, test_client: Client): +def test_sync_client_sequential_thread_access(test_client, test_config: TestConfig): + """Test that sync clients can handle sequential access from different threads.""" if test_config.cloud: - pytest.skip('Skipping threading test in ClickHouse Cloud') - thrown = None + pytest.skip("Skipping threading test in ClickHouse Cloud") + + results = [] + errors = [] + + def run_query(value): + try: + result = test_client.command(f"SELECT {value}") + results.append(result) + except Exception as ex: # pylint: disable=broad-exception-caught + errors.append(ex) + + threads = [threading.Thread(target=run_query, args=(i,)) for i in range(3)] + for thread in threads: + thread.start() + thread.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 3 + assert results == [0, 1, 2] + + +@pytest.mark.asyncio +async def test_async_client_threadsafe_submission(test_native_async_client, test_config: TestConfig): + """Test that async clients work correctly with run_coroutine_threadsafe from multiple threads.""" + if test_config.cloud: + pytest.skip("Skipping threading test in ClickHouse Cloud") + + loop = asyncio.get_running_loop() + results = [] + errors = [] + lock = threading.Lock() + + def run_query_threadsafe(value): + try: + future = asyncio.run_coroutine_threadsafe( + test_native_async_client.command(f"SELECT {value}"), + loop + ) + result = future.result(timeout=5) + with lock: + results.append(result) + except Exception as ex: # pylint: disable=broad-exception-caught + with lock: + errors.append(ex) + + threads = [threading.Thread(target=run_query_threadsafe, args=(i,)) for i in range(3)] + for thread in threads: + thread.start() + + await asyncio.sleep(2) + + for thread in threads: + thread.join() + + assert len(errors) == 0, f"Unexpected errors: {errors}" + assert len(results) == 3 + assert sorted(results) == [0, 1, 2] - class QueryThread (threading.Thread): - def run(self): - nonlocal thrown - try: - test_client.command('SELECT randomString(512) FROM numbers(1000000)') - except ProgrammingError as ex: - thrown = ex - threads = [QueryThread(), QueryThread()] +def test_sync_concurrent_session_usage_detection(test_config: TestConfig): + """Test that ClickHouse server detects concurrent usage of the same session (sync client).""" + if test_config.cloud: + pytest.skip("Skipping session concurrency test in ClickHouse Cloud") + + session_id = str(uuid.uuid4()) + client1 = create_client(**make_client_config(test_config, session_id=session_id)) + client2 = create_client(**make_client_config(test_config, session_id=session_id)) + + thrown = [] + + def run_query(client): + try: + client.command("SELECT sleep(1)") + except (ProgrammingError, Exception) as ex: # pylint: disable=broad-exception-caught + thrown.append(ex) + + threads = [ + threading.Thread(target=run_query, args=(client1,)), + threading.Thread(target=run_query, args=(client2,)) + ] + for thread in threads: thread.start() + for thread in threads: thread.join() - assert 'concurrent' in str(thrown) + try: + client1.close() + client2.close() + except Exception: # pylint: disable=broad-exception-caught + pass + + # At least one should fail due to concurrent session usage + assert len(thrown) > 0, "Expected ClickHouse to detect concurrent session usage" + assert any("concurrent" in str(ex).lower() or "session" in str(ex).lower() for ex in thrown), \ + f"Expected session concurrency error, got: {thrown}" + + +@pytest.mark.asyncio +async def test_async_concurrent_session_usage_detection(test_config: TestConfig): + """Test that ClickHouse server detects concurrent usage of the same session (async client).""" + if test_config.cloud: + pytest.skip("Skipping session concurrency test in ClickHouse Cloud") + + session_id = str(uuid.uuid4()) + loop = asyncio.get_running_loop() + + async with await get_async_client(**make_client_config(test_config, session_id=session_id)) as client1, \ + await get_async_client(**make_client_config(test_config, session_id=session_id)) as client2: + + thrown = [] + + def run_query(client): + try: + future = asyncio.run_coroutine_threadsafe( + client.command("SELECT sleep(1)"), + loop + ) + future.result(timeout=5) + except (ProgrammingError, Exception) as ex: # pylint: disable=broad-exception-caught + thrown.append(ex) + + threads = [ + threading.Thread(target=run_query, args=(client1,)), + threading.Thread(target=run_query, args=(client2,)) + ] + + for thread in threads: + thread.start() + + await asyncio.sleep(2) + + for thread in threads: + thread.join() + + # At least one should fail due to concurrent session usage + assert len(thrown) > 0, "Expected ClickHouse to detect concurrent session usage" + assert any("concurrent" in str(ex).lower() or "session" in str(ex).lower() for ex in thrown), \ + f"Expected session concurrency error, got: {thrown}" diff --git a/tests/integration_tests/test_native.py b/tests/integration_tests/test_native.py index bed64f0d..0905465a 100644 --- a/tests/integration_tests/test_native.py +++ b/tests/integration_tests/test_native.py @@ -8,64 +8,66 @@ from clickhouse_connect.driver import Client -def test_low_card(test_client: Client, table_context: Callable): +def test_low_card(param_client: Client, call, table_context: Callable): with table_context('native_test', ['key LowCardinality(Int32)', 'value_1 LowCardinality(String)']): - test_client.insert('native_test', [[55, 'TV1'], [-578328, 'TV38882'], [57372, 'Kabc/defXX']]) - result = test_client.query("SELECT * FROM native_test WHERE value_1 LIKE '%abc/def%'") + call(param_client.insert, 'native_test', [[55, 'TV1'], [-578328, 'TV38882'], [57372, 'Kabc/defXX']]) + result = call(param_client.query, "SELECT * FROM native_test WHERE value_1 LIKE '%abc/def%'") assert len(result.result_set) == 1 -def test_low_card_uuid(test_client: Client, table_context: Callable): +def test_low_card_uuid(param_client: Client, call, table_context: Callable): with table_context('low_card_uuid', ['dt Date', 'low_card_uuid LowCardinality(UUID)']): data = ([date(2023, 1, 1), '80397B00E0B248AFAF34AE11A5546A3B'], [date(2024, 1, 1), '70397B00-E0B2-48AF-AF34-AE11A5546A3B']) - test_client.insert('low_card_uuid', data) - result = test_client.query("SELECT * FROM low_card_uuid order by dt").result_set + call(param_client.insert, 'low_card_uuid', data) + result = call(param_client.query, "SELECT * FROM low_card_uuid order by dt").result_set assert len(result) == 2 assert str(result[0][1]) == '80397b00-e0b2-48af-af34-ae11a5546a3b' assert str(result[1][1]) == '70397b00-e0b2-48af-af34-ae11a5546a3b' -def test_bare_datetime64(test_client: Client, table_context: Callable): +def test_bare_datetime64(param_client: Client, call, table_context: Callable): with table_context('bare_datetime64_test', ['key UInt32', 'dt64 DateTime64']): - test_client.insert('bare_datetime64_test', + call(param_client.insert, 'bare_datetime64_test', [[1, datetime(2023, 3, 25, 10, 5, 44, 772402)], [2, datetime.now()], [3, datetime(1965, 10, 15, 12, 0, 0)]]) - result = test_client.query('SELECT * FROM bare_datetime64_test ORDER BY key').result_rows + result = call(param_client.query, 'SELECT * FROM bare_datetime64_test ORDER BY key').result_rows assert result[0][0] == 1 assert result[0][1] == datetime(2023, 3, 25, 10, 5, 44, 772000) assert result[2][1] == datetime(1965, 10, 15, 12, 0, 0) -def test_nulls(test_client: Client, table_context: Callable): +def test_nulls(param_client: Client, call, table_context: Callable): with table_context('nullable_test', ['key UInt32', 'null_str Nullable(String)', 'null_int Nullable(Int64)']): - test_client.insert('nullable_test', [[1, None, None], + call(param_client.insert, 'nullable_test', [[1, None, None], [2, 'nonnull', -57382882345666], [3, None, 5882374747732834], [4, 'nonnull2', None]]) - result = test_client.query('SELECT * FROM nullable_test ORDER BY key', use_none=False).result_rows + result = call(param_client.query, 'SELECT * FROM nullable_test ORDER BY key', use_none=False).result_rows assert result[2] == (3, '', 5882374747732834) assert result[3] == (4, 'nonnull2', 0) - result = test_client.query('SELECT * FROM nullable_test ORDER BY key').result_rows + result = call(param_client.query, 'SELECT * FROM nullable_test ORDER BY key').result_rows assert result[1] == (2, 'nonnull', -57382882345666) assert result[2] == (3, None, 5882374747732834) assert result[3] == (4, 'nonnull2', None) -def test_read_formats(test_client: Client, test_table_engine: str): - test_client.command('DROP TABLE IF EXISTS read_format_test') - test_client.command('CREATE TABLE read_format_test (key Int32, uuid UUID, fs FixedString(10), ipv4 IPv4,' + + + +def test_read_formats(param_client: Client, call, test_table_engine: str): + call(param_client.command, 'DROP TABLE IF EXISTS read_format_test') + call(param_client.command, 'CREATE TABLE read_format_test (key Int32, uuid UUID, fs FixedString(10), ipv4 IPv4,' + 'ip_array Array(IPv6), tup Tuple(u1 UInt64, ip2 IPv4))' + f'Engine {test_table_engine} ORDER BY key') uuid1 = uuid.UUID('23E45688e89B-12D3-3273-426614174000') uuid2 = uuid.UUID('77AA3278-3728-12d3-5372-000377723832') row1 = (1, uuid1, '530055777k', '10.251.30.50', ['2600::', '2001:4860:4860::8844'], (7372, '10.20.30.203')) row2 = (2, uuid2, 'short str', '10.44.75.20', ['74:382::3332', '8700:5200::5782:3992'], (7320, '252.18.4.50')) - test_client.insert('read_format_test', [row1, row2]) + call(param_client.insert, 'read_format_test', [row1, row2]) - result = test_client.query('SELECT * FROM read_format_test;;;').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test;;;').result_set assert result[0][1] == uuid1 assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][2] == b'\x35\x33\x30\x30\x35\x35\x37\x37\x37\x6b' @@ -73,104 +75,104 @@ def test_read_formats(test_client: Client, test_table_engine: str): assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') set_default_formats('uuid', 'string', 'ip*', 'string', 'FixedString', 'string') - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[0][1] == '23e45688-e89b-12d3-3273-426614174000' assert result[1][3] == '10.44.75.20' assert result[0][2] == '530055777k' assert result[0][4][1] == '2001:4860:4860::8844' clear_default_format('ip*') - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[0][1] == '23e45688-e89b-12d3-3273-426614174000' assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][4][1] == IPv6Address('2001:4860:4860::8844') assert result[0][2] == '530055777k' # Test query formats - result = test_client.query('SELECT * FROM read_format_test', query_formats={'IP*': 'string', + result = call(param_client.query, 'SELECT * FROM read_format_test', query_formats={'IP*': 'string', 'tup': 'json'}).result_set assert result[1][3] == '10.44.75.20' assert result[0][5] == b'{"u1":7372,"ip2":"10.20.30.203"}' # Ensure that the query format clears - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') # Test column formats - result = test_client.query('SELECT * FROM read_format_test', column_formats={'ipv4': 'string', + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'ipv4': 'string', 'tup': 'tuple'}).result_set assert result[1][3] == '10.44.75.20' assert result[0][5][1] == IPv4Address('10.20.30.203') # Ensure that the column format clears - result = test_client.query('SELECT * FROM read_format_test').result_set + result = call(param_client.query, 'SELECT * FROM read_format_test').result_set assert result[1][3] == IPv4Address('10.44.75.20') assert result[0][5]['ip2'] == IPv4Address('10.20.30.203') # Test sub column formats set_read_format('tuple', 'tuple') - result = test_client.query('SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set assert result[0][5][1] == '10.20.30.203' set_read_format('tuple', 'native') - result = test_client.query('SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set + result = call(param_client.query, 'SELECT * FROM read_format_test', column_formats={'tup': {'ip*': 'string'}}).result_set assert result[0][5]['ip2'] == '10.20.30.203' -def test_tuple_inserts(test_client: Client, table_context: Callable): +def test_tuple_inserts(param_client: Client, call, table_context: Callable): with table_context('insert_tuple_test', ['key Int32', 'named Tuple(fl Float64, `ns space` Nullable(String))', 'unnamed Tuple(Float64, Nullable(String))']): data = [[1, (3.55, 'str1'), (555, None)], [2, (-43.2, None), (0, 'str2')]] - test_client.insert('insert_tuple_test', data, settings={'insert_deduplication_token': 5772}) + call(param_client.insert, 'insert_tuple_test', data, settings={'insert_deduplication_token': 5772}) data = [[1, {'fl': 3.55, 'ns space': 'str1'}, (555, None)], [2, {'fl': -43.2}, (0, 'str2')]] - test_client.insert('insert_tuple_test', data, settings={'insert_deduplication_token': 5773}) - query_result = test_client.query('SELECT * FROM insert_tuple_test ORDER BY key').result_rows + call(param_client.insert, 'insert_tuple_test', data, settings={'insert_deduplication_token': 5773}) + query_result = call(param_client.query, 'SELECT * FROM insert_tuple_test ORDER BY key').result_rows assert len(query_result) == 4 assert query_result[0] == query_result[1] assert query_result[2] == query_result[3] -def test_agg_function(test_client: Client, table_context: Callable): +def test_agg_function(param_client: Client, call, table_context: Callable): with table_context('agg_func_test', ['key Int32', 'str SimpleAggregateFunction(any, String)', 'lc_str SimpleAggregateFunction(any, LowCardinality(String))'], engine='AggregatingMergeTree'): - test_client.insert('agg_func_test', [(1, 'str', 'lc_str')]) - row = test_client.query('SELECT str, lc_str FROM agg_func_test').first_row + call(param_client.insert, 'agg_func_test', [(1, 'str', 'lc_str')]) + row = call(param_client.query, 'SELECT str, lc_str FROM agg_func_test').first_row assert row[0] == 'str' assert row[1] == 'lc_str' -def test_decimal_rounding(test_client: Client, table_context: Callable): +def test_decimal_rounding(param_client: Client, call, table_context: Callable): test_vals = [732.4, 75.57, 75.49, 40.16] with table_context('test_decimal', ['key Int32, value Decimal(10, 2)']): - test_client.insert('test_decimal', [[ix, x] for ix, x in enumerate(test_vals)]) - values = test_client.query('SELECT value FROM test_decimal').result_columns[0] + call(param_client.insert, 'test_decimal', [[ix, x] for ix, x in enumerate(test_vals)]) + values = call(param_client.query, 'SELECT value FROM test_decimal').result_columns[0] with decimal.localcontext() as dec_ctx: dec_ctx.prec = 10 assert [decimal.Decimal(str(x)) for x in test_vals] == values -def test_empty_maps(test_client: Client): - result = test_client.query("select Cast(([],[]), 'Map(String, Map(String, String))')") +def test_empty_maps(param_client: Client, call): + result = call(param_client.query, "select Cast(([],[]), 'Map(String, Map(String, String))')") assert result.first_row[0] == {} -def test_fixed_str_padding(test_client: Client, table_context: Callable): +def test_fixed_str_padding(param_client: Client, call, table_context: Callable): table = 'test_fixed_str_padding' with table_context(table, 'key Int32, value FixedString(3)'): - test_client.insert(table, [[1, 'abc']]) - test_client.insert(table, [[2, 'a']]) - test_client.insert(table, [[3, '']]) - result = test_client.query(f'select * from {table} ORDER BY key') + call(param_client.insert, table, [[1, 'abc']]) + call(param_client.insert, table, [[2, 'a']]) + call(param_client.insert, table, [[3, '']]) + result = call(param_client.query, f'select * from {table} ORDER BY key') assert result.result_columns[1] == [b'abc', b'a\x00\x00', b'\x00\x00\x00'] -def test_nonstandard_column_names(test_client: Client, table_context: Callable): +def test_nonstandard_column_names(param_client: Client, call, table_context: Callable): table = 'пример_кириллица' with table_context(table, 'колонка String') as t: - test_client.insert(t.table, (('привет',),)) - result = test_client.query(f'SELECT * FROM {t.table}').result_set + call(param_client.insert, t.table, (('привет',),)) + result = call(param_client.query, f'SELECT * FROM {t.table}').result_set assert result[0][0] == 'привет' diff --git a/tests/integration_tests/test_native_fuzz.py b/tests/integration_tests/test_native_fuzz.py index dcce0aa2..69227175 100644 --- a/tests/integration_tests/test_native_fuzz.py +++ b/tests/integration_tests/test_native_fuzz.py @@ -1,3 +1,4 @@ +import asyncio import os import random @@ -13,28 +14,39 @@ # pylint: disable=duplicate-code -def test_query_fuzz(test_client: Client, test_table_engine: str): - if not test_client.min_version('21'): - pytest.skip(f'flatten_nested setting not supported in this server version {test_client.server_version}') +def test_query_fuzz(param_client: Client, call, test_table_engine: str, client_mode: str): + if not param_client.min_version('21'): + pytest.skip(f'flatten_nested setting not supported in this server version {param_client.server_version}') test_runs = int(os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250')) - test_client.tz_source = "server" + param_client.tz_source = "server" try: for _ in range(test_runs): - test_client.command('DROP TABLE IF EXISTS fuzz_test') + call(param_client.command, 'DROP TABLE IF EXISTS fuzz_test') data_rows = random.randint(0, MAX_DATA_ROWS) col_names, col_types = random_columns(TEST_COLUMNS) - data = random_data(col_types, data_rows, test_client.server_tz) + data = random_data(col_types, data_rows, param_client.server_tz) col_names = ('row_id',) + col_names col_types = (get_from_name('UInt32'),) + col_types col_defs = [TableColumnDef(name, ch_type) for name, ch_type in zip(col_names, col_types)] create_stmt = create_table('fuzz_test', col_defs, test_table_engine, {'order by': 'row_id'}) - test_client.command(create_stmt, settings={'flatten_nested': 0}) - test_client.insert('fuzz_test', data, col_names) + call(param_client.command, create_stmt, settings={'flatten_nested': 0}) + call(param_client.insert, 'fuzz_test', data, col_names) + + if client_mode == 'async': + async def get_results(): + result = await param_client.query('SELECT * FROM fuzz_test') + loop = asyncio.get_running_loop() + rows = await loop.run_in_executor(None, lambda: list(result.result_set)) + return rows, result.column_names + result_rows, result_cols = call(get_results) + else: + data_result = call(param_client.query, 'SELECT * FROM fuzz_test') + result_rows = data_result.result_set + result_cols = data_result.column_names - data_result = test_client.query('SELECT * FROM fuzz_test') if data_rows: - assert data_result.column_names == col_names - assert data_result.result_set == data + assert result_cols == col_names + assert result_rows == data finally: - test_client.tz_source = "auto" + param_client.tz_source = "auto" diff --git a/tests/integration_tests/test_network.py b/tests/integration_tests/test_network.py index c19b81e5..01842d02 100644 --- a/tests/integration_tests/test_network.py +++ b/tests/integration_tests/test_network.py @@ -1,8 +1,9 @@ from ipaddress import IPv4Address, IPv6Address +from typing import Callable + import pytest from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig # A collection of diverse IPv6 addresses for testing IPV6_TEST_CASES = [ @@ -15,37 +16,13 @@ ] -# pylint: disable=attribute-defined-outside-init -class TestIPv6: - """Integration tests for ClickHouse IPv6 data type handling.""" - - client: Client - table_name: str = "ipv6_integration_test" - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_config: TestConfig, test_client: Client): - """Create the test table before each test and drop it after.""" - self.config = test_config - self.client = test_client - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - self.client.command( - f""" - CREATE TABLE {self.table_name} ( - id UInt32, - ip_addr IPv6, - ip_addr_nullable Nullable(IPv6) - ) ENGINE = MergeTree ORDER BY id - """ - ) - yield - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - - def test_ipv6_round_trip(self): - """Tests that various IPv6 addresses can be inserted as objects and read back correctly.""" +def test_ipv6_round_trip(param_client: Client, call, table_context: Callable): + """Test that various IPv6 addresses can be inserted as objects and read back correctly.""" + with table_context("ipv6_round_trip_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): data = [[i, ip, ip] for i, ip in enumerate(IPV6_TEST_CASES)] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv6_round_trip_test", data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM ipv6_round_trip_test ORDER BY id") assert result.row_count == len(IPV6_TEST_CASES) for i, ip in enumerate(IPV6_TEST_CASES): @@ -53,9 +30,9 @@ def test_ipv6_round_trip(self): assert result.result_rows[i][2] == ip assert isinstance(result.result_rows[i][1], IPv6Address) - def test_ipv4_mapping_and_promotion(self): - """Tests that plain IPv4 strings/objects are correctly promoted to IPv4-mapped - IPv6 addresses on insertion and read back correctly.""" +def test_ipv4_mapping_and_promotion(param_client: Client, call, table_context: Callable): + """Test that plain IPv4 strings/objects are correctly promoted to IPv4-mapped IPv6 addresses.""" + with table_context("ipv4_promotion_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): test_ips = [ "198.51.100.1", IPv4Address("203.0.113.255"), @@ -68,45 +45,44 @@ def test_ipv4_mapping_and_promotion(self): ] data = [[i, ip, None] for i, ip in enumerate(test_ips)] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv4_promotion_test", data) - result = self.client.query( - f"SELECT id, ip_addr FROM {self.table_name} ORDER BY id" - ) + result = call(param_client.query, "SELECT id, ip_addr FROM ipv4_promotion_test ORDER BY id") assert result.row_count == len(test_ips) for i, ip in enumerate(expected_ips): assert isinstance(result.result_rows[i][1], IPv6Address) assert result.result_rows[i][1] == ip - def test_null_handling(self): - """Tests inserting and retrieving NULL values in an IPv6 column.""" +def test_ipv6_null_handling(param_client: Client, call, table_context: Callable): + """Test inserting and retrieving NULL values in an IPv6 column.""" + with table_context("ipv6_null_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): data = [[1, "::1", None], [2, "2001:db8::", "2001:db8::"]] - self.client.insert(self.table_name, data) + call(param_client.insert, "ipv6_null_test", data) - result = self.client.query( - f"SELECT id, ip_addr_nullable FROM {self.table_name} ORDER BY id" - ) + result = call(param_client.query, "SELECT id, ip_addr_nullable FROM ipv6_null_test ORDER BY id") assert result.row_count == 2 assert result.result_rows[0][1] is None assert result.result_rows[1][1] == IPv6Address("2001:db8::") - def test_read_as_string(self): - """Tests reading IPv6 values as strings using the toString() function.""" +def test_ipv6_read_as_string(param_client: Client, call, table_context: Callable): + """Test reading IPv6 values as strings using the toString() function.""" + with table_context("ipv6_string_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): ip = IPV6_TEST_CASES[0] - self.client.insert(self.table_name, [[1, ip, None]]) + call(param_client.insert, "ipv6_string_test", [[1, ip, None]]) - result = self.client.query(f"SELECT toString(ip_addr) FROM {self.table_name}") + result = call(param_client.query, "SELECT toString(ip_addr) FROM ipv6_string_test") assert result.row_count == 1 read_val = result.result_rows[0][0] assert isinstance(read_val, str) assert read_val == str(ip) - def test_insert_invalid_ipv6_fails(self): - """Tests that the client correctly rejects an invalid IPv6 string.""" +def test_ipv6_insert_invalid_fails(param_client: Client, call, table_context: Callable): + """Test that the client correctly rejects an invalid IPv6 string.""" + with table_context("ipv6_invalid_test", ["id UInt32", "ip_addr IPv6", "ip_addr_nullable Nullable(IPv6)"], order_by="id"): with pytest.raises(ValueError) as excinfo: - self.client.insert(self.table_name, [[1, "not a valid ip address", None]]) + call(param_client.insert, "ipv6_invalid_test", [[1, "not a valid ip address", None]]) assert "Failed to parse 'not a valid ip address'" in str(excinfo.value) diff --git a/tests/integration_tests/test_numeric.py b/tests/integration_tests/test_numeric.py index 677549e9..c3160ead 100644 --- a/tests/integration_tests/test_numeric.py +++ b/tests/integration_tests/test_numeric.py @@ -1,96 +1,67 @@ +from typing import Callable import pytest -from clickhouse_connect.driver import Client -from tests.integration_tests.conftest import TestConfig - - -# pylint: disable=duplicate-code -# pylint: disable=attribute-defined-outside-init -class TestBFloat16: - """Integration tests for ClickHouse BFloat16 data type handling.""" - - client: Client - table_name: str = "bf16_integration_test" - - # pylint: disable=no-self-use - @pytest.fixture(scope="class", autouse=True) - def check_version(self, test_client: Client): - """Skips the entire class if the server version is too old.""" - if not test_client.min_version("24.11"): - pytest.skip( - f"BFloat16 type not supported in ClickHouse version {test_client.server_version}" - ) - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_config: TestConfig, test_client: Client): - """Create the test table before each test and drop it after.""" - self.config = test_config - self.client = test_client - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - self.client.command( - f""" - CREATE TABLE {self.table_name} ( - id UInt32, - bfloat16 BFloat16, - bfloat16_nullable Nullable(BFloat16) - ) ENGINE = MergeTree ORDER BY id - """ - ) - yield - self.client.command(f"DROP TABLE IF EXISTS {self.table_name}") - - def test_bf16_round_trip(self): - """Basic round trip test with precision loss.""" + +def test_bfloat16_round_trip(param_client, call, table_context: Callable): + """Test BFloat16 data type with precision loss on round trip.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): input_data = [[0, 3.141592, -2.71828], [1, 3.141592, -2.71828]] expected = [[0, 3.140625, -2.703125], [1, 3.140625, -2.703125]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row assert isinstance(result_row[1], float) - def test_bf16_nullable_round_trip(self): - """Basic round nullable trip test with precision loss.""" +def test_bfloat16_nullable_round_trip(param_client, call, table_context: Callable): + """Test BFloat16 nullable column with precision loss.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_nullable_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): input_data = [[0, 3.141592, None], [1, 3.141592, -2.71828]] expected = [[0, 3.140625, None], [1, 3.140625, -2.703125]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_nullable_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_nullable_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row assert isinstance(result_row[1], float) - def test_bf16_empty_and_all_null_inserts(self): - """Tests inserting no rows, and inserting rows with all-null columns.""" - self.client.insert(self.table_name, []) - result = self.client.query(f"SELECT count() FROM {self.table_name}") +def test_bfloat16_empty_and_all_null_inserts(param_client, call, table_context: Callable): + """Test BFloat16 with empty inserts and all-null columns.""" + if not param_client.min_version("24.11"): + pytest.skip(f"BFloat16 type not supported in ClickHouse version {param_client.server_version}") + + with table_context('bf16_empty_test', ['id UInt32', 'bfloat16 BFloat16', 'bfloat16_nullable Nullable(BFloat16)'], + order_by='id'): + # Test empty insert + call(param_client.insert, 'bf16_empty_test', []) + result = call(param_client.query, "SELECT count() FROM bf16_empty_test") assert result.result_rows[0][0] == 0 input_data = [[0, 3.141592, None], [1, -2.71828, None]] expected = [[0, 3.140625, None], [1, -2.703125, None]] - self.client.insert(self.table_name, input_data) + call(param_client.insert, 'bf16_empty_test', input_data) - result = self.client.query(f"SELECT * FROM {self.table_name} ORDER BY id") + result = call(param_client.query, "SELECT * FROM bf16_empty_test ORDER BY id") assert result.row_count == len(input_data) for result_row, expected_row in zip(result.result_rows, expected): assert list(result_row) == expected_row -class TestSpecialIntervalTypes: - """Integration tests for ClickHouse special interval type handling.""" - - client: Client - - @pytest.fixture(autouse=True) - def setup_teardown(self, test_client: Client): - self.client = test_client - - def test_interval_selects_work(self): - result = self.client.query("SELECT INTERVAL 30 DAY") - assert result.result_rows[0][0] == 30 +def test_interval_selects(param_client, call): + """Test that interval type selects work correctly.""" + result = call(param_client.query, "SELECT INTERVAL 30 DAY") + assert result.result_rows[0][0] == 30 diff --git a/tests/integration_tests/test_numpy.py b/tests/integration_tests/test_numpy.py index 874e1744..73dce763 100644 --- a/tests/integration_tests/test_numpy.py +++ b/tests/integration_tests/test_numpy.py @@ -18,30 +18,30 @@ pytestmark = pytest.mark.skipif(np is None, reason='Numpy package not installed') -def test_numpy_dates(test_client: Client, table_context: Callable): +def test_numpy_dates(param_client: Client, call, table_context: Callable): np_array = np.array(dt_ds, dtype='datetime64[s]').reshape(-1, 1) source_arr = np_array.copy() with table_context('test_numpy_dates', dt_ds_columns, dt_ds_types): - test_client.insert('test_numpy_dates', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_dates') + call(param_client.insert, 'test_numpy_dates', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_dates') assert np.array_equal(np_array, new_np_array) assert np.array_equal(source_arr, np_array) -def test_invalid_date(test_client): +def test_invalid_date(param_client: Client, call): try: sql = "SELECT cast(now64(), 'DateTime64(1)')" - if not test_client.min_version('20'): + if not param_client.min_version('20'): sql = "SELECT cast(now(), 'DateTime')" - test_client.query_df(sql) + call(param_client.query_df, sql) except ProgrammingError as ex: assert 'milliseconds' in str(ex) -def test_numpy_record_type(test_client: Client, table_context: Callable): +def test_numpy_record_type(param_client: Client, call, table_context: Callable): dt_type = 'datetime64[ns]' ds_types = basic_ds_types - if not test_client.min_version('20'): + if not param_client.min_version('20'): dt_type = 'datetime64[s]' ds_types = basic_ds_types_ver19 @@ -49,18 +49,18 @@ def test_numpy_record_type(test_client: Client, table_context: Callable): source_arr = np_array.copy() np_array.dtype.names = basic_ds_columns with table_context('test_numpy_basic', basic_ds_columns, ds_types): - test_client.insert('test_numpy_basic', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_basic', max_str_len=20) + call(param_client.insert, 'test_numpy_basic', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_basic', max_str_len=20) assert np.array_equal(np_array, new_np_array) - empty_np_array = test_client.query_np("SELECT * FROM test_numpy_basic WHERE key = 'NOT A KEY' ") + empty_np_array = call(param_client.query_np, "SELECT * FROM test_numpy_basic WHERE key = 'NOT A KEY' ") assert len(empty_np_array) == 0 assert np.array_equal(source_arr, np_array) -def test_numpy_object_type(test_client: Client, table_context: Callable): +def test_numpy_object_type(param_client: Client, call, table_context: Callable): dt_type = 'datetime64[ns]' ds_types = basic_ds_types - if not test_client.min_version('20'): + if not param_client.min_version('20'): dt_type = 'datetime64[s]' ds_types = basic_ds_types_ver19 @@ -68,88 +68,114 @@ def test_numpy_object_type(test_client: Client, table_context: Callable): np_array.dtype.names = basic_ds_columns source_arr = np_array.copy() with table_context('test_numpy_basic', basic_ds_columns, ds_types): - test_client.insert('test_numpy_basic', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_basic') + call(param_client.insert, 'test_numpy_basic', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_basic') assert np.array_equal(np_array, new_np_array) assert np.array_equal(source_arr, np_array) -def test_numpy_nulls(test_client: Client, table_context: Callable): +def test_numpy_nulls(param_client: Client, call, table_context: Callable): np_types = [(col_name, 'O') for col_name in null_ds_columns] np_array = np.rec.fromrecords(null_ds, dtype=np_types) source_arr = np_array.copy() with table_context('test_numpy_nulls', null_ds_columns, null_ds_types): - test_client.insert('test_numpy_nulls', np_array) - new_np_array = test_client.query_np('SELECT * FROM test_numpy_nulls', use_none=True) + call(param_client.insert, 'test_numpy_nulls', np_array) + new_np_array = call(param_client.query_np, 'SELECT * FROM test_numpy_nulls', use_none=True) assert list_equal(np_array.tolist(), new_np_array.tolist()) assert list_equal(source_arr.tolist(), np_array.tolist()) -def test_numpy_matrix(test_client: Client, table_context: Callable): +def test_numpy_matrix(param_client: Client, call, table_context: Callable): source = [25000, -37283, 4000, 25770, 40032, 33002, 73086, -403882, 57723, 77382, 1213477, 2, 0, 5777732, 99827616] source_array = np.array(source, dtype='int32') matrix = source_array.reshape((5, 3)) matrix_copy = matrix.copy() with table_context('test_numpy_matrix', ['col1 Int32', 'col2 Int32', 'col3 Int32']): - test_client.insert('test_numpy_matrix', matrix) - py_result = test_client.query('SELECT * FROM test_numpy_matrix').result_set + call(param_client.insert, 'test_numpy_matrix', matrix) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_matrix').result_set assert list(py_result[1]) == [25000, -37283, 4000] - numpy_result = test_client.query_np('SELECT * FROM test_numpy_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_matrix') assert list(numpy_result[1]) == list(py_result[1]) - test_client.command('TRUNCATE TABLE test_numpy_matrix') - numpy_result = test_client.query_np('SELECT * FROM test_numpy_matrix') + call(param_client.command, 'TRUNCATE TABLE test_numpy_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_matrix') assert np.size(numpy_result) == 0 assert np.array_equal(matrix, matrix_copy) -def test_numpy_bigint_matrix(test_client: Client, table_context: Callable): +def test_numpy_bigint_matrix(param_client: Client, call, table_context: Callable): source = [25000, -37283, 4000, 25770, 40032, 33002, 73086, -403882, 57723, 77382, 1213477, 2, 0, 5777732, 99827616] source_array = np.array(source, dtype='int64') matrix = source_array.reshape((5, 3)) matrix_copy = matrix.copy() columns = ['col1 UInt256', 'col2 Int64', 'col3 Int128'] - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['col1 UInt64', 'col2 Int64', 'col3 Int64'] with table_context('test_numpy_bigint_matrix', columns): - test_client.insert('test_numpy_bigint_matrix', matrix) - py_result = test_client.query('SELECT * FROM test_numpy_bigint_matrix').result_set + call(param_client.insert, 'test_numpy_bigint_matrix', matrix) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_bigint_matrix').result_set assert list(py_result[1]) == [25000, -37283, 4000] - numpy_result = test_client.query_np('SELECT * FROM test_numpy_bigint_matrix') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_bigint_matrix') assert list(numpy_result[1]) == list(py_result[1]) assert np.array_equal(matrix, matrix_copy) -def test_numpy_bigint_object(test_client: Client, table_context: Callable): +def test_numpy_bigint_object(param_client: Client, call, table_context: Callable): source = [('key1', 347288, datetime.datetime(1999, 10, 15, 12, 3, 44)), ('key2', '348147832478', datetime.datetime.now())] np_array = np.array(source, dtype='O,uint64,datetime64[s]') source_arr = np_array.copy() columns = ['key String', 'big_value UInt256', 'dt DateTime'] - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['key String', 'big_value UInt64', 'dt DateTime'] with table_context('test_numpy_bigint_object', columns): - test_client.insert('test_numpy_bigint_object', np_array) - py_result = test_client.query('SELECT * FROM test_numpy_bigint_object').result_set + call(param_client.insert, 'test_numpy_bigint_object', np_array) + py_result = call(param_client.query, 'SELECT * FROM test_numpy_bigint_object').result_set assert list(py_result[0]) == list(source[0]) - numpy_result = test_client.query_np('SELECT * FROM test_numpy_bigint_object') + numpy_result = call(param_client.query_np, 'SELECT * FROM test_numpy_bigint_object') assert list(py_result[1]) == list(numpy_result[1]) assert np.array_equal(source_arr, np_array) -def test_numpy_streams(test_client: Client): - if not test_client.min_version('22'): - pytest.skip(f'generateRandom is not supported in this server version {test_client.server_version}') +def test_numpy_streams(param_client: Client, call, consume_stream): + if not param_client.min_version('22'): + pytest.skip(f'generateRandom is not supported in this server version {param_client.server_version}') runs = os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250') for _ in range(int(runs) // 2): query_rows = random.randint(0, 5000) + 20000 stream_count = 0 row_count = 0 query = random_query(query_rows) - stream = test_client.query_np_stream(query, settings={'max_block_size': 5000}) - with stream: - for np_array in stream: + stream = call(param_client.query_np_stream, query, settings={'max_block_size': 5000}) + + def process(np_array): + nonlocal stream_count, row_count + stream_count += 1 + row_count += np_array.shape[0] + + consume_stream(stream, process) + + assert row_count == query_rows + assert stream_count > 2 + + +@pytest.mark.asyncio +async def test_numpy_streams_async(test_native_async_client): + """Async-only numpy streaming test.""" + if not test_native_async_client.min_version("22"): + pytest.skip(f'generateRandom is not supported in this server version {test_native_async_client.server_version}') + + runs = os.environ.get("CLICKHOUSE_CONNECT_TEST_FUZZ", "250") + for _ in range(int(runs) // 2): + query_rows = random.randint(0, 5000) + 20000 + stream_count = 0 + row_count = 0 + query = random_query(query_rows) + + stream = await test_native_async_client.query_np_stream(query, settings={"max_block_size": 5000}) + async with stream: + async for np_array in stream: stream_count += 1 row_count += np_array.shape[0] assert row_count == query_rows diff --git a/tests/integration_tests/test_pandas.py b/tests/integration_tests/test_pandas.py index 33310133..a7e38132 100644 --- a/tests/integration_tests/test_pandas.py +++ b/tests/integration_tests/test_pandas.py @@ -16,23 +16,23 @@ pytestmark = pytest.mark.skipif(pd is None, reason='Pandas package not installed') -def test_pandas_basic(test_client: Client, test_table_engine: str): - df = test_client.query_df('SELECT * FROM system.tables') +def test_pandas_basic(param_client: Client, call, test_table_engine: str): + df = call(param_client.query_df, 'SELECT * FROM system.tables') source_df = df.copy() - test_client.command('DROP TABLE IF EXISTS test_system_insert_pd') - test_client.command(f'CREATE TABLE test_system_insert_pd as system.tables Engine {test_table_engine}' - f' ORDER BY (database, name)') - test_client.insert_df('test_system_insert_pd', df) - new_df = test_client.query_df('SELECT * FROM test_system_insert_pd') - test_client.command('DROP TABLE IF EXISTS test_system_insert_pd') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert_pd') + call(param_client.command, f'CREATE TABLE test_system_insert_pd as system.tables Engine {test_table_engine}' + f' ORDER BY (database, name)') + call(param_client.insert_df, 'test_system_insert_pd', df) + new_df = call(param_client.query_df, 'SELECT * FROM test_system_insert_pd') + call(param_client.command, 'DROP TABLE IF EXISTS test_system_insert_pd') assert new_df.columns.all() == df.columns.all() assert df.equals(source_df) - df = test_client.query_df("SELECT * FROM system.tables WHERE engine = 'not_a_thing'") + df = call(param_client.query_df, "SELECT * FROM system.tables WHERE engine = 'not_a_thing'") assert len(df) == 0 assert isinstance(df, pd.DataFrame) -def test_pandas_nulls(test_client: Client, table_context: Callable): +def test_pandas_nulls(param_client: Client, call, table_context: Callable): df = pd.DataFrame(null_ds, columns=['key', 'num', 'flt', 'str', 'dt', 'd']) source_df = df.copy() insert_columns = ['key', 'num', 'flt', 'str', 'dt', 'day_col'] @@ -40,13 +40,13 @@ def test_pandas_nulls(test_client: Client, table_context: Callable): 'str String', 'dt DateTime', 'day_col Date']): with pytest.raises(DataError): - test_client.insert_df('test_pandas_nulls_bad', df, column_names=insert_columns) + call(param_client.insert_df, 'test_pandas_nulls_bad', df, column_names=insert_columns) with table_context('test_pandas_nulls_good', ['key String', 'num Nullable(Int32)', 'flt Nullable(Float32)', 'str Nullable(String)', "dt Nullable(DateTime('America/Denver'))", 'day_col Nullable(Date)']): - test_client.insert_df('test_pandas_nulls_good', df, column_names=insert_columns) - result_df = test_client.query_df('SELECT * FROM test_pandas_nulls_good') + call(param_client.insert_df, 'test_pandas_nulls_good', df, column_names=insert_columns) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_nulls_good') assert result_df.iloc[0]['num'] == 1000 assert pd.isna(result_df.iloc[2]['num']) assert result_df.iloc[1]['day_col'] == pd.Timestamp(year=1976, month=5, day=5) @@ -56,18 +56,18 @@ def test_pandas_nulls(test_client: Client, table_context: Callable): assert pd.isna(result_df.iloc[2]['num']) assert pd.isnull(result_df.iloc[3]['flt']) assert result_df['num'].dtype.name == 'Int32' - if test_client.protocol_version: + if param_client.protocol_version: assert isinstance(result_df['dt'].dtype, pd.core.dtypes.dtypes.DatetimeTZDtype) assert result_df.iloc[2]['str'] == 'value3' assert df.equals(source_df) -def test_pandas_all_null_float(test_client: Client): - df = test_client.query_df("SELECT number, cast(NULL, 'Nullable(Float64)') as flt FROM numbers(500)") +def test_pandas_all_null_float(param_client: Client, call): + df = call(param_client.query_df, "SELECT number, cast(NULL, 'Nullable(Float64)') as flt FROM numbers(500)") assert df['flt'].dtype.name == 'float64' -def test_pandas_csv(test_client: Client, table_context: Callable): +def test_pandas_csv(param_client: Client, call, table_context: Callable): csv = """ key,num,flt,str,dt,d key1,555,25.44,string1,2022-11-22 15:00:44,2001-02-14 @@ -78,32 +78,32 @@ def test_pandas_csv(test_client: Client, table_context: Callable): df = df[['num', 'flt']].astype('Float32') source_df = df.copy() with table_context('test_pandas_csv', null_ds_columns, null_ds_types): - test_client.insert_df('test_pandas_csv', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_csv') + call(param_client.insert_df, 'test_pandas_csv', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_csv') assert np.isclose(result_df.iloc[0]['flt'], 25.44) assert pd.isna(result_df.iloc[1]['flt']) - result_df = test_client.query('SELECT * FROM test_pandas_csv') + result_df = call(param_client.query, 'SELECT * FROM test_pandas_csv') assert pd.isna(result_df.result_set[1][2]) assert df.equals(source_df) -def test_pandas_context_inserts(test_client: Client, table_context: Callable): +def test_pandas_context_inserts(param_client: Client, call, table_context: Callable): with table_context('test_pandas_multiple', null_ds_columns, null_ds_types): df = pd.DataFrame(null_ds, columns=null_ds_columns) source_df = df.copy() - insert_context = test_client.create_insert_context('test_pandas_multiple', df.columns) + insert_context = call(param_client.create_insert_context, 'test_pandas_multiple', df.columns) insert_context.data = df - test_client.data_insert(insert_context) - assert test_client.command('SELECT count() FROM test_pandas_multiple') == 4 + call(param_client.data_insert, insert_context) + assert call(param_client.command, 'SELECT count() FROM test_pandas_multiple') == 4 next_df = pd.DataFrame( [['key4', -415, None, 'value4', datetime(2022, 7, 4, 15, 33, 4, 5233), date(1999, 12, 31)]], columns=null_ds_columns) - test_client.insert_df(df=next_df, context=insert_context) - assert test_client.command('SELECT count() FROM test_pandas_multiple') == 5 + call(param_client.insert_df, df=next_df, context=insert_context) + assert call(param_client.command, 'SELECT count() FROM test_pandas_multiple') == 5 assert df.equals(source_df) -def test_pandas_low_card(test_client: Client, table_context: Callable): +def test_pandas_low_card(param_client: Client, call, table_context: Callable): with table_context('test_pandas_low_card', ['key String', 'value LowCardinality(Nullable(String))', 'date_value LowCardinality(Nullable(DateTime))', @@ -117,8 +117,8 @@ def test_pandas_low_card(test_client: Client, table_context: Callable): ], columns=['key', 'value', 'date_value', 'int_value']) source_df = df.copy() - test_client.insert_df('test_pandas_low_card', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_low_card', use_none=True) + call(param_client.insert_df, 'test_pandas_low_card', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_low_card', use_none=True) assert result_df.iloc[0]['value'] == 'test_string_0' assert result_df.iloc[1]['value'] == 'test_string_1' assert result_df.iloc[0]['date_value'] == pd.Timestamp(2022, 10, 15, 4, 25) @@ -129,39 +129,39 @@ def test_pandas_low_card(test_client: Client, table_context: Callable): assert df.equals(source_df) -def test_pandas_large_types(test_client: Client, table_context: Callable): +def test_pandas_large_types(param_client: Client, call, table_context: Callable): columns = ['key String', 'value Int256', 'u_value UInt256' ] key2_value = 30000000000000000000000000000000000 - if not test_client.min_version('21'): + if not param_client.min_version('21'): columns = ['key String', 'value Int64'] key2_value = 3000000000000000000 with table_context('test_pandas_big_int', columns): df = pd.DataFrame([['key1', 2000, 50], ['key2', key2_value, 70], ['key3', -2350, 70]], columns=['key', 'value', 'u_value']) source_df = df.copy() - test_client.insert_df('test_pandas_big_int', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_big_int') + call(param_client.insert_df, 'test_pandas_big_int', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_big_int') assert result_df.iloc[0]['value'] == 2000 assert result_df.iloc[1]['value'] == key2_value assert df.equals(source_df) -def test_pandas_enums(test_client: Client, table_context: Callable): +def test_pandas_enums(param_client: Client, call, table_context: Callable): columns = ['key String', "value Enum8('Moscow' = 0, 'Rostov' = 1, 'Kiev' = 2)", "null_value Nullable(Enum8('red'=0,'blue'=5,'yellow'=10))"] with table_context('test_pandas_enums', columns): df = pd.DataFrame([['key1', 1, 0], ['key2', 0, None]], columns=['key', 'value', 'null_value']) source_df = df.copy() - test_client.insert_df('test_pandas_enums', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_enums ORDER BY key') + call(param_client.insert_df, 'test_pandas_enums', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_enums ORDER BY key') assert result_df.iloc[0]['value'] == 'Rostov' assert result_df.iloc[1]['value'] == 'Moscow' assert result_df.iloc[1]['null_value'] is None assert result_df.iloc[0]['null_value'] == 'red' assert df.equals(source_df) df = pd.DataFrame([['key3', 'Rostov', 'blue'], ['key4', 'Moscow', None]], columns=['key', 'value', 'null_value']) - test_client.insert_df('test_pandas_enums', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_enums ORDER BY key') + call(param_client.insert_df, 'test_pandas_enums', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_enums ORDER BY key') assert result_df.iloc[2]['key'] == 'key3' assert result_df.iloc[2]['value'] == 'Rostov' assert result_df.iloc[3]['value'] == 'Moscow' @@ -169,9 +169,9 @@ def test_pandas_enums(test_client: Client, table_context: Callable): assert result_df.iloc[3]['null_value'] is None -def test_pandas_datetime64(test_client: Client, table_context: Callable): - if not test_client.min_version('20'): - pytest.skip(f'DateTime64 not supported in this server version {test_client.server_version}') +def test_pandas_datetime64(param_client: Client, call, table_context: Callable): + if not param_client.min_version('20'): + pytest.skip(f'DateTime64 not supported in this server version {param_client.server_version}') nano_timestamp = pd.Timestamp(1992, 11, 6, 12, 50, 40, 7420, 44) milli_timestamp = pd.Timestamp(2022, 5, 3, 10, 44, 10, 55000) chicago_timestamp = milli_timestamp.tz_localize('America/Chicago') @@ -184,8 +184,8 @@ def test_pandas_datetime64(test_client: Client, table_context: Callable): ['key2', nano_timestamp, milli_timestamp, chicago_timestamp]], columns=['key', 'nanos', 'millis', 'chicago']) source_df = df.copy() - test_client.insert_df('test_pandas_dt64', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_dt64') + call(param_client.insert_df, 'test_pandas_dt64', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_dt64') assert result_df.iloc[0]['nanos'] == now assert result_df.iloc[1]['nanos'] == nano_timestamp assert result_df.iloc[1]['millis'] == milli_timestamp @@ -194,37 +194,40 @@ def test_pandas_datetime64(test_client: Client, table_context: Callable): test_dt = np.array(['2017-11-22 15:42:58.270000+00:00'][0]) assert df.equals(source_df) df = pd.DataFrame([['key3', pd.to_datetime(test_dt)]], columns=['key', 'nanos']) - test_client.insert_df('test_pandas_dt64', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_dt64 WHERE key = %s', parameters=('key3',)) + call(param_client.insert_df, 'test_pandas_dt64', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_dt64 WHERE key = %s', parameters=('key3',)) assert result_df.iloc[0]['nanos'].second == 58 -def test_pandas_streams(test_client: Client): - if not test_client.min_version('22'): - pytest.skip(f'generateRandom is not supported in this server version {test_client.server_version}') +def test_pandas_streams(param_client: Client, call, consume_stream): + if not param_client.min_version('22'): + pytest.skip(f'generateRandom is not supported in this server version {param_client.server_version}') runs = os.environ.get('CLICKHOUSE_CONNECT_TEST_FUZZ', '250') for _ in range(int(runs) // 2): query_rows = random.randint(0, 5000) + 20000 stream_count = 0 row_count = 0 query = random_query(query_rows, date32=False) - stream = test_client.query_df_stream(query, settings={'max_block_size': 5000}) - with stream: - for df in stream: - stream_count += 1 - row_count += len(df) + stream = call(param_client.query_df_stream, query, settings={'max_block_size': 5000}) + + def process(df): + nonlocal stream_count, row_count + stream_count += 1 + row_count += len(df) + + consume_stream(stream, process) assert row_count == query_rows assert stream_count > 2 -def test_pandas_date(test_client: Client, table_context:Callable): +def test_pandas_date(param_client: Client, call, table_context: Callable): with table_context('test_pandas_date', ['key UInt32', 'dt Date', 'null_dt Nullable(Date)']): df = pd.DataFrame([[1, pd.Timestamp(1992, 10, 15), pd.Timestamp(2023, 5, 4)], [2, pd.Timestamp(2088, 1, 31), pd.NaT], [3, pd.Timestamp(1971, 4, 15), pd.Timestamp(2101, 12, 31)]], columns=['key', 'dt', 'null_dt']) - test_client.insert_df('test_pandas_date', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_date') + call(param_client.insert_df, 'test_pandas_date', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_date') assert result_df.iloc[0]['dt'] == pd.Timestamp(1992, 10, 15) assert result_df.iloc[1]['dt'] == pd.Timestamp(2088, 1, 31) assert result_df.iloc[0]['null_dt'] == pd.Timestamp(2023, 5, 4) @@ -232,14 +235,14 @@ def test_pandas_date(test_client: Client, table_context:Callable): assert result_df.iloc[2]['null_dt'] == pd.Timestamp(2101, 12, 31) -def test_pandas_date32(test_client: Client, table_context:Callable): +def test_pandas_date32(param_client: Client, call, table_context: Callable): with table_context('test_pandas_date32', ['key UInt32', 'dt Date32', 'null_dt Nullable(Date32)']): df = pd.DataFrame([[1, pd.Timestamp(1992, 10, 15), pd.Timestamp(2023, 5, 4)], [2, pd.Timestamp(2088, 1, 31), pd.NaT], [3, pd.Timestamp(1968, 4, 15), pd.Timestamp(2101, 12, 31)]], columns=['key', 'dt', 'null_dt']) - test_client.insert_df('test_pandas_date32', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_date32') + call(param_client.insert_df, 'test_pandas_date32', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_date32') assert result_df.iloc[1]['dt'] == pd.Timestamp(2088, 1, 31) assert result_df.iloc[0]['dt'] == pd.Timestamp(1992, 10, 15) assert result_df.iloc[0]['null_dt'] == pd.Timestamp(2023, 5, 4) @@ -248,15 +251,15 @@ def test_pandas_date32(test_client: Client, table_context:Callable): assert result_df.iloc[2]['dt'] == pd.Timestamp(1968, 4, 15) -def test_pandas_row_df(test_client: Client, table_context:Callable): +def test_pandas_row_df(param_client: Client, call, table_context: Callable): with table_context('test_pandas_row_df', ['key UInt64', 'dt DateTime64(6)', 'fs FixedString(5)']): df = pd.DataFrame({'key': [1, 2], 'dt': [pd.Timestamp(2023, 5, 4, 10, 20), pd.Timestamp(2023, 10, 15, 14, 50, 2, 4038)], 'fs': ['seven', 'bit']}) df = df.iloc[1:] source_df = df.copy() - test_client.insert_df('test_pandas_row_df', df) - result_df = test_client.query_df('SELECT * FROM test_pandas_row_df', column_formats={'fs': 'string'}) + call(param_client.insert_df, 'test_pandas_row_df', df) + result_df = call(param_client.query_df, 'SELECT * FROM test_pandas_row_df', column_formats={'fs': 'string'}) assert str(result_df.dtypes.iloc[2]) == 'string' assert result_df.iloc[0]['key'] == 2 assert result_df.iloc[0]['dt'] == pd.Timestamp(2023, 10, 15, 14, 50, 2, 4038) @@ -265,30 +268,30 @@ def test_pandas_row_df(test_client: Client, table_context:Callable): assert source_df.equals(df) -def test_pandas_null_strings(test_client: Client, table_context:Callable): +def test_pandas_null_strings(param_client: Client, call, table_context: Callable): with table_context('test_pandas_null_strings', ['id String', 'test_col LowCardinality(String)']): row = {'id': 'id', 'test_col': None} df = pd.DataFrame([row]) assert df['test_col'].isnull().values.all() with pytest.raises(DataError): - test_client.insert_df('test_pandas_null_strings', df) + call(param_client.insert_df, 'test_pandas_null_strings', df) row2 = {'id': 'id2', 'test_col': 'val'} df = pd.DataFrame([row, row2]) with pytest.raises(DataError): - test_client.insert_df('test_pandas_null_strings', df) + call(param_client.insert_df, 'test_pandas_null_strings', df) -def test_pandas_small_blocks(test_config: TestConfig, test_client: Client): +def test_pandas_small_blocks(test_config: TestConfig, param_client: Client, call): if test_config.cloud: pytest.skip('Skipping performance test in ClickHouse Cloud') - res = test_client.query_df('SELECT number, randomString(512) FROM numbers(1000000)', + res = call(param_client.query_df, 'SELECT number, randomString(512) FROM numbers(1000000)', settings={'max_block_size': 250}) assert len(res) == 1000000 -def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable): - if not test_client.min_version('25.2'): - pytest.skip(f'Nullable(JSON) type not available in this version: {test_client.server_version}') +def test_pandas_string_to_df_insert(param_client: Client, call, table_context: Callable): + if not param_client.min_version('25.2'): + pytest.skip(f'Nullable(JSON) type not available in this version: {param_client.server_version}') with table_context( "test_pandas_string_to_df_insert", [ @@ -325,8 +328,8 @@ def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable ] df = pd.DataFrame(data) - test_client.insert_df("test_pandas_string_to_df_insert", df) - result_df = test_client.query_df( + call(param_client.insert_df, "test_pandas_string_to_df_insert", df) + result_df = call(param_client.query_df, "SELECT * FROM test_pandas_string_to_df_insert ORDER BY id" ) @@ -336,10 +339,10 @@ def test_pandas_string_to_df_insert(test_client: Client, table_context: Callable def test_pandas_time( - test_config: TestConfig, test_client: Client, table_context: Callable + test_config: TestConfig, param_client: Client, call, table_context: Callable ): """Round trip test for Time types""" - if not test_client.min_version("25.6"): + if not param_client.min_version("25.6"): pytest.skip("Time and types require ClickHouse 25.6+") if test_config.cloud: @@ -348,7 +351,6 @@ def test_pandas_time( ) table_name = "time_tests" - test_client.command("SET enable_time_time64_type = 1") with table_context( table_name, @@ -356,6 +358,7 @@ def test_pandas_time( "t Time", "nt Nullable(Time)", ], + settings={"enable_time_time64_type": 1}, ): test_data = { "t": [timedelta(seconds=1), timedelta(seconds=2)], @@ -363,11 +366,10 @@ def test_pandas_time( } df = pd.DataFrame(test_data) - test_client.insert(table_name, df) + call(param_client.insert, table_name, df) - df_res = test_client.query_df(f"SELECT * FROM {table_name}") - print(df_res.to_string()) - print(df_res.dtypes) + df_res = call(param_client.query_df, f"SELECT * FROM {table_name}") + print(df_res) assert df_res["t"][0] == pd.Timedelta("0 days 00:00:01") assert df_res["t"][1] == pd.Timedelta("0 days 00:00:02") assert df_res["nt"][0] == pd.Timedelta("0 days 00:01:00") @@ -375,10 +377,10 @@ def test_pandas_time( def test_pandas_time64( - test_config: TestConfig, test_client: Client, table_context: Callable + test_config: TestConfig, param_client: Client, call, table_context: Callable ): """Round trip test for Time64 types""" - if not test_client.min_version("25.6"): + if not param_client.min_version("25.6"): pytest.skip("Time64 types require ClickHouse 25.6+") if test_config.cloud: @@ -387,7 +389,6 @@ def test_pandas_time64( ) table_name = "time64_tests" - test_client.command("SET enable_time_time64_type = 1") with table_context( table_name, @@ -399,6 +400,7 @@ def test_pandas_time64( "t64_9 Time64(9)", "nt64_9 Nullable(Time64(9))", ], + settings={"enable_time_time64_type": 1}, ): test_data = { "t64_3": [1, 2], @@ -410,10 +412,10 @@ def test_pandas_time64( } df = pd.DataFrame(test_data) - test_client.insert(table_name, df) + call(param_client.insert, table_name, df) # Make sure the df insert worked correctly - int_res = test_client.query( + int_res = call(param_client.query, f"SELECT * FROM {table_name}", query_formats={"Time": "int", "Time64": "int"}, ) @@ -421,7 +423,7 @@ def test_pandas_time64( assert rows[0] == (1, 45000, 10500000, 60, 1100000000, 30000500000) assert rows[1] == (2, None, 10000000, None, 600000000000, None) - df_res = test_client.query_df(f"SELECT * FROM {table_name}") + df_res = call(param_client.query_df, f"SELECT * FROM {table_name}") expected_row_0 = [ pd.Timedelta(t) for t in [ diff --git a/tests/integration_tests/test_pandas_compat.py b/tests/integration_tests/test_pandas_compat.py index 06ac4bb7..5e287885 100644 --- a/tests/integration_tests/test_pandas_compat.py +++ b/tests/integration_tests/test_pandas_compat.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.skipif(pd is None, reason="Pandas package not installed") -def test_pandas_date_compat(test_client: Client, table_context: Callable): +def test_pandas_date_compat(param_client: Client, call, table_context: Callable): table_name = "test_date" with table_context( table_name, @@ -32,15 +32,15 @@ def test_pandas_date_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -51,7 +51,7 @@ def test_pandas_date_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_date32_compat(test_client: Client, table_context: Callable): +def test_pandas_date32_compat(param_client: Client, call, table_context: Callable): table_name = "test_date32" with table_context( table_name, @@ -69,15 +69,15 @@ def test_pandas_date32_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -88,7 +88,7 @@ def test_pandas_date32_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_datetime_compat(test_client: Client, table_context: Callable): +def test_pandas_datetime_compat(param_client: Client, call, table_context: Callable): table_name = "test_datetime" with table_context( table_name, @@ -106,15 +106,15 @@ def test_pandas_datetime_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt", "ndt"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -125,7 +125,7 @@ def test_pandas_datetime_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): +def test_pandas_datetime64_compat(param_client: Client, call, table_context: Callable): table_name = "test_datetime64" with table_context( table_name, @@ -162,15 +162,15 @@ def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): ], columns=["key", "dt3", "null_dt3", "dt6", "null_dt6", "dt9", "null_dt9"], ) - test_client.insert_df(table_name, df) + call(param_client.insert_df, table_name, df) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: dts = list(result_df.dtypes)[1:] @@ -183,18 +183,13 @@ def test_pandas_datetime64_compat(test_client: Client, table_context: Callable): assert res in str(dt) -def test_pandas_time_compat( - test_config: TestConfig, - test_client: Client, - table_context: Callable, -): - if not test_client.min_version("25.6"): +def test_pandas_time_compat(test_config: TestConfig, param_client: Client, call, table_context: Callable): + if not param_client.min_version("25.6"): pytest.skip("Time types require ClickHouse 25.6+") if test_config.cloud: pytest.skip("Time types require settings change, but settings are locked in cloud, skipping tests.") - test_client.command("SET enable_time_time64_type = 1") table_name = "test_time" with table_context( table_name, @@ -202,7 +197,7 @@ def test_pandas_time_compat( "key UInt8", "t Time", "null_t Nullable(Time)", - ], + ], settings={"enable_time_time64_type": 1} ): data = ( [1, datetime.timedelta(hours=5), 500], @@ -210,15 +205,15 @@ def test_pandas_time_compat( [3, -datetime.timedelta(minutes=45), 600], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: res = "[s]" @@ -229,18 +224,13 @@ def test_pandas_time_compat( assert res in str(dt) -def test_pandas_time64_compat( - test_config: TestConfig, - test_client: Client, - table_context: Callable, -): - if not test_client.min_version("25.6"): +def test_pandas_time64_compat(test_config: TestConfig, param_client: Client, call, table_context: Callable): + if not param_client.min_version("25.6"): pytest.skip("Time64 types require ClickHouse 25.6+") if test_config.cloud: pytest.skip("Time types require settings change, but settings are locked in cloud, skipping tests.") - test_client.command("SET enable_time_time64_type = 1") table_name = "test_time64" with table_context( table_name, @@ -252,22 +242,22 @@ def test_pandas_time64_compat( "null_t6 Nullable(Time64(6))", "t9 Time64(9)", "null_t9 Nullable(Time64(9))", - ], + ], settings={"enable_time_time64_type": 1} ): data = ( [1, 1, 2, 3, 4, 5, 6], [2, 10, None, 30, None, 50, None], [3, 100, 200, 300, 400, 500, 600], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) set_setting(SETTING_NAME, False) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes)[1:]: assert "[ns]" in str(dt) set_setting(SETTING_NAME, True) - result_df = test_client.query_df(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df, f"SELECT * FROM {table_name}") if IS_PANDAS_2: dts = list(result_df.dtypes)[1:] @@ -280,7 +270,7 @@ def test_pandas_time64_compat( assert res in str(dt) -def test_pandas_query_df_arrow(test_client: Client, table_context: Callable): +def test_pandas_query_df_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") @@ -304,17 +294,17 @@ def test_pandas_query_df_arrow(test_client: Client, table_context: Callable): [2, pd.Timestamp(2023, 5, 5), None, -45678912, 8.5555588888, "string 2", None, 1], [3, pd.Timestamp(2023, 5, 6), 30, 789123456, 3.14159, "string 3", None, 1], ) - test_client.insert(table_name, data) + call(param_client.insert, table_name, data) if IS_PANDAS_2: - result_df = test_client.query_df_arrow(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") for dt in list(result_df.dtypes): assert isinstance(dt, pd.ArrowDtype) else: with pytest.raises(ProgrammingError): - result_df = test_client.query_df_arrow(f"SELECT * FROM {table_name}") + result_df = call(param_client.query_df_arrow, f"SELECT * FROM {table_name}") -def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): +def test_pandas_insert_df_arrow(param_client: Client, call, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") @@ -332,12 +322,12 @@ def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): ): if IS_PANDAS_2: df = df.convert_dtypes(dtype_backend="pyarrow") - test_client.insert_df_arrow(table_name, df) - res_df = test_client.query(f"SELECT * from {table_name} ORDER BY i64") + call(param_client.insert_df_arrow, table_name, df) + res_df = call(param_client.query, f"SELECT * from {table_name} ORDER BY i64") assert res_df.result_rows == [(51, 421, "b"), (78, None, "a")] else: with pytest.raises(ProgrammingError, match="pandas 2.x"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) with table_context( table_name, @@ -351,7 +341,7 @@ def test_pandas_insert_df_arrow(test_client: Client, table_context: Callable): df = pd.DataFrame(data, columns=["i64", "ni64", "str"]) df["i64"] = df["i64"].astype(pd.ArrowDtype(arrow.int64())) with pytest.raises(ProgrammingError, match="Non-Arrow columns found"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) else: with pytest.raises(ProgrammingError, match="pandas 2.x"): - test_client.insert_df_arrow(table_name, df) + call(param_client.insert_df_arrow, table_name, df) diff --git a/tests/integration_tests/test_params.py b/tests/integration_tests/test_params.py index 98c4fabe..bd271cd7 100644 --- a/tests/integration_tests/test_params.py +++ b/tests/integration_tests/test_params.py @@ -5,84 +5,84 @@ from clickhouse_connect.driver.binding import DT64Param -def test_params(test_client: Client, table_context: Callable): - result = test_client.query('SELECT name, database FROM system.tables WHERE database = {db:String}', +def test_params(param_client: Client, call, table_context: Callable): + result = call(param_client.query, 'SELECT name, database FROM system.tables WHERE database = {db:String}', parameters={'db': 'system'}) assert result.first_item['database'] == 'system' - if test_client.min_version('21'): - result = test_client.query('SELECT name, {col:String} FROM system.tables WHERE table ILIKE {t:String}', + if param_client.min_version('21'): + result = call(param_client.query, 'SELECT name, {col:String} FROM system.tables WHERE table ILIKE {t:String}', parameters={'t': '%rr%', 'col': 'database'}) assert 'rr' in result.first_item['name'] first_date = datetime.strptime('Jun 1 2005 1:33PM', '%b %d %Y %I:%M%p') - first_date = test_client.server_tz.localize(first_date) + first_date = param_client.server_tz.localize(first_date) second_date = datetime.strptime('Dec 25 2022 5:00AM', '%b %d %Y %I:%M%p') - second_date = test_client.server_tz.localize(second_date) + second_date = param_client.server_tz.localize(second_date) with table_context('test_bind_params', ['key UInt64', 'dt DateTime', 'value String', 't Tuple(String, String)']): - test_client.insert('test_bind_params', + call(param_client.insert, 'test_bind_params', [[1, first_date, 'v11', ('one', 'two')], [2, second_date, 'v21', ('t1', 't2')], [3, datetime.now(), 'v31', ('str1', 'str2')]]) - result = test_client.query('SELECT * FROM test_bind_params WHERE dt = {dt:DateTime}', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE dt = {dt:DateTime}', parameters={'dt': second_date}) assert result.first_item['key'] == 2 - result = test_client.query('SELECT * FROM test_bind_params WHERE dt = %(dt)s', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE dt = %(dt)s', parameters={'dt': first_date}) assert result.first_item['key'] == 1 - result = test_client.query("SELECT * FROM test_bind_params WHERE value != %(v)s AND value like '%%1'", + result = call(param_client.query, "SELECT * FROM test_bind_params WHERE value != %(v)s AND value like '%%1'", parameters={'v': 'v11'}) assert result.row_count == 2 - result = test_client.query('SELECT * FROM test_bind_params WHERE value IN %(tp)s', + result = call(param_client.query, 'SELECT * FROM test_bind_params WHERE value IN %(tp)s', parameters={'tp': ('v18', 'v31')}) assert result.first_item['key'] == 3 - result = test_client.query('SELECT number FROM numbers(10) WHERE {n:Nullable(String)} IS NULL', + result = call(param_client.query, 'SELECT number FROM numbers(10) WHERE {n:Nullable(String)} IS NULL', parameters={'n': None}).result_rows assert len(result) == 10 date_params = [date(2023, 6, 1), date(2023, 8, 5)] - result = test_client.query('SELECT {l:Array(Date)}', parameters={'l': date_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(Date)}', parameters={'l': date_params}).first_row assert date_params == result[0] dt_params = [datetime(2023, 6, 1, 7, 40, 2), datetime(2023, 8, 17, 20, 0, 10)] - result = test_client.query('SELECT {l:Array(DateTime)}', parameters={'l': dt_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(DateTime)}', parameters={'l': dt_params}).first_row assert dt_params == result[0] num_array_params = [2.5, 5.3, 7.4] - result = test_client.query('SELECT {l:Array(Float64)}', parameters={'l': num_array_params}).first_row + result = call(param_client.query, 'SELECT {l:Array(Float64)}', parameters={'l': num_array_params}).first_row assert num_array_params == result[0] - result = test_client.query('SELECT %(l)s', parameters={'l': num_array_params}).first_row + result = call(param_client.query, 'SELECT %(l)s', parameters={'l': num_array_params}).first_row assert num_array_params == result[0] tp_params = ('str1', 'str2') - result = test_client.query('SELECT %(tp)s', parameters={'tp': tp_params}).first_row + result = call(param_client.query, 'SELECT %(tp)s', parameters={'tp': tp_params}).first_row assert tp_params == result[0] num_params = {'p_0': 2, 'p_1': 100523.55} - result = test_client.query( + result = call(param_client.query, 'SELECT count() FROM system.tables WHERE total_rows > %(p_0)d and total_rows < %(p_1)f', parameters=num_params) assert result.first_row[0] > 0 -def test_datetime_64_params(test_client: Client): +def test_datetime_64_params(param_client: Client, call): dt_values = [datetime(2023, 6, 1, 7, 40, 2, 250306), datetime(2023, 8, 17, 20, 0, 10, 777722)] dt_params = {f'd{ix}': DT64Param(v) for ix, v in enumerate(dt_values)} - result = test_client.query('SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row + result = call(param_client.query, 'SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row assert result[0] == dt_values[0].replace(microsecond=250000) assert result[1] == dt_values[1] - result = test_client.query('SELECT {a1:Array(DateTime64(6))}', parameters={'a1': [dt_params['d0'], dt_params['d1']]}).first_row + result = call(param_client.query, 'SELECT {a1:Array(DateTime64(6))}', parameters={'a1': [dt_params['d0'], dt_params['d1']]}).first_row assert result[0] == dt_values dt_params = {f'd{ix}_64': v for ix, v in enumerate(dt_values)} - result = test_client.query('SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row + result = call(param_client.query, 'SELECT {d0:DateTime64(3)}, {d1:Datetime64(9)}', parameters=dt_params).first_row assert result[0] == dt_values[0].replace(microsecond=250000) assert result[1] == dt_values[1] - result = test_client.query('SELECT {a1:Array(DateTime64(6))}', + result = call(param_client.query, 'SELECT {a1:Array(DateTime64(6))}', parameters={'a1_64': dt_values}).first_row assert result[0] == dt_values dt_params = [DT64Param(v) for v in dt_values] - result = test_client.query("SELECT %s as string, toDateTime64(%s,6) as dateTime", parameters = dt_params).first_row + result = call(param_client.query, "SELECT %s as string, toDateTime64(%s,6) as dateTime", parameters = dt_params).first_row assert result == ('2023-06-01 07:40:02.250306', dt_values[1]) diff --git a/tests/integration_tests/test_polars.py b/tests/integration_tests/test_polars.py index 16ebfab8..20e63da6 100644 --- a/tests/integration_tests/test_polars.py +++ b/tests/integration_tests/test_polars.py @@ -13,7 +13,7 @@ ] -def test_polars_insert(test_client: Client, table_context: Callable): +def test_polars_insert(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -35,12 +35,12 @@ def test_polars_insert(test_client: Client, table_context: Callable): "day_col": [date(2025, 7, 1), date(2025, 8, 1), date(2025, 8, 12)], } ) - test_client.insert_df_arrow(ctx.table, df) - res = test_client.query(f"SELECT key FROM {ctx.table}") + call(param_client.insert_df_arrow, ctx.table, df) + res = call(param_client.query, f"SELECT key FROM {ctx.table}") assert [i[0] for i in res.result_rows] == df["key"].to_list() -def test_bad_insert_fails(test_client: Client, table_context: Callable): +def test_bad_insert_fails(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -53,10 +53,10 @@ def test_bad_insert_fails(test_client: Client, table_context: Callable): ], ): with pytest.raises(TypeError, match="got list"): - test_client.insert_df_arrow("test_polars", [[1, 2, 3]]) + call(param_client.insert_df_arrow, "test_polars", [[1, 2, 3]]) -def test_polars_query(test_client: Client, table_context: Callable): +def test_polars_query(param_client: Client, call, table_context: Callable): with table_context( "test_polars", [ @@ -76,32 +76,44 @@ def test_polars_query(test_client: Client, table_context: Callable): [datetime(2025, 7, 1, 10, 30, 0, 0), datetime(2025, 8, 1, 10, 30, 0, 0), datetime(2025, 8, 12, 10, 30, 1, 0)], [date(2025, 7, 1), date(2025, 8, 1), date(2025, 8, 12)], ] - test_client.insert( + call(param_client.insert, ctx.table, data, column_names=["key", "num", "flt", "str", "dt", "day_col"], column_oriented=True, ) - df = test_client.query_df_arrow(f"SELECT key FROM {ctx.table}", dataframe_library="polars") + df = call(param_client.query_df_arrow, f"SELECT key FROM {ctx.table}", dataframe_library="polars") assert isinstance(df, pl.DataFrame) assert data[0] == df["key"].to_list() -def test_polars_arrow_stream(test_client: Client, table_context: Callable): +# pylint: disable=too-many-locals +def test_polars_arrow_stream(param_client: Client, call, client_mode, table_context: Callable): if not arrow: pytest.skip("PyArrow package not available") - if not test_client.min_version("21"): - pytest.skip(f"PyArrow is not supported in this server version {test_client.server_version}") - with table_context("test_arrow_insert", ["counter Int64", "letter String"]): + if not param_client.min_version("21"): + pytest.skip(f"PyArrow is not supported in this server version {param_client.server_version}") + with table_context("test_arrow_insert", ["counter Int64", "letter String"]) as ctx: counter = arrow.array(range(1000000)) alphabet = string.ascii_lowercase letter = arrow.array([alphabet[x % 26] for x in range(1000000)]) names = ["counter", "letter"] insert_table = arrow.Table.from_arrays([counter, letter], names=names) - test_client.insert_arrow("test_arrow_insert", insert_table) - stream = test_client.query_df_arrow_stream("SELECT * FROM test_arrow_insert", dataframe_library="polars") - with stream: - result_dfs = list(stream) + call(param_client.insert_arrow, ctx.table, insert_table) + stream = call(param_client.query_df_arrow_stream, f"SELECT * FROM {ctx.table}", dataframe_library="polars") + + result_dfs = [] + + if client_mode == "sync": + with stream: + for df in stream: + result_dfs.append(df) + else: + async def consume_async(): + async with stream: + async for df in stream: + result_dfs.append(df) + call(consume_async) assert len(result_dfs) > 1 total_rows = 0 diff --git a/tests/integration_tests/test_protocol_version.py b/tests/integration_tests/test_protocol_version.py index bbe56cfa..a9d63635 100644 --- a/tests/integration_tests/test_protocol_version.py +++ b/tests/integration_tests/test_protocol_version.py @@ -1,12 +1,12 @@ from clickhouse_connect.driver import Client -def test_protocol_version(test_client: Client): +def test_protocol_version(param_client: Client, call): query = "select toDateTime(1676369730, 'Asia/Shanghai') as dt FORMAT Native" - raw = test_client.raw_query(query) + raw = call(param_client.raw_query, query) assert raw.hex() == '0101026474084461746554696d65425feb63' - if test_client.min_version('23.3'): - raw = test_client.raw_query(query, settings={'client_protocol_version': 54337}) + if param_client.min_version('23.3'): + raw = call(param_client.raw_query, query, settings={'client_protocol_version': 54337}) ch_type = raw[14:39].decode() assert ch_type == "DateTime('Asia/Shanghai')" diff --git a/tests/integration_tests/test_proxy.py b/tests/integration_tests/test_proxy.py index df5746bf..cf8d2b11 100644 --- a/tests/integration_tests/test_proxy.py +++ b/tests/integration_tests/test_proxy.py @@ -4,63 +4,74 @@ import pytest from urllib3 import ProxyManager -import clickhouse_connect from tests.integration_tests.conftest import TestConfig -def test_proxies(test_config: TestConfig): +# pylint: disable=protected-access +def test_proxies(client_factory, call, test_config: TestConfig): if not test_config.proxy_address: pytest.skip('Proxy address not configured') if test_config.port in (8123, 10723): - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, http_proxy=test_config.proxy_address) - assert '2' in client.command('SELECT version()') - client.close() + assert '2' in call(client.command, 'SELECT version()') try: os.environ['HTTP_PROXY'] = f'http://{test_config.proxy_address}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password) - assert isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert hasattr(client, '_proxy_url') and client._proxy_url is not None + assert '2' in call(client.command, 'SELECT version()') os.environ['no_proxy'] = f'{test_config.host}:{test_config.port}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password) - assert not isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + # Check proxy is NOT configured + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert not isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert not hasattr(client, '_proxy_url') or client._proxy_url is None + assert '2' in call(client.command, 'SELECT version()') finally: os.environ.pop('HTTP_PROXY', None) os.environ.pop('no_proxy', None) else: cert_file = f'{Path(__file__).parent}/proxy_ca_cert.crt' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, ca_cert=cert_file, https_proxy=test_config.proxy_address) - assert '2' in client.command('SELECT version()') - client.close() + assert '2' in call(client.command, 'SELECT version()') try: os.environ['HTTPS_PROXY'] = f'{test_config.proxy_address}' - client = clickhouse_connect.get_client(host=test_config.host, + client = client_factory(host=test_config.host, port=test_config.port, username=test_config.username, password=test_config.password, ca_cert=cert_file) - assert isinstance(client.http, ProxyManager) - assert '2' in client.command('SELECT version()') - client.close() + if hasattr(client, 'http'): + # Sync client uses urllib3 + assert isinstance(client.http, ProxyManager) + else: + # Async client uses aiohttp + assert hasattr(client, '_proxy_url') and client._proxy_url is not None + assert '2' in call(client.command, 'SELECT version()') finally: os.environ.pop('HTTPS_PROXY', None) diff --git a/tests/integration_tests/test_pyarrow_ddl_integration.py b/tests/integration_tests/test_pyarrow_ddl_integration.py index cd3106e3..14d9f0cd 100644 --- a/tests/integration_tests/test_pyarrow_ddl_integration.py +++ b/tests/integration_tests/test_pyarrow_ddl_integration.py @@ -13,15 +13,15 @@ pytest.importorskip("pyarrow") -def test_arrow_create_table_and_insert(test_client: Client): - if not test_client.min_version("20"): +def test_arrow_create_table_and_insert(param_client: Client, call): + if not param_client.min_version("20"): pytest.skip( - f"Not supported server version {test_client.server_version}" + f"Not supported server version {param_client.server_version}" ) table_name = "test_arrow_basic_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -38,7 +38,7 @@ def test_arrow_create_table_and_insert(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -50,9 +50,9 @@ def test_arrow_create_table_and_insert(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query( + result = call(param_client.query, f"SELECT id, name, score, flag FROM {table_name} ORDER BY id" ) assert result.result_rows == [ @@ -60,13 +60,13 @@ def test_arrow_create_table_and_insert(test_client: Client): (2, "b", 2.5, False), ] - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") -def test_arrow_schema_to_column_defs(test_client: Client): +def test_arrow_schema_to_column_defs(param_client: Client, call): table_name = "test_arrow_manual_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -84,7 +84,7 @@ def test_arrow_schema_to_column_defs(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -94,26 +94,26 @@ def test_arrow_schema_to_column_defs(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query(f"SELECT id, name FROM {table_name} ORDER BY id") + result = call(param_client.query, f"SELECT id, name FROM {table_name} ORDER BY id") assert result.result_rows == [ (10, "x"), (20, "y"), ] - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") -def test_arrow_datetime_create_and_insert(test_client: Client): - if not test_client.min_version("20"): +def test_arrow_datetime_create_and_insert(param_client: Client, call): + if not param_client.min_version("20"): pytest.skip( - f"Not supported server version {test_client.server_version}" + f"Not supported server version {param_client.server_version}" ) table_name = "test_arrow_datetime_integration" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") schema = pa.schema( [ @@ -130,7 +130,7 @@ def test_arrow_datetime_create_and_insert(test_client: Client): engine="MergeTree", engine_params={"ORDER BY": "id"}, ) - test_client.command(ddl) + call(param_client.command, ddl) arrow_table = pa.table( { @@ -148,9 +148,9 @@ def test_arrow_datetime_create_and_insert(test_client: Client): schema=schema, ) - test_client.insert_arrow(table=table_name, arrow_table=arrow_table) + call(param_client.insert_arrow, table=table_name, arrow_table=arrow_table) - result = test_client.query( + result = call(param_client.query, f"SELECT id, event_date, event_ts, event_ts_tz " f"FROM {table_name} ORDER BY id" ) @@ -162,4 +162,4 @@ def test_arrow_datetime_create_and_insert(test_client: Client): assert rows[1][0] == 2 assert str(rows[1][1]) == "2025-01-02" - test_client.command(f"DROP TABLE IF EXISTS {table_name}") + call(param_client.command, f"DROP TABLE IF EXISTS {table_name}") diff --git a/tests/integration_tests/test_raw_insert.py b/tests/integration_tests/test_raw_insert.py index 518088ac..daf93c86 100644 --- a/tests/integration_tests/test_raw_insert.py +++ b/tests/integration_tests/test_raw_insert.py @@ -4,31 +4,31 @@ from clickhouse_connect.driver import Client -def test_raw_insert(test_client: Client, table_context: Callable): +def test_raw_insert(param_client: Client, call, table_context: Callable): with table_context('test_raw_insert', ["`weir'd` String", 'value String']): csv = 'value1\nvalue2' - test_client.raw_insert('test_raw_insert', ['"weir\'d"'], csv.encode(), fmt='CSV') - result = test_client.query('SELECT * FROM test_raw_insert') + call(param_client.raw_insert, 'test_raw_insert', ['"weir\'d"'], csv.encode(), fmt='CSV') + result = call(param_client.query, 'SELECT * FROM test_raw_insert') assert result.result_set[1][0] == 'value2' - test_client.command('TRUNCATE TABLE test_raw_insert') + call(param_client.command, 'TRUNCATE TABLE test_raw_insert') tsv = 'weird1\tvalue__`2\nweird2\tvalue77' - test_client.raw_insert('test_raw_insert', ["`weir'd`", 'value'], tsv, fmt='TSV') - result = test_client.query('SELECT * FROM test_raw_insert') + call(param_client.raw_insert, 'test_raw_insert', ["`weir'd`", 'value'], tsv, fmt='TSV') + result = call(param_client.query, 'SELECT * FROM test_raw_insert') assert result.result_set[0][1] == 'value__`2' assert result.result_set[1][1] == 'value77' -def test_raw_insert_compression(test_client: Client, table_context: Callable): +def test_raw_insert_compression(param_client: Client, call, table_context: Callable): data_file = f'{Path(__file__).parent}/movies.csv.gz' with open(data_file, mode='rb') as movies_file: data = movies_file.read() with table_context('test_gzip_movies', ['movie String', 'year UInt16', 'rating Decimal32(3)']): - test_client.raw_insert('test_gzip_movies', None, data, fmt='CSV', compression='gzip', + call(param_client.raw_insert, 'test_gzip_movies', None, data, fmt='CSV', compression='gzip', settings={'input_format_allow_errors_ratio': .2, 'input_format_allow_errors_num': 5} ) - res = test_client.query( + res = call(param_client.query, 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_gzip_movies').first_item assert res['count'] == 248 assert res['year'] == 2022 diff --git a/tests/integration_tests/test_session_id.py b/tests/integration_tests/test_session_id.py index 17ac9112..11a656b5 100644 --- a/tests/integration_tests/test_session_id.py +++ b/tests/integration_tests/test_session_id.py @@ -4,25 +4,22 @@ import pytest -from clickhouse_connect.driver import create_async_client -from tests.integration_tests.conftest import TestConfig - SESSION_KEY = 'session_id' -def test_client_default_session_id(test_create_client: Callable): +def test_client_default_session_id(client_factory: Callable): # by default, the sync client will autogenerate the session id - client = test_create_client() + # for async clients, we need to explicitly enable it + client = client_factory(autogenerate_session_id=True) session_id = client.get_client_setting(SESSION_KEY) try: uuid.UUID(session_id) except ValueError: pytest.fail(f"Invalid session_id: {session_id}") - client.close() -def test_client_autogenerate_session_id(test_create_client: Callable): - client = test_create_client() +def test_client_autogenerate_session_id(client_factory: Callable): + client = client_factory(autogenerate_session_id=True) session_id = client.get_client_setting(SESSION_KEY) try: uuid.UUID(session_id) @@ -30,49 +27,23 @@ def test_client_autogenerate_session_id(test_create_client: Callable): pytest.fail(f"Invalid session_id: {session_id}") -def test_client_custom_session_id(test_create_client: Callable): +def test_client_custom_session_id(client_factory: Callable): session_id = 'custom_session_id' - client = test_create_client(session_id=session_id) + client = client_factory(session_id=session_id) assert client.get_client_setting(SESSION_KEY) == session_id - client.close() -@pytest.mark.asyncio -async def test_async_client_default_session_id(test_config: TestConfig): - # by default, the async client will NOT autogenerate the session id - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password) - assert async_client.get_client_setting(SESSION_KEY) is None - await async_client.close() +def test_explicit_session_id(client_factory: Callable, call): + """Test explicit session_id allows sharing state like temp tables.""" + session_id = f"test_session_{uuid.uuid4()}" + client = client_factory(session_id=session_id) + assert client.get_client_setting("session_id") == session_id -@pytest.mark.asyncio -async def test_async_client_autogenerate_session_id(test_config: TestConfig): - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password, - autogenerate_session_id=True) - session_id = async_client.get_client_setting(SESSION_KEY) - try: - uuid.UUID(session_id) - except ValueError: - pytest.fail(f"Invalid session_id: {session_id}") - await async_client.close() - + call(client.command, "CREATE TEMPORARY TABLE temp_test (id UInt32, val String)") + call(client.command, "INSERT INTO temp_test VALUES (1, 'a'), (2, 'b')") -@pytest.mark.asyncio -async def test_async_client_custom_session_id(test_config: TestConfig): - session_id = 'custom_session_id' - async_client = await create_async_client(database=test_config.test_database, - host=test_config.host, - port=test_config.port, - user=test_config.username, - password=test_config.password, - session_id=session_id) - assert async_client.get_client_setting(SESSION_KEY) == session_id - await async_client.close() + result = call(client.query, "SELECT * FROM temp_test ORDER BY id") + assert result.row_count == 2 + assert result.result_rows[0] == (1, "a") + assert result.result_rows[1] == (2, "b") diff --git a/tests/integration_tests/test_sqlalchemy/conftest.py b/tests/integration_tests/test_sqlalchemy/conftest.py index 7cf76abe..3fbb180d 100644 --- a/tests/integration_tests/test_sqlalchemy/conftest.py +++ b/tests/integration_tests/test_sqlalchemy/conftest.py @@ -11,11 +11,17 @@ @fixture(scope='module', name='test_engine') def test_engine_fixture(test_config: TestConfig) -> Iterator[Engine]: - test_engine: Engine = create_engine( - f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:' + - f'{test_config.port}/{test_config.test_database}?ch_http_max_field_name_size=99999' + + conn_str = ( + f'clickhousedb://{test_config.username}:{test_config.password}@{test_config.host}:' + f'{test_config.port}/{test_config.test_database}?ch_http_max_field_name_size=99999' '&use_skip_indexes=0&ca_cert=certifi&query_limit=2333&compression=zstd' ) + if test_config.cloud: + conn_str += '&select_sequential_consistency=1' + if test_config.insert_quorum: + conn_str += f'&insert_quorum={test_config.insert_quorum}' + + test_engine: Engine = create_engine(conn_str) yield test_engine test_engine.dispose() diff --git a/tests/integration_tests/test_sqlalchemy/test_delete.py b/tests/integration_tests/test_sqlalchemy/test_delete.py index 0bf9e189..f07f2add 100644 --- a/tests/integration_tests/test_sqlalchemy/test_delete.py +++ b/tests/integration_tests/test_sqlalchemy/test_delete.py @@ -6,6 +6,7 @@ from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import String, UInt64 from clickhouse_connect.cc_sqlalchemy.ddl.tableengine import engine_map +from tests.integration_tests.test_sqlalchemy.conftest import verify_tables_ready def test_delete_with_table_object(test_engine: Engine, test_db: str, test_table_engine: str): @@ -179,6 +180,8 @@ def test_explicit_delete(test_engine: Engine, test_table_engine: str): conn.execute(db.insert(test_table).values({"id": 1, "name": "hello world"})) conn.execute(db.insert(test_table).values({"id": 2, "name": "test data"})) conn.execute(db.insert(test_table).values({"id": 3, "name": "hello test"})) + # Wait for inserts to complete in cloud environments + verify_tables_ready(conn, {"delete_explicit_test": 3}) starting = conn.execute(db.select(test_table).order_by(test_table.c.id)).fetchall() assert len(starting) == 3 assert [row.id for row in starting] == [1, 2, 3] diff --git a/tests/integration_tests/test_streaming.py b/tests/integration_tests/test_streaming.py index 8be1ed82..a6f10f15 100644 --- a/tests/integration_tests/test_streaming.py +++ b/tests/integration_tests/test_streaming.py @@ -1,81 +1,142 @@ import random import string +import pytest -from clickhouse_connect.driver import Client from clickhouse_connect.driver.exceptions import StreamClosedError, ProgrammingError, StreamFailureError -def test_row_stream(test_client: Client): - row_stream = test_client.query_rows_stream('SELECT number FROM numbers(10000)') +def test_row_stream(param_client, call, consume_stream): + stream = call(param_client.query_rows_stream, 'SELECT number FROM numbers(10000)') total = 0 - with row_stream: - for row in row_stream: - total += row[0] - try: - with row_stream: - pass - except StreamClosedError: - pass + + def process(row): + nonlocal total + total += row[0] + + consume_stream(stream, process) + + # Verify stream is closed by trying to consume it again + # This logic relies on consume_stream handling the context manager which checks state + with pytest.raises(StreamClosedError): + consume_stream(stream, lambda x: None) + assert total == 49995000 -def test_column_block_stream(test_client: Client): +def test_column_block_stream(param_client, call, consume_stream): random_string = 'randomStringUTF8(50)' - if not test_client.min_version('20'): + if not param_client.min_version('20'): random_string = random.choices(string.ascii_lowercase, k=50) - block_stream = test_client.query_column_block_stream(f'SELECT number, {random_string} FROM numbers(10000)', - settings={'max_block_size': 4000}) + stream = call(param_client.query_column_block_stream, + f'SELECT number, {random_string} FROM numbers(10000)', + settings={'max_block_size': 4000}) total = 0 block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - total += sum(block[0]) + + def process(block): + nonlocal total, block_count + block_count += 1 + total += sum(block[0]) + + consume_stream(stream, process) + assert total == 49995000 assert block_count > 1 -def test_row_block_stream(test_client: Client): +def test_row_block_stream(param_client, call, consume_stream): random_string = 'randomStringUTF8(50)' - if not test_client.min_version('20'): + if not param_client.min_version('20'): random_string = random.choices(string.ascii_lowercase, k=50) - block_stream = test_client.query_row_block_stream(f'SELECT number, {random_string} FROM numbers(10000)', - settings={'max_block_size': 4000}) + stream = call(param_client.query_row_block_stream, + f'SELECT number, {random_string} FROM numbers(10000)', + settings={'max_block_size': 4000}) total = 0 block_count = 0 - with block_stream: - for block in block_stream: - block_count += 1 - for row in block: - total += row[0] + + def process(block): + nonlocal total, block_count + block_count += 1 + for row in block: + total += row[0] + + consume_stream(stream, process) + assert total == 49995000 assert block_count > 1 -def test_stream_errors(test_client: Client): +def test_stream_errors_sync(test_client): query_result = test_client.query('SELECT number FROM numbers(100000)') - try: + + # 1. Test accessing without context manager raises error + with pytest.raises(ProgrammingError, match="context"): for _ in query_result.row_block_stream: pass - except ProgrammingError as ex: - assert 'context' in str(ex) + assert query_result.row_count == 100000 - try: + + # 2. Test that previous access consumed the generator, so next access raises StreamClosedError + with pytest.raises(StreamClosedError): with query_result.rows_stream as stream: - assert sum(row[0] for row in stream) == 3882 - except StreamClosedError: - pass + for _ in stream: + pass -def test_stream_failure(test_client: Client): - with test_client.query_row_block_stream('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + - ' where intDiv(1,number-300000)>-100000000') as stream: - blocks = 0 - failed = False - try: +@pytest.mark.asyncio +async def test_stream_errors_async(test_native_async_client): + stream = await test_native_async_client.query_row_block_stream('SELECT number FROM numbers(100)') + async with stream: + async for _ in stream: + pass + + # Try to reuse + with pytest.raises(StreamClosedError): + async with stream: + async for _ in stream: + pass + + +def test_stream_failure_sync(test_client): + query = ('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + + ' where intDiv(1,number-300000)>-100000000') + + stream = test_client.query_row_block_stream(query) + + with pytest.raises(StreamFailureError) as excinfo: + with stream: for _ in stream: - blocks += 1 - except StreamFailureError as ex: - failed = True - assert 'division by zero' in str(ex).lower() - assert failed + pass + + error_msg = str(excinfo.value).lower() + # Race condition: may get actual ClickHouse error or generic connection closed + assert 'division by zero' in error_msg or 'connection closed' in error_msg + + +@pytest.mark.asyncio +async def test_stream_failure_async(test_native_async_client): + query = ('SELECT toString(cityHash64(number)) FROM numbers(10000000)' + + ' where intDiv(1,number-300000)>-100000000') + + stream = await test_native_async_client.query_row_block_stream(query) + + with pytest.raises(StreamFailureError): + async with stream: + async for _ in stream: + pass + + +def test_raw_stream(param_client, call, consume_stream): + """Test raw_stream for streaming response.""" + chunks = [] + stream = call(param_client.raw_stream, "SELECT number FROM system.numbers LIMIT 1000", fmt="TabSeparated") + + def process(chunk): + nonlocal chunks + chunks.append(chunk) + + consume_stream(stream, process) + + assert len(chunks) > 0 + full_data = b"".join(chunks) + assert len(full_data) > 0 diff --git a/tests/integration_tests/test_temporal.py b/tests/integration_tests/test_temporal.py index 665bd22a..cb1eef1e 100644 --- a/tests/integration_tests/test_temporal.py +++ b/tests/integration_tests/test_temporal.py @@ -1,32 +1,22 @@ -from datetime import timedelta, time -from typing import List, Any +from datetime import time, timedelta +from typing import Any, List + import pytest -from clickhouse_connect.driver import Client from tests.integration_tests.conftest import TestConfig -# pylint: disable=no-self-use - +# Module-level version and cloud checks @pytest.fixture(autouse=True, scope="module") def module_setup_and_checks(test_client, test_config: TestConfig): - """ - Performs all module-level setup: - - Skips if in a cloud environment where settings are locked. - - Skips if the server version is too old for Time/Time64 types. - """ + """Check prerequisites for Time/Time64 type tests.""" if test_config.cloud: - pytest.skip( - "Time/Time64 types require settings change, but settings are locked in cloud, skipping tests.", - allow_module_level=True, - ) + pytest.skip("Time/Time64 types require settings change, but settings are locked in cloud") version_str = test_client.query("SELECT version()").result_rows[0][0] major, minor, *_ = map(int, version_str.split(".")) if (major, minor) < (25, 6): - pytest.skip( - "Time and Time64 types require ClickHouse 25.6+", allow_module_level=True - ) + pytest.skip("Time and Time64 types require ClickHouse 25.6+") class TimeTestData: @@ -77,8 +67,47 @@ class TimeTestData: TIME64_NS_TICKS = [3723123456789, -5500000000, 1] +class ClockTimeData: + """Test data for datetime.time objects.""" + + TIME_OBJS = [ + time(0, 0, 5), + time(1, 2, 3), + time(23, 59, 59), + time(0, 0, 0), + ] + + TIME_DELTAS = [ + timedelta(seconds=5), + timedelta(hours=1, minutes=2, seconds=3), + timedelta(hours=23, minutes=59, seconds=59), + timedelta(0), + ] + + TIME64_US_OBJS = [ + time(1, 2, 3, 123456), + time(0, 0, 0, 1), + time(23, 59, 59, 999999), + ] + + TIME64_NS_OBJS = [ + time(1, 2, 3, 123456), + time(0, 0, 0), + time(23, 59, 59, 123000), + ] + + TABLE_NAME = "temp_time_test" -COLUMN_COUNT = 7 + +STANDARD_TIME_TABLE_SCHEMA = [ + "id UInt32", + "t Time", + "t_nullable Nullable(Time)", + "t64_us Time64(6)", + "t64_us_nullable Nullable(Time64(6))", + "t64_ns Time64(9)", + "t64_ns_nullable Nullable(Time64(9))", +] def create_test_row(row_id: int, time_val: Any) -> List[Any]: @@ -99,106 +128,70 @@ def create_nullable_test_row(row_id: int, **column_values) -> List[Any]: ] -@pytest.fixture(autouse=True) -def setup_time_table(test_client: Client): - """Setup and teardown test table with Time and Time64 columns.""" - client = test_client - - client.command("SET enable_time_time64_type = 1") - - client.command(f"DROP TABLE IF EXISTS {TABLE_NAME}") - client.command( - f""" - CREATE TABLE {TABLE_NAME} ( - id UInt32, - t Time, - t_nullable Nullable(Time), - t64_us Time64(6), - t64_us_nullable Nullable(Time64(6)), - t64_ns Time64(9), - t64_ns_nullable Nullable(Time64(9)) - ) ENGINE = MergeTree ORDER BY id - """ - ) - - yield client - - client.command(f"DROP TABLE IF EXISTS {TABLE_NAME}") - - -def insert_test_data(client: Client, rows: List[List[Any]]) -> None: - """Insert test data into the test table.""" - client.insert(TABLE_NAME, rows) - - -def query_column(client: Client, column: str, **query_formats) -> List[Any]: - """Query a single column ordered by ID with optional format specifications.""" - query_result = client.query( - f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats=query_formats - ) - return [row[0] for row in query_result.result_rows] - - -class TestTimeRoundtrip: - """Test round-trip conversion for Time type.""" - - def test_time_native_format(self, test_client: Client): - """Test Time round-trip with native timedelta format.""" +def test_time_native_format(param_client, call, table_context): + """Test Time round-trip with native timedelta format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, td) for i, td in enumerate(TimeTestData.TIME_DELTAS)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == TimeTestData.TIME_DELTAS - - def test_time_string_format(self, test_client: Client): - """Test Time round-trip with string format.""" - rows = [create_test_row(i, s) for i, s in enumerate(TimeTestData.TIME_STRINGS)] - insert_test_data(test_client, rows) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_DELTAS - result = query_column(test_client, "t", Time="string") - assert result == TimeTestData.TIME_STRINGS - def test_time_int_format(self, test_client: Client): - """Test Time round-trip with integer format.""" - rows = [create_test_row(i, val) for i, val in enumerate(TimeTestData.TIME_INTS)] - insert_test_data(test_client, rows) - - result = query_column(test_client, "t", Time="int") - assert result == TimeTestData.TIME_INTS +def test_time_string_format(param_client, call, table_context): + """Test Time round-trip with string format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): + rows = [create_test_row(i, s) for i, s in enumerate(TimeTestData.TIME_STRINGS)] + call(param_client.insert, TABLE_NAME, rows) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "string"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_STRINGS -class TestTime64Roundtrip: - """Test round-trip conversion for Time64 types.""" - @pytest.mark.parametrize( - "column,strings,deltas,ticks,type_name", - [ - ( - "t64_us", - TimeTestData.TIME64_US_STRINGS, - TimeTestData.TIME64_US_DELTAS, - TimeTestData.TIME64_US_TICKS, - "Time64", - ), - ( - "t64_ns", - TimeTestData.TIME64_NS_STRINGS, - TimeTestData.TIME64_NS_DELTAS, - TimeTestData.TIME64_NS_TICKS, - "Time64", - ), - ], - ) - def test_time64_all_formats( - self, - test_client: Client, - column: str, - strings: List[str], - deltas: List[timedelta], - ticks: List[int], - type_name: str, - ): - """Test Time64 round-trip with all supported formats.""" +def test_time_int_format(param_client, call, table_context): + """Test Time round-trip with integer format.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): + rows = [create_test_row(i, val) for i, val in enumerate(TimeTestData.TIME_INTS)] + call(param_client.insert, TABLE_NAME, rows) + + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "int"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == TimeTestData.TIME_INTS + + +@pytest.mark.parametrize( + "column,strings,deltas,ticks,type_name", + [ + ( + "t64_us", + TimeTestData.TIME64_US_STRINGS, + TimeTestData.TIME64_US_DELTAS, + TimeTestData.TIME64_US_TICKS, + "Time64", + ), + ( + "t64_ns", + TimeTestData.TIME64_NS_STRINGS, + TimeTestData.TIME64_NS_DELTAS, + TimeTestData.TIME64_NS_TICKS, + "Time64", + ), + ], +) +def test_time64_all_formats( + param_client, + call, + table_context, + column: str, + strings: List[str], + deltas: List[timedelta], + ticks: List[int], + type_name: str, +): + """Test Time64 round-trip with all supported formats.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [] for i, string_val in enumerate(strings): row = [i, timedelta(0), None, timedelta(0), None, timedelta(0), None] @@ -208,23 +201,24 @@ def test_time64_all_formats( row[5] = string_val rows.append(row) - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, column) - assert result == deltas + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == deltas - result = query_column(test_client, column, **{type_name: "string"}) - assert result == strings + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={type_name: "string"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == strings - result = query_column(test_client, column, **{type_name: "int"}) - assert result == ticks + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={type_name: "int"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == ticks -class TestNullableColumns: - """Test nullable Time and Time64 columns.""" - - def test_nullable_time_columns(self, test_client: Client): - """Test that nullable columns handle None values correctly.""" +def test_nullable_time_columns(param_client, call, table_context): + """Test that nullable columns handle None values correctly.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [ create_nullable_test_row( 0, @@ -245,163 +239,136 @@ def test_nullable_time_columns(self, test_client: Client): t64_ns_nullable="07:00:00.123000000", ), ] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t_nullable") + result = call(param_client.query, f"SELECT t_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=5)] - assert result == expected + assert result_values == expected - result = query_column(test_client, "t64_us_nullable") + result = call(param_client.query, f"SELECT t64_us_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=6, microseconds=789000)] - assert result == expected + assert result_values == expected - result = query_column(test_client, "t64_ns_nullable") + result = call(param_client.query, f"SELECT t64_ns_nullable FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = [None, timedelta(hours=7, microseconds=123000)] - assert result == expected - - -class TestErrorHandling: - """Test error handling for invalid inputs.""" - - @pytest.mark.parametrize( - "invalid_value", - [ - "1000:00:00", # Out of range string - 3600000, # Out of range int - ], - ) - def test_time_out_of_range_values(self, test_client: Client, invalid_value: Any): - """Test that out-of-range Time values raise ValueError.""" + assert result_values == expected + + +@pytest.mark.parametrize( + "invalid_value", + [ + "1000:00:00", # Out of range string + 3600000, # Out of range int + ], +) +def test_time_out_of_range_values(param_client, call, table_context, invalid_value: Any): + """Test that out-of-range Time values raise ValueError.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): with pytest.raises(ValueError, match="out of range"): rows = [create_test_row(0, invalid_value)] - insert_test_data(test_client, rows) - - @pytest.mark.parametrize( - "time_val,time64_val", - [ - ("1:2:3:4", "1:2:3:4"), # Too many colons - ("10:70:00", "10:70:00"), # Invalid minutes - ("10:00:00.123.456", "10:00:00.123.456"), # Invalid fractional format - ], - ) - def test_invalid_time_formats( - self, test_client: Client, time_val: str, time64_val: str - ): - """Test that invalid time formats raise ValueError.""" + call(param_client.insert, TABLE_NAME, rows) + + +@pytest.mark.parametrize( + "time_val,time64_val", + [ + ("1:2:3:4", "1:2:3:4"), # Too many colons + ("10:70:00", "10:70:00"), # Invalid minutes + ("10:00:00.123.456", "10:00:00.123.456"), # Invalid fractional format + ], +) +def test_invalid_time_formats(param_client, call, table_context, time_val: str, time64_val: str): + """Test that invalid time formats raise ValueError.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): with pytest.raises(ValueError): - rows = [ - create_nullable_test_row( - 0, t=time_val, t64_us=time64_val, t64_ns=time64_val - ) - ] - insert_test_data(test_client, rows) - + rows = [create_nullable_test_row(0, t=time_val, t64_us=time64_val, t64_ns=time64_val)] + call(param_client.insert, TABLE_NAME, rows) -class TestMixedInputTypes: - """Test handling of mixed input types.""" - def test_timedelta_input_conversion(self, test_client: Client): - """Test conversion of timedelta inputs to internal representation.""" +def test_timedelta_input_conversion(param_client, call, table_context): + """Test conversion of timedelta inputs to internal representation.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): test_deltas = TimeTestData.TIME_DELTAS[:3] rows = [create_test_row(i, td) for i, td in enumerate(test_deltas)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == test_deltas - assert all(isinstance(td, timedelta) for td in result) + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == test_deltas + assert all(isinstance(td, timedelta) for td in result_values) - def test_integer_input_conversion(self, test_client: Client): - """Test conversion of integer inputs to internal representation.""" + +def test_integer_input_conversion(param_client, call, table_context): + """Test conversion of integer inputs to internal representation.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): test_ints = TimeTestData.TIME_INTS rows = [create_test_row(i, val) for i, val in enumerate(test_ints)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] expected = TimeTestData.TIME_DELTAS - assert result == expected + assert result_values == expected -class ClockTimeData: - TIME_OBJS = [ - time(0, 0, 5), - time(1, 2, 3), - time(23, 59, 59), - time(0, 0, 0), - ] - - TIME_DELTAS = [ - timedelta(seconds=5), - timedelta(hours=1, minutes=2, seconds=3), - timedelta(hours=23, minutes=59, seconds=59), - timedelta(0), - ] - - TIME64_US_OBJS = [ - time(1, 2, 3, 123456), - time(0, 0, 0, 1), - time(23, 59, 59, 999999), - ] - - TIME64_NS_OBJS = [ - time(1, 2, 3, 123456), # 123 456 789 ns -> 123 456 us - time(0, 0, 0), # 1 ns (floor) - time(23, 59, 59, 123000), # 123 000 µs - ] - - -class TestTimeDatetimeTimeRoundtrip: +def test_time_roundtrip_time_format(param_client, call, table_context): """Ensure Time columns accept & return datetime.time when format='time'.""" - - def test_roundtrip_time_format(self, test_client: Client): + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(ClockTimeData.TIME_OBJS)] - insert_test_data(test_client, rows) - result = query_column(test_client, "t", Time="time") - assert result == ClockTimeData.TIME_OBJS + call(param_client.insert, TABLE_NAME, rows) - def test_default_read_from_time_objects(self, test_client: Client): - """Writing time objects + default read still yields timedelta.""" + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id", query_formats={"Time": "time"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == ClockTimeData.TIME_OBJS + + +def test_time_default_read_from_time_objects(param_client, call, table_context): + """Writing time objects + default read still yields timedelta.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(ClockTimeData.TIME_OBJS)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, "t") - assert result == ClockTimeData.TIME_DELTAS + result = call(param_client.query, f"SELECT t FROM {TABLE_NAME} ORDER BY id") + result_values = [row[0] for row in result.result_rows] + assert result_values == ClockTimeData.TIME_DELTAS -class TestTime64DatetimeTimeRoundtrip: +@pytest.mark.parametrize( + "column,objects", + [ + ("t64_us", ClockTimeData.TIME64_US_OBJS), + ("t64_ns", ClockTimeData.TIME64_NS_OBJS), + ], +) +def test_time64_time_format(param_client, call, table_context, column: str, objects: List[time]): """Validate Time64(6/9) ⇄ datetime.time conversions.""" - - @pytest.mark.parametrize( - "column,objects", - [ - ("t64_us", ClockTimeData.TIME64_US_OBJS), - ("t64_ns", ClockTimeData.TIME64_NS_OBJS), - ], - ) - def test_time64_time_format( - self, test_client: Client, column: str, objects: List[time] - ): + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(i, t) for i, t in enumerate(objects)] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) - result = query_column(test_client, column, Time64="time") - assert result == objects + result = call(param_client.query, f"SELECT {column} FROM {TABLE_NAME} ORDER BY id", query_formats={"Time64": "time"}) + result_values = [row[0] for row in result.result_rows] + assert result_values == objects -class TestTimeFormatErrorHandling: - """Errors that only show up when requesting format='time'.""" - - def test_negative_value_cannot_be_coerced_to_time(self, test_client: Client): - """Database contains -2 s; asking for format='time' should fail.""" +def test_negative_value_cannot_be_coerced_to_time(param_client, call, table_context): + """Database contains -2 s; asking for format='time' should fail.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(0, timedelta(seconds=-2))] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) with pytest.raises(ValueError, match="outside valid range"): - query_column(test_client, "t", Time="time") + call(param_client.query, f"SELECT t FROM {TABLE_NAME}", query_formats={"Time": "time"}) + - def test_over_24h_value_cannot_be_coerced_to_time(self, test_client: Client): - """30 h is legal for ClickHouse but illegal for datetime.time.""" +def test_over_24h_value_cannot_be_coerced_to_time(param_client, call, table_context): + """30 h is legal for ClickHouse but illegal for datetime.time.""" + with table_context(TABLE_NAME, STANDARD_TIME_TABLE_SCHEMA, settings={"enable_time_time64_type": 1}): rows = [create_test_row(0, "030:00:00")] - insert_test_data(test_client, rows) + call(param_client.insert, TABLE_NAME, rows) with pytest.raises(ValueError, match="outside valid range"): - query_column(test_client, "t", Time="time") + call(param_client.query, f"SELECT t FROM {TABLE_NAME}", query_formats={"Time": "time"}) diff --git a/tests/integration_tests/test_timezones.py b/tests/integration_tests/test_timezones.py index 0e4603b9..a4cfd2f7 100644 --- a/tests/integration_tests/test_timezones.py +++ b/tests/integration_tests/test_timezones.py @@ -15,8 +15,8 @@ # pylint:disable=protected-access -def test_basic_timezones(test_client: Client): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + +def test_basic_timezones(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as utc", query_tz='America/Chicago').first_row @@ -28,8 +28,8 @@ def test_basic_timezones(test_client: Client): assert row[1].hour == 10 assert row[1].day == 5 - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime64('2022-10-25 10:55:22.789123', 6, 'America/Chicago')", + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime64('2022-10-25 10:55:22.789123', 6, 'America/Chicago')", query_tz='America/Chicago').first_row assert row[0].tzinfo == chicago_tz assert row[0].hour == 10 @@ -37,14 +37,14 @@ def test_basic_timezones(test_client: Client): assert row[0].microsecond == 789123 -def test_server_timezone(test_client: Client): +def test_server_timezone(param_client: Client, call): # This test is really for manual testing since changing the timezone on the test ClickHouse server # still requires a restart. Other tests will depend on https://github.com/ClickHouse/ClickHouse/pull/44149 - test_client.tz_source = "server" + param_client.tz_source = "server" test_datetime = datetime(2023, 3, 18, 16, 4, 25) try: - date = test_client.query('SELECT toDateTime(%s) as st', parameters=[test_datetime]).first_row[0] - if test_client.server_tz == pytz.UTC: + date = call(param_client.query, 'SELECT toDateTime(%s) as st', parameters=[test_datetime]).first_row[0] + if param_client.server_tz == pytz.UTC: assert date.tzinfo is None assert date == datetime(2023, 3, 18, 16, 4, 25, tzinfo=None) assert date.timestamp() == 1679155465 @@ -54,15 +54,15 @@ def test_server_timezone(test_client: Client): assert date.tzinfo == den_tz assert date.timestamp() == 1679177065 finally: - test_client.tz_source = "auto" + param_client.tz_source = "auto" -def test_column_timezones(test_client: Client): +def test_column_timezones(param_client: Client, call): date_tz64 = "toDateTime64('2023-01-02 15:44:22.7832', 6, 'Asia/Shanghai')" - if not test_client.min_version('20'): + if not param_client.min_version('20'): date_tz64 = "toDateTime('2023-01-02 15:44:22', 'Asia/Shanghai')" column_tzs = {'chicago': 'America/Chicago', 'china': 'Asia/Shanghai'} - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + f'{date_tz64} as china,' + "toDateTime('2023-07-05 15:10:40') as utc", column_tzs=column_tzs).first_row @@ -71,27 +71,27 @@ def test_column_timezones(test_client: Client): assert row[1].tzinfo == china_tz assert row[2].tzinfo is None - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime64('2023-01-02 15:44:22.7832', 6, 'Asia/Shanghai') as china").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo == chicago_tz else: assert row[0].tzinfo is None assert row[1].tzinfo == china_tz # DateTime64 columns work correctly -def test_local_timezones(test_client: Client): +def test_local_timezones(param_client: Client, call): denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.tz_source = "local" + param_client.tz_source = "local" try: - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22'," + + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22'," + "'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as raw_utc_dst," + "toDateTime('2023-07-05 12:44:22', 'UTC') as forced_utc," + "toDateTime('2023-12-31 17:00:55') as raw_utc_std").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo.tzname(None) == chicago_tz.tzname(None) else: assert row[0].tzinfo.tzname(None) == denver_tz.tzname(None) @@ -100,98 +100,98 @@ def test_local_timezones(test_client: Client): assert row[3].tzinfo.tzname(None) == denver_tz.tzname(None) finally: tzutil.local_tz = pytz.UTC - test_client.tz_source = "auto" + param_client.tz_source = "auto" -def test_naive_timezones(test_client: Client): - row = test_client.query("SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + +def test_naive_timezones(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2022-10-25 10:55:22', 'America/Chicago') as chicago," + "toDateTime('2023-07-05 15:10:40') as utc").first_row - if test_client.protocol_version: + if param_client.protocol_version: assert row[0].tzinfo == chicago_tz else: assert row[0].tzinfo is None assert row[1].tzinfo is None -def test_timezone_binding_client(test_client: Client): +def test_timezone_binding_client(param_client: Client, call): os.environ['TZ'] = 'America/Denver' time.tzset() denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.tz_source = "local" + param_client.tz_source = "local" denver_time = datetime(2023, 3, 18, 16, 4, 25, tzinfo=denver_tz) try: - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': denver_time}).first_row[0] assert server_time == denver_time finally: os.environ['TZ'] = 'UTC' tzutil.local_tz = pytz.UTC time.tzset() - test_client.tz_source = "auto" + param_client.tz_source = "auto" naive_time = datetime(2023, 3, 18, 16, 4, 25) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': naive_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == naive_time.astimezone(pytz.UTC) utc_time = datetime(2023, 3, 18, 16, 4, 25, tzinfo=pytz.UTC) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime(%(dt)s) as dt', parameters={'dt': utc_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == utc_time -def test_timezone_binding_server(test_client: Client): +def test_timezone_binding_server(param_client: Client, call): os.environ['TZ'] = 'America/Denver' time.tzset() denver_tz = pytz.timezone('America/Denver') tzutil.local_tz = denver_tz - test_client.tz_source = "local" + param_client.tz_source = "local" denver_time = datetime(2022, 3, 18, 16, 4, 25, tzinfo=denver_tz) try: - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': denver_time}).first_row[0] assert server_time == denver_time finally: os.environ['TZ'] = 'UTC' time.tzset() tzutil.local_tz = pytz.UTC - test_client.tz_source = "auto" + param_client.tz_source = "auto" naive_time = datetime(2022, 3, 18, 16, 4, 25) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': naive_time}).first_row[0] assert naive_time.astimezone(pytz.UTC) == server_time.astimezone(pytz.UTC) utc_time = datetime(2020, 3, 18, 16, 4, 25, tzinfo=pytz.UTC) - server_time = test_client.query( + server_time = call(param_client.query, 'SELECT toDateTime({dt:DateTime}) as dt', parameters={'dt': utc_time}).first_row[0] assert server_time.astimezone(pytz.UTC) == utc_time -def test_tz_mode(test_client: Client): - row = test_client.query("SELECT toDateTime('2023-07-05 15:10:40') as dt," + +def test_tz_mode(param_client: Client, call): + row = call(param_client.query, "SELECT toDateTime('2023-07-05 15:10:40') as dt," + "toDateTime('2023-07-05 15:10:40', 'UTC') as dt_utc", query_tz='UTC').first_row assert row[0].tzinfo is None assert row[1].tzinfo is None - row = test_client.query("SELECT toDateTime('2023-07-05 15:10:40') as dt," + + row = call(param_client.query, "SELECT toDateTime('2023-07-05 15:10:40') as dt," + "toDateTime('2023-07-05 15:10:40', 'UTC') as dt_utc", query_tz='UTC', tz_mode="aware").first_row assert row[0].tzinfo == pytz.UTC assert row[1].tzinfo == pytz.UTC - if test_client.min_version('20'): - row = test_client.query("SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + + if param_client.min_version('20'): + row = call(param_client.query, "SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + "toDateTime64('2023-07-05 15:10:40.123456', 6, 'UTC') as dt64_utc", query_tz='UTC').first_row assert row[0].tzinfo is None assert row[1].tzinfo is None assert row[0].microsecond == 123456 - row = test_client.query("SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + + row = call(param_client.query, "SELECT toDateTime64('2023-07-05 15:10:40.123456', 6) as dt64," + "toDateTime64('2023-07-05 15:10:40.123456', 6, 'UTC') as dt64_utc", query_tz='UTC', tz_mode="aware").first_row assert row[0].tzinfo == pytz.UTC @@ -199,51 +199,53 @@ def test_tz_mode(test_client: Client): assert row[0].microsecond == 123456 -def test_apply_server_timezone_setter_deprecated(test_client: Client): +def test_apply_server_timezone_setter_deprecated(param_client: Client): """Setting client.apply_server_timezone should emit a DeprecationWarning and update state.""" try: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - test_client.apply_server_timezone = True - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert "apply_server_timezone is deprecated" in str(w[0].message) - assert test_client.tz_source == "server" - assert test_client._apply_server_tz is True + param_client.apply_server_timezone = True + assert len(w) >= 1 + dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(dep_warnings) >= 1 + assert "apply_server_timezone is deprecated" in str(dep_warnings[0].message) + assert param_client.tz_source == "server" + assert param_client._apply_server_tz is True with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - test_client.apply_server_timezone = False - assert len(w) == 1 - assert issubclass(w[0].category, DeprecationWarning) - assert test_client.tz_source == "local" - assert test_client._apply_server_tz is False + param_client.apply_server_timezone = False + assert len(w) >= 1 + dep_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)] + assert len(dep_warnings) >= 1 + assert param_client.tz_source == "local" + assert param_client._apply_server_tz is False finally: - test_client.tz_source = "auto" + param_client.tz_source = "auto" -def test_apply_server_timezone_getter_deprecated(test_client: Client): +def test_apply_server_timezone_getter_deprecated(param_client: Client): """Reading client.apply_server_timezone should emit a DeprecationWarning.""" with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - _ = test_client.apply_server_timezone + _ = param_client.apply_server_timezone assert len(w) == 1 assert issubclass(w[0].category, DeprecationWarning) -def test_tz_source_setter_validates(test_client: Client): +def test_tz_source_setter_validates(param_client: Client): """Setting client.tz_source to an invalid value should raise ProgrammingError.""" with pytest.raises(ProgrammingError, match='tz_source must be'): - test_client.tz_source = "serer" + param_client.tz_source = "serer" -def test_tz_source_setter_auto_restores_dst_safe(test_client: Client): +def test_tz_source_setter_auto_restores_dst_safe(param_client: Client): """Setting tz_source back to 'auto' should re-resolve based on server DST safety.""" - original = test_client._apply_server_tz + original = param_client._apply_server_tz try: - test_client.tz_source = "local" - assert test_client._apply_server_tz is False - test_client.tz_source = "auto" - assert test_client._apply_server_tz == original + param_client.tz_source = "local" + assert param_client._apply_server_tz is False + param_client.tz_source = "auto" + assert param_client._apply_server_tz == original finally: - test_client.tz_source = "auto" + param_client.tz_source = "auto" diff --git a/tests/integration_tests/test_tls.py b/tests/integration_tests/test_tls.py index 378c064a..d0c7087c 100644 --- a/tests/integration_tests/test_tls.py +++ b/tests/integration_tests/test_tls.py @@ -1,9 +1,7 @@ import os import pytest -from urllib3.exceptions import SSLError -from clickhouse_connect import get_client from clickhouse_connect.driver.common import coerce_bool from clickhouse_connect.driver.exceptions import OperationalError from tests.helpers import PROJECT_ROOT_DIR @@ -13,39 +11,41 @@ host = 'server1.clickhouse.test' -def test_basic_tls(): +def test_basic_tls(client_factory, call): if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_TEST_TLS', 'False')): pytest.skip('TLS tests not enabled') - client = get_client(interface='https', host=host, port=10843, verify=False) - assert client.command("SELECT 'insecure'") == 'insecure' - client.close_connections() + client = client_factory(interface='https', host=host, port=10843, verify=False, database='default') + assert call(client.command, "SELECT 'insecure'") == 'insecure' - client = get_client(interface='https', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt') - assert client.command("SELECT 'verify_server'") == 'verify_server' - client.close_connections() + client = client_factory(interface='https', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt', database='default') + assert call(client.command, "SELECT 'verify_server'") == 'verify_server' try: - get_client(interface='https', host='localhost', port=10843, ca_cert=f'{cert_dir}ca.crt') + client_factory(interface='https', host='localhost', port=10843, ca_cert=f'{cert_dir}ca.crt', database='default') pytest.fail('Expected TLS exception with a different hostname') except OperationalError as ex: - assert isinstance(ex.__cause__.reason, SSLError) # pylint: disable=no-member - client.close_connections() + # For sync (urllib3): ex.__cause__.reason is SSLError + # For async (aiohttp): ex.__cause__ is ClientConnectorCertificateError + assert ex.__cause__ is not None + assert 'SSL' in str(ex.__cause__) or 'certificate' in str(ex.__cause__).lower() try: - get_client(interface='https', host='localhost', port=10843) + client_factory(interface='https', host='localhost', port=10843, database='default') pytest.fail('Expected TLS exception with a self-signed cert') except OperationalError as ex: - assert isinstance(ex.__cause__.reason, SSLError) # pylint: disable=no-member + assert ex.__cause__ is not None + assert 'SSL' in str(ex.__cause__) or 'certificate' in str(ex.__cause__).lower() -def test_mutual_tls(): +def test_mutual_tls(client_factory, call): if not coerce_bool(os.environ.get('CLICKHOUSE_CONNECT_TEST_TLS', 'False')): pytest.skip('TLS tests not enabled') - client = get_client(interface='https', + client = client_factory(interface='https', username='cert_user', host=host, port=10843, ca_cert=f'{cert_dir}ca.crt', client_cert=f'{cert_dir}client.crt', - client_cert_key=f'{cert_dir}client.key') - assert client.command('SELECT user()') == 'cert_user' + client_cert_key=f'{cert_dir}client.key', + database='default') + assert call(client.command, 'SELECT user()') == 'cert_user' diff --git a/tests/integration_tests/test_tools.py b/tests/integration_tests/test_tools.py index 6ed546af..adbd0c4d 100644 --- a/tests/integration_tests/test_tools.py +++ b/tests/integration_tests/test_tools.py @@ -2,39 +2,50 @@ from typing import Callable from clickhouse_connect.driver import Client -from clickhouse_connect.driver.tools import insert_file -from tests.integration_tests.conftest import TestConfig +from clickhouse_connect.driver.tools import insert_file, insert_file_async -def test_csv_upload(test_client: Client, table_context: Callable): +def test_csv_upload(param_client: Client, call, table_context: Callable, client_mode): data_file = f'{Path(__file__).parent}/movies.csv.gz' with table_context('test_csv_upload', ['movie String', 'year UInt16', 'rating Decimal32(3)']): - insert_file(test_client, 'test_csv_upload', data_file, - settings={'input_format_allow_errors_ratio': .2, - 'input_format_allow_errors_num': 5}) - res = test_client.query( + # Use appropriate insert_file version based on client type + if client_mode == "async": + call(insert_file_async, param_client, 'test_csv_upload', data_file, + settings={'input_format_allow_errors_ratio': .2, + 'input_format_allow_errors_num': 5}) + else: # Sync client + insert_file(param_client, 'test_csv_upload', data_file, + settings={'input_format_allow_errors_ratio': .2, + 'input_format_allow_errors_num': 5}) + res = call(param_client.query, 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_csv_upload').first_item assert res['count'] == 248 assert res['year'] == 2022 -def test_parquet_upload(test_config: TestConfig, test_client: Client, table_context: Callable): +def test_parquet_upload(param_client: Client, call, client_mode, table_context: Callable): data_file = f'{Path(__file__).parent}/movies.parquet' - full_table = f'{test_config.test_database}.test_parquet_upload' - with table_context(full_table, ['movie String', 'year UInt16', 'rating Float64']): - insert_file(test_client, full_table, data_file, 'Parquet', - settings={'output_format_parquet_string_as_string': 1}) - res = test_client.query( - f'SELECT count() as count, sum(rating) as rating, max(year) as year FROM {full_table}').first_item + with table_context('test_parquet_upload', ['movie String', 'year UInt16', 'rating Float64']): + if client_mode == "async": + call(insert_file_async, param_client, 'test_parquet_upload', data_file, 'Parquet', + settings={'output_format_parquet_string_as_string': 1}) + else: + insert_file(param_client, 'test_parquet_upload', data_file, 'Parquet', + settings={'output_format_parquet_string_as_string': 1}) + res = call(param_client.query, + 'SELECT count() as count, sum(rating) as rating, max(year) as year FROM test_parquet_upload').first_item assert res['count'] == 250 assert res['year'] == 2022 -def test_json_insert(test_client: Client, table_context: Callable): +def test_json_insert(param_client: Client, call, client_mode, table_context: Callable): data_file = f'{Path(__file__).parent}/json_test.ndjson' with table_context('test_json_upload', ['key UInt16', 'flt_val Float64', 'int_val Int8']): - insert_file(test_client, 'test_json_upload', data_file, 'JSONEachRow') - res = test_client.query('SELECT * FROM test_json_upload ORDER BY key').result_rows + if client_mode == "async": + call(insert_file_async, param_client, 'test_json_upload', data_file, 'JSONEachRow') + else: + insert_file(param_client, 'test_json_upload', data_file, 'JSONEachRow') + res = call(param_client.query, 'SELECT * FROM test_json_upload ORDER BY key').result_rows assert res[1][0] == 17 assert res[1][1] == 5.3 assert res[1][2] == 121 diff --git a/tests/integration_tests/test_vector.py b/tests/integration_tests/test_vector.py index e51b801a..9077d386 100644 --- a/tests/integration_tests/test_vector.py +++ b/tests/integration_tests/test_vector.py @@ -23,7 +23,7 @@ def module_setup_and_checks(test_client: Client, test_config: TestConfig): test_client.command("SET allow_experimental_qbit_type = 1") -def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_float64(param_client: Client, call, table_context: Callable): """Test QBit(Float64) round-trip accuracy with fruit_animal example data""" with table_context("fruit_animal", ["word String", "vec QBit(Float64, 5)"]): @@ -33,13 +33,13 @@ def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): ("orange", [0.93338752, 2.06571317, -0.54612565, -1.51625717, 0.69775337]), ] - test_client.insert("fruit_animal", test_data) - count = test_client.query("SELECT COUNT(*) FROM fruit_animal").result_set[0][0] + call(param_client.insert, "fruit_animal", test_data) + count = call(param_client.query, "SELECT COUNT(*) FROM fruit_animal").result_set[0][0] assert count == 3 for word, original_vec in test_data: - result = test_client.query("SELECT vec FROM fruit_animal WHERE word = %(word)s", parameters={"word": word}) + result = call(param_client.query, "SELECT vec FROM fruit_animal WHERE word = %(word)s", parameters={"word": word}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -47,7 +47,7 @@ def test_qbit_roundtrip_float64(test_client: Client, table_context: Callable): assert retrieved_vec == original_vec -def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_float32(param_client: Client, call, table_context: Callable): """Test QBit(Float32) round-trip accuracy""" with table_context("vectors_f32", ["id Int32", "vec QBit(Float32, 8)"]): @@ -57,10 +57,10 @@ def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): (3, [-1.5, -2.5, -3.5, -4.5, -5.5, -6.5, -7.5, -8.5]), ] - test_client.insert("vectors_f32", test_data) + call(param_client.insert, "vectors_f32", test_data) for id_val, original_vec in test_data: - result = test_client.query("SELECT vec FROM vectors_f32 WHERE id = %(id)s", parameters={"id": id_val}) + result = call(param_client.query, "SELECT vec FROM vectors_f32 WHERE id = %(id)s", parameters={"id": id_val}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -68,7 +68,7 @@ def test_qbit_roundtrip_float32(test_client: Client, table_context: Callable): assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): +def test_qbit_roundtrip_bfloat16(param_client: Client, call, table_context: Callable): """Test QBit(BFloat16) round-trip with appropriate tolerance""" with table_context("vectors_bf16", ["id Int32", "vec QBit(BFloat16, 8)"]): @@ -77,10 +77,10 @@ def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): (2, [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5]), ] - test_client.insert("vectors_bf16", test_data) + call(param_client.insert, "vectors_bf16", test_data) for id_val, original_vec in test_data: - result = test_client.query("SELECT vec FROM vectors_bf16 WHERE id = %(id)s", parameters={"id": id_val}) + result = call(param_client.query, "SELECT vec FROM vectors_bf16 WHERE id = %(id)s", parameters={"id": id_val}) retrieved_vec = result.result_set[0][0] assert isinstance(retrieved_vec, list) @@ -88,7 +88,7 @@ def test_qbit_roundtrip_bfloat16(test_client: Client, table_context: Callable): assert retrieved_vec == pytest.approx(original_vec, rel=1e-2, abs=1e-2) -def test_qbit_distance_search(test_client: Client, table_context: Callable): +def test_qbit_distance_search(param_client: Client, call, table_context: Callable): """Test L2DistanceTransposed with different precision levels""" with table_context("fruit_animal", ["word String", "vec QBit(Float64, 5)"]): @@ -101,13 +101,13 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): ("horse", [-0.61435682, 0.4851571, 1.21091247, -0.62530446, -1.33082533]), ] - test_client.insert("fruit_animal", test_data) + call(param_client.insert, "fruit_animal", test_data) # Search for "lemon" vector lemon_vector = [-0.88693672, 1.31532824, -0.51182908, -0.99652702, 0.59907770] # Full precision search (64-bit) - full_precision = test_client.query( + full_precision = call(param_client.query, """ SELECT word, L2DistanceTransposed(vec, %(lemon)s, 64) AS distance FROM fruit_animal @@ -121,7 +121,7 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): assert apple_distance == pytest.approx(0.1464, abs=1e-3) # Reduced precision search - reduced_precision = test_client.query( + reduced_precision = call(param_client.query, """ SELECT word, L2DistanceTransposed(vec, %(lemon)s, 12) AS distance FROM fruit_animal @@ -134,7 +134,7 @@ def test_qbit_distance_search(test_client: Client, table_context: Callable): assert reduced_precision.result_set[0][1] > 0 -def test_qbit_batch_insert(test_client: Client, table_context: Callable): +def test_qbit_batch_insert(param_client: Client, call, table_context: Callable): """Test batch insert with multiple vectors""" dimension = 16 @@ -147,22 +147,22 @@ def test_qbit_batch_insert(test_client: Client, table_context: Callable): vector = [random.uniform(-1.0, 1.0) for _ in range(dimension)] batch_data.append((i, vector)) - test_client.insert("embeddings", batch_data) + call(param_client.insert, "embeddings", batch_data) - count = test_client.query("SELECT COUNT(*) FROM embeddings").result_set[0][0] + count = call(param_client.query, "SELECT COUNT(*) FROM embeddings").result_set[0][0] assert count == 100 # Spot check a few for test_id in [0, 50, 99]: original_vec = batch_data[test_id][1] - result = test_client.query("SELECT embedding FROM embeddings WHERE id = %(id)s", parameters={"id": test_id}) + result = call(param_client.query, "SELECT embedding FROM embeddings WHERE id = %(id)s", parameters={"id": test_id}) retrieved_vec = result.result_set[0][0] assert len(retrieved_vec) == dimension assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_null_handling(test_client: Client, table_context: Callable): +def test_qbit_null_handling(param_client: Client, call, table_context: Callable): """Test QBit with NULL values using Nullable wrapper""" with table_context("nullable_vecs", ["id Int32", "vec Nullable(QBit(Float32, 4))"]): @@ -172,49 +172,49 @@ def test_qbit_null_handling(test_client: Client, table_context: Callable): (3, [5.0, 6.0, 7.0, 8.0]), ] - test_client.insert("nullable_vecs", test_data) + call(param_client.insert, "nullable_vecs", test_data) - result = test_client.query("SELECT id, vec FROM nullable_vecs ORDER BY id") + result = call(param_client.query, "SELECT id, vec FROM nullable_vecs ORDER BY id") assert result.result_set[0][1] == pytest.approx([1.0, 2.0, 3.0, 4.0]) assert result.result_set[1][1] is None assert result.result_set[2][1] == pytest.approx([5.0, 6.0, 7.0, 8.0]) -def test_qbit_dimension_mismatch_error(test_client: Client, table_context: Callable): +def test_qbit_dimension_mismatch_error(param_client: Client, call, table_context: Callable): """Test that inserting vectors with wrong dimensions raises an error""" with table_context("dim_test", ["id Int32", "vec QBit(Float32, 8)"]): wrong_data = [(1, [1.0, 2.0, 3.0, 4.0, 5.0])] with pytest.raises(ValueError) as exc_info: - test_client.insert("dim_test", wrong_data) + call(param_client.insert, "dim_test", wrong_data) assert "dimension mismatch" in str(exc_info.value).lower() -def test_qbit_empty_insert(test_client: Client, table_context: Callable): +def test_qbit_empty_insert(param_client: Client, call, table_context: Callable): """Test inserting an empty list (no rows)""" with table_context("empty_test", ["id Int32", "vec QBit(Float32, 4)"]): - test_client.insert("empty_test", []) - result = test_client.query("SELECT COUNT(*) FROM empty_test") + call(param_client.insert, "empty_test", []) + result = call(param_client.query, "SELECT COUNT(*) FROM empty_test") assert result.result_set[0][0] == 0 -def test_qbit_single_row(test_client: Client, table_context: Callable): +def test_qbit_single_row(param_client: Client, call, table_context: Callable): """Test inserting a single row""" with table_context("single_row", ["id Int32", "vec QBit(Float32, 4)"]): single_data = [(1, [1.0, 2.0, 3.0, 4.0])] - test_client.insert("single_row", single_data) + call(param_client.insert, "single_row", single_data) - result = test_client.query("SELECT id, vec FROM single_row") + result = call(param_client.query, "SELECT id, vec FROM single_row") assert len(result.result_set) == 1 assert result.result_set[0][0] == 1 assert result.result_set[0][1] == pytest.approx([1.0, 2.0, 3.0, 4.0]) -def test_qbit_special_float_values(test_client: Client, table_context: Callable): +def test_qbit_special_float_values(param_client: Client, call, table_context: Callable): """Test QBit with special float values (inf, -inf, nan)""" with table_context("special_floats", ["id Int32", "vec QBit(Float64, 4)"]): @@ -225,8 +225,8 @@ def test_qbit_special_float_values(test_client: Client, table_context: Callable) (4, [0.0, -0.0, 1.0, -1.0]), ] - test_client.insert("special_floats", test_data) - result = test_client.query("SELECT id, vec FROM special_floats ORDER BY id") + call(param_client.insert, "special_floats", test_data) + result = call(param_client.query, "SELECT id, vec FROM special_floats ORDER BY id") assert result.result_set[0][1][0] == float("inf") assert result.result_set[0][1][1] == pytest.approx(1.0) @@ -241,68 +241,69 @@ def test_qbit_special_float_values(test_client: Client, table_context: Callable) assert result.result_set[3][1][1] == -0.0 or result.result_set[3][1][1] == 0.0 -def test_qbit_edge_case_dimensions(test_client: Client, table_context: Callable): +def test_qbit_edge_case_dimensions(param_client: Client, call, table_context: Callable): """Test QBit with edge case dimensions (1, not multiple of 8)""" with table_context("dim_one", ["id Int32", "vec QBit(Float32, 1)"]): test_data = [(1, [1.0]), (2, [3.14])] - test_client.insert("dim_one", test_data) + call(param_client.insert, "dim_one", test_data) - result = test_client.query("SELECT vec FROM dim_one ORDER BY id") + result = call(param_client.query, "SELECT vec FROM dim_one ORDER BY id") assert result.result_set[0][0] == pytest.approx([1.0]) assert result.result_set[1][0] == pytest.approx([3.14]) with table_context("dim_five", ["id Int32", "vec QBit(Float32, 5)"]): test_data = [(1, [1.0, 2.0, 3.0, 4.0, 5.0])] - test_client.insert("dim_five", test_data) + call(param_client.insert, "dim_five", test_data) - result = test_client.query("SELECT vec FROM dim_five") + result = call(param_client.query, "SELECT vec FROM dim_five") assert result.result_set[0][0] == pytest.approx([1.0, 2.0, 3.0, 4.0, 5.0]) -def test_qbit_very_large_batch(test_client: Client, table_context: Callable): +def test_qbit_very_large_batch(param_client: Client, call, table_context: Callable): """Test inserting a very large batch of vectors (1000 rows)""" with table_context("large_batch", ["id Int32", "vec QBit(Float32, 8)"]): random.seed(1) large_batch = [(i, [random.uniform(-10, 10) for _ in range(8)]) for i in range(1000)] - test_client.insert("large_batch", large_batch) + call(param_client.insert, "large_batch", large_batch) - count = test_client.query("SELECT COUNT(*) FROM large_batch").result_set[0][0] + count = call(param_client.query, "SELECT COUNT(*) FROM large_batch").result_set[0][0] assert count == 1000 for check_id in [0, 500, 999]: original_vec = large_batch[check_id][1] - result = test_client.query("SELECT vec FROM large_batch WHERE id = %(id)s", parameters={"id": check_id}) + result = call(param_client.query, "SELECT vec FROM large_batch WHERE id = %(id)s", parameters={"id": check_id}) retrieved_vec = result.result_set[0][0] assert retrieved_vec == pytest.approx(original_vec, rel=1e-6) -def test_qbit_all_nulls(test_client: Client, table_context: Callable): +def test_qbit_all_nulls(param_client: Client, call, table_context: Callable): """Test QBit nullable column with all NULL values""" with table_context("all_nulls", ["id Int32", "vec Nullable(QBit(Float32, 4))"]): test_data = [(1, None), (2, None), (3, None)] - test_client.insert("all_nulls", test_data) + call(param_client.insert, "all_nulls", test_data) - result = test_client.query("SELECT vec FROM all_nulls ORDER BY id") + result = call(param_client.query, "SELECT vec FROM all_nulls ORDER BY id") assert all(row[0] is None for row in result.result_set) -def test_qbit_all_zeros(test_client: Client, table_context: Callable): +def test_qbit_all_zeros(param_client: Client, call, table_context: Callable): """Test QBit with all zero vectors""" with table_context("all_zeros", ["id Int32", "vec QBit(Float32, 4)"]): test_data = [(1, [0.0, 0.0, 0.0, 0.0]), (2, [0.0, 0.0, 0.0, 0.0])] - test_client.insert("all_zeros", test_data) + call(param_client.insert, "all_zeros", test_data) - result = test_client.query("SELECT vec FROM all_zeros ORDER BY id") + result = call(param_client.query, "SELECT vec FROM all_zeros ORDER BY id") assert result.result_set[0][0] == pytest.approx([0.0, 0.0, 0.0, 0.0]) assert result.result_set[1][0] == pytest.approx([0.0, 0.0, 0.0, 0.0]) -def test_invalid_dimension(table_context: Callable): +# pylint: disable=unused-argument +def test_invalid_dimension(param_client: Client, call, table_context: Callable): """Try creating a column with a negative dimension.""" with pytest.raises(DatabaseError): @@ -310,7 +311,8 @@ def test_invalid_dimension(table_context: Callable): pass -def test_invalid_element_type(table_context: Callable): +# pylint: disable=unused-argument +def test_invalid_element_type(param_client: Client, call, table_context: Callable): """Try creating a column with an invalid element type.""" with pytest.raises(DatabaseError): diff --git a/tests/test_requirements.txt b/tests/test_requirements.txt index 0b4510d9..79a85386 100644 --- a/tests/test_requirements.txt +++ b/tests/test_requirements.txt @@ -2,6 +2,7 @@ pytz urllib3>=1.26 setuptools certifi +aiohttp>=3.8.0 sqlalchemy>=2.0,<3.0 cython==3.0.11 pyarrow @@ -10,6 +11,7 @@ pytest-asyncio pytest-mock pytest-dotenv pytest-cov +pytest-xdist numpy~=1.22.0; python_version >= '3.8' and python_version <= '3.10' numpy~=1.26.0; python_version >= '3.11' and python_version <= '3.12' numpy~=2.1.0; python_version == '3.13' diff --git a/tests/unit_tests/test_asyncqueue.py b/tests/unit_tests/test_asyncqueue.py new file mode 100644 index 00000000..b675b000 --- /dev/null +++ b/tests/unit_tests/test_asyncqueue.py @@ -0,0 +1,297 @@ +import asyncio +import threading +import time + +import pytest + +from clickhouse_connect.driver.asyncqueue import EOF_SENTINEL, AsyncSyncQueue, Empty + +# pylint: disable=broad-exception-caught + + +def test_async_put_sync_get(): + """Test async producer putting items, sync consumer getting them.""" + queue = AsyncSyncQueue(maxsize=5) + items_received = [] + + async def async_producer(): + """Put items from async context.""" + for i in range(10): + await queue.async_q.put(f"item_{i}") + await queue.async_q.put(EOF_SENTINEL) + + def sync_consumer(): + """Get items from sync context.""" + while True: + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + items_received.append(item) + + async def run_test(): + consumer_thread = threading.Thread(target=sync_consumer) + consumer_thread.start() + + await async_producer() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive(), "Consumer thread hung" + + asyncio.run(run_test()) + + assert len(items_received) == 10 + assert items_received == [f"item_{i}" for i in range(10)] + + +def test_sync_put_async_get(): + """Test sync producer putting items, async consumer getting them.""" + queue = AsyncSyncQueue(maxsize=5) + items_received = [] + + def sync_producer(): + """Put items from sync context.""" + for i in range(10): + queue.sync_q.put(f"item_{i}") + queue.sync_q.put(EOF_SENTINEL) + + async def async_consumer(): + """Get items from async context.""" + while True: + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + items_received.append(item) + + async def run_test(): + producer_thread = threading.Thread(target=sync_producer) + producer_thread.start() + + await async_consumer() + + producer_thread.join(timeout=5.0) + assert not producer_thread.is_alive(), "Producer thread hung" + + asyncio.run(run_test()) + + assert len(items_received) == 10 + assert items_received == [f"item_{i}" for i in range(10)] + + +def test_backpressure_async_producer(): + """Test that bounded queue provides backpressure to async producer.""" + queue = AsyncSyncQueue(maxsize=3) + produced = [] + consumed = [] + + async def fast_producer(): + """Producer that tries to produce faster than consumer.""" + for i in range(10): + produced.append(f"before_put_{i}") + await queue.async_q.put(f"item_{i}") + produced.append(f"after_put_{i}") + await queue.async_q.put(EOF_SENTINEL) + + def slow_consumer(): + """Consumer that's slower than producer.""" + while True: + time.sleep(0.01) + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + consumed.append(item) + + async def run_test(): + consumer_thread = threading.Thread(target=slow_consumer) + consumer_thread.start() + + await fast_producer() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(consumed) == 10 + assert consumed == [f"item_{i}" for i in range(10)] + + +def test_backpressure_sync_producer(): + """Test that bounded queue provides backpressure to sync producer.""" + queue = AsyncSyncQueue(maxsize=3) + produced = [] + consumed = [] + + def fast_producer(): + """Producer that tries to produce faster than consumer.""" + for i in range(10): + produced.append(f"before_put_{i}") + queue.sync_q.put(f"item_{i}") + produced.append(f"after_put_{i}") + queue.sync_q.put(EOF_SENTINEL) + + async def slow_consumer(): + """Consumer that's slower than producer.""" + while True: + await asyncio.sleep(0.01) + item = await queue.async_q.get() + if item is EOF_SENTINEL: + break + consumed.append(item) + + async def run_test(): + producer_thread = threading.Thread(target=fast_producer) + producer_thread.start() + + await slow_consumer() + + producer_thread.join(timeout=5.0) + assert not producer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(consumed) == 10 + assert consumed == [f"item_{i}" for i in range(10)] + + +def test_shutdown_unblocks_consumer(): + """Test that shutdown() unblocks a consumer waiting on an empty queue.""" + queue = AsyncSyncQueue(maxsize=2) + consumer_unblocked = threading.Event() + + def blocking_consumer(): + """Consumer that will block waiting for items.""" + try: + item = queue.sync_q.get(timeout=2.0) + if item is EOF_SENTINEL: + consumer_unblocked.set() + except Exception: + pass + + async def run_test(): + consumer_thread = threading.Thread(target=blocking_consumer) + consumer_thread.start() + + await asyncio.sleep(0.1) + + queue.shutdown() + + consumer_thread.join(timeout=2.0) + assert consumer_unblocked.is_set(), "Consumer was not unblocked by shutdown" + + asyncio.run(run_test()) + + +def test_shutdown_unblocks_producer(): + """Test that shutdown() unblocks a producer waiting on a full queue.""" + queue = AsyncSyncQueue(maxsize=2) + producer_unblocked = threading.Event() + + async def blocking_producer(): + """Producer that will block when queue is full.""" + try: + await queue.async_q.put("item1") + await queue.async_q.put("item2") + + await asyncio.wait_for(queue.async_q.put("item3"), timeout=2.0) + except (RuntimeError, asyncio.TimeoutError): + producer_unblocked.set() + except Exception as e: + print(f"Producer caught unexpected exception: {e}") + + async def run_test(): + producer_task = asyncio.create_task(blocking_producer()) + + await asyncio.sleep(0.1) + + queue.shutdown() + + await producer_task + assert producer_unblocked.is_set(), "Producer was not unblocked by shutdown" + + asyncio.run(run_test()) + + +def test_multiple_producers_single_consumer(): + """Test multiple async producers with single sync consumer.""" + queue = AsyncSyncQueue(maxsize=10) + items_received = [] + + async def producer(producer_id, count): + """Producer that sends count items.""" + for i in range(count): + await queue.async_q.put(f"p{producer_id}_item{i}") + + def consumer(): + """Consumer that reads until getting 30 items (3 producers × 10 items).""" + received = 0 + while received < 30: + item = queue.sync_q.get(timeout=5.0) + items_received.append(item) + received += 1 + + async def run_test(): + consumer_thread = threading.Thread(target=consumer) + consumer_thread.start() + + await asyncio.gather(producer(0, 10), producer(1, 10), producer(2, 10)) + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(items_received) == 30 + assert len(set(items_received)) == 30 + + +def test_exception_propagation(): + """Test that exceptions can be passed through the queue.""" + queue = AsyncSyncQueue(maxsize=5) + exception_received = [] + + async def producer_with_error(): + """Producer that sends an exception.""" + await queue.async_q.put("item1") + await queue.async_q.put("item2") + await queue.async_q.put(ValueError("test error")) + await queue.async_q.put(EOF_SENTINEL) + + def consumer(): + """Consumer that should receive the exception.""" + items = [] + while True: + item = queue.sync_q.get() + if item is EOF_SENTINEL: + break + if isinstance(item, Exception): + exception_received.append(item) + else: + items.append(item) + return items + + async def run_test(): + consumer_thread = threading.Thread(target=consumer) + consumer_thread.start() + + await producer_with_error() + + consumer_thread.join(timeout=5.0) + assert not consumer_thread.is_alive() + + asyncio.run(run_test()) + + assert len(exception_received) == 1 + assert isinstance(exception_received[0], ValueError) + assert str(exception_received[0]) == "test error" + + +def test_empty_exception_on_non_blocking_get(): + """Test that non-blocking get raises Empty when queue is empty.""" + queue = AsyncSyncQueue(maxsize=5) + + with pytest.raises(Empty): + queue.sync_q.get(block=False) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/unit_tests/test_streaming_source.py b/tests/unit_tests/test_streaming_source.py new file mode 100644 index 00000000..06c62ec5 --- /dev/null +++ b/tests/unit_tests/test_streaming_source.py @@ -0,0 +1,444 @@ +import asyncio +import gzip +import time +import zlib +from unittest.mock import Mock + +import lz4.frame +import pytest +import zstandard + +from clickhouse_connect.driver.exceptions import OperationalError +from clickhouse_connect.driver.streaming import ( + StreamingInsertSource, + StreamingResponseSource, +) + + +class MockAsyncIterator: + """Mock async iterator for simulating aiohttp response content.""" + + def __init__(self, chunks): + self.chunks = chunks + self.index = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + if self.index >= len(self.chunks): + raise StopAsyncIteration + chunk = self.chunks[self.index] + self.index += 1 + return chunk + + +class MockContent: + """Mock aiohttp StreamReader content.""" + + def __init__(self, chunks): + self.chunks = chunks + self.index = 0 + + async def read(self, n=-1): # pylint: disable=unused-argument + """Mock read method that returns chunks sequentially.""" + if self.index >= len(self.chunks): + return b"" + chunk = self.chunks[self.index] + self.index += 1 + return chunk + + +class MockResponse: + """Mock aiohttp ClientResponse.""" + + def __init__(self, chunks, encoding=None): + self.content = MockContent(chunks) + self.headers = {"Content-Encoding": encoding} if encoding else {} + self.status = 200 + self.closed = False + + def close(self): + self.closed = True + + +@pytest.mark.asyncio +async def test_basic_streaming_no_compression(): + """Test basic streaming without compression.""" + chunks = [b"hello ", b"world", b"!"] + response = MockResponse(chunks) + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert result == chunks + assert b"".join(result) == b"hello world!" + + +@pytest.mark.asyncio +async def test_streaming_with_gzip_compression(): + """Test streaming with gzip decompression.""" + original_data = b"hello world! " * 1000 + compressed = gzip.compress(original_data) + chunk_size = 100 + chunks = [compressed[i : i + chunk_size] for i in range(0, len(compressed), chunk_size)] + + response = MockResponse(chunks, encoding="gzip") + source = StreamingResponseSource(response, encoding="gzip") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_deflate_compression(): + """Test streaming with deflate decompression.""" + original_data = b"test data " * 500 + compressed = zlib.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="deflate") + source = StreamingResponseSource(response, encoding="deflate") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_zstd_compression(): + """Test streaming with zstd decompression.""" + original_data = b"zstd test data " * 500 + compressor = zstandard.ZstdCompressor() + compressed = compressor.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="zstd") + source = StreamingResponseSource(response, encoding="zstd") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_streaming_with_lz4_compression(): + """Test streaming with lz4 decompression.""" + original_data = b"lz4 test data " * 500 + compressed = lz4.frame.compress(original_data) + + chunks = [compressed[i : i + 50] for i in range(0, len(compressed), 50)] + + response = MockResponse(chunks, encoding="lz4") + source = StreamingResponseSource(response, encoding="lz4") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return b"".join(result) + + decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + + +@pytest.mark.asyncio +async def test_empty_stream(): + """Test streaming with empty response.""" + response = MockResponse([]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert result == [] + + +@pytest.mark.asyncio +async def test_single_large_chunk(): + """Test streaming with single large chunk.""" + large_chunk = b"x" * 1000000 + response = MockResponse([large_chunk]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert len(result) == 1 + assert result[0] == large_chunk + + +@pytest.mark.asyncio +async def test_many_small_chunks(): + """Test streaming with many small chunks.""" + chunks = [f"chunk{i}".encode() for i in range(1000)] + response = MockResponse(chunks) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + result = [] + for chunk in source.gen: + result.append(chunk) + return result + + result = await loop.run_in_executor(None, consume) + + assert len(result) == 1000 + assert result == chunks + + +@pytest.mark.asyncio +async def test_generator_caching(): + """Test that .gen property returns cached generator.""" + response = MockResponse([b"test"]) + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + # Access .gen multiple times - should return same generator + gen1 = source.gen + gen2 = source.gen + + assert gen1 is gen2, "Generator should be cached" + + +@pytest.mark.asyncio +async def test_producer_error_propagation(): + """Test that producer errors are propagated to consumer.""" + + class FailingContent: + @staticmethod + async def read(n=-1): + raise ValueError("Producer error!") + + response = Mock() + response.content = FailingContent() + response.headers = {} + response.closed = False + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + try: + for _ in source.gen: + pass + except OperationalError as e: + return str(e) + return "No error raised!" + + error_msg = await loop.run_in_executor(None, consume) + + assert error_msg == "Failed to read response data from server" + + +@pytest.mark.asyncio +async def test_gzip_with_incremental_decompression(): + """Test that gzip decompression works incrementally with streaming.""" + original_data = b"The quick brown fox jumps over the lazy dog. " * 100 + compressed = gzip.compress(original_data) + + # Split compressed data into very small chunks to force incremental decompression + chunks = [compressed[i : i + 10] for i in range(0, len(compressed), 10)] + + response = MockResponse(chunks, encoding="gzip") + source = StreamingResponseSource(response, encoding="gzip") + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + def consume(): + """Consume and verify we get multiple decompressed chunks.""" + chunks_received = [] + for chunk in source.gen: + chunks_received.append(chunk) + return chunks_received, b"".join(chunks_received) + + chunks_received, decompressed = await loop.run_in_executor(None, consume) + + assert decompressed == original_data + assert len([c for c in chunks_received if c]) > 0 + + +@pytest.mark.asyncio +async def test_backpressure_with_bounded_queue(): + """Test that bounded queue provides backpressure.""" + # Create many chunks to test backpressure + chunks = [f"chunk{i}".encode() for i in range(100)] + response = MockResponse(chunks) + + source = StreamingResponseSource(response, encoding=None) + loop = asyncio.get_running_loop() + + await source.start_producer(loop) + + # Slow consumer + def slow_consume(): + result = [] + for chunk in source.gen: + time.sleep(0.001) + result.append(chunk) + return result + + result = await loop.run_in_executor(None, slow_consume) + + # All chunks should still be received despite slow consumer + assert len(result) == 100 + assert result == chunks + + +class MockTransform: + """Mock NativeTransform.""" + + def __init__(self, chunks=None): + self.chunks = chunks or [b"chunk1", b"chunk2"] + + def build_insert(self, context): # pylint: disable=unused-argument + yield from self.chunks + + +class FailingTransform: + """Mock NativeTransform that raises error.""" + + @staticmethod + def build_insert(context): # pylint: disable=unused-argument + yield b"chunk1" + raise ValueError("Serialization error") + + +class MockContext: + """Mock InsertContext.""" + + +@pytest.mark.asyncio +async def test_streaming_insert_basic(): + """Test basic streaming insert (reverse bridge).""" + transform = MockTransform() + context = MockContext() + loop = asyncio.get_running_loop() + + source = StreamingInsertSource(transform, context, loop) + source.start_producer() + + chunks = [] + async for chunk in source.async_generator(): + chunks.append(chunk) + + await source.close() + + assert chunks == [b"chunk1", b"chunk2"] + + +@pytest.mark.asyncio +async def test_streaming_insert_error_propagation(): + """Test that insert producer errors are propagated to async consumer.""" + transform = FailingTransform() + context = MockContext() + loop = asyncio.get_running_loop() + + source = StreamingInsertSource(transform, context, loop) + source.start_producer() + + chunks = [] + with pytest.raises(ValueError, match="Serialization error"): + async for chunk in source.async_generator(): + chunks.append(chunk) + + await source.close() + + # Should have received first chunk before error + assert chunks == [b"chunk1"] + + +@pytest.mark.asyncio +async def test_streaming_insert_backpressure(): + """Test backpressure in streaming insert.""" + chunks = [f"chunk{i}".encode() for i in range(100)] + transform = MockTransform(chunks) + context = MockContext() + loop = asyncio.get_running_loop() + + # Small queue size to force backpressure + source = StreamingInsertSource(transform, context, loop, maxsize=2) + source.start_producer() + + received = [] + async for chunk in source.async_generator(): + received.append(chunk) + # Yield to allow producer to run (since we're in same loop/process) + await asyncio.sleep(0.001) + + await source.close() + + assert len(received) == 100 + assert received == chunks + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])