diff --git a/CHANGELOG.md b/CHANGELOG.md index 0223930b..cd92589c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,9 +24,14 @@ The supported method of passing ClickHouse server settings is to prefix such arg ## UNRELEASED ### New Features +- SQLAlchemy: Comprehensive ClickHouse JOIN support via the new `ch_join()` helper. All strictness modifiers (`ALL`, `ANY`, `SEMI`, `ANTI`, `ASOF`), the `GLOBAL` distribution modifier, and explicit `CROSS JOIN` are now available. Use with `select_from()` to generate ClickHouse-specific join syntax like `GLOBAL ALL LEFT OUTER JOIN`. Closes [#635](https://github.com/ClickHouse/clickhouse-connect/issues/635) +- SQLAlchemy: `array_join()` now supports multiple columns for parallel array expansion. Pass a list of columns and a matching list of aliases to generate `ARRAY JOIN col1 AS a, col2 AS b, col3 AS c`. Single-column usage is unchanged. Closes [#633](https://github.com/ClickHouse/clickhouse-connect/issues/633) +- SQLAlchemy: `ch_join()` now supports `USING` syntax via the new `using` parameter. Pass a list of column name strings to generate `USING (col1, col2)` instead of `ON`. This is important for `FULL OUTER JOIN` where `USING` merges the join column correctly while `ON` produces default values (0, '') for unmatched sides. Closes [#636](https://github.com/ClickHouse/clickhouse-connect/issues/636) - SQLAlchemy: Add missing Replicated table engine variants: `ReplicatedReplacingMergeTree`, `ReplicatedCollapsingMergeTree`, `ReplicatedVersionedCollapsingMergeTree`, and `ReplicatedGraphiteMergeTree`. Closes [#687](https://github.com/ClickHouse/clickhouse-connect/issues/687) ### Bug Fixes +- SQLAlchemy: Fix `.final()` and `.sample()` silently overwriting each other when chained. Both methods now store modifiers as custom attributes on the `Select` instance and render them during compilation, replacing the previous `with_hint()` approach that only allowed one hint per table. Chaining in either order (e.g. `select(t).final().sample(0.1)`) correctly produces `FROM t FINAL SAMPLE 0.1`. Also fixes rendering for aliased tables (`FROM t AS u FINAL`) and supports explicit table targeting in joins. Fixes [#658](https://github.com/ClickHouse/clickhouse-connect/issues/658) +- SQLAlchemy: Fix `sqlalchemy.values()` to generate ClickHouse's `VALUES` table function syntax. The compiler now emits `VALUES('col1 Type1, col2 Type2', ...)` with the column structure as the first argument, instead of the standard SQL form that places column names after the alias. Generic SQLAlchemy types are mapped to ClickHouse equivalents (e.g. `Integer` to `Int32`, `String` to `String`). Also handles CTE usage by wrapping in `SELECT * FROM VALUES(...)`. Fixes [#681](https://github.com/ClickHouse/clickhouse-connect/issues/681) - SQLAlchemy: Fix `GraphiteMergeTree` and `ReplicatedGraphiteMergeTree` to properly single-quote the `config_section` argument as ClickHouse requires. ## 0.14.1, 2026-03-11 diff --git a/README.md b/README.md index 8b7288b9..31a9d877 100644 --- a/README.md +++ b/README.md @@ -32,16 +32,23 @@ When creating a Superset Data Source, either use the provided connection dialog, ### SQLAlchemy Implementation ClickHouse Connect includes a lightweight SQLAlchemy dialect implementation focused on compatibility with **Superset** -and **SQLAlchemy Core**. +and **SQLAlchemy Core**. Both SQLAlchemy 1.4 and 2.x are supported. SQLAlchemy 1.4 compatibility is maintained +because Apache Superset currently requires `sqlalchemy>=1.4,<2`. Supported features include: - Basic query execution via SQLAlchemy Core -- `SELECT` queries with `JOIN`s, `ARRAY JOIN`, and `FINAL` modifier +- `SELECT` queries with `JOIN`s (including ClickHouse-specific strictness, `USING`, and `GLOBAL` modifiers), + `ARRAY JOIN` (single and multi-column), `FINAL`, and `SAMPLE` +- `VALUES` table function syntax - Lightweight `DELETE` statements -The implementation does not include ORM support and is not intended as a full SQLAlchemy dialect. While it can support -a range of Core-based applications beyond Superset, it may not be suitable for more complex SQLAlchemy applications -that rely on full ORM or advanced dialect functionality. +A small number of features require SQLAlchemy 2.x: `Values.cte()` and certain literal-rendering behaviors. +All other dialect features, including those used by Superset, work on both 1.4 and 2.x. + +Basic ORM usage works for insert-heavy, read-focused workloads: declarative model definitions, `CREATE TABLE`, +`session.add()`, `bulk_save_objects()`, and read queries all function correctly. However, full ORM support is not +provided. UPDATE compilation, foreign key/relationship reflection, autoincrement/RETURNING, and cascade operations +are not implemented. The dialect is best suited for SQLAlchemy Core usage and Superset connectivity. ### Asyncio Support diff --git a/clickhouse_connect/cc_sqlalchemy/__init__.py b/clickhouse_connect/cc_sqlalchemy/__init__.py index 1e644119..5d505cf7 100644 --- a/clickhouse_connect/cc_sqlalchemy/__init__.py +++ b/clickhouse_connect/cc_sqlalchemy/__init__.py @@ -1,10 +1,10 @@ from clickhouse_connect import driver_name from clickhouse_connect.cc_sqlalchemy.datatypes.base import schema_types -from clickhouse_connect.cc_sqlalchemy.sql import final -from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join, ArrayJoin +from clickhouse_connect.cc_sqlalchemy.sql import final, sample +from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join, ArrayJoin, ch_join, ClickHouseJoin # pylint: disable=invalid-name dialect_name = driver_name ischema_names = schema_types -__all__ = ['dialect_name', 'ischema_names', 'array_join', 'ArrayJoin', 'final'] +__all__ = ['dialect_name', 'ischema_names', 'array_join', 'ArrayJoin', 'ch_join', 'ClickHouseJoin', 'final', 'sample'] diff --git a/clickhouse_connect/cc_sqlalchemy/sql/__init__.py b/clickhouse_connect/cc_sqlalchemy/sql/__init__.py index a61a1005..f115d34f 100644 --- a/clickhouse_connect/cc_sqlalchemy/sql/__init__.py +++ b/clickhouse_connect/cc_sqlalchemy/sql/__init__.py @@ -5,6 +5,10 @@ from clickhouse_connect.driver.binding import quote_identifier +# Dialect name used for non-rendering statement hints that only serve to +# differentiate cache keys when FINAL/SAMPLE modifiers are applied. +_CH_MODIFIER_DIALECT = "_ch_modifier" + def full_table(table_name: str, schema: Optional[str] = None) -> str: if table_name.startswith('(') or '.' in table_name or not schema: @@ -16,38 +20,61 @@ def format_table(table: Table): return full_table(table.name, table.schema) -def final(select_stmt: Select, table: Optional[FromClause] = None) -> Select: - """ - Apply the ClickHouse FINAL modifier to a select statement. - - Args: - select_stmt: The SQLAlchemy Select statement to modify. - table: Optional explicit table/alias to apply FINAL to. When omitted the - method will use the single FROM element present on the select. A - ValueError is raised if the statement has no FROMs or more than one - FROM element and table is not provided. - - Returns: - A new Select that renders the FINAL modifier for the target table. - """ +def _resolve_target(select_stmt: Select, table: Optional[FromClause], method_name: str) -> FromClause: + """Resolve the target FROM clause for ClickHouse modifiers (FINAL/SAMPLE).""" if not isinstance(select_stmt, Select): - raise TypeError("final() expects a SQLAlchemy Select instance") + raise TypeError(f"{method_name}() expects a SQLAlchemy Select instance") target = table if target is None: froms = select_stmt.get_final_froms() if not froms: - raise ValueError("final() requires a table to apply the FINAL modifier.") + raise ValueError(f"{method_name}() requires a table to apply the {method_name.upper()} modifier.") if len(froms) > 1: raise ValueError( - "final() is ambiguous for statements with multiple FROM clauses. Specify the table explicitly." + f"{method_name}() is ambiguous for statements with multiple FROM clauses. " + "Specify the table explicitly." ) target = froms[0] if not isinstance(target, FromClause): raise TypeError("table must be a SQLAlchemy FromClause when provided") - return select_stmt.with_hint(target, "FINAL") + return target + + +def _target_cache_key(target: FromClause) -> str: + """Stable string identifying a FROM target for cache key differentiation.""" + if hasattr(target, "fullname"): + return target.fullname + return target.name + + +# pylint: disable=protected-access +def final(select_stmt: Select, table: Optional[FromClause] = None) -> Select: + """Apply the ClickHouse FINAL modifier to a select statement. + + FINAL forces ClickHouse to merge data parts before returning results, + guaranteeing fully collapsed rows for ReplacingMergeTree, CollapsingMergeTree, + and similar engines. + + Args: + select_stmt: The SELECT statement to modify. + table: The target table to apply FINAL to. Required when the query + joins multiple tables, optional when there is a single FROM target. + """ + target = _resolve_target(select_stmt, table, "final") + ch_final = getattr(select_stmt, "_ch_final", set()) + + if target in ch_final: + return select_stmt + + # with_statement_hint creates a generative copy and adds a non-rendering + # hint that participates in the statement cache key. + hint_key = _target_cache_key(target) + new_stmt = select_stmt.with_statement_hint(f"FINAL:{hint_key}", dialect_name=_CH_MODIFIER_DIALECT) + new_stmt._ch_final = ch_final | {target} + return new_stmt def _select_final(self: Select, table: Optional[FromClause] = None) -> Select: @@ -58,39 +85,27 @@ def _select_final(self: Select, table: Optional[FromClause] = None) -> Select: def sample(select_stmt: Select, sample_value: Union[str, int, float], table: Optional[FromClause] = None) -> Select: - """ - Apply ClickHouse SAMPLE clause to a select statement. - Reference: https://clickhouse.com/docs/sql-reference/statements/select/sample + """Apply the ClickHouse SAMPLE modifier to a select statement. + Args: - select_stmt: The SQLAlchemy Select statement to modify. - sample_value: Controls the sampling behavior. Accepts three forms: - - A float in (0, 1) for proportional sampling (e.g., 0.1 for ~10% of data). - - A positive integer for row-count sampling (e.g., 10000000 for ~10M rows). - - A string for fraction or offset notation (e.g., "1/10" or "1/10 OFFSET 1/2"). - table: Optional explicit table to apply SAMPLE to. When omitted the - method will use the single FROM element present on the select. A - ValueError is raised if the statement has no FROMs or more than one - FROM element and table is not provided. - - Returns: - A new Select that renders the SAMPLE clause for the target table. + select_stmt: The SELECT statement to modify. + sample_value: The sample expression. Can be a float between 0 and 1 + for a fractional sample (e.g. 0.1 for 10%), an integer for an + approximate row count, or a string for SAMPLE expressions like + '1/10 OFFSET 1/2'. + table: The target table to sample. Required when the query joins + multiple tables, optional when there is a single FROM target. """ - if not isinstance(select_stmt, Select): - raise TypeError("sample() expects a SQLAlchemy Select instance") - - target_table = table - if target_table is None: - froms = select_stmt.get_final_froms() - if not froms: - raise ValueError("sample() requires a FROM clause to apply the SAMPLE modifier.") - if len(froms) > 1: - raise ValueError("sample() is ambiguous for statements with multiple FROM clauses. Specify the table explicitly.") - target_table = froms[0] - - if not isinstance(target_table, FromClause): - raise TypeError("table must be a SQLAlchemy FromClause when provided") - - return select_stmt.with_hint(target_table, f"SAMPLE {sample_value}") + target = _resolve_target(select_stmt, table, "sample") + + hint_key = _target_cache_key(target) + new_stmt = select_stmt.with_statement_hint( + f"SAMPLE:{hint_key}:{sample_value}", dialect_name=_CH_MODIFIER_DIALECT + ) + ch_sample = dict(getattr(select_stmt, "_ch_sample", {})) + ch_sample[target] = sample_value + new_stmt._ch_sample = ch_sample + return new_stmt def _select_sample(self: Select, sample_value: Union[str, int, float], table: Optional[FromClause] = None) -> Select: diff --git a/clickhouse_connect/cc_sqlalchemy/sql/clauses.py b/clickhouse_connect/cc_sqlalchemy/sql/clauses.py index dd5e7a44..e7c16407 100644 --- a/clickhouse_connect/cc_sqlalchemy/sql/clauses.py +++ b/clickhouse_connect/cc_sqlalchemy/sql/clauses.py @@ -1,10 +1,45 @@ +from typing import Optional + +from sqlalchemy import and_, true from sqlalchemy.sql.base import Immutable -from sqlalchemy.sql.selectable import FromClause +from sqlalchemy.sql.selectable import FromClause, Join +from sqlalchemy.sql.visitors import InternalTraversal + + +def _normalize_array_columns(array_column, alias): + """Normalize single/multi column input into a list of (column, alias_or_none) tuples.""" + if isinstance(array_column, (list, tuple)): + columns = list(array_column) + if not columns: + raise ValueError("At least one array column is required") + if alias is None: + aliases = [None] * len(columns) + elif isinstance(alias, (list, tuple)): + aliases = list(alias) + if len(aliases) != len(columns): + raise ValueError(f"Length of alias list ({len(aliases)}) must match " f"length of array_column list ({len(columns)})") + else: + raise ValueError("alias must be a list when array_column is a list") + else: + columns = [array_column] + if isinstance(alias, (list, tuple)): + raise ValueError("alias must be a string or None when array_column is a single column") + aliases = [alias] + + return list(zip(columns, aliases)) # pylint: disable=protected-access,too-many-ancestors,abstract-method,unused-argument class ArrayJoin(Immutable, FromClause): - """Represents ClickHouse ARRAY JOIN clause""" + """Represents ClickHouse ARRAY JOIN clause. + + Supports single or multiple array columns with optional per-column aliases. + Multiple columns are expanded in parallel (zipped by position), not as a + cartesian product. All arrays in a single ARRAY JOIN must have the same + length per row unless enable_unaligned_array_join is set on the server. + + See: https://clickhouse.com/docs/sql-reference/statements/select/array-join + """ __visit_name__ = "array_join" _is_from_container = True @@ -12,18 +47,19 @@ class ArrayJoin(Immutable, FromClause): _is_join = True def __init__(self, left, array_column, alias=None, is_left=False): - """Initialize ARRAY JOIN clause + """Initialize ARRAY JOIN clause. Args: - left: The left side (table or subquery) - array_column: The array column to join - alias: Optional alias for the joined array elements - is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN + left: The left side (table or subquery). + array_column: A single array column, or a list/tuple of array columns. + alias: Optional alias. A single string when array_column is a single + column, or a list/tuple of strings (same length as array_column) + when array_column is a list. None means no aliases. + is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN. """ super().__init__() self.left = left - self.array_column = array_column - self.alias = alias + self.array_columns = _normalize_array_columns(array_column, alias) self.is_left = is_left self._is_clone_of = None @@ -50,10 +86,10 @@ def _clone(self, **kw): return c def _copy_internals(self, clone=None, **kw): - """Copy internal state for cloning + """Copy internal state for cloning. This ensures that when queries are cloned (e.g., for subqueries, unions, or CTEs), - the left FromClause and array_column references are properly deep-cloned. + the left FromClause and array column references are properly deep-cloned. """ def _default_clone(elem, **kwargs): return elem @@ -61,33 +97,194 @@ def _default_clone(elem, **kwargs): if clone is None: clone = _default_clone - # Clone the left FromClause and array column to ensure proper - # reference handling in complex query scenarios self.left = clone(self.left, **kw) - self.array_column = clone(self.array_column, **kw) + self.array_columns = [ + (clone(col, **kw), alias) + for col, alias in self.array_columns + ] def array_join(left, array_column, alias=None, is_left=False): - """Create an ARRAY JOIN clause + """Create an ARRAY JOIN clause. + + Supports single or multiple array columns. When multiple columns are + provided, they are expanded in parallel (zipped by index position). Args: - left: The left side (table or subquery) - array_column: The array column to join - alias: Optional alias for the joined array elements - is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN + left: The left side (table or subquery). + array_column: A single array column, or a list/tuple of array columns. + alias: Optional alias. A single string when array_column is a single + column, or a list/tuple of strings (same length as array_column) + when array_column is a list. None means no aliases. + is_left: If True, use LEFT ARRAY JOIN instead of ARRAY JOIN. Returns: - ArrayJoin: An ArrayJoin clause element + ArrayJoin: An ArrayJoin clause element. - Example: + Examples: from clickhouse_connect.cc_sqlalchemy.sql.clauses import array_join - # Basic ARRAY JOIN + # Single column ARRAY JOIN query = select(table).select_from(array_join(table, table.c.tags)) - # LEFT ARRAY JOIN with alias + # Single column LEFT ARRAY JOIN with alias + query = select(table).select_from( + array_join(table, table.c.tags, alias="tag", is_left=True) + ) + + # Multiple columns with aliases query = select(table).select_from( - array_join(table, table.c.tags, alias='tag', is_left=True) + array_join( + table, + [table.c.names, table.c.prices, table.c.quantities], + alias=["name", "price", "quantity"], + ) ) """ return ArrayJoin(left, array_column, alias, is_left) + + +_VALID_STRICTNESS = frozenset({None, "ALL", "ANY", "SEMI", "ANTI", "ASOF"}) +_VALID_DISTRIBUTION = frozenset({None, "GLOBAL"}) + + +def _validate_ch_join(strictness, distribution, onclause, isouter, full, is_cross, using): + """Validate ClickHouse join parameter combinations.""" + if strictness not in _VALID_STRICTNESS: + raise ValueError(f"Invalid strictness {strictness!r}. Must be one of: ALL, ANY, SEMI, ANTI, ASOF") + if distribution not in _VALID_DISTRIBUTION: + raise ValueError(f"Invalid distribution {distribution!r}. Must be: GLOBAL") + if is_cross and strictness is not None: + raise ValueError("Strictness modifiers cannot be used with CROSS JOIN") + if is_cross and (isouter or full): + raise ValueError("CROSS JOIN cannot be combined with isouter or full") + if strictness in ("SEMI", "ANTI") and not isouter: + raise ValueError(f"{strictness} JOIN requires isouter=True (LEFT) or swapped table order (RIGHT)") + if strictness == "ASOF" and full: + raise ValueError("ASOF is not supported with FULL joins") + if using is not None: + if is_cross: + raise ValueError("USING cannot be combined with CROSS JOIN") + if onclause is not None: + raise ValueError("Cannot specify both onclause and using") + if not isinstance(using, (list, tuple)) or not using: + raise ValueError("using must be a non-empty list of column name strings") + if not all(isinstance(col, str) for col in using): + raise ValueError("using must contain only column name strings") + + +def _build_using_onclause(left, right, using): + """Build an equality onclause from USING column names. + + This gives SQLAlchemy's from-linter proper column references so it + knows the tables are connected. The compiler renders USING instead of ON. + """ + conditions = [] + for col in using: + try: + conditions.append(left.c[col] == right.c[col]) + except KeyError: + left_cols = {c.name for c in left.c} + right_cols = {c.name for c in right.c} + missing_from = [] + if col not in left_cols: + missing_from.append(str(left)) + if col not in right_cols: + missing_from.append(str(right)) + raise ValueError(f"USING column {col!r} not found in: {', '.join(missing_from)}") from None + return and_(*conditions) if len(conditions) > 1 else conditions[0] + + +# pylint: disable=too-many-ancestors,abstract-method +class ClickHouseJoin(Join): + """A SQLAlchemy Join subclass that supports ClickHouse-specific join features. + + ClickHouse JOIN syntax: [GLOBAL] [ALL|ANY|SEMI|ANTI|ASOF] [INNER|LEFT|RIGHT|FULL|CROSS] JOIN + + Strictness modifiers control how multiple matches are handled: + - ALL: return all matching rows (default, standard SQL behavior) + - ANY: return only the first match per left row + - SEMI: acts as an allowlist on join keys, no Cartesian product + - ANTI: acts as a denylist on join keys, no Cartesian product + - ASOF: time-series join, finds the closest match + + Distribution modifier: + - GLOBAL: broadcasts the right table to all nodes in distributed queries + + USING clause: + - Joins on same-named columns from both tables. Unlike ON, USING merges + matched columns into one, which is important for FULL OUTER JOIN where + ON produces default values (0, '') for unmatched sides. + + Note: RIGHT JOIN is achieved by swapping table order, which is standard SQLAlchemy behavior. + ASOF JOIN requires the last ON condition to be an inequality which is validated by + the ClickHouse server, not here. Not all strictness/join type combinations are supported + by every join algorithm and the server will report unsupported combinations. + """ + + __visit_name__ = "join" + + _traverse_internals = Join._traverse_internals + [ + ("strictness", InternalTraversal.dp_string), + ("distribution", InternalTraversal.dp_string), + ("_is_cross", InternalTraversal.dp_boolean), + ("using_columns", InternalTraversal.dp_string_list), + ] + + def __init__(self, left, right, onclause=None, isouter=False, full=False, + strictness=None, distribution=None, _is_cross=False, using=None): + if strictness is not None: + strictness = strictness.upper() + if distribution is not None: + distribution = distribution.upper() + + _validate_ch_join(strictness, distribution, onclause, isouter, full, _is_cross, using) + + effective_onclause = _build_using_onclause(left, right, using) if using else onclause + super().__init__(left, right, effective_onclause, isouter, full) + self.strictness = strictness + self.distribution = distribution + self._is_cross = _is_cross + self.using_columns = list(using) if using is not None else None + + +def ch_join( + left, + right, + onclause=None, + *, + isouter=False, + full=False, + cross=False, + using=None, + strictness: Optional[str] = None, + distribution: Optional[str] = None, +): + """Create a ClickHouse JOIN with optional strictness, distribution, and USING support. + + Args: + left: The left side table or selectable. + right: The right side table or selectable. + onclause: The ON clause expression. Mutually exclusive with ``using``. + isouter: If True, render a LEFT OUTER JOIN. + full: If True, render a FULL OUTER JOIN. + cross: If True, render a CROSS JOIN. Cannot be combined with + onclause, using, or strictness modifiers. + using: A list of column name strings for USING syntax. The columns + must have the same name in both tables. Mutually exclusive with + ``onclause``. Produces ``USING (col1, col2)`` instead of ``ON``. + strictness: ClickHouse strictness modifier, one of + "ALL", "ANY", "SEMI", "ANTI", or "ASOF". + distribution: ClickHouse distribution modifier "GLOBAL". + + Returns: + ClickHouseJoin: A join element with ClickHouse modifiers. + """ + if cross: + if onclause is not None: + raise ValueError("cross=True conflicts with an explicit onclause") + if using is not None: + raise ValueError("cross=True conflicts with using") + onclause = true() + return ClickHouseJoin(left, right, onclause, isouter, full, + strictness, distribution, _is_cross=cross, using=using) diff --git a/clickhouse_connect/cc_sqlalchemy/sql/compiler.py b/clickhouse_connect/cc_sqlalchemy/sql/compiler.py index baccdd89..65d1c2c2 100644 --- a/clickhouse_connect/cc_sqlalchemy/sql/compiler.py +++ b/clickhouse_connect/cc_sqlalchemy/sql/compiler.py @@ -1,11 +1,45 @@ from sqlalchemy.exc import CompileError -from sqlalchemy.sql import elements +from sqlalchemy.sql import elements, sqltypes from sqlalchemy.sql.compiler import SQLCompiler from clickhouse_connect.cc_sqlalchemy import ArrayJoin +from clickhouse_connect.cc_sqlalchemy.datatypes.base import ChSqlaType from clickhouse_connect.cc_sqlalchemy.sql import format_table +# pylint: disable=too-many-return-statements +def _resolve_ch_type_name(sqla_type): + """Resolve a SQLAlchemy type instance to a ClickHouse type name string. + + Handles both native ChSqlaType instances which carry their ClickHouse name + directly and generic SQLAlchemy types by mapping to reasonable ClickHouse defaults. + """ + if isinstance(sqla_type, ChSqlaType): + return sqla_type.name + # Order matters so we need to check subtypes before parent types + if isinstance(sqla_type, sqltypes.SmallInteger): + return "Int16" + if isinstance(sqla_type, sqltypes.BigInteger): + return "Int64" + if isinstance(sqla_type, sqltypes.Integer): + return "Int32" + if isinstance(sqla_type, sqltypes.Float): + return "Float64" + if isinstance(sqla_type, sqltypes.Numeric): + p = sqla_type.precision or 18 + s = sqla_type.scale or 0 + return f"Decimal({p}, {s})" + if isinstance(sqla_type, sqltypes.Boolean): + return "Bool" + if isinstance(sqla_type, sqltypes.DateTime): + return "DateTime" + if isinstance(sqla_type, sqltypes.Date): + return "Date" + if isinstance(sqla_type, sqltypes.String): + return "String" + return "String" + + # pylint: disable=arguments-differ class ChStatementCompiler(SQLCompiler): @@ -25,15 +59,84 @@ def visit_delete(self, delete_stmt, visiting_cte=None, **kw): return text + # pylint: disable=protected-access + def visit_values(self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw): + """Compile a VALUES clause using ClickHouse's VALUES table function syntax. + + ClickHouse requires the column structure as the first argument: + VALUES('col1 Type1, col2 Type2', (row1_val1, row1_val2), ...) + + This differs from standard SQL which places column names after the alias: + (VALUES (row1), (row2)) AS name (col1, col2) + + Compatible with both SQLAlchemy 1.4 and 2.x. + """ + if getattr(element, "_independent_ctes", None): + self._dispatch_independent_ctes(element, kw) + + structure = ", ".join( + f"{col.name} {_resolve_ch_type_name(col.type)}" + for col in element.columns + ) + + kw.setdefault("literal_binds", element.literal_binds) + tuples = ", ".join( + self.process( + elements.Tuple(types=element._column_types, *elem).self_group(), + **kw, + ) + for chunk in element._data + for elem in chunk + ) + + structure_literal = self.render_literal_value(structure, sqltypes.String()) + v = f"VALUES({structure_literal}, {tuples})" + + # SA 2.x has _unnamed; SA 1.4 uses name=None for unnamed values + is_unnamed = getattr(element, "_unnamed", element.name is None) + if is_unnamed: + name = None + elif isinstance(element.name, elements._truncated_label): + name = self._truncated_identifier("values", element.name) + else: + name = element.name + + lateral = "LATERAL " if element._is_lateral else "" + + if asfrom: + if from_linter: + # SA 2.x has _de_clone(); SA 1.4 doesn't + key = element._de_clone() if hasattr(element, "_de_clone") else element + from_linter.froms[key] = ( + name if name is not None else "(unnamed VALUES element)" + ) + + if visiting_cte is not None and visiting_cte.element is element: + if element._is_lateral: + raise CompileError( + "Can't use a LATERAL VALUES expression inside of a CTE" + ) + v = f"SELECT * FROM {v}" + elif name: + kw["include_table"] = False + v = f"{lateral}{v}{self.get_render_as_alias_suffix(self.preparer.quote(name))}" + else: + v = f"{lateral}{v}" + + return v + def visit_array_join(self, array_join_clause, asfrom=False, from_linter=None, **kw): left = self.process(array_join_clause.left, asfrom=True, from_linter=from_linter, **kw) - array_col = self.process(array_join_clause.array_column, **kw) join_type = "LEFT ARRAY JOIN" if array_join_clause.is_left else "ARRAY JOIN" - text = f"{left} {join_type} {array_col}" - if array_join_clause.alias: - text += f" AS {self.preparer.quote(array_join_clause.alias)}" - return text + parts = [] + for col, alias in array_join_clause.array_columns: + col_text = self.process(col, **kw) + if alias is not None: + col_text += f" AS {self.preparer.quote(alias)}" + parts.append(col_text) + + return f"{left} {join_type} {', '.join(parts)}" def visit_join(self, join, **kw): if isinstance(join, ArrayJoin): @@ -43,18 +146,38 @@ def visit_join(self, join, **kw): right = self.process(join.right, **kw) onclause = join.onclause + is_cross = getattr(join, "_is_cross", False) or onclause is None if getattr(join, "full", False): - join_kw = " FULL OUTER JOIN " - elif onclause is None: - join_kw = " CROSS JOIN " + join_type = "FULL OUTER JOIN" + elif is_cross: + join_type = "CROSS JOIN" elif join.isouter: - join_kw = " LEFT OUTER JOIN " + join_type = "LEFT OUTER JOIN" else: - join_kw = " INNER JOIN " + join_type = "INNER JOIN" - text = left + join_kw + right + # ClickHouse modifiers: [GLOBAL] [ALL|ANY|ASOF] + distribution = getattr(join, "distribution", None) + strictness = getattr(join, "strictness", None) + parts = [] + if distribution: + parts.append(distribution) + if strictness: + parts.append(strictness) + parts.append(join_type) + join_kw = " ".join(parts) - if onclause is not None: + text = f"{left} {join_kw} {right}" + + using_columns = getattr(join, "using_columns", None) + if using_columns: + # Process the onclause so the from-linter registers the + # table relationship, but render USING syntax instead. + if onclause is not None: + self.process(onclause, **kw) + quoted = ", ".join(self.preparer.quote(col) for col in using_columns) + text += f" USING ({quoted})" + elif not is_cross and onclause is not None: text += " ON " + self.process(onclause, **kw) return text @@ -112,9 +235,44 @@ def visit_label( **kw, ) - def get_from_hint_text(self, table, text): - if text == "FINAL": - return "FINAL" - if text.startswith("SAMPLE"): - return text - return super().get_from_hint_text(table, text) + # pylint: disable=protected-access + def _compose_select_body(self, text, select, compile_state, inner_columns, froms, byfrom, toplevel, kwargs): + ch_final = getattr(select, "_ch_final", set()) + ch_sample = getattr(select, "_ch_sample", {}) + + if ch_final or ch_sample: + mods = {} + for target in ch_final | set(ch_sample): + parts = [] + if target in ch_final: + parts.append("FINAL") + if target in ch_sample: + parts.append(f"SAMPLE {ch_sample[target]}") + mods[target] = " ".join(parts) + + prev = getattr(self, "_ch_from_modifiers", None) + self._ch_from_modifiers = mods + try: + return super()._compose_select_body(text, select, compile_state, inner_columns, froms, byfrom, toplevel, kwargs) + finally: + self._ch_from_modifiers = prev + + return super()._compose_select_body(text, select, compile_state, inner_columns, froms, byfrom, toplevel, kwargs) + + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, fromhints=None, enclosing_alias=None, **kwargs): + result = super().visit_table( + table, asfrom=asfrom, iscrud=iscrud, ashint=ashint, fromhints=fromhints, enclosing_alias=enclosing_alias, **kwargs + ) + if asfrom and enclosing_alias is None: + mods = getattr(self, "_ch_from_modifiers", None) + if mods and table in mods: + result += " " + mods[table] + return result + + def visit_alias(self, alias, asfrom=False, **kwargs): + result = super().visit_alias(alias, asfrom=asfrom, **kwargs) + if asfrom: + mods = getattr(self, "_ch_from_modifiers", None) + if mods and alias in mods: + result += " " + mods[alias] + return result diff --git a/tests/integration_tests/test_sqlalchemy/test_array_join.py b/tests/integration_tests/test_sqlalchemy/test_array_join.py index 9d261541..88625f7a 100644 --- a/tests/integration_tests/test_sqlalchemy/test_array_join.py +++ b/tests/integration_tests/test_sqlalchemy/test_array_join.py @@ -11,6 +11,8 @@ def test_tables(test_engine: Engine, test_db: str): """Create test tables for ARRAY JOIN tests""" with test_engine.begin() as conn: conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_array_join")) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_multi_array_join")) + conn.execute( text( f""" @@ -23,6 +25,19 @@ def test_tables(test_engine: Engine, test_db: str): ) ) + conn.execute( + text( + f""" + CREATE TABLE {test_db}.test_multi_array_join ( + id UInt32, + names Array(String), + prices Array(UInt32), + quantities Array(UInt32) + ) ENGINE MergeTree() ORDER BY id + """ + ) + ) + conn.execute( text( f""" @@ -35,12 +50,26 @@ def test_tables(test_engine: Engine, test_db: str): ) ) - # Verify data is actually queryable before yielding to tests - verify_tables_ready(conn, {f"{test_db}.test_array_join": 4}) + conn.execute( + text( + f""" + INSERT INTO {test_db}.test_multi_array_join VALUES + (1, ['widget_a', 'widget_b'], [100, 200], [5, 10]), + (2, ['widget_c'], [300], [15]), + (3, [], [], []) + """ + ) + ) + + verify_tables_ready(conn, { + f"{test_db}.test_array_join": 4, + f"{test_db}.test_multi_array_join": 3, + }) yield conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_array_join")) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_multi_array_join")) def test_array_join(test_engine: Engine, test_db: str): @@ -106,3 +135,109 @@ def test_left_array_join_with_alias(test_engine: Engine, test_db: str): charlie_rows = [row for row in rows if row.name == "Charlie"] assert len(charlie_rows) == 1 assert charlie_rows[0].tag == "" + + +def test_multi_column_array_join(test_engine: Engine, test_db: str): + """Test ARRAY JOIN with multiple columns expanded in parallel""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + test_table = Table("test_multi_array_join", metadata, autoload_with=test_engine) + + query = ( + select( + test_table.c.id, + literal_column("item_name"), + literal_column("price"), + literal_column("qty"), + ) + .select_from( + array_join( + test_table, + [test_table.c.names, test_table.c.prices, test_table.c.quantities], + alias=["item_name", "price", "qty"], + ) + ) + .order_by(test_table.c.id, literal_column("item_name")) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "ARRAY JOIN" in compiled_str.upper() + # All three columns should appear comma-separated after ARRAY JOIN + assert "AS `item_name`" in compiled_str + assert "AS `price`" in compiled_str + assert "AS `qty`" in compiled_str + + result = conn.execute(query) + rows = result.fetchall() + + # id=1 has 2 elements, id=2 has 1 element -> 3 rows total + assert len(rows) == 3 + assert rows[0] == (1, "widget_a", 100, 5) + assert rows[1] == (1, "widget_b", 200, 10) + assert rows[2] == (2, "widget_c", 300, 15) + + +def test_multi_column_array_join_no_aliases(test_engine: Engine, test_db: str): + """Test multi-column ARRAY JOIN without aliases""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + test_table = Table("test_multi_array_join", metadata, autoload_with=test_engine) + + query = ( + select(test_table.c.id, test_table.c.names, test_table.c.prices) + .select_from( + array_join( + test_table, + [test_table.c.names, test_table.c.prices], + ) + ) + .order_by(test_table.c.id, test_table.c.names) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "ARRAY JOIN" in compiled_str.upper() + assert "AS" not in compiled_str.split("ARRAY JOIN")[1] + + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 3 + assert rows[0] == (1, "widget_a", 100) + assert rows[1] == (1, "widget_b", 200) + assert rows[2] == (2, "widget_c", 300) + + +def test_multi_column_left_array_join(test_engine: Engine, test_db: str): + """Test LEFT ARRAY JOIN with multiple columns preserves empty-array rows""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + test_table = Table("test_multi_array_join", metadata, autoload_with=test_engine) + + query = ( + select( + test_table.c.id, + literal_column("item_name"), + literal_column("price"), + ) + .select_from( + array_join( + test_table, + [test_table.c.names, test_table.c.prices], + alias=["item_name", "price"], + is_left=True, + ) + ) + .order_by(test_table.c.id, literal_column("item_name")) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "LEFT ARRAY JOIN" in compiled_str.upper() + + result = conn.execute(query) + rows = result.fetchall() + + # id=1 has 2, id=2 has 1, id=3 has 0 (preserved by LEFT) = 4 + assert len(rows) == 4 + empty_rows = [r for r in rows if r.id == 3] + assert len(empty_rows) == 1 + assert empty_rows[0].item_name == "" # default for String + assert empty_rows[0].price == 0 # default for UInt32 diff --git a/tests/integration_tests/test_sqlalchemy/test_select.py b/tests/integration_tests/test_sqlalchemy/test_select.py index 8af30947..b9125ae3 100644 --- a/tests/integration_tests/test_sqlalchemy/test_select.py +++ b/tests/integration_tests/test_sqlalchemy/test_select.py @@ -1,6 +1,6 @@ # pylint: disable=no-member from pytest import fixture -from sqlalchemy import MetaData, Table, func, select, text +from sqlalchemy import MetaData, Table, func, literal_column, select, text from sqlalchemy.engine import Engine from clickhouse_connect import common @@ -9,6 +9,7 @@ String, UInt32, ) +from clickhouse_connect.cc_sqlalchemy.sql.clauses import ch_join from tests.integration_tests.test_sqlalchemy.conftest import verify_tables_ready @@ -95,10 +96,53 @@ def test_tables(test_engine: Engine, test_db: str): ) ) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_using_sales")) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_using_returns")) + + conn.execute( + text( + f""" + CREATE TABLE {test_db}.test_using_sales ( + product_id UInt32, + sold UInt32 + ) ENGINE MergeTree() ORDER BY product_id + """ + ) + ) + conn.execute( + text( + f""" + CREATE TABLE {test_db}.test_using_returns ( + product_id UInt32, + returned UInt32 + ) ENGINE MergeTree() ORDER BY product_id + """ + ) + ) + + conn.execute( + text( + f""" + INSERT INTO {test_db}.test_using_sales VALUES + (1, 10), (2, 20), (3, 30) + """ + ) + ) + conn.execute( + text( + f""" + INSERT INTO {test_db}.test_using_returns VALUES + (2, 5), (3, 10), (4, 15) + """ + ) + ) + verify_tables_ready(conn, { f"{test_db}.select_test_users": 3, f"{test_db}.select_test_orders": 4, - f"{test_db}.test_argmax": 5 + f"{test_db}.test_argmax": 5, + f"{test_db}.test_using_sales": 3, + f"{test_db}.test_using_returns": 3, }) yield @@ -106,6 +150,8 @@ def test_tables(test_engine: Engine, test_db: str): conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.select_test_users")) conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.select_test_orders")) conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_argmax")) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_using_sales")) + conn.execute(text(f"DROP TABLE IF EXISTS {test_db}.test_using_returns")) def test_basic_select(test_engine: Engine, test_db: str): @@ -163,6 +209,56 @@ def test_basic_select_with_sample(test_engine: Engine, test_db: str): assert compiled_str.endswith("SAMPLE 1") +def test_final_and_sample_chained(test_engine: Engine, test_db: str): + """Chaining .final() and .sample() in either order should produce both clauses.""" + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + + # final() then sample() + query_fs = select(users).final().sample(0.1) + compiled_fs = str(query_fs.compile(dialect=test_engine.dialect)) + assert "FINAL" in compiled_fs + assert "SAMPLE 0.1" in compiled_fs + assert compiled_fs.index("FINAL") < compiled_fs.index("SAMPLE") + + # sample() then final() + query_sf = select(users).sample(0.1).final() + compiled_sf = str(query_sf.compile(dialect=test_engine.dialect)) + assert "FINAL" in compiled_sf + assert "SAMPLE 0.1" in compiled_sf + assert compiled_sf.index("FINAL") < compiled_sf.index("SAMPLE") + + +def test_final_and_sample_with_alias(test_engine: Engine, test_db: str): + """FINAL/SAMPLE on aliased tables renders after the alias suffix.""" + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + alias = users.alias("u") + + compiled = str(select(alias).final().sample(0.1).compile(dialect=test_engine.dialect)) + assert "AS `u` FINAL SAMPLE 0.1" in compiled + assert "FINAL AS" not in compiled + + # Reversed order produces the same output + compiled_rev = str(select(alias).sample(0.1).final().compile(dialect=test_engine.dialect)) + assert "AS `u` FINAL SAMPLE 0.1" in compiled_rev + + +def test_final_with_explicit_table_on_join(test_engine: Engine, test_db: str): + """FINAL applied to a specific table in a join renders correctly.""" + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + orders = Table("select_test_orders", metadata, autoload_with=test_engine) + + join = users.join(orders, users.c.id == orders.c.user_id) + query = select(users.c.id, orders.c.product).select_from(join).final(users) + compiled = str(query.compile(dialect=test_engine.dialect)) + # FINAL should appear between the users table and the JOIN keyword + from_clause = compiled[compiled.index("FROM"):] + assert "select_test_users` FINAL" in from_clause + assert "FINAL" not in from_clause[from_clause.index("JOIN"):] + + def test_select_with_where_with_sample(test_engine: Engine, test_db: str): with test_engine.begin() as conn: metadata = MetaData(schema=test_db) @@ -340,3 +436,149 @@ def test_argmax_aggregate_function(test_engine: Engine, test_db: str): assert rows[1].id == 2 assert rows[1].latest_name == "Bob_v2" assert rows[1].latest_value == 250 + + +def test_all_inner_ch_join(test_engine: Engine, test_db: str): + """ALL INNER JOIN returns all matching rows""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + orders = Table("select_test_orders", metadata, autoload_with=test_engine) + + query = select(users.c.id, users.c.name, orders.c.product).select_from( + ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ALL") + ) + + compiled = query.compile(dialect=test_engine.dialect) + assert "ALL INNER JOIN" in str(compiled).upper() + + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 4 + + +def test_any_left_ch_join(test_engine: Engine, test_db: str): + """ANY LEFT JOIN returns at most one match per left row""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + orders = Table("select_test_orders", metadata, autoload_with=test_engine) + + query = select(users.c.id, users.c.name, orders.c.product).select_from( + ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ANY") + ) + + compiled = query.compile(dialect=test_engine.dialect) + sql_str = str(compiled).upper() + assert "ANY LEFT OUTER JOIN" in sql_str + + result = conn.execute(query) + rows = result.fetchall() + # ANY returns at most one order per user; user_id=1 has 2 orders but gets 1 + assert len(rows) == 3 + user_ids = [row.id for row in rows] + assert sorted(user_ids) == [1, 2, 3] + + +def test_global_all_left_ch_join(test_engine: Engine, test_db: str): + """GLOBAL ALL LEFT OUTER JOIN compiles and executes correctly""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + users = Table("select_test_users", metadata, autoload_with=test_engine) + orders = Table("select_test_orders", metadata, autoload_with=test_engine) + + query = select(users.c.id, users.c.name, orders.c.product).select_from( + ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ALL", distribution="GLOBAL") + ) + + compiled = query.compile(dialect=test_engine.dialect) + sql_str = str(compiled).upper() + assert "GLOBAL ALL LEFT OUTER JOIN" in sql_str + + result = conn.execute(query) + rows = result.fetchall() + # LEFT JOIN: at least all 3 users returned + assert len(rows) >= 3 + user_names = {row.name for row in rows} + assert {"Alice", "Bob", "Charlie"}.issubset(user_names) + + +def test_using_inner_join(test_engine: Engine, test_db: str): + """INNER JOIN USING on a shared column name""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + sales = Table("test_using_sales", metadata, autoload_with=test_engine) + returns = Table("test_using_returns", metadata, autoload_with=test_engine) + + query = ( + select(sales.c.product_id, sales.c.sold, returns.c.returned) + .select_from(ch_join(sales, returns, using=["product_id"])) + .order_by(sales.c.product_id) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "USING" in compiled_str + assert "ON" not in compiled_str + + result = conn.execute(query) + rows = result.fetchall() + # Only product_id 2 and 3 exist in both tables + assert len(rows) == 2 + assert rows[0] == (2, 20, 5) + assert rows[1] == (3, 30, 10) + + +def test_using_full_outer_join(test_engine: Engine, test_db: str): + """FULL OUTER JOIN USING merges the join column correctly.""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + sales = Table("test_using_sales", metadata, autoload_with=test_engine) + returns = Table("test_using_returns", metadata, autoload_with=test_engine) + + # Use unqualified product_id to get the merged USING column + pid = literal_column("product_id") + query = ( + select(pid, sales.c.sold, returns.c.returned) + .select_from(ch_join(sales, returns, using=["product_id"], full=True)) + .order_by(pid) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "FULL OUTER JOIN" in compiled_str + assert "USING" in compiled_str + + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 4 + + by_pid = {row.product_id: row for row in rows} + # product_id=4 only in returns. With USING, product_id is 4 (correct). + # With ON, it would be 0 (wrong). + assert by_pid[4].product_id == 4 + assert by_pid[4].sold == 0 + assert by_pid[4].returned == 15 + # product_id=1 only in sales + assert by_pid[1].sold == 10 + assert by_pid[1].returned == 0 + + +def test_using_with_strictness_integration(test_engine: Engine, test_db: str): + """ANY INNER JOIN with USING compiles and executes""" + with test_engine.begin() as conn: + metadata = MetaData(schema=test_db) + sales = Table("test_using_sales", metadata, autoload_with=test_engine) + returns = Table("test_using_returns", metadata, autoload_with=test_engine) + + query = ( + select(sales.c.product_id, sales.c.sold, returns.c.returned) + .select_from(ch_join(sales, returns, using=["product_id"], strictness="ANY")) + .order_by(sales.c.product_id) + ) + + compiled_str = str(query.compile(dialect=test_engine.dialect)) + assert "ANY INNER JOIN" in compiled_str + assert "USING" in compiled_str + + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 2 diff --git a/tests/integration_tests/test_sqlalchemy/test_values.py b/tests/integration_tests/test_sqlalchemy/test_values.py new file mode 100644 index 00000000..02a41268 --- /dev/null +++ b/tests/integration_tests/test_sqlalchemy/test_values.py @@ -0,0 +1,63 @@ +import pytest +import sqlalchemy as db +from sqlalchemy.engine import Engine + +from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import DateTime + +SA_2 = db.__version__ >= "2" + + +def test_values_round_trip_multi_column(test_engine: Engine): + with test_engine.begin() as conn: + values_clause = db.values( + db.column("id", db.Integer), + db.column("name", db.String), + name="v", + ).data([(17, "user_1"), (29, "user_2")]) + + rows = conn.execute( + db.select(values_clause.c.id, values_clause.c.name).select_from(values_clause).order_by(values_clause.c.id) + ).fetchall() + + assert [(row.id, row.name) for row in rows] == [(17, "user_1"), (29, "user_2")] + + +def test_values_round_trip_single_column(test_engine: Engine): + with test_engine.begin() as conn: + values_clause = db.values( + db.column("score", db.Integer), + name="v", + ).data([(17,), (29,)]) + + total = conn.execute(db.select(db.func.sum(values_clause.c.score)).select_from(values_clause)).scalar() + + assert total == 46 + + +def test_values_round_trip_type_name_with_quotes(test_engine: Engine): + with test_engine.begin() as conn: + values_clause = db.values( + db.column("event_ts", DateTime("UTC")), + name="v", + ).data([("2024-01-02 03:04:05",)]) + + value = conn.execute(db.select(values_clause.c.event_ts).select_from(values_clause)).scalar() + + assert str(value).startswith("2024-01-02 03:04:05") + + +@pytest.mark.skipif(not SA_2, reason="Values.cte() was added in SA 2.x") +def test_values_cte_round_trip(test_engine: Engine): + with test_engine.begin() as conn: + values_clause = ( + db.values( + db.column("id", db.Integer), + name="v", + ) + .data([(17,), (29,)]) + .cte("input_rows") + ) + + value = conn.execute(db.select(db.func.max(values_clause.c.id)).select_from(values_clause)).scalar() + + assert value == 29 diff --git a/tests/unit_tests/test_sqlalchemy/test_array_join.py b/tests/unit_tests/test_sqlalchemy/test_array_join.py new file mode 100644 index 00000000..7496f877 --- /dev/null +++ b/tests/unit_tests/test_sqlalchemy/test_array_join.py @@ -0,0 +1,161 @@ +import pytest +import sqlalchemy as db + +from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import Array, String, UInt32 +from clickhouse_connect.cc_sqlalchemy.dialect import ClickHouseDialect +from clickhouse_connect.cc_sqlalchemy.sql.clauses import ArrayJoin, array_join + +dialect = ClickHouseDialect() +metadata = db.MetaData() + +products = db.Table( + "products", + metadata, + db.Column("id", UInt32), + db.Column("names", Array(String)), + db.Column("prices", Array(UInt32)), + db.Column("quantities", Array(UInt32)), +) + + +def compile_sql(query): + return str(query.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + + +def test_single_column_no_alias(): + query = db.select(products.c.id, products.c.names).select_from(array_join(products, products.c.names)) + sql = compile_sql(query) + assert "ARRAY JOIN" in sql + assert "LEFT" not in sql + assert "AS" not in sql.split("ARRAY JOIN")[1] + + +def test_single_column_with_alias(): + query = db.select(products.c.id, db.literal_column("n")).select_from(array_join(products, products.c.names, alias="n")) + sql = compile_sql(query) + assert "ARRAY JOIN" in sql + assert "AS `n`" in sql + + +def test_single_column_left(): + query = db.select(products.c.id).select_from(array_join(products, products.c.names, is_left=True)) + sql = compile_sql(query) + assert "LEFT ARRAY JOIN" in sql + + +def test_multi_column_with_aliases(): + query = db.select( + products.c.id, + db.literal_column("item_name"), + db.literal_column("price"), + db.literal_column("qty"), + ).select_from( + array_join( + products, + [products.c.names, products.c.prices, products.c.quantities], + alias=["item_name", "price", "qty"], + ) + ) + sql = compile_sql(query) + after_aj = sql.split("ARRAY JOIN")[1] + assert "AS `item_name`" in after_aj + assert "AS `price`" in after_aj + assert "AS `qty`" in after_aj + # Columns should be comma-separated + assert after_aj.count(",") >= 2 + + +def test_multi_column_no_aliases(): + query = db.select(products.c.id, products.c.names, products.c.prices).select_from( + array_join( + products, + [products.c.names, products.c.prices], + ) + ) + sql = compile_sql(query) + after_aj = sql.split("ARRAY JOIN")[1] + assert "AS" not in after_aj + assert "`names`" in after_aj + assert "`prices`" in after_aj + + +def test_multi_column_left(): + query = db.select(products.c.id).select_from( + array_join( + products, + [products.c.names, products.c.prices], + alias=["n", "p"], + is_left=True, + ) + ) + sql = compile_sql(query) + assert "LEFT ARRAY JOIN" in sql + assert "AS `n`" in sql + assert "AS `p`" in sql + + +def test_multi_column_mixed_aliases(): + """Some columns aliased, some not""" + query = db.select( + products.c.id, + db.literal_column("item_name"), + products.c.prices, + db.literal_column("qty"), + ).select_from( + array_join( + products, + [products.c.names, products.c.prices, products.c.quantities], + alias=["item_name", None, "qty"], + ) + ) + sql = compile_sql(query) + after_aj = sql.split("ARRAY JOIN")[1] + assert "AS `item_name`" in after_aj + assert "AS `qty`" in after_aj + # prices should appear without an alias + assert "`prices`" in after_aj + # Make sure there's no AS immediately following prices + prices_segment = after_aj.split("`prices`")[1].lstrip() + assert prices_segment.startswith(",") + + +def test_error_alias_list_with_single_column(): + with pytest.raises(ValueError, match="must be a string or None"): + array_join(products, products.c.names, alias=["n"]) + + +def test_error_alias_string_with_multi_column(): + with pytest.raises(ValueError, match="must be a list"): + array_join(products, [products.c.names, products.c.prices], alias="n") + + +def test_error_alias_length_mismatch(): + with pytest.raises(ValueError, match="must match"): + array_join( + products, + [products.c.names, products.c.prices], + alias=["n"], + ) + + +def test_error_empty_column_list(): + with pytest.raises(ValueError, match="At least one"): + array_join(products, []) + + +def test_direct_constructor_backward_compat(): + """ArrayJoin is public API. Old-style positional calls must still work.""" + aj = ArrayJoin(products, products.c.names, "n", True) + query = db.select(products.c.id, db.literal_column("n")).select_from(aj) + sql = compile_sql(query) + assert "LEFT ARRAY JOIN" in sql + assert "AS `n`" in sql + + +def test_direct_constructor_no_alias(): + """ArrayJoin constructor with no alias, keyword is_left.""" + aj = ArrayJoin(products, products.c.names, is_left=False) + query = db.select(products.c.id, products.c.names).select_from(aj) + sql = compile_sql(query) + assert "ARRAY JOIN" in sql + assert "AS" not in sql.split("ARRAY JOIN")[1] diff --git a/tests/unit_tests/test_sqlalchemy/test_ch_join.py b/tests/unit_tests/test_sqlalchemy/test_ch_join.py new file mode 100644 index 00000000..2007b06a --- /dev/null +++ b/tests/unit_tests/test_sqlalchemy/test_ch_join.py @@ -0,0 +1,300 @@ +import pytest +import sqlalchemy as db + +from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import String, UInt32 +from clickhouse_connect.cc_sqlalchemy.dialect import ClickHouseDialect +from clickhouse_connect.cc_sqlalchemy.sql.clauses import ch_join + +dialect = ClickHouseDialect() +metadata = db.MetaData() + +users = db.Table( + "users", + metadata, + db.Column("id", UInt32), + db.Column("name", String), +) + +orders = db.Table( + "orders", + metadata, + db.Column("id", UInt32), + db.Column("user_id", UInt32), + db.Column("product", String), +) + +items = db.Table( + "items", + metadata, + db.Column("id", UInt32), + db.Column("order_id", UInt32), + db.Column("sku", String), +) + + +def compile_query(stmt): + return str(stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + + +def test_all_inner_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ALL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ALL INNER JOIN" in sql + + +def test_any_inner_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ANY") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ANY INNER JOIN" in sql + + +def test_any_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ANY") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ANY LEFT OUTER JOIN" in sql + + +def test_asof_inner_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ASOF") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ASOF INNER JOIN" in sql + + +def test_asof_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ASOF") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ASOF LEFT OUTER JOIN" in sql + + +def test_semi_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="SEMI") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "SEMI LEFT OUTER JOIN" in sql + + +def test_anti_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ANTI") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ANTI LEFT OUTER JOIN" in sql + + +def test_all_full_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, full=True, strictness="ALL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ALL FULL OUTER JOIN" in sql + + +def test_global_inner_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL INNER JOIN" in sql + + +def test_global_only_join(): + """GLOBAL without strictness on an INNER JOIN.""" + j = ch_join(users, orders, users.c.id == orders.c.user_id, distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL INNER JOIN" in sql + assert "ALL" not in sql + assert "ANY" not in sql + + +def test_global_all_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ALL", distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL ALL LEFT OUTER JOIN" in sql + + +def test_global_asof_left_outer_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, isouter=True, strictness="ASOF", distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL ASOF LEFT OUTER JOIN" in sql + + +def test_no_modifiers_inner_join(): + j = ch_join(users, orders, users.c.id == orders.c.user_id) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert " INNER JOIN " in sql + assert "ALL" not in sql + assert "GLOBAL" not in sql + + +def test_standard_join_unchanged(): + j = users.join(orders, users.c.id == orders.c.user_id) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert " INNER JOIN " in sql + assert "ALL" not in sql + assert "GLOBAL" not in sql + + +def test_standard_outerjoin_unchanged(): + j = users.outerjoin(orders, users.c.id == orders.c.user_id) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert " LEFT OUTER JOIN " in sql + + +def test_case_insensitive_strictness(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, strictness="all") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ALL INNER JOIN" in sql + + +def test_case_insensitive_distribution(): + j = ch_join(users, orders, users.c.id == orders.c.user_id, distribution="global") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL INNER JOIN" in sql + + +def test_invalid_strictness_raises(): + with pytest.raises(ValueError, match="Invalid strictness"): + ch_join(users, orders, users.c.id == orders.c.user_id, strictness="PARTIAL") + + +def test_invalid_distribution_raises(): + with pytest.raises(ValueError, match="Invalid distribution"): + ch_join(users, orders, users.c.id == orders.c.user_id, distribution="LOCAL") + + +def test_cross_join_with_strictness_raises(): + with pytest.raises(ValueError, match="CROSS JOIN"): + ch_join(users, orders, cross=True, strictness="ALL") + + +def test_cross_join_with_onclause_raises(): + with pytest.raises(ValueError, match="cross=True conflicts"): + ch_join(users, orders, users.c.id == orders.c.user_id, cross=True) + + +def test_cross_join_with_isouter_raises(): + with pytest.raises(ValueError, match="isouter or full"): + ch_join(users, orders, cross=True, isouter=True) + + +def test_cross_join_with_full_raises(): + with pytest.raises(ValueError, match="isouter or full"): + ch_join(users, orders, cross=True, full=True) + + +def test_semi_inner_raises(): + with pytest.raises(ValueError, match="SEMI JOIN requires isouter=True"): + ch_join(users, orders, users.c.id == orders.c.user_id, strictness="SEMI") + + +def test_anti_inner_raises(): + with pytest.raises(ValueError, match="ANTI JOIN requires isouter=True"): + ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ANTI") + + +def test_asof_full_join_raises(): + with pytest.raises(ValueError, match="ASOF is not supported with FULL"): + ch_join(users, orders, users.c.id == orders.c.user_id, full=True, strictness="ASOF") + + +def test_cross_join(): + j = ch_join(users, orders, cross=True) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "CROSS JOIN" in sql + assert "ON" not in sql + + +def test_global_cross_join(): + j = ch_join(users, orders, cross=True, distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL CROSS JOIN" in sql + assert "ON" not in sql + + +def test_chained_joins(): + j1 = ch_join(users, orders, users.c.id == orders.c.user_id, strictness="ALL") + j2 = ch_join(j1, items, orders.c.id == items.c.order_id, strictness="ANY") + sql = compile_query(db.select(users.c.name, items.c.sku).select_from(j2)) + assert "ALL INNER JOIN" in sql + assert "ANY INNER JOIN" in sql + + +def test_using_single_column(): + j = ch_join(users, orders, using=["id"]) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "INNER JOIN" in sql + assert "USING (`id`)" in sql + assert "ON" not in sql + + +def test_using_multiple_columns(): + # Use users + items which both have 'id'; add a second shared column name for the test + t1 = db.Table("t1", db.MetaData(), db.Column("a", UInt32), db.Column("b", UInt32), db.Column("x", String)) + t2 = db.Table("t2", db.MetaData(), db.Column("a", UInt32), db.Column("b", UInt32), db.Column("y", String)) + j = ch_join(t1, t2, using=["a", "b"]) + sql = compile_query(db.select(t1.c.x, t2.c.y).select_from(j)) + assert "USING (`a`, `b`)" in sql + + +def test_using_full_outer_join(): + j = ch_join(users, orders, using=["id"], full=True) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "FULL OUTER JOIN" in sql + assert "USING (`id`)" in sql + assert "ON" not in sql + + +def test_using_left_outer_join(): + j = ch_join(users, orders, using=["id"], isouter=True) + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "LEFT OUTER JOIN" in sql + assert "USING (`id`)" in sql + + +def test_using_with_strictness(): + j = ch_join(users, orders, using=["id"], strictness="ANY") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "ANY INNER JOIN" in sql + assert "USING (`id`)" in sql + + +def test_using_with_distribution(): + j = ch_join(users, orders, using=["id"], distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL INNER JOIN" in sql + assert "USING (`id`)" in sql + + +def test_using_with_all_modifiers(): + j = ch_join(users, orders, using=["id"], full=True, strictness="ALL", distribution="GLOBAL") + sql = compile_query(db.select(users.c.name).select_from(j)) + assert "GLOBAL ALL FULL OUTER JOIN" in sql + assert "USING (`id`)" in sql + + +def test_using_with_onclause_raises(): + with pytest.raises(ValueError, match="Cannot specify both onclause and using"): + ch_join(users, orders, users.c.id == orders.c.id, using=["id"]) + + +def test_using_with_cross_raises(): + with pytest.raises(ValueError, match="cross=True conflicts with using"): + ch_join(users, orders, cross=True, using=["id"]) + + +def test_using_empty_list_raises(): + with pytest.raises(ValueError, match="non-empty list"): + ch_join(users, orders, using=[]) + + +def test_using_non_string_raises(): + with pytest.raises(ValueError, match="column name strings"): + ch_join(users, orders, using=[users.c.id]) + + +def test_using_missing_column_raises(): + with pytest.raises(ValueError, match="USING column 'missing'.*not found"): + ch_join(users, orders, using=["missing"]) + + +# pylint: disable=protected-access +def test_using_cache_key_differs_from_on(): + """USING and ON joins on the same column must produce different cache keys.""" + j_on = ch_join(users, orders, users.c.id == orders.c.id) + j_using = ch_join(users, orders, using=["id"]) + key_on = j_on._generate_cache_key() + key_using = j_using._generate_cache_key() + assert key_on != key_using diff --git a/tests/unit_tests/test_sqlalchemy/test_values.py b/tests/unit_tests/test_sqlalchemy/test_values.py new file mode 100644 index 00000000..42b383ff --- /dev/null +++ b/tests/unit_tests/test_sqlalchemy/test_values.py @@ -0,0 +1,66 @@ +from datetime import datetime + +import pytest +from sqlalchemy import DateTime as SqlaDateTime +import sqlalchemy as db + +from clickhouse_connect.cc_sqlalchemy.datatypes.sqltypes import DateTime +from clickhouse_connect.cc_sqlalchemy.dialect import ClickHouseDialect + +SA_2 = db.__version__ >= "2" + +dialect = ClickHouseDialect() + + +def compile_query(stmt): + return str(stmt.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + + +def test_values_renders_clickhouse_table_function_syntax(): + values_clause = db.values( + db.column("id", db.Integer), + db.column("name", db.String), + name="v", + ).data([(13, "user_1"), (29, "user_2")]) + + sql = compile_query(db.select(values_clause)) + + assert "FROM VALUES('id Int32, name String', (13, 'user_1'), (29, 'user_2')) AS `v`" in sql + assert "FROM (VALUES" not in sql + assert "AS `v` (`id`, `name`)" not in sql + + +def test_values_escapes_structure_literal_for_clickhouse_type_names(): + values_clause = db.values( + db.column("ts", DateTime("UTC")), + name="v", + ).data([("2024-01-02 03:04:05",)]) + + sql = compile_query(db.select(values_clause)) + + assert "VALUES('ts DateTime(''UTC'')', ('2024-01-02 03:04:05')) AS `v`" in sql + + +@pytest.mark.skipif(not SA_2, reason="SA 1.4 lacks literal datetime rendering for this type") +def test_values_maps_generic_sqla_datetime_type(): + values_clause = db.values( + db.column("ts", SqlaDateTime()), + name="v", + ).data([(datetime(2024, 1, 2, 3, 4, 5),)]) + + sql = compile_query(db.select(values_clause)) + + assert "VALUES('ts DateTime', ('2024-01-02 03:04:05')) AS `v`" in sql + + +@pytest.mark.skipif(not SA_2, reason="Values.cte() was added in SA 2.x") +def test_values_cte_wraps_table_function_in_select(): + values_clause = db.values( + db.column("id", db.Integer), + name="v", + ).data([(17,), (29,)]).cte("input_rows") + + sql = compile_query(db.select(values_clause.c.id).select_from(values_clause)) + + assert "WITH `input_rows`(`id`) AS" in sql + assert "(SELECT * FROM VALUES('id Int32', (17), (29)))" in sql