diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..53fa7a35 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,83 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +pg_graphql is a PostgreSQL extension written in Rust using the pgrx framework. It reflects a GraphQL schema from existing SQL tables and enables GraphQL queries directly within PostgreSQL, without additional servers or processes. + +## Build Commands + +```bash +# Build and install extension (required after any Rust changes) +cargo pgrx install + +# Interactive psql with extension installed +cargo pgrx run pg16 # or pg14, pg15, pg17, pg18 +``` + +## Testing + +**IMPORTANT**: pg_regress is the primary test suite for checking regressions. Always run pg_regress tests to verify changes, not just unit tests. + +Tests use PostgreSQL's pg_regress framework with SQL files: +- Test SQL files: `test/sql/*.sql` +- Expected output: `test/expected/*.out` +- Actual output (after running): `results/*.out` + +```bash +# Preferred: Run all pg_regress tests (installs extension and runs tests) +./run_tests.sh + +# Run specific test(s) +./run_tests.sh test_name another_test + +# Alternative: Manual steps +cargo pgrx install --pg-config /opt/homebrew/opt/postgresql@17/bin/pg_config --features pg17 --no-default-features +./bin/installcheck_local + +# Run unit tests only (not sufficient for regression testing) +cargo pgrx test pg17 +``` + +When writing or editing tests: +1. Create/modify SQL in `test/sql/test_name.sql` +2. Run the test to generate output in `results/test_name.out` +3. Manually verify the output +4. Copy to `test/expected/test_name.out` to make it pass + +**Never modify expected output files** (`test/expected/*.out`) unless you have verified that the new output is correct. If a test fails, investigate and fix the code, don't change the expected output. + +## Architecture + +The extension processes GraphQL queries through this pipeline: + +1. **Entry Point** (`src/lib.rs`): The `resolve` function is exposed as `graphql._internal_resolve`. It parses the GraphQL query and orchestrates resolution. + +2. **SQL Schema Loading** (`src/sql_types.rs`): Loads PostgreSQL schema metadata (tables, columns, functions, foreign keys, permissions) into Rust structs via SQL queries in `sql/load_sql_context.sql`. + +3. **GraphQL Schema Building** (`src/graphql.rs`): Transforms SQL metadata into a GraphQL schema (`__Schema`). Tables become connection types, foreign keys become relationships, and functions can extend types. + +4. **Query Resolution** (`src/resolve.rs`): Validates the GraphQL query against the schema, handles fragments, variables, and operation selection. + +5. **SQL Transpilation** (`src/transpile.rs`): Converts validated GraphQL operations into SQL queries. Implements `QueryEntrypoint` and `MutationEntrypoint` traits that generate and execute SQL. + +Key SQL files loaded as extension SQL: +- `sql/resolve.sql`: Wrapper function exposing `graphql.resolve()` +- `sql/directives.sql`: Schema comment directive parsing (`@graphql({...})`) + +## Debugging + +Use the pgrx elog macro to print debug output that appears in test `.out` files: + +```rust +pgrx_pg_sys::submodules::elog::info!("debug: {:?}", value); +``` + +## Documentation + +```bash +pip install -r docs/requirements_docs.txt +mkdocs serve +# Visit http://127.0.0.1:8000/pg_graphql/ +``` diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 00000000..a3880bc0 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,1323 @@ +# SQL AST Refactoring Plan for pg_graphql + +## Executive Summary + +This plan outlines a comprehensive refactoring of the pg_graphql transpiler to use a type-safe SQL Abstract Syntax Tree (AST). The new architecture will provide: + +1. **Type Safety**: Compile-time guarantees for SQL structure validity +2. **Modularity**: Clean separation between AST, rendering, and execution +3. **Future Extensibility**: Support for nested inserts requiring multiple SQL executions +4. **Debuggability**: Rich logging and telemetry throughout the pipeline +5. **Testability**: Each layer independently testable + +--- + +## Current Architecture Analysis + +### Current Flow +``` +GraphQL Query → Parser (resolve.rs) + → Builder Creation (builder.rs) + → SQL String Generation (transpile.rs) - uses format!() strings + → Execution via pgrx SPI +``` + +### Problems with Current Approach + +1. **String-based SQL Generation**: Uses `format!()` macros throughout `transpile.rs`, making it easy to create malformed SQL +2. **Tight pgrx Coupling**: `ParamContext` is tightly coupled to pgrx's `DatumWithOid` +3. **No Intermediate Representation**: No way to inspect/validate/transform SQL before rendering +4. **Single Execution Assumption**: Current design assumes one SQL statement per request +5. **Limited Debugging**: No structured way to log generated SQL or execution plans + +--- + +## Proposed Architecture + +### New Flow +``` +GraphQL Query → Parser (resolve.rs) + → Builder Creation (builder.rs) + → AST Construction (NEW: src/ast/mod.rs) + → SQL Rendering (NEW: src/ast/render.rs) + → Execution Plan (NEW: src/executor/mod.rs) + → Execution via pgrx SPI +``` + +### Module Structure + +``` +src/ +├── ast/ # NEW: Standalone SQL AST module +│ ├── mod.rs # Public API, re-exports +│ ├── expr.rs # Expression types (columns, literals, operators) +│ ├── stmt.rs # Statement types (SELECT, INSERT, UPDATE, DELETE) +│ ├── cte.rs # CTE (WITH clause) support +│ ├── types.rs # SQL type representations +│ ├── render.rs # SQL string rendering (the only place SQL strings are built) +│ ├── params.rs # Parameter handling (decoupled from pgrx) +│ └── validate.rs # AST validation utilities +│ +├── executor/ # NEW: Execution layer +│ ├── mod.rs # Public API +│ ├── plan.rs # Execution plans (single or multi-statement) +│ ├── pgrx_backend.rs # pgrx-specific execution +│ └── telemetry.rs # Logging and metrics +│ +├── transpile.rs # REFACTOR: Now builds AST instead of strings +├── builder.rs # UNCHANGED initially +└── ... +``` + +--- + +## Phase 1: Core AST Module (Standalone, No pgrx Dependencies) + +### 1.1 Expression Types (`src/ast/expr.rs`) + +```rust +/// SQL expressions - the building blocks +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + /// Column reference: table.column or just column + Column(ColumnRef), + + /// Literal value with type information + Literal(Literal), + + /// Parameterized value: $1, $2, etc. + Param(ParamRef), + + /// Binary operation: expr op expr + BinaryOp { + left: Box, + op: BinaryOperator, + right: Box, + }, + + /// Unary operation: op expr (e.g., NOT) + UnaryOp { + op: UnaryOperator, + expr: Box, + }, + + /// Function call: func(args) + FunctionCall(FunctionCall), + + /// CASE WHEN ... THEN ... ELSE ... END + Case(CaseExpr), + + /// Subquery: (SELECT ...) + Subquery(Box), + + /// Array construction: ARRAY[...] + Array(Vec), + + /// Type cast: expr::type + Cast { + expr: Box, + target_type: SqlType, + }, + + /// IS NULL / IS NOT NULL + IsNull { + expr: Box, + negated: bool, + }, + + /// expr IN (values) or expr = ANY(array) + InList { + expr: Box, + list: Vec, + negated: bool, + }, + + /// Aggregate function with optional filter + Aggregate(AggregateExpr), + + /// JSON/JSONB building functions + JsonBuild(JsonBuildExpr), + + /// Raw SQL (escape hatch, should be minimized) + Raw(String), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnRef { + pub table_alias: Option, + pub column: Ident, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct ParamRef { + pub index: usize, + pub type_cast: SqlType, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Literal { + Null, + Bool(bool), + Integer(i64), + Float(f64), + String(String), + Default, // SQL DEFAULT keyword +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BinaryOperator { + // Comparison + Eq, NotEq, Lt, LtEq, Gt, GtEq, + // Array + Contains, ContainedBy, Overlap, Any, + // String + Like, ILike, RegEx, IRegEx, StartsWith, + // Logical + And, Or, + // Arithmetic + Add, Sub, Mul, Div, + // JSON + JsonExtract, JsonExtractText, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct FunctionCall { + pub schema: Option, + pub name: Ident, + pub args: Vec, + pub filter: Option>, + pub order_by: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum FunctionArg { + Unnamed(Expr), + Named { name: Ident, value: Expr }, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AggregateExpr { + pub function: AggregateFunction, + pub args: Vec, + pub filter: Option>, + pub order_by: Option>, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum AggregateFunction { + Count, + Sum, + Avg, + Min, + Max, + JsonAgg, + JsonbAgg, + ArrayAgg, + BoolAnd, + BoolOr, + Coalesce, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum JsonBuildExpr { + Object(Vec<(Expr, Expr)>), // key-value pairs + Array(Vec), +} + +/// A quoted identifier +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Ident(pub String); +``` + +### 1.2 Statement Types (`src/ast/stmt.rs`) + +```rust +/// SQL statements +#[derive(Debug, Clone)] +pub enum Stmt { + Select(SelectStmt), + Insert(InsertStmt), + Update(UpdateStmt), + Delete(DeleteStmt), +} + +#[derive(Debug, Clone)] +pub struct SelectStmt { + pub ctes: Vec, + pub columns: Vec, + pub from: Option, + pub where_clause: Option, + pub group_by: Vec, + pub having: Option, + pub order_by: Vec, + pub limit: Option, + pub offset: Option, +} + +#[derive(Debug, Clone)] +pub enum SelectColumn { + Expr { expr: Expr, alias: Option }, + AllFrom { table: Ident }, +} + +#[derive(Debug, Clone)] +pub enum FromClause { + Table { + schema: Option, + name: Ident, + alias: Option, + }, + Subquery { + query: Box, + alias: Ident, + }, + Function { + call: FunctionCall, + alias: Ident, + }, + Join { + left: Box, + join_type: JoinType, + right: Box, + on: Option, + }, + CrossJoin { + left: Box, + right: Box, + }, +} + +#[derive(Debug, Clone)] +pub enum JoinType { + Inner, + Left, + Right, + Full, +} + +#[derive(Debug, Clone)] +pub struct OrderByExpr { + pub expr: Expr, + pub direction: Option, + pub nulls: Option, +} + +#[derive(Debug, Clone)] +pub enum OrderDirection { + Asc, + Desc, +} + +#[derive(Debug, Clone)] +pub enum NullsOrder { + First, + Last, +} + +#[derive(Debug, Clone)] +pub struct InsertStmt { + pub ctes: Vec, + pub schema: Option, + pub table: Ident, + pub columns: Vec, + pub values: InsertValues, + pub returning: Vec, +} + +#[derive(Debug, Clone)] +pub enum InsertValues { + Values(Vec>), // Multiple rows + Query(Box), +} + +#[derive(Debug, Clone)] +pub struct UpdateStmt { + pub ctes: Vec, + pub schema: Option, + pub table: Ident, + pub alias: Option, + pub set: Vec<(Ident, Expr)>, + pub where_clause: Option, + pub returning: Vec, +} + +#[derive(Debug, Clone)] +pub struct DeleteStmt { + pub ctes: Vec, + pub schema: Option, + pub table: Ident, + pub alias: Option, + pub where_clause: Option, + pub returning: Vec, +} +``` + +### 1.3 CTE Support (`src/ast/cte.rs`) + +```rust +#[derive(Debug, Clone)] +pub struct Cte { + pub name: Ident, + pub columns: Option>, + pub query: CteQuery, + pub materialized: Option, +} + +#[derive(Debug, Clone)] +pub enum CteQuery { + Select(SelectStmt), + Insert(InsertStmt), + Update(UpdateStmt), + Delete(DeleteStmt), +} +``` + +### 1.4 Type System (`src/ast/types.rs`) + +```rust +#[derive(Debug, Clone, PartialEq)] +pub struct SqlType { + pub schema: Option, + pub name: String, + pub oid: Option, // PostgreSQL OID when known + pub is_array: bool, +} + +impl SqlType { + pub fn text() -> Self { + Self { schema: None, name: "text".into(), oid: Some(25), is_array: false } + } + + pub fn integer() -> Self { + Self { schema: None, name: "integer".into(), oid: Some(23), is_array: false } + } + + pub fn bigint() -> Self { + Self { schema: None, name: "bigint".into(), oid: Some(20), is_array: false } + } + + pub fn jsonb() -> Self { + Self { schema: None, name: "jsonb".into(), oid: Some(3802), is_array: false } + } + + pub fn boolean() -> Self { + Self { schema: None, name: "boolean".into(), oid: Some(16), is_array: false } + } + + pub fn array_of(base: Self) -> Self { + Self { is_array: true, ..base } + } + + pub fn custom(schema: Option, name: String) -> Self { + Self { schema, name, oid: None, is_array: false } + } +} +``` + +### 1.5 Parameter Handling (`src/ast/params.rs`) + +```rust +/// A parameter value that can be rendered to SQL +/// Decoupled from pgrx - this is the AST's view of parameters +#[derive(Debug, Clone)] +pub struct Param { + pub index: usize, + pub value: ParamValue, + pub sql_type: SqlType, +} + +#[derive(Debug, Clone)] +pub enum ParamValue { + Null, + Bool(bool), + String(String), + Integer(i64), + Float(f64), + Array(Vec), + Json(serde_json::Value), +} + +/// Collects parameters during AST construction +#[derive(Debug, Default)] +pub struct ParamCollector { + params: Vec, +} + +impl ParamCollector { + pub fn new() -> Self { + Self { params: Vec::new() } + } + + /// Add a parameter and return its reference expression + pub fn add(&mut self, value: ParamValue, sql_type: SqlType) -> Expr { + let index = self.params.len() + 1; // 1-indexed for SQL + self.params.push(Param { index, value, sql_type: sql_type.clone() }); + Expr::Param(ParamRef { index, type_cast: sql_type }) + } + + /// Get all collected parameters + pub fn into_params(self) -> Vec { + self.params + } + + /// Get parameters as a slice + pub fn params(&self) -> &[Param] { + &self.params + } +} +``` + +### 1.6 SQL Rendering (`src/ast/render.rs`) + +```rust +use std::fmt::Write; + +pub struct SqlRenderer { + output: String, + indent_level: usize, + pretty: bool, +} + +impl SqlRenderer { + pub fn new() -> Self { + Self { output: String::new(), indent_level: 0, pretty: false } + } + + pub fn pretty() -> Self { + Self { output: String::new(), indent_level: 0, pretty: true } + } + + pub fn render_stmt(&mut self, stmt: &Stmt) -> &str { + match stmt { + Stmt::Select(s) => self.render_select(s), + Stmt::Insert(s) => self.render_insert(s), + Stmt::Update(s) => self.render_update(s), + Stmt::Delete(s) => self.render_delete(s), + } + &self.output + } + + fn render_select(&mut self, stmt: &SelectStmt) { + self.render_ctes(&stmt.ctes); + self.write("SELECT "); + self.render_select_columns(&stmt.columns); + if let Some(from) = &stmt.from { + self.newline(); + self.write("FROM "); + self.render_from(from); + } + if let Some(where_clause) = &stmt.where_clause { + self.newline(); + self.write("WHERE "); + self.render_expr(where_clause); + } + // ... group_by, having, order_by, limit, offset + } + + fn render_expr(&mut self, expr: &Expr) { + match expr { + Expr::Column(col) => { + if let Some(table) = &col.table_alias { + self.write_ident(table); + self.write("."); + } + self.write_ident(&col.column); + } + Expr::Param(p) => { + write!(self.output, "(${}", p.index).unwrap(); + self.write("::"); + self.render_type(&p.type_cast); + self.write(")"); + } + Expr::BinaryOp { left, op, right } => { + self.render_expr(left); + self.write(" "); + self.render_binary_op(op); + self.write(" "); + self.render_expr(right); + } + // ... all other expression types + _ => todo!("Render {:?}", expr), + } + } + + fn write_ident(&mut self, ident: &Ident) { + // Use PostgreSQL's quote_ident rules + // For now, always quote to be safe + write!(self.output, "\"{}\"", ident.0.replace('"', "\"\"")).unwrap(); + } + + fn write(&mut self, s: &str) { + self.output.push_str(s); + } + + fn newline(&mut self) { + if self.pretty { + self.output.push('\n'); + for _ in 0..self.indent_level { + self.output.push_str(" "); + } + } else { + self.output.push(' '); + } + } +} + +/// Convenience function for quick rendering +pub fn render(stmt: &Stmt) -> String { + let mut renderer = SqlRenderer::new(); + renderer.render_stmt(stmt); + renderer.output +} + +pub fn render_pretty(stmt: &Stmt) -> String { + let mut renderer = SqlRenderer::pretty(); + renderer.render_stmt(stmt); + renderer.output +} +``` + +--- + +## Phase 2: Executor Module + +### 2.1 Execution Plans (`src/executor/plan.rs`) + +```rust +use crate::ast::{Stmt, Param}; + +/// An execution plan that may contain one or more SQL statements +#[derive(Debug)] +pub struct ExecutionPlan { + pub steps: Vec, + pub telemetry: PlanTelemetry, +} + +#[derive(Debug)] +pub struct ExecutionStep { + pub id: String, + pub stmt: Stmt, + pub params: Vec, + pub description: String, + pub depends_on: Vec, // IDs of steps this depends on +} + +#[derive(Debug, Default)] +pub struct PlanTelemetry { + pub graphql_query: Option, + pub operation_name: Option, + pub created_at: std::time::Instant, +} + +impl ExecutionPlan { + pub fn single(stmt: Stmt, params: Vec, description: &str) -> Self { + Self { + steps: vec![ExecutionStep { + id: "main".to_string(), + stmt, + params, + description: description.to_string(), + depends_on: vec![], + }], + telemetry: PlanTelemetry::default(), + } + } + + /// For future nested inserts: create a multi-step plan + pub fn multi(steps: Vec) -> Self { + Self { + steps, + telemetry: PlanTelemetry::default(), + } + } + + pub fn with_graphql_context(mut self, query: &str, operation_name: Option<&str>) -> Self { + self.telemetry.graphql_query = Some(query.to_string()); + self.telemetry.operation_name = operation_name.map(|s| s.to_string()); + self + } +} +``` + +### 2.2 pgrx Backend (`src/executor/pgrx_backend.rs`) + +```rust +use crate::ast::{Param, ParamValue, SqlType, render}; +use crate::executor::ExecutionPlan; +use crate::error::{GraphQLError, GraphQLResult}; +use pgrx::prelude::*; +use pgrx::datum::DatumWithOid; +use pgrx::spi::SpiClient; + +/// Converts AST parameters to pgrx Datums +pub fn params_to_datums(params: &[Param]) -> GraphQLResult>> { + params.iter().map(param_to_datum).collect() +} + +fn param_to_datum(param: &Param) -> GraphQLResult> { + let datum = match ¶m.value { + ParamValue::Null => None::.into_datum(), + ParamValue::Bool(b) => b.to_string().into_datum(), + ParamValue::String(s) => s.clone().into_datum(), + ParamValue::Integer(i) => i.to_string().into_datum(), + ParamValue::Float(f) => f.to_string().into_datum(), + ParamValue::Array(arr) => { + let strings: Vec> = arr.iter() + .map(|v| match v { + ParamValue::Null => None, + ParamValue::String(s) => Some(s.clone()), + ParamValue::Integer(i) => Some(i.to_string()), + ParamValue::Float(f) => Some(f.to_string()), + ParamValue::Bool(b) => Some(b.to_string()), + _ => None, // Nested arrays not supported + }) + .collect(); + strings.into_datum() + } + ParamValue::Json(v) => v.to_string().into_datum(), + }; + + let oid = if param.sql_type.is_array { + pgrx::pg_sys::TEXTARRAYOID + } else { + pgrx::pg_sys::TEXTOID + }; + + Ok(unsafe { DatumWithOid::new(datum, oid) }) +} + +/// Execute a query plan and return JSON result +pub fn execute_query(plan: &ExecutionPlan) -> GraphQLResult { + // Log the plan for debugging + log_execution_plan(plan); + + if plan.steps.len() != 1 { + return Err(GraphQLError::internal( + "Query execution currently only supports single-step plans" + )); + } + + let step = &plan.steps[0]; + let sql = render(&step.stmt); + let datums = params_to_datums(&step.params)?; + + log_sql_execution(&sql, &step.params); + + Spi::connect(|client| { + let result = client.select(&sql, Some(1), &datums)?; + if result.is_empty() { + Ok(serde_json::Value::Null) + } else { + let jsonb: pgrx::JsonB = result.first().get(1)? + .ok_or_else(|| GraphQLError::internal("No result from query"))?; + Ok(jsonb.0) + } + }).map_err(|e| GraphQLError::sql_execution(format!("SPI error: {:?}", e))) +} + +/// Execute a mutation plan and return JSON result +pub fn execute_mutation<'conn>( + plan: &ExecutionPlan, + client: &mut SpiClient<'conn>, +) -> GraphQLResult { + log_execution_plan(plan); + + if plan.steps.len() != 1 { + // TODO: For nested inserts, we'll need to execute multiple steps + return Err(GraphQLError::internal( + "Mutation execution currently only supports single-step plans" + )); + } + + let step = &plan.steps[0]; + let sql = render(&step.stmt); + let datums = params_to_datums(&step.params)?; + + log_sql_execution(&sql, &step.params); + + let result = client.update(&sql, None, &datums) + .map_err(|_| GraphQLError::sql_execution("Failed to execute mutation"))?; + + let jsonb: pgrx::JsonB = result.first().get(1)? + .ok_or_else(|| GraphQLError::internal("No result from mutation"))?; + + Ok(jsonb.0) +} +``` + +### 2.3 Telemetry (`src/executor/telemetry.rs`) + +```rust +use crate::ast::{Param, render_pretty}; +use crate::executor::ExecutionPlan; + +/// Log level for SQL telemetry +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogLevel { + Off, + Basic, // Just SQL and timing + Detailed, // SQL + parameters + Debug, // Everything including AST dump +} + +/// Get current log level from GUC or environment +pub fn get_log_level() -> LogLevel { + // TODO: Read from pg_graphql.log_level GUC + // For now, check environment variable + match std::env::var("PG_GRAPHQL_LOG_LEVEL").as_deref() { + Ok("off") => LogLevel::Off, + Ok("basic") => LogLevel::Basic, + Ok("detailed") => LogLevel::Detailed, + Ok("debug") => LogLevel::Debug, + _ => LogLevel::Off, + } +} + +pub fn log_execution_plan(plan: &ExecutionPlan) { + let level = get_log_level(); + if level == LogLevel::Off { + return; + } + + pgrx::info!( + "pg_graphql: Executing plan with {} step(s)", + plan.steps.len() + ); + + if level >= LogLevel::Detailed { + if let Some(query) = &plan.telemetry.graphql_query { + pgrx::info!("pg_graphql: GraphQL query:\n{}", query); + } + } +} + +pub fn log_sql_execution(sql: &str, params: &[Param]) { + let level = get_log_level(); + if level == LogLevel::Off { + return; + } + + pgrx::info!("pg_graphql: Executing SQL:\n{}", sql); + + if level >= LogLevel::Detailed { + for param in params { + pgrx::info!( + "pg_graphql: Param ${}: {:?} ({})", + param.index, + param.value, + param.sql_type.name + ); + } + } +} + +pub fn log_execution_result(duration_ms: u64, row_count: usize) { + let level = get_log_level(); + if level >= LogLevel::Basic { + pgrx::info!( + "pg_graphql: Execution completed in {}ms, {} rows", + duration_ms, + row_count + ); + } +} +``` + +--- + +## Phase 3: Refactor Transpile Module + +### 3.1 New Transpiler Architecture + +The transpiler will be refactored to build AST nodes instead of strings. Each builder's `to_sql()` method becomes `to_ast()`. + +```rust +// src/transpile.rs - new structure + +use crate::ast::*; +use crate::builder::*; +use crate::error::GraphQLResult; +use crate::executor::{ExecutionPlan, ExecutionStep}; + +/// Trait for types that can be transpiled to an execution plan +pub trait ToExecutionPlan { + fn to_plan(&self) -> GraphQLResult; +} + +/// Trait for types that can be transpiled to an AST expression or statement fragment +pub trait ToAst { + type Output; + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult; +} + +impl ToExecutionPlan for InsertBuilder { + fn to_plan(&self) -> GraphQLResult { + let mut params = ParamCollector::new(); + let stmt = self.to_ast(&mut params)?; + Ok(ExecutionPlan::single( + Stmt::Select(stmt), // INSERT wrapped in CTE returns via SELECT + params.into_params(), + "Insert mutation" + )) + } +} + +impl ToAst for InsertBuilder { + type Output = SelectStmt; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let block_name = Ident(rand_block_name_raw()); + + // Build the INSERT CTE + let insert_cte = self.build_insert_cte(params)?; + + // Build the SELECT that reads from the CTE + let select_columns = self.build_select_columns(&block_name, params)?; + + Ok(SelectStmt { + ctes: vec![insert_cte], + columns: select_columns, + from: Some(FromClause::Table { + schema: None, + name: Ident("affected".to_string()), + alias: Some(block_name), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }) + } +} + +impl InsertBuilder { + fn build_insert_cte(&self, params: &mut ParamCollector) -> GraphQLResult { + let columns: Vec = self.referenced_columns() + .iter() + .map(|c| Ident(c.name.clone())) + .collect(); + + let values: Vec> = self.objects + .iter() + .map(|row| self.row_to_exprs(row, params)) + .collect::>()?; + + let returning = self.table.columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::Expr { + expr: Expr::Column(ColumnRef { + table_alias: None, + column: Ident(c.name.clone()), + }), + alias: None, + }) + .collect(); + + Ok(Cte { + name: Ident("affected".to_string()), + columns: None, + query: CteQuery::Insert(InsertStmt { + ctes: vec![], + schema: Some(Ident(self.table.schema.clone())), + table: Ident(self.table.name.clone()), + columns, + values: InsertValues::Values(values), + returning, + }), + materialized: None, + }) + } + + fn row_to_exprs( + &self, + row: &InsertRowBuilder, + params: &mut ParamCollector, + ) -> GraphQLResult> { + self.referenced_columns() + .iter() + .map(|col| { + match row.row.get(&col.name) { + None | Some(InsertElemValue::Default) => Ok(Expr::Literal(Literal::Default)), + Some(InsertElemValue::Value(val)) => { + let param_value = json_to_param_value(val)?; + let sql_type = SqlType::custom(None, col.type_name.clone()); + Ok(params.add(param_value, sql_type)) + } + } + }) + .collect() + } +} +``` + +### 3.2 Filter Expression Building + +```rust +impl ToAst for FilterBuilder { + type Output = Expr; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + if self.elems.is_empty() { + return Ok(Expr::Literal(Literal::Bool(true))); + } + + let exprs: Vec = self.elems + .iter() + .map(|elem| elem.to_ast(params)) + .collect::>()?; + + // Combine with AND + Ok(exprs.into_iter().reduce(|acc, expr| { + Expr::BinaryOp { + left: Box::new(acc), + op: BinaryOperator::And, + right: Box::new(expr), + } + }).unwrap_or(Expr::Literal(Literal::Bool(true)))) + } +} + +impl ToAst for FilterBuilderElem { + type Output = Expr; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + match self { + Self::Column { column, op, value } => { + let col_expr = Expr::Column(ColumnRef { + table_alias: None, // Will be set by caller + column: Ident(column.name.clone()), + }); + + match op { + FilterOp::Is => { + let is_null = match value.as_str() { + Some("NULL") => true, + Some("NOT_NULL") => false, + _ => return Err(GraphQLError::sql_generation("Invalid IS filter value")), + }; + Ok(Expr::IsNull { + expr: Box::new(col_expr), + negated: !is_null, + }) + } + _ => { + let sql_type = self.determine_type(column, op); + let param_value = json_to_param_value(value)?; + let param_expr = params.add(param_value, sql_type); + + Ok(Expr::BinaryOp { + left: Box::new(col_expr), + op: filter_op_to_binary_op(op), + right: Box::new(param_expr), + }) + } + } + } + Self::NodeId(node_id) => node_id.to_ast(params), + Self::Compound(compound) => compound.to_ast(params), + } + } +} + +fn filter_op_to_binary_op(op: &FilterOp) -> BinaryOperator { + match op { + FilterOp::Equal => BinaryOperator::Eq, + FilterOp::NotEqual => BinaryOperator::NotEq, + FilterOp::LessThan => BinaryOperator::Lt, + FilterOp::LessThanEqualTo => BinaryOperator::LtEq, + FilterOp::GreaterThan => BinaryOperator::Gt, + FilterOp::GreaterThanEqualTo => BinaryOperator::GtEq, + FilterOp::In => BinaryOperator::Any, + FilterOp::StartsWith => BinaryOperator::StartsWith, + FilterOp::Like => BinaryOperator::Like, + FilterOp::ILike => BinaryOperator::ILike, + FilterOp::RegEx => BinaryOperator::RegEx, + FilterOp::IRegEx => BinaryOperator::IRegEx, + FilterOp::Contains => BinaryOperator::Contains, + FilterOp::ContainedBy => BinaryOperator::ContainedBy, + FilterOp::Overlap => BinaryOperator::Overlap, + FilterOp::Is => unreachable!("Is handled separately"), + } +} +``` + +--- + +## Phase 4: Testing Strategy + +### 4.1 Unit Tests for AST Module + +Create `src/ast/tests.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_render_simple_select() { + let stmt = Stmt::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::Expr { + expr: Expr::Column(ColumnRef { + table_alias: Some(Ident("t".to_string())), + column: Ident("id".to_string()), + }), + alias: None, + }], + from: Some(FromClause::Table { + schema: Some(Ident("public".to_string())), + name: Ident("users".to_string()), + alias: Some(Ident("t".to_string())), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }); + + let sql = render(&stmt); + assert!(sql.contains("SELECT")); + assert!(sql.contains("\"t\".\"id\"")); + assert!(sql.contains("FROM \"public\".\"users\"")); + } + + #[test] + fn test_render_insert_with_returning() { + // Test INSERT statement rendering + } + + #[test] + fn test_param_collector() { + let mut collector = ParamCollector::new(); + let expr1 = collector.add(ParamValue::String("hello".into()), SqlType::text()); + let expr2 = collector.add(ParamValue::Integer(42), SqlType::integer()); + + assert_eq!(collector.params().len(), 2); + match &expr1 { + Expr::Param(p) => assert_eq!(p.index, 1), + _ => panic!("Expected Param"), + } + } + + #[test] + fn test_jsonb_build_object_chunking() { + // Test that large objects are properly chunked with || + } +} +``` + +### 4.2 Integration with Existing pg_regress Tests + +All existing tests in `test/sql/` remain unchanged. The refactoring should produce **identical SQL output** (modulo whitespace) for all existing queries. + +We'll add a test helper that compares old and new transpiler output: + +```rust +// In development only: compare outputs +#[cfg(feature = "transpiler_validation")] +fn validate_transpiler_output( + old_sql: &str, + new_sql: &str, +) { + // Normalize whitespace and compare + let normalize = |s: &str| s.split_whitespace().collect::>().join(" "); + assert_eq!(normalize(old_sql), normalize(new_sql)); +} +``` + +### 4.3 New AST-Specific Tests + +Create new test files: + +- `test/sql/ast_basic.sql` - Basic AST rendering tests +- `test/sql/ast_cte.sql` - CTE generation tests +- `test/sql/ast_params.sql` - Parameter handling tests +- `test/sql/telemetry.sql` - Logging/telemetry tests + +--- + +## Phase 5: Implementation Order + +### Step 1: Create AST Module Foundation (No Breaking Changes) +1. Create `src/ast/mod.rs` with module structure +2. Implement `expr.rs` with all expression types +3. Implement `stmt.rs` with statement types +4. Implement `cte.rs` for CTE support +5. Implement `types.rs` for SQL types +6. Implement `params.rs` for parameter collection +7. Implement `render.rs` for SQL generation +8. Add comprehensive unit tests + +**Verification**: Run `cargo test` - all AST unit tests pass + +### Step 2: Create Executor Module (No Breaking Changes) +1. Create `src/executor/mod.rs` +2. Implement `plan.rs` for execution plans +3. Implement `telemetry.rs` for logging +4. Implement `pgrx_backend.rs` for pgrx integration +5. Add unit tests + +**Verification**: Run `cargo test` - all executor unit tests pass + +### Step 3: Add Parallel Transpilation Path +1. Add `ToAst` trait implementations alongside existing `to_sql()` methods +2. Start with `InsertBuilder` as it's well-understood +3. Add feature flag to switch between old and new paths +4. Compare outputs in development mode + +**Verification**: +- `cargo pgrx install --features pg18` +- `./bin/installcheck mutation_insert` + +### Step 4: Incrementally Migrate Builders +1. Migrate `FilterBuilder` and `FilterBuilderElem` +2. Migrate `UpdateBuilder` +3. Migrate `DeleteBuilder` +4. Migrate `NodeBuilder` +5. Migrate `ConnectionBuilder` (most complex) +6. Migrate `FunctionCallBuilder` + +For each builder: +- Implement `ToAst` trait +- Run corresponding tests +- Compare output with old implementation + +**Verification after each**: Run relevant `./bin/installcheck` tests + +### Step 5: Remove Old String-Based Code +1. Remove feature flag +2. Delete old `to_sql()` methods +3. Update `resolve.rs` to use new executor +4. Clean up any remaining string-based SQL generation + +**Verification**: Run full test suite: `./bin/installcheck` + +### Step 6: Add Multi-Statement Support (Future Extensibility) +1. Extend `ExecutionPlan` for dependent steps +2. Implement step ordering in executor +3. Add transaction handling for multi-statement plans +4. Document API for nested inserts + +--- + +## Design Decisions + +### Why Text-Based Parameters? +The current approach converts all parameters to text and lets PostgreSQL cast them. We maintain this pattern because: +1. It's proven to work reliably +2. PostgreSQL's type coercion is well-tested +3. It avoids complex Datum construction +4. The generated SQL is readable and debuggable + +### Why CTE-Based Mutations? +CTEs provide atomic execution and allow us to: +1. Check affected row counts before committing +2. Return data from the modified rows +3. Compose complex operations +4. Support future features like nested inserts + +### Why Separate Render Phase? +Having a dedicated render phase allows: +1. AST inspection/validation before execution +2. Different rendering styles (compact vs pretty) +3. SQL logging and debugging +4. Future optimizations at the AST level + +### Identifier Quoting Strategy +All identifiers are quoted by default using PostgreSQL's rules: +- Double quotes around the identifier +- Internal double quotes escaped by doubling +- This prevents SQL injection and handles special characters + +--- + +## PostgreSQL Version Compatibility + +The AST module must work with PostgreSQL 14-18. Key considerations: + +1. **Standard SQL constructs only**: The AST uses standard PostgreSQL features available in all supported versions +2. **No version-specific features**: Avoid using features only in newer PostgreSQL versions +3. **Test matrix**: CI runs against all supported versions + +--- + +## Risk Mitigation + +### Risk: Breaking Existing Functionality +**Mitigation**: +- Parallel implementation with feature flag +- Output comparison in development +- Full regression test suite +- Incremental migration + +### Risk: Performance Regression +**Mitigation**: +- Benchmark key queries before/after +- AST construction adds minimal overhead +- Rendering is O(n) in SQL length +- No runtime overhead once SQL is generated + +### Risk: Complex Nested Queries +**Mitigation**: +- ConnectionBuilder tests cover complex CTEs +- Start with simpler builders +- Extensive unit tests for edge cases + +--- + +## Success Criteria + +1. **All existing tests pass**: `./bin/installcheck` returns 0 +2. **No SQL changes**: Generated SQL is semantically identical +3. **Type safety**: No string interpolation for SQL identifiers or values +4. **Modularity**: AST module has no pgrx dependencies +5. **Debuggability**: Telemetry shows generated SQL and parameters +6. **Documentation**: All new types and functions documented +7. **Future-ready**: Clear path for multi-statement execution + +--- + +## Appendix: Full Module Structure + +``` +src/ +├── ast/ +│ ├── mod.rs # 50 lines - re-exports +│ ├── expr.rs # 400 lines - expression types +│ ├── stmt.rs # 300 lines - statement types +│ ├── cte.rs # 50 lines - CTE types +│ ├── types.rs # 100 lines - SQL type system +│ ├── params.rs # 100 lines - parameter handling +│ ├── render.rs # 600 lines - SQL rendering +│ ├── validate.rs # 100 lines - validation utilities +│ └── tests.rs # 400 lines - unit tests +│ +├── executor/ +│ ├── mod.rs # 50 lines - re-exports +│ ├── plan.rs # 150 lines - execution plans +│ ├── pgrx_backend.rs # 200 lines - pgrx execution +│ └── telemetry.rs # 150 lines - logging +│ +├── transpile.rs # REFACTORED: 1500 lines - builds AST +├── builder.rs # UNCHANGED: ~800 lines +├── resolve.rs # MINOR CHANGES: uses executor +├── graphql.rs # UNCHANGED +├── lib.rs # MINOR CHANGES: module declarations +├── error.rs # MINOR ADDITIONS: new error types +├── sql_types.rs # UNCHANGED +├── constants.rs # UNCHANGED +├── gson.rs # UNCHANGED +├── omit.rs # UNCHANGED +├── parser_util.rs # UNCHANGED +└── merge.rs # UNCHANGED + +Estimated new code: ~2,500 lines +Estimated refactored code: ~1,500 lines +Total: ~4,000 lines of changes +``` diff --git a/bin/installcheck_local b/bin/installcheck_local new file mode 100755 index 00000000..4d141752 --- /dev/null +++ b/bin/installcheck_local @@ -0,0 +1,56 @@ +#! /bin/bash + +######## +# Vars # +######## +TMPDIR="$(mktemp -d)" +export PGDATA="$TMPDIR" +export PGHOST="$TMPDIR" +export PGUSER=postgres +export PGDATABASE=postgres +export PGTZ=UTC +export PG_COLOR=auto + +# Use postgresql@17 binaries +PG17=/opt/homebrew/opt/postgresql@17 +export PATH="$PG17/bin:$PATH" + +#################### +# Ensure Clean Env # +#################### +# Stop the server (if running) +trap 'pg_ctl stop -m i' sigint sigterm exit +# Remove temporary data dir +rm -rf "$TMPDIR" + +############## +# Initialize # +############## +# Initialize: setting PGUSER as the owner +initdb --no-locale --encoding=UTF8 --nosync -U "$PGUSER" +# Start the server +pg_ctl start -o "-F -c listen_addresses=\"\" -c log_min_messages=WARNING -k $PGDATA" +# Create the test db +createdb contrib_regression + +######### +# Tests # +######### +TESTDIR="test" +PGXS=$(dirname $(pg_config --pgxs)) +REGRESS="${PGXS}/../test/regress/pg_regress" + +# Test names can be passed as parameters to this script. +# If any test names are passed run only those tests. +# Otherwise run all tests. +if [ "$#" -ne 0 ]; then + TESTS=$@ +else + TESTS=$(ls ${TESTDIR}/sql | sed -e 's/\..*$//' | sort) +fi + +# Execute the test fixtures +psql -v ON_ERROR_STOP=1 -f test/fixtures.sql -d contrib_regression + +# Run tests +${REGRESS} --use-existing --dbname=contrib_regression --inputdir=${TESTDIR} ${TESTS} diff --git a/dockerfiles/db/Dockerfile b/dockerfiles/db/Dockerfile index 853b95a9..52533dbc 100644 --- a/dockerfiles/db/Dockerfile +++ b/dockerfiles/db/Dockerfile @@ -35,6 +35,6 @@ RUN cargo pgrx init --pg${PG_MAJOR} $(which pg_config) USER root COPY . . -RUN cargo pgrx install --release --features pg${PG_MAJOR} +RUN cargo pgrx install --release --features "pg${PG_MAJOR} ast_transpile" USER postgres diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 00000000..93c3c5c4 --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,13 @@ +#!/bin/bash +# Run pg_regress tests for pg_graphql +# Usage: ./run_tests.sh [test_name ...] + +set -e + +export PATH="/opt/homebrew/opt/postgresql@17/bin:$PATH" + +# Install the extension +cargo pgrx install --pg-config /opt/homebrew/opt/postgresql@17/bin/pg_config --features pg17 --no-default-features + +# Run tests +./bin/installcheck_local "$@" diff --git a/src/ast/builder_bridge.rs b/src/ast/builder_bridge.rs new file mode 100644 index 00000000..a81894dd --- /dev/null +++ b/src/ast/builder_bridge.rs @@ -0,0 +1,508 @@ +//! Bridge module connecting existing builders to the AST system +//! +//! This module provides the `ToAst` trait and implementations that convert +//! the existing builder structures into AST nodes. This allows incremental +//! migration from the string-based SQL generation to the type-safe AST. +//! +//! # Design Philosophy +//! +//! - Maintain backwards compatibility with existing code +//! - Use proper parameter binding (never string interpolation for values) +//! - Support both the old ParamContext and new ParamCollector +//! - Identifiers are always properly quoted via the AST's Ident type + +use super::{ + AggregateExpr, AggregateFunction, ColumnRef, Expr, FunctionArg, FunctionCall, JsonBuildExpr, + Literal, OrderByExpr, ParamCollector, SqlType, +}; +use crate::error::GraphQLResult; + +// Re-export json_to_param_value from params module (it's already defined there) +pub use super::params::json_to_param_value; + +/// Trait for converting builders to AST nodes +/// +/// This trait is implemented by builder types to convert them into +/// type-safe AST representations. +pub trait ToAst { + /// The AST type this builder produces + type Ast; + + /// Convert this builder to an AST node, collecting parameters + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult; +} + +/// Context for AST building that tracks block names and other metadata +#[derive(Debug, Clone)] +pub struct AstBuildContext { + /// The current block/alias name for table references + pub block_name: String, + /// Counter for generating unique block names + block_counter: usize, +} + +impl AstBuildContext { + pub fn new() -> Self { + Self { + block_name: Self::generate_block_name(), + block_counter: 0, + } + } + + /// Create a context with a specific block name + pub fn with_block_name(block_name: impl Into) -> Self { + Self { + block_name: block_name.into(), + block_counter: 0, + } + } + + /// Generate a random block name (like the existing rand_block_name) + fn generate_block_name() -> String { + use rand::distributions::Alphanumeric; + use rand::{thread_rng, Rng}; + thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect::() + .to_lowercase() + } + + /// Get a new unique block name + pub fn next_block_name(&mut self) -> String { + self.block_counter += 1; + format!("{}_{}", self.block_name, self.block_counter) + } +} + +impl Default for AstBuildContext { + fn default() -> Self { + Self::new() + } +} + +/// Convert a SQL type name string to a SqlType +/// +/// Handles array types (ending with []) and common PostgreSQL types. +pub fn type_name_to_sql_type(type_name: &str) -> SqlType { + let is_array = type_name.ends_with("[]"); + let base_name = if is_array { + &type_name[..type_name.len() - 2] + } else { + type_name + }; + + let base_type = match base_name.to_lowercase().as_str() { + "text" | "varchar" | "char" | "character varying" | "character" => SqlType::text(), + "integer" | "int" | "int4" => SqlType::integer(), + "bigint" | "int8" => SqlType::bigint(), + "smallint" | "int2" => SqlType::smallint(), + "boolean" | "bool" => SqlType::boolean(), + "real" | "float4" => SqlType::real(), + "double precision" | "float8" => SqlType::double_precision(), + "numeric" | "decimal" => SqlType::numeric(), + "uuid" => SqlType::uuid(), + "json" => SqlType::json(), + "jsonb" => SqlType::jsonb(), + "timestamp" | "timestamp without time zone" => SqlType::timestamp(), + "timestamptz" | "timestamp with time zone" => SqlType::timestamptz(), + "date" => SqlType::date(), + "time" | "time without time zone" => SqlType::time(), + "timetz" | "time with time zone" => SqlType::new("time with time zone"), + "bytea" => SqlType::bytea(), + _ => { + // For custom types, preserve the original name + if base_name.contains('.') { + let parts: Vec<&str> = base_name.splitn(2, '.').collect(); + SqlType::with_schema(parts[0], parts[1]) + } else { + SqlType::new(base_name) + } + } + }; + + if is_array { + base_type.as_array() + } else { + base_type + } +} + +/// Add a JSON value as a parameter and return the expression referencing it +/// +/// This is the key function for safe parameter binding. It: +/// 1. Converts the JSON value to a ParamValue +/// 2. Adds it to the ParamCollector with proper type info +/// 3. Returns an Expr that references the parameter with type cast +pub fn add_param_from_json( + params: &mut ParamCollector, + value: &serde_json::Value, + type_name: &str, +) -> GraphQLResult { + let param_value = json_to_param_value(value); + let sql_type = type_name_to_sql_type(type_name); + Ok(params.add(param_value, sql_type)) +} + +// ============================================================================= +// Expression Builder - Consolidated helper functions for creating Expr nodes +// ============================================================================= + +/// Builder for creating SQL expressions with a fluent API. +/// +/// This struct consolidates all expression-building helper functions into +/// a single namespace for better discoverability and organization. +/// +/// # Example +/// +/// ```rust,ignore +/// use pg_graphql::ast::ExprBuilder; +/// +/// let col = ExprBuilder::column("users", "id"); +/// let lit = ExprBuilder::string("hello"); +/// let agg = ExprBuilder::count_star(); +/// ``` +pub struct ExprBuilder; + +impl ExprBuilder { + // ------------------------------------------------------------------------- + // Column references + // ------------------------------------------------------------------------- + + /// Create a qualified column reference (table.column) + #[inline] + pub fn column(table_alias: &str, column_name: &str) -> Expr { + Expr::Column(ColumnRef::qualified(table_alias, column_name)) + } + + /// Create an unqualified column reference + #[inline] + pub fn column_unqualified(column_name: &str) -> Expr { + Expr::Column(ColumnRef::new(column_name)) + } + + // ------------------------------------------------------------------------- + // Literals + // ------------------------------------------------------------------------- + + /// Create a string literal + #[inline] + pub fn string(s: &str) -> Expr { + Expr::Literal(Literal::String(s.to_string())) + } + + /// Create an integer literal + #[inline] + pub fn int(i: i64) -> Expr { + Expr::Literal(Literal::Integer(i)) + } + + /// Create a boolean literal + #[inline] + pub fn bool(b: bool) -> Expr { + Expr::Literal(Literal::Bool(b)) + } + + /// Create a NULL literal + #[inline] + pub fn null() -> Expr { + Expr::Literal(Literal::Null) + } + + /// Create a DEFAULT expression (for INSERT) + #[inline] + pub fn default() -> Expr { + Expr::Literal(Literal::Default) + } + + // ------------------------------------------------------------------------- + // Function calls + // ------------------------------------------------------------------------- + + /// Create a simple function call + #[inline] + pub fn func(name: &str, args: Vec) -> Expr { + Expr::FunctionCall(FunctionCall::new( + name, + args.into_iter().map(FunctionArg::unnamed).collect(), + )) + } + + /// Create a schema-qualified function call + #[inline] + pub fn func_with_schema(schema: &str, name: &str, args: Vec) -> Expr { + Expr::FunctionCall(FunctionCall::with_schema( + schema, + name, + args.into_iter().map(FunctionArg::unnamed).collect(), + )) + } + + /// Create a COALESCE expression + #[inline] + pub fn coalesce(args: Vec) -> Expr { + Expr::Coalesce(args) + } + + // ------------------------------------------------------------------------- + // Aggregates + // ------------------------------------------------------------------------- + + /// Create a COUNT(*) expression + #[inline] + pub fn count_star() -> Expr { + Expr::Aggregate(AggregateExpr::count_star()) + } + + /// Create a jsonb_agg expression + #[inline] + pub fn jsonb_agg(expr: Expr) -> Expr { + Expr::Aggregate(AggregateExpr::new(AggregateFunction::JsonbAgg, vec![expr])) + } + + /// Create a jsonb_agg expression with FILTER clause + pub fn jsonb_agg_filtered(expr: Expr, filter: Expr) -> Expr { + let mut agg = AggregateExpr::new(AggregateFunction::JsonbAgg, vec![expr]); + agg.filter = Some(Box::new(filter)); + Expr::Aggregate(agg) + } + + /// Create a jsonb_agg expression with ORDER BY and optional FILTER + pub fn jsonb_agg_ordered(expr: Expr, order_by: Vec, filter: Option) -> Expr { + let mut agg = AggregateExpr::new(AggregateFunction::JsonbAgg, vec![expr]); + if !order_by.is_empty() { + agg.order_by = Some(order_by); + } + if let Some(f) = filter { + agg.filter = Some(Box::new(f)); + } + Expr::Aggregate(agg) + } + + // ------------------------------------------------------------------------- + // JSON/JSONB helpers + // ------------------------------------------------------------------------- + + /// Create a jsonb_build_object expression + pub fn jsonb_object(pairs: Vec<(String, Expr)>) -> Expr { + Expr::JsonBuild(JsonBuildExpr::Object( + pairs + .into_iter() + .map(|(k, v)| (Expr::Literal(Literal::String(k)), v)) + .collect(), + )) + } + + /// Create an empty jsonb array: jsonb_build_array() + #[inline] + pub fn empty_jsonb_array() -> Expr { + Self::func("jsonb_build_array", vec![]) + } + + /// Create an empty jsonb object: '{}'::jsonb + #[inline] + pub fn empty_jsonb_object() -> Expr { + Expr::Cast { + expr: Box::new(Expr::Literal(Literal::String("{}".to_string()))), + target_type: type_name_to_sql_type("jsonb"), + } + } +} + +// ============================================================================= +// Standalone helper functions (for backwards compatibility) +// ============================================================================= + +/// Helper to create a column reference expression +#[inline] +pub fn column_ref(table_alias: &str, column_name: &str) -> Expr { + ExprBuilder::column(table_alias, column_name) +} + +/// Helper to create an unqualified column reference +#[inline] +pub fn column_ref_unqualified(column_name: &str) -> Expr { + ExprBuilder::column_unqualified(column_name) +} + +/// Helper to create a simple function call expression +#[inline] +pub fn func_call(name: &str, args: Vec) -> Expr { + ExprBuilder::func(name, args) +} + +/// Helper to create a schema-qualified function call +#[inline] +pub fn func_call_schema(schema: &str, name: &str, args: Vec) -> Expr { + ExprBuilder::func_with_schema(schema, name, args) +} + +/// Helper to build jsonb_build_object calls +#[inline] +pub fn jsonb_build_object(pairs: Vec<(String, Expr)>) -> Expr { + ExprBuilder::jsonb_object(pairs) +} + +/// Helper to build jsonb_agg calls +#[inline] +pub fn jsonb_agg(expr: Expr) -> Expr { + ExprBuilder::jsonb_agg(expr) +} + +/// Helper to build jsonb_agg calls with a FILTER clause +#[inline] +pub fn jsonb_agg_with_filter(expr: Expr, filter: Expr) -> Expr { + ExprBuilder::jsonb_agg_filtered(expr, filter) +} + +/// Helper to build jsonb_agg calls with ORDER BY and FILTER clauses +#[inline] +pub fn jsonb_agg_with_order_and_filter( + expr: Expr, + order_by: Vec, + filter: Option, +) -> Expr { + ExprBuilder::jsonb_agg_ordered(expr, order_by, filter) +} + +/// Helper to build coalesce calls +#[inline] +pub fn coalesce(args: Vec) -> Expr { + ExprBuilder::coalesce(args) +} + +/// Helper to build count(*) expression +#[inline] +pub fn count_star() -> Expr { + ExprBuilder::count_star() +} + +/// Helper to create a string literal expression +#[inline] +pub fn string_literal(s: &str) -> Expr { + ExprBuilder::string(s) +} + +/// Helper to create an empty jsonb array expression +#[inline] +pub fn empty_jsonb_array() -> Expr { + ExprBuilder::empty_jsonb_array() +} + +/// Helper to create an empty jsonb object expression +#[inline] +pub fn empty_jsonb_object() -> Expr { + ExprBuilder::empty_jsonb_object() +} + +/// Helper to create an integer literal +#[inline] +pub fn int_literal(i: i64) -> Expr { + ExprBuilder::int(i) +} + +/// Helper to create a boolean literal +#[inline] +pub fn bool_literal(b: bool) -> Expr { + ExprBuilder::bool(b) +} + +/// Helper to create a NULL literal +#[inline] +pub fn null_literal() -> Expr { + ExprBuilder::null() +} + +/// Helper to create DEFAULT expression (for INSERT) +#[inline] +pub fn default_expr() -> Expr { + ExprBuilder::default() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::ParamValue; + + #[test] + fn test_type_name_to_sql_type() { + let t = type_name_to_sql_type("integer"); + assert_eq!(t.name, "integer"); + assert!(!t.is_array); + + let t = type_name_to_sql_type("text[]"); + assert_eq!(t.name, "text"); + assert!(t.is_array); + + let t = type_name_to_sql_type("public.my_type"); + assert_eq!(t.name, "my_type"); + assert_eq!(t.schema, Some("public".to_string())); + } + + #[test] + fn test_add_param_from_json() { + use serde_json::json; + + let mut params = ParamCollector::new(); + + let expr = add_param_from_json(&mut params, &json!(42), "integer").unwrap(); + + // Should be a parameter reference + match expr { + Expr::Param(p) => { + assert_eq!(p.index, 1); + assert_eq!(p.type_cast.name, "integer"); + } + _ => panic!("Expected Param expression"), + } + + // Check the collected parameter + let collected = params.into_params(); + assert_eq!(collected.len(), 1); + assert_eq!(collected[0].index, 1); + match &collected[0].value { + ParamValue::Integer(i) => assert_eq!(*i, 42), + _ => panic!("Expected integer"), + } + } + + #[test] + fn test_ast_build_context() { + let ctx = AstBuildContext::new(); + assert!(!ctx.block_name.is_empty()); + assert_eq!(ctx.block_name.len(), 7); + } + + #[test] + fn test_helper_functions() { + // Test column_ref + let col = column_ref("t", "id"); + match col { + Expr::Column(c) => { + assert_eq!(c.table_alias.unwrap().0, "t"); + assert_eq!(c.column.0, "id"); + } + _ => panic!("Expected column"), + } + + // Test func_call + let f = func_call("count", vec![Expr::Raw("*".to_string())]); + match f { + Expr::FunctionCall(fc) => { + assert_eq!(fc.name.0, "count"); + assert_eq!(fc.args.len(), 1); + } + _ => panic!("Expected function call"), + } + + // Test jsonb_build_object + let obj = jsonb_build_object(vec![("key".to_string(), string_literal("value"))]); + match obj { + Expr::JsonBuild(JsonBuildExpr::Object(pairs)) => { + assert_eq!(pairs.len(), 1); // 1 key-value pair + } + _ => panic!("Expected JsonBuild Object"), + } + } +} diff --git a/src/ast/cte.rs b/src/ast/cte.rs new file mode 100644 index 00000000..d30099e9 --- /dev/null +++ b/src/ast/cte.rs @@ -0,0 +1,150 @@ +//! Common Table Expression (CTE) support +//! +//! CTEs are used extensively in pg_graphql for atomic mutations and +//! complex query composition. + +use super::expr::Ident; +use super::stmt::{DeleteStmt, InsertStmt, SelectStmt, UpdateStmt}; + +/// A Common Table Expression (CTE) in a WITH clause +#[derive(Debug, Clone, PartialEq)] +pub struct Cte { + /// Name of the CTE + pub name: Ident, + /// Optional column list: WITH cte(col1, col2) AS (...) + pub columns: Option>, + /// The query that defines the CTE + pub query: CteQuery, + /// MATERIALIZED / NOT MATERIALIZED hint + pub materialized: Option, +} + +impl Cte { + /// Create a new CTE with a SELECT query + pub fn select(name: impl Into, query: SelectStmt) -> Self { + Self { + name: name.into(), + columns: None, + query: CteQuery::Select(query), + materialized: None, + } + } + + /// Create a new CTE with an INSERT query + pub fn insert(name: impl Into, query: InsertStmt) -> Self { + Self { + name: name.into(), + columns: None, + query: CteQuery::Insert(query), + materialized: None, + } + } + + /// Create a new CTE with an UPDATE query + pub fn update(name: impl Into, query: UpdateStmt) -> Self { + Self { + name: name.into(), + columns: None, + query: CteQuery::Update(query), + materialized: None, + } + } + + /// Create a new CTE with a DELETE query + pub fn delete(name: impl Into, query: DeleteStmt) -> Self { + Self { + name: name.into(), + columns: None, + query: CteQuery::Delete(query), + materialized: None, + } + } + + /// Add column aliases to the CTE + pub fn with_columns(mut self, columns: Vec>) -> Self { + self.columns = Some(columns.into_iter().map(|c| c.into()).collect()); + self + } + + /// Mark the CTE as MATERIALIZED + pub fn materialized(mut self) -> Self { + self.materialized = Some(true); + self + } + + /// Mark the CTE as NOT MATERIALIZED + pub fn not_materialized(mut self) -> Self { + self.materialized = Some(false); + self + } +} + +/// The query that defines a CTE +/// +/// CTEs can contain any DML statement (SELECT, INSERT, UPDATE, DELETE). +/// This is particularly useful for data-modifying CTEs. +#[derive(Debug, Clone, PartialEq)] +pub enum CteQuery { + Select(SelectStmt), + Insert(InsertStmt), + Update(UpdateStmt), + Delete(DeleteStmt), +} + +impl From for CteQuery { + fn from(stmt: SelectStmt) -> Self { + Self::Select(stmt) + } +} + +impl From for CteQuery { + fn from(stmt: InsertStmt) -> Self { + Self::Insert(stmt) + } +} + +impl From for CteQuery { + fn from(stmt: UpdateStmt) -> Self { + Self::Update(stmt) + } +} + +impl From for CteQuery { + fn from(stmt: DeleteStmt) -> Self { + Self::Delete(stmt) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::{InsertValues, SelectColumn}; + + #[test] + fn test_cte_select() { + let cte = Cte::select( + "active_users", + SelectStmt::columns(vec![SelectColumn::star()]), + ); + + assert_eq!(cte.name.0, "active_users"); + assert!(matches!(cte.query, CteQuery::Select(_))); + } + + #[test] + fn test_cte_insert() { + let insert = InsertStmt::new("users", vec![], InsertValues::DefaultValues); + let cte = Cte::insert("new_user", insert).with_columns(vec!["id", "name"]); + + assert_eq!(cte.columns.unwrap().len(), 2); + } + + #[test] + fn test_cte_materialized() { + let cte = Cte::select("temp", SelectStmt::new()).materialized(); + assert_eq!(cte.materialized, Some(true)); + + let cte = Cte::select("temp", SelectStmt::new()).not_materialized(); + assert_eq!(cte.materialized, Some(false)); + } +} diff --git a/src/ast/execute.rs b/src/ast/execute.rs new file mode 100644 index 00000000..4525e51c --- /dev/null +++ b/src/ast/execute.rs @@ -0,0 +1,282 @@ +//! AST-based query execution +//! +//! This module provides execution methods for builders using the new AST system. +//! It bridges the gap between the existing builder infrastructure and the new +//! type-safe SQL generation and execution. +//! +//! # Parameter Binding +//! +//! All parameters are converted to text representation and passed with TEXTOID. +//! PostgreSQL handles type conversion via SQL-side cast expressions like `($1::integer)`. +//! This matches the original transpiler behavior exactly. + +use crate::ast::{render, Param, ParamCollector, ParamValue}; +use crate::error::{GraphQLError, GraphQLResult}; +use pgrx::datum::DatumWithOid; +use pgrx::pg_sys::PgBuiltInOids; +use pgrx::spi::{self, Spi, SpiClient}; +use pgrx::{IntoDatum, JsonB, PgOid}; + +/// Convert all collected parameters to pgrx datums (as text) +/// +/// All parameters are converted to text representation and passed with TEXTOID +/// (or TEXTARRAYOID for arrays). PostgreSQL handles type conversion via the +/// SQL-side cast expressions like `($1::integer)`. +fn params_to_datums(params: &ParamCollector) -> Vec> { + params.params().iter().map(param_to_datum).collect() +} + +/// Convert a parameter to a pgrx DatumWithOid (as text) +fn param_to_datum(param: &Param) -> DatumWithOid<'static> { + let type_oid = if param.sql_type.is_array { + PgOid::BuiltIn(PgBuiltInOids::TEXTARRAYOID) + } else { + PgOid::BuiltIn(PgBuiltInOids::TEXTOID) + }; + + match ¶m.value { + ParamValue::Null => DatumWithOid::null_oid(type_oid.value()), + + ParamValue::String(s) => { + let datum = s.clone().into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + + ParamValue::Integer(i) => { + let datum = i.to_string().into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + + ParamValue::Float(f) => { + let datum = f.to_string().into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + + ParamValue::Bool(b) => { + let datum = b.to_string().into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + + ParamValue::Json(j) => { + let datum = j.to_string().into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + + ParamValue::Array(arr) => { + // Convert array to PostgreSQL text array format + let elements: Vec> = arr.iter().map(param_value_to_string).collect(); + let datum = elements.into_datum(); + match datum { + Some(d) => unsafe { DatumWithOid::new(d, type_oid.value()) }, + None => DatumWithOid::null_oid(type_oid.value()), + } + } + } +} + +/// Convert a ParamValue to Option for array element conversion +fn param_value_to_string(value: &ParamValue) -> Option { + match value { + ParamValue::Null => None, + ParamValue::String(s) => Some(s.clone()), + ParamValue::Integer(i) => Some(i.to_string()), + ParamValue::Float(f) => Some(f.to_string()), + ParamValue::Bool(b) => Some(b.to_string()), + ParamValue::Json(j) => Some(j.to_string()), + ParamValue::Array(_) => value.to_sql_literal(), + } +} + +/// Execute a query builder using the AST path +/// +/// This trait can be implemented alongside existing execute methods +/// to allow gradual migration to the new AST system. +pub trait AstExecutable { + type Ast; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult; + + fn execute_via_ast(&self) -> GraphQLResult + where + Self::Ast: AsStatement, + { + let mut params = ParamCollector::new(); + let ast = self.to_ast(&mut params)?; + + // Render the AST to SQL + let sql = ast.render_sql(); + + // Convert parameters to pgrx format (all as text) + let pgrx_params = params_to_datums(¶ms); + + // Execute via SPI + let spi_result: Result, spi::Error> = Spi::connect(|c| { + let val = c.select(&sql, Some(1), &pgrx_params)?; + if val.is_empty() { + Ok(None) + } else { + val.first().get::(1) + } + }); + + match spi_result { + Ok(Some(jsonb)) => Ok(jsonb.0), + Ok(None) => Ok(serde_json::Value::Null), + Err(e) => Err(GraphQLError::internal(format!("{}", e))), + } + } + + fn execute_mutation_via_ast<'conn, 'c>( + &self, + conn: &'c mut SpiClient<'conn>, + ) -> GraphQLResult<(serde_json::Value, &'c mut SpiClient<'conn>)> + where + Self::Ast: AsStatement, + { + let mut params = ParamCollector::new(); + let ast = self.to_ast(&mut params)?; + + // Render the AST to SQL + let sql = ast.render_sql(); + + // Convert parameters to pgrx format (all as text) + let pgrx_params = params_to_datums(¶ms); + + // Execute via SPI update (for mutations) + let res_q = conn.update(&sql, None, &pgrx_params).map_err(|_| { + GraphQLError::sql_execution("Internal Error: Failed to execute AST-generated mutation") + })?; + + let res: JsonB = match res_q.first().get::(1) { + Ok(Some(dat)) => dat, + Ok(None) => JsonB(serde_json::Value::Null), + Err(e) => { + return Err(GraphQLError::sql_generation(format!( + "Internal Error: Failed to load result from AST query: {e}" + ))); + } + }; + + Ok((res.0, conn)) + } +} + +/// Trait for types that can be rendered as SQL statements +pub trait AsStatement { + fn render_sql(&self) -> String; +} + +// Implement AsStatement for our AST result types +impl AsStatement for crate::ast::InsertAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +impl AsStatement for crate::ast::UpdateAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +impl AsStatement for crate::ast::DeleteAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +impl AsStatement for crate::ast::NodeAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +impl AsStatement for crate::ast::ConnectionAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +impl AsStatement for crate::ast::FunctionCallAst { + fn render_sql(&self) -> String { + render(&self.stmt) + } +} + +// Implement AstExecutable for builders +// These connect the ToAst implementations from transpile_*.rs to the execution trait + +use crate::ast::ToAst; +use crate::builder::{ + ConnectionBuilder, DeleteBuilder, FunctionCallBuilder, InsertBuilder, NodeBuilder, + UpdateBuilder, +}; + +impl AstExecutable for InsertBuilder { + type Ast = crate::ast::InsertAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +impl AstExecutable for UpdateBuilder { + type Ast = crate::ast::UpdateAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +impl AstExecutable for DeleteBuilder { + type Ast = crate::ast::DeleteAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +impl AstExecutable for NodeBuilder { + type Ast = crate::ast::NodeAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +impl AstExecutable for ConnectionBuilder { + type Ast = crate::ast::ConnectionAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +impl AstExecutable for FunctionCallBuilder { + type Ast = crate::ast::FunctionCallAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + ::to_ast(self, params) + } +} + +#[cfg(test)] +mod tests { + // Tests that don't require pgrx runtime can go here + // Full integration tests require pg_regress +} diff --git a/src/ast/expr.rs b/src/ast/expr.rs new file mode 100644 index 00000000..e9db02a6 --- /dev/null +++ b/src/ast/expr.rs @@ -0,0 +1,747 @@ +//! SQL expression types +//! +//! This module defines all SQL expression types that can appear in queries. +//! Expressions are the building blocks of SQL: columns, literals, operators, +//! function calls, etc. + +use super::types::SqlType; + +/// A quoted SQL identifier (table name, column name, etc.) +/// +/// Identifiers are always quoted when rendered to prevent SQL injection +/// and handle special characters correctly. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Ident(pub String); + +impl Ident { + /// Create a new identifier from any string-like type + #[inline] + pub fn new(s: impl Into) -> Self { + Self(s.into()) + } + + /// Get the identifier as a string slice + #[inline] + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From for Ident { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for Ident { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl std::fmt::Display for Ident { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Reference to a column, optionally qualified with a table alias +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnRef { + /// Table alias (e.g., "t" in "t.id") + pub table_alias: Option, + /// Column name + pub column: Ident, +} + +impl ColumnRef { + pub fn new(column: impl Into) -> Self { + Self { + table_alias: None, + column: column.into(), + } + } + + pub fn qualified(table: impl Into, column: impl Into) -> Self { + Self { + table_alias: Some(table.into()), + column: column.into(), + } + } +} + +/// Reference to a query parameter ($1, $2, etc.) with type cast +#[derive(Debug, Clone, PartialEq)] +pub struct ParamRef { + /// 1-indexed parameter number + pub index: usize, + /// Type to cast the parameter to + pub type_cast: SqlType, +} + +/// SQL literal values +#[derive(Debug, Clone, PartialEq)] +pub enum Literal { + /// SQL NULL + Null, + /// Boolean true/false + Bool(bool), + /// Integer literal + Integer(i64), + /// Floating point literal + Float(f64), + /// String literal (will be properly quoted) + String(String), + /// SQL DEFAULT keyword (for INSERT statements) + Default, +} + +impl Literal { + pub fn string(s: impl Into) -> Self { + Self::String(s.into()) + } +} + +/// Binary operators +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOperator { + // Comparison + Eq, + NotEq, + Lt, + LtEq, + Gt, + GtEq, + + // Array operators + Contains, // @> + ContainedBy, // <@ + Overlap, // && + Any, // = ANY(...) + + // String operators + Like, + ILike, + RegEx, // ~ + IRegEx, // ~* + StartsWith, // ^@ + + // Logical + And, + Or, + + // Arithmetic + Add, + Sub, + Mul, + Div, + Mod, + + // JSON operators + JsonExtract, // -> + JsonExtractText, // ->> + JsonPath, // #> + JsonPathText, // #>> + + // JSONB concatenation + JsonConcat, // || +} + +impl BinaryOperator { + /// Get the SQL representation of this operator + pub fn as_sql(&self) -> &'static str { + match self { + Self::Eq => "=", + Self::NotEq => "<>", + Self::Lt => "<", + Self::LtEq => "<=", + Self::Gt => ">", + Self::GtEq => ">=", + Self::Contains => "@>", + Self::ContainedBy => "<@", + Self::Overlap => "&&", + Self::Any => "= any", + Self::Like => "like", + Self::ILike => "ilike", + Self::RegEx => "~", + Self::IRegEx => "~*", + Self::StartsWith => "^@", + Self::And => "and", + Self::Or => "or", + Self::Add => "+", + Self::Sub => "-", + Self::Mul => "*", + Self::Div => "/", + Self::Mod => "%", + Self::JsonExtract => "->", + Self::JsonExtractText => "->>", + Self::JsonPath => "#>", + Self::JsonPathText => "#>>", + Self::JsonConcat => "||", + } + } +} + +/// Unary operators +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOperator { + Not, + Neg, + BitNot, +} + +impl UnaryOperator { + pub fn as_sql(&self) -> &'static str { + match self { + Self::Not => "not", + Self::Neg => "-", + Self::BitNot => "~", + } + } +} + +/// Function argument (can be named or positional) +#[derive(Debug, Clone, PartialEq)] +pub enum FunctionArg { + /// Positional argument + Unnamed(Expr), + /// Named argument (name => value) + Named { name: Ident, value: Expr }, +} + +impl FunctionArg { + pub fn unnamed(expr: Expr) -> Self { + Self::Unnamed(expr) + } + + pub fn named(name: impl Into, value: Expr) -> Self { + Self::Named { + name: name.into(), + value, + } + } +} + +/// A function call expression +#[derive(Debug, Clone, PartialEq)] +pub struct FunctionCall { + /// Schema (e.g., "pg_catalog", "public") + pub schema: Option, + /// Function name + pub name: Ident, + /// Arguments + pub args: Vec, + /// FILTER clause (for aggregate functions) + pub filter: Option>, + /// ORDER BY within the function (for ordered-set aggregates) + pub order_by: Option>, +} + +impl FunctionCall { + pub fn new(name: impl Into, args: Vec) -> Self { + Self { + schema: None, + name: name.into(), + args, + filter: None, + order_by: None, + } + } + + pub fn with_schema( + schema: impl Into, + name: impl Into, + args: Vec, + ) -> Self { + Self { + schema: Some(schema.into()), + name: name.into(), + args, + filter: None, + order_by: None, + } + } + + /// Add a FILTER clause + pub fn with_filter(mut self, filter: Expr) -> Self { + self.filter = Some(Box::new(filter)); + self + } + + /// Add ORDER BY within the function + pub fn with_order_by(mut self, order_by: Vec) -> Self { + self.order_by = Some(order_by); + self + } +} + +/// Aggregate functions +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AggregateFunction { + Count, + Sum, + Avg, + Min, + Max, + JsonAgg, + JsonbAgg, + ArrayAgg, + BoolAnd, + BoolOr, + StringAgg, +} + +impl AggregateFunction { + pub fn as_sql(&self) -> &'static str { + match self { + Self::Count => "count", + Self::Sum => "sum", + Self::Avg => "avg", + Self::Min => "min", + Self::Max => "max", + Self::JsonAgg => "json_agg", + Self::JsonbAgg => "jsonb_agg", + Self::ArrayAgg => "array_agg", + Self::BoolAnd => "bool_and", + Self::BoolOr => "bool_or", + Self::StringAgg => "string_agg", + } + } +} + +/// An aggregate expression with optional FILTER and ORDER BY +#[derive(Debug, Clone, PartialEq)] +pub struct AggregateExpr { + pub function: AggregateFunction, + pub args: Vec, + pub distinct: bool, + pub filter: Option>, + pub order_by: Option>, +} + +impl AggregateExpr { + pub fn new(function: AggregateFunction, args: Vec) -> Self { + Self { + function, + args, + distinct: false, + filter: None, + order_by: None, + } + } + + pub fn count_star() -> Self { + Self::new( + AggregateFunction::Count, + vec![Expr::Literal(Literal::String("*".to_string()))], + ) + } + + pub fn count_all() -> Self { + Self::new(AggregateFunction::Count, vec![]) + } + + pub fn with_distinct(mut self) -> Self { + self.distinct = true; + self + } + + pub fn with_filter(mut self, filter: Expr) -> Self { + self.filter = Some(Box::new(filter)); + self + } + + pub fn with_order_by(mut self, order_by: Vec) -> Self { + self.order_by = Some(order_by); + self + } +} + +/// CASE expression +#[derive(Debug, Clone, PartialEq)] +pub struct CaseExpr { + /// CASE (simple case) vs CASE WHEN (searched case) + pub operand: Option>, + /// WHEN ... THEN ... pairs + pub when_clauses: Vec<(Expr, Expr)>, + /// ELSE clause + pub else_clause: Option>, +} + +impl CaseExpr { + /// Create a searched CASE expression (CASE WHEN ... THEN ...) + pub fn searched(when_clauses: Vec<(Expr, Expr)>, else_clause: Option) -> Self { + Self { + operand: None, + when_clauses, + else_clause: else_clause.map(Box::new), + } + } + + /// Create a simple CASE expression (CASE x WHEN ... THEN ...) + pub fn simple( + operand: Expr, + when_clauses: Vec<(Expr, Expr)>, + else_clause: Option, + ) -> Self { + Self { + operand: Some(Box::new(operand)), + when_clauses, + else_clause: else_clause.map(Box::new), + } + } +} + +/// JSON/JSONB building expressions +#[derive(Debug, Clone, PartialEq)] +pub enum JsonBuildExpr { + /// jsonb_build_object(k1, v1, k2, v2, ...) + Object(Vec<(Expr, Expr)>), + /// jsonb_build_array(v1, v2, ...) + Array(Vec), +} + +/// ORDER BY expression component +#[derive(Debug, Clone, PartialEq)] +pub struct OrderByExpr { + pub expr: Expr, + pub direction: Option, + pub nulls: Option, +} + +impl OrderByExpr { + pub fn new(expr: Expr) -> Self { + Self { + expr, + direction: None, + nulls: None, + } + } + + pub fn asc(expr: Expr) -> Self { + Self { + expr, + direction: Some(OrderDirection::Asc), + nulls: None, + } + } + + pub fn desc(expr: Expr) -> Self { + Self { + expr, + direction: Some(OrderDirection::Desc), + nulls: None, + } + } + + pub fn with_nulls(mut self, nulls: NullsOrder) -> Self { + self.nulls = Some(nulls); + self + } +} + +/// ORDER BY direction +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OrderDirection { + Asc, + Desc, +} + +impl OrderDirection { + pub fn as_sql(&self) -> &'static str { + match self { + Self::Asc => "asc", + Self::Desc => "desc", + } + } +} + +/// NULLS FIRST/LAST in ORDER BY +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NullsOrder { + First, + Last, +} + +impl NullsOrder { + pub fn as_sql(&self) -> &'static str { + match self { + Self::First => "nulls first", + Self::Last => "nulls last", + } + } +} + +/// The main expression enum encompassing all SQL expression types +#[derive(Debug, Clone, PartialEq)] +pub enum Expr { + /// Column reference: table.column or just column + Column(ColumnRef), + + /// Literal value + Literal(Literal), + + /// Parameterized value: $1, $2, etc. + Param(ParamRef), + + /// Binary operation: expr op expr + BinaryOp { + left: Box, + op: BinaryOperator, + right: Box, + }, + + /// Unary operation: op expr (e.g., NOT) + UnaryOp { op: UnaryOperator, expr: Box }, + + /// Function call + FunctionCall(FunctionCall), + + /// Aggregate function + Aggregate(AggregateExpr), + + /// CASE expression + Case(CaseExpr), + + /// Subquery: (SELECT ...) + Subquery(Box), + + /// Array literal: ARRAY[...] + Array(Vec), + + /// Type cast: expr::type + Cast { + expr: Box, + target_type: SqlType, + }, + + /// IS NULL / IS NOT NULL + IsNull { expr: Box, negated: bool }, + + /// expr IN (values) - for a list + InList { + expr: Box, + list: Vec, + negated: bool, + }, + + /// expr BETWEEN low AND high + Between { + expr: Box, + low: Box, + high: Box, + negated: bool, + }, + + /// EXISTS (subquery) + Exists { + subquery: Box, + negated: bool, + }, + + /// JSON/JSONB building + JsonBuild(JsonBuildExpr), + + /// Coalesce function: COALESCE(expr1, expr2, ...) + Coalesce(Vec), + + /// Parenthesized expression (for explicit grouping) + Nested(Box), + + /// Raw SQL string - **DEPRECATED**: Use only in tests. + /// + /// This is a security-sensitive escape hatch that bypasses SQL injection protection. + /// All production code should use type-safe AST nodes instead. If you need SQL + /// functionality not yet supported by the AST, add a proper node type. + /// + /// # Security Warning + /// + /// Never use this with user-provided input. The string is rendered directly + /// to SQL without any escaping or validation. + #[cfg(test)] + Raw(String), + + /// Array index access: array[index] + ArrayIndex { array: Box, index: Box }, + + /// Function call with ORDER BY clause (e.g., array_agg(x ORDER BY y)) + FunctionCallWithOrderBy { + name: String, + args: Vec, + order_by: Vec, + }, + + /// ROW constructor: ROW(expr1, expr2, ...) or (expr1, expr2, ...) + Row(Vec), +} + +impl Expr { + // Convenience constructors + + /// Create a column reference + pub fn column(name: impl Into) -> Self { + Self::Column(ColumnRef::new(name)) + } + + /// Create a qualified column reference (table.column) + pub fn qualified_column(table: impl Into, column: impl Into) -> Self { + Self::Column(ColumnRef::qualified(table, column)) + } + + /// Create a NULL literal + pub fn null() -> Self { + Self::Literal(Literal::Null) + } + + /// Create a boolean literal + pub fn bool(b: bool) -> Self { + Self::Literal(Literal::Bool(b)) + } + + /// Create an integer literal + pub fn int(n: i64) -> Self { + Self::Literal(Literal::Integer(n)) + } + + /// Create a string literal + pub fn string(s: impl Into) -> Self { + Self::Literal(Literal::String(s.into())) + } + + /// Create a binary operation + pub fn binary(left: Expr, op: BinaryOperator, right: Expr) -> Self { + Self::BinaryOp { + left: Box::new(left), + op, + right: Box::new(right), + } + } + + /// Create a NOT expression + pub fn not(expr: Expr) -> Self { + Self::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(expr), + } + } + + /// Create an IS NULL expression + pub fn is_null(expr: Expr) -> Self { + Self::IsNull { + expr: Box::new(expr), + negated: false, + } + } + + /// Create an IS NOT NULL expression + pub fn is_not_null(expr: Expr) -> Self { + Self::IsNull { + expr: Box::new(expr), + negated: true, + } + } + + /// Create a type cast + pub fn cast(expr: Expr, target_type: SqlType) -> Self { + Self::Cast { + expr: Box::new(expr), + target_type, + } + } + + /// Create a function call + pub fn function(name: impl Into, args: Vec) -> Self { + Self::FunctionCall(FunctionCall::new( + name, + args.into_iter().map(FunctionArg::Unnamed).collect(), + )) + } + + /// Create a COALESCE expression + pub fn coalesce(exprs: Vec) -> Self { + Self::Coalesce(exprs) + } + + /// Create jsonb_build_object + pub fn jsonb_build_object(pairs: Vec<(Expr, Expr)>) -> Self { + Self::JsonBuild(JsonBuildExpr::Object(pairs)) + } + + /// Create jsonb_build_array + pub fn jsonb_build_array(exprs: Vec) -> Self { + Self::JsonBuild(JsonBuildExpr::Array(exprs)) + } + + /// Wrap in parentheses + pub fn nested(self) -> Self { + Self::Nested(Box::new(self)) + } + + /// Combine with AND + pub fn and(self, other: Expr) -> Self { + Self::binary(self, BinaryOperator::And, other) + } + + /// Combine with OR + pub fn or(self, other: Expr) -> Self { + Self::binary(self, BinaryOperator::Or, other) + } + + /// Check equality + pub fn eq(self, other: Expr) -> Self { + Self::binary(self, BinaryOperator::Eq, other) + } + + /// Create raw SQL - **DEPRECATED**: Use only in tests. + /// + /// # Security Warning + /// + /// This bypasses SQL injection protection. Never use with user input. + #[cfg(test)] + pub fn raw(sql: impl Into) -> Self { + Self::Raw(sql.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_column_ref() { + let col = ColumnRef::new("id"); + assert_eq!(col.column.as_str(), "id"); + assert!(col.table_alias.is_none()); + + let col = ColumnRef::qualified("users", "id"); + assert_eq!(col.table_alias.unwrap().as_str(), "users"); + assert_eq!(col.column.as_str(), "id"); + } + + #[test] + fn test_expr_constructors() { + let expr = Expr::qualified_column("t", "id"); + match expr { + Expr::Column(c) => { + assert_eq!(c.table_alias.unwrap().as_str(), "t"); + assert_eq!(c.column.as_str(), "id"); + } + _ => panic!("Expected Column"), + } + + let expr = Expr::int(42); + match expr { + Expr::Literal(Literal::Integer(n)) => assert_eq!(n, 42), + _ => panic!("Expected Integer"), + } + } + + #[test] + fn test_binary_op() { + let expr = Expr::qualified_column("t", "id").eq(Expr::int(1)); + match expr { + Expr::BinaryOp { op, .. } => assert_eq!(op, BinaryOperator::Eq), + _ => panic!("Expected BinaryOp"), + } + } +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs new file mode 100644 index 00000000..f570a6eb --- /dev/null +++ b/src/ast/mod.rs @@ -0,0 +1,77 @@ +//! SQL Abstract Syntax Tree (AST) module +//! +//! This module provides a type-safe representation of SQL statements that can be +//! constructed programmatically and rendered to SQL strings. It is intentionally +//! decoupled from pgrx to allow for: +//! +//! - Independent testing without PostgreSQL +//! - Potential reuse in other contexts +//! - Clear separation of concerns +//! +//! # Architecture +//! +//! The AST is built from several components: +//! +//! - [`expr`]: SQL expressions (columns, literals, operators, functions) +//! - [`stmt`]: SQL statements (SELECT, INSERT, UPDATE, DELETE) +//! - [`cte`]: Common Table Expressions (WITH clauses) +//! - [`types`]: SQL type representations +//! - [`params`]: Parameter handling for prepared statements +//! - [`render`]: SQL string generation +//! +//! # Example +//! +//! ```rust,ignore +//! use pg_graphql::ast::*; +//! +//! let mut params = ParamCollector::new(); +//! let stmt = SelectStmt { +//! columns: vec![SelectColumn::star()], +//! from: Some(FromClause::table("public", "users", "u")), +//! where_clause: Some(Expr::binary( +//! Expr::column("u", "id"), +//! BinaryOperator::Eq, +//! params.add(ParamValue::Integer(1), SqlType::integer()), +//! )), +//! ..Default::default() +//! }; +//! +//! let sql = render(&Stmt::Select(stmt)); +//! // SELECT * FROM "public"."users" "u" WHERE "u"."id" = ($1::integer) +//! ``` + +mod builder_bridge; +mod cte; +mod execute; +mod expr; +mod params; +mod render; +mod stmt; +mod transpile_connection; +mod transpile_delete; +mod transpile_filter; +mod transpile_function_call; +mod transpile_insert; +mod transpile_node; +mod transpile_update; +mod types; + +// Re-export all public types +pub use builder_bridge::*; +pub use cte::*; +pub use execute::*; +pub use expr::*; +pub use params::*; +pub use render::*; +pub use stmt::*; +pub use transpile_connection::*; +pub use transpile_delete::*; +pub use transpile_filter::*; +pub use transpile_function_call::*; +pub use transpile_insert::*; +pub use transpile_node::*; +pub use transpile_update::*; +pub use types::*; + +#[cfg(test)] +mod tests; diff --git a/src/ast/params.rs b/src/ast/params.rs new file mode 100644 index 00000000..ec8ca749 --- /dev/null +++ b/src/ast/params.rs @@ -0,0 +1,311 @@ +//! Parameter handling for prepared statements +//! +//! This module provides a pgrx-independent way to collect and manage +//! query parameters. The actual conversion to pgrx Datums happens +//! in the executor module. + +use super::expr::{Expr, ParamRef}; +use super::types::SqlType; + +/// A parameter value that can be rendered to SQL +/// +/// This is intentionally decoupled from pgrx to allow the AST module +/// to be tested independently. +#[derive(Debug, Clone, PartialEq)] +pub enum ParamValue { + /// SQL NULL + Null, + /// Boolean value + Bool(bool), + /// String value + String(String), + /// Integer value + Integer(i64), + /// Floating point value + Float(f64), + /// Array of values (for array parameters) + Array(Vec), + /// JSON value (stored as serde_json::Value) + Json(serde_json::Value), +} + +impl ParamValue { + /// Check if this is a null value + pub fn is_null(&self) -> bool { + matches!(self, Self::Null) + } + + /// Convert to a string representation for SQL + pub fn to_sql_literal(&self) -> Option { + match self { + Self::Null => None, + Self::Bool(b) => Some(b.to_string()), + Self::String(s) => Some(s.clone()), + Self::Integer(n) => Some(n.to_string()), + Self::Float(f) => Some(f.to_string()), + Self::Array(arr) => { + let elements: Vec = arr.iter().filter_map(|v| v.to_sql_literal()).collect(); + Some(format!("{{{}}}", elements.join(","))) + } + Self::Json(v) => Some(v.to_string()), + } + } +} + +impl From for ParamValue { + fn from(b: bool) -> Self { + Self::Bool(b) + } +} + +impl From for ParamValue { + fn from(s: String) -> Self { + Self::String(s) + } +} + +impl From<&str> for ParamValue { + fn from(s: &str) -> Self { + Self::String(s.to_string()) + } +} + +impl From for ParamValue { + fn from(n: i64) -> Self { + Self::Integer(n) + } +} + +impl From for ParamValue { + fn from(n: i32) -> Self { + Self::Integer(n as i64) + } +} + +impl From for ParamValue { + fn from(f: f64) -> Self { + Self::Float(f) + } +} + +impl From for ParamValue { + fn from(v: serde_json::Value) -> Self { + Self::Json(v) + } +} + +impl> From> for ParamValue { + fn from(opt: Option) -> Self { + match opt { + Some(v) => v.into(), + None => Self::Null, + } + } +} + +/// A collected parameter with its index, value, and type +#[derive(Debug, Clone)] +pub struct Param { + /// 1-indexed parameter number + pub index: usize, + /// The parameter value + pub value: ParamValue, + /// The SQL type for casting + pub sql_type: SqlType, +} + +impl Param { + pub fn new(index: usize, value: ParamValue, sql_type: SqlType) -> Self { + Self { + index, + value, + sql_type, + } + } +} + +/// Collects parameters during AST construction +/// +/// This allows building parameterized queries without worrying about +/// parameter numbering. Each call to `add()` returns an expression +/// that references the parameter. +#[derive(Debug, Default)] +pub struct ParamCollector { + params: Vec, +} + +impl ParamCollector { + /// Create a new empty parameter collector + pub fn new() -> Self { + Self { params: Vec::new() } + } + + /// Add a parameter and return an expression that references it + /// + /// Parameters are 1-indexed in SQL ($1, $2, etc.) + pub fn add(&mut self, value: ParamValue, sql_type: SqlType) -> Expr { + let index = self.params.len() + 1; // 1-indexed for SQL + self.params.push(Param { + index, + value, + sql_type: sql_type.clone(), + }); + Expr::Param(ParamRef { + index, + type_cast: sql_type, + }) + } + + /// Add a parameter from a serde_json::Value + /// + /// This is a convenience method for the common case of converting + /// GraphQL input values to parameters. + pub fn add_json(&mut self, value: &serde_json::Value, sql_type: SqlType) -> Expr { + let param_value = json_to_param_value(value); + self.add(param_value, sql_type) + } + + /// Get all collected parameters + pub fn into_params(self) -> Vec { + self.params + } + + /// Get parameters as a slice + pub fn params(&self) -> &[Param] { + &self.params + } + + /// Get the number of collected parameters + pub fn len(&self) -> usize { + self.params.len() + } + + /// Check if no parameters have been collected + pub fn is_empty(&self) -> bool { + self.params.is_empty() + } + + /// Get the next parameter index that would be assigned + pub fn next_index(&self) -> usize { + self.params.len() + 1 + } +} + +/// Convert a serde_json::Value to a ParamValue +pub fn json_to_param_value(value: &serde_json::Value) -> ParamValue { + match value { + serde_json::Value::Null => ParamValue::Null, + serde_json::Value::Bool(b) => ParamValue::Bool(*b), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + ParamValue::Integer(i) + } else if let Some(f) = n.as_f64() { + ParamValue::Float(f) + } else { + // Fallback to string representation + ParamValue::String(n.to_string()) + } + } + serde_json::Value::String(s) => ParamValue::String(s.clone()), + serde_json::Value::Array(arr) => { + ParamValue::Array(arr.iter().map(json_to_param_value).collect()) + } + serde_json::Value::Object(_) => { + // Store objects as JSON + ParamValue::Json(value.clone()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_param_collector() { + let mut collector = ParamCollector::new(); + + let expr1 = collector.add(ParamValue::String("hello".into()), SqlType::text()); + let expr2 = collector.add(ParamValue::Integer(42), SqlType::integer()); + + assert_eq!(collector.len(), 2); + + match &expr1 { + Expr::Param(p) => { + assert_eq!(p.index, 1); + assert_eq!(p.type_cast.name, "text"); + } + _ => panic!("Expected Param"), + } + + match &expr2 { + Expr::Param(p) => { + assert_eq!(p.index, 2); + assert_eq!(p.type_cast.name, "integer"); + } + _ => panic!("Expected Param"), + } + + let params = collector.into_params(); + assert_eq!(params.len(), 2); + assert!(matches!(params[0].value, ParamValue::String(_))); + assert!(matches!(params[1].value, ParamValue::Integer(42))); + } + + #[test] + fn test_json_to_param_value() { + assert!(matches!( + json_to_param_value(&serde_json::Value::Null), + ParamValue::Null + )); + + assert!(matches!( + json_to_param_value(&serde_json::json!(true)), + ParamValue::Bool(true) + )); + + assert!(matches!( + json_to_param_value(&serde_json::json!(42)), + ParamValue::Integer(42) + )); + + assert!(matches!( + json_to_param_value(&serde_json::json!("hello")), + ParamValue::String(s) if s == "hello" + )); + + let arr = json_to_param_value(&serde_json::json!([1, 2, 3])); + match arr { + ParamValue::Array(v) => assert_eq!(v.len(), 3), + _ => panic!("Expected Array"), + } + } + + #[test] + fn test_param_value_to_sql_literal() { + assert_eq!(ParamValue::Null.to_sql_literal(), None); + assert_eq!( + ParamValue::Bool(true).to_sql_literal(), + Some("true".to_string()) + ); + assert_eq!( + ParamValue::Integer(42).to_sql_literal(), + Some("42".to_string()) + ); + assert_eq!( + ParamValue::String("hello".into()).to_sql_literal(), + Some("hello".to_string()) + ); + } + + #[test] + fn test_param_value_from() { + let _: ParamValue = true.into(); + let _: ParamValue = "hello".into(); + let _: ParamValue = 42i32.into(); + let _: ParamValue = 42i64.into(); + let _: ParamValue = 3.14f64.into(); + let _: ParamValue = None::.into(); + let _: ParamValue = Some(42i32).into(); + } +} diff --git a/src/ast/render.rs b/src/ast/render.rs new file mode 100644 index 00000000..32b37cc5 --- /dev/null +++ b/src/ast/render.rs @@ -0,0 +1,1186 @@ +//! SQL string rendering +//! +//! This module converts AST nodes to SQL strings. It is the only place +//! in the codebase where SQL strings are constructed. +//! +//! # Architecture +//! +//! The rendering system is built around two key components: +//! +//! - [`Render`] trait: Implemented by all AST nodes to define how they render to SQL +//! - [`SqlRenderer`]: The rendering context that handles output buffering and formatting +//! +//! # Safety +//! +//! All identifiers are quoted using PostgreSQL's native `quote_ident()` function. +//! All literals are escaped using PostgreSQL's native `quote_literal()` function. + +use super::cte::{Cte, CteQuery}; +use super::expr::*; +use super::stmt::*; +use super::types::SqlType; +use pgrx::{direct_function_call, pg_sys, IntoDatum}; +use std::fmt::Write; + +/// Quote an identifier using PostgreSQL's native quote_ident() function +fn quote_ident(ident: &str) -> String { + unsafe { + direct_function_call::(pg_sys::quote_ident, &[ident.into_datum()]) + .expect("quote_ident failed") + } +} + +/// Quote a literal using PostgreSQL's native quote_literal() function +fn quote_literal(lit: &str) -> String { + unsafe { + direct_function_call::(pg_sys::quote_literal, &[lit.into_datum()]) + .expect("quote_literal failed") + } +} + +// ============================================================================= +// Render Trait +// ============================================================================= + +/// Trait for AST nodes that can be rendered to SQL. +/// +/// This trait allows each AST node type to define its own rendering logic +/// while sharing the common [`SqlRenderer`] infrastructure. +/// +/// # Example +/// +/// ```rust,ignore +/// use pg_graphql::ast::{Render, SqlRenderer, Expr}; +/// +/// let expr = Expr::int(42); +/// let mut renderer = SqlRenderer::new(); +/// expr.render(&mut renderer); +/// let sql = renderer.into_sql(); +/// assert_eq!(sql, "42"); +/// ``` +pub trait Render { + /// Render this node to the given SQL renderer + fn render(&self, renderer: &mut SqlRenderer); +} + +// Implement Render for Stmt +impl Render for Stmt { + fn render(&self, renderer: &mut SqlRenderer) { + match self { + Stmt::Select(s) => s.render(renderer), + Stmt::Insert(s) => s.render(renderer), + Stmt::Update(s) => s.render(renderer), + Stmt::Delete(s) => s.render(renderer), + } + } +} + +// Implement Render for SelectStmt +impl Render for SelectStmt { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_select(self); + } +} + +// Implement Render for InsertStmt +impl Render for InsertStmt { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_insert(self); + } +} + +// Implement Render for UpdateStmt +impl Render for UpdateStmt { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_update(self); + } +} + +// Implement Render for DeleteStmt +impl Render for DeleteStmt { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_delete(self); + } +} + +// Implement Render for Expr +impl Render for Expr { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_expr(self); + } +} + +// Implement Render for Literal +impl Render for Literal { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.render_literal(self); + } +} + +// Implement Render for Ident +impl Render for Ident { + fn render(&self, renderer: &mut SqlRenderer) { + renderer.write_ident(self); + } +} + +// ============================================================================= +// Constants +// ============================================================================= + +/// Default buffer capacity for simple queries +const DEFAULT_BUFFER_CAPACITY: usize = 1024; + +/// Buffer capacity for queries with CTEs (e.g., connection queries) +const CTE_BUFFER_CAPACITY: usize = 4096; + +/// Buffer capacity for complex queries with many CTEs +const LARGE_BUFFER_CAPACITY: usize = 8192; + +/// SQL renderer with optional pretty-printing +pub struct SqlRenderer { + output: String, + indent_level: usize, + pretty: bool, +} + +impl SqlRenderer { + /// Create a new renderer with compact output + pub fn new() -> Self { + Self { + output: String::with_capacity(DEFAULT_BUFFER_CAPACITY), + indent_level: 0, + pretty: false, + } + } + + /// Create a new renderer with a specific buffer capacity + pub fn with_capacity(capacity: usize) -> Self { + Self { + output: String::with_capacity(capacity), + indent_level: 0, + pretty: false, + } + } + + /// Create a new renderer with pretty-printed output + pub fn pretty() -> Self { + Self { + output: String::with_capacity(DEFAULT_BUFFER_CAPACITY), + indent_level: 0, + pretty: true, + } + } + + /// Estimate appropriate buffer capacity based on statement complexity + pub fn estimate_capacity(stmt: &Stmt) -> usize { + match stmt { + Stmt::Select(s) => { + let cte_count = s.ctes.len(); + if cte_count >= 5 { + LARGE_BUFFER_CAPACITY + } else if cte_count > 0 { + CTE_BUFFER_CAPACITY + } else { + DEFAULT_BUFFER_CAPACITY + } + } + Stmt::Insert(s) => { + let cte_count = s.ctes.len(); + if cte_count > 0 { + CTE_BUFFER_CAPACITY + } else { + DEFAULT_BUFFER_CAPACITY + } + } + Stmt::Update(s) => { + let cte_count = s.ctes.len(); + if cte_count > 0 { + CTE_BUFFER_CAPACITY + } else { + DEFAULT_BUFFER_CAPACITY + } + } + Stmt::Delete(s) => { + let cte_count = s.ctes.len(); + if cte_count > 0 { + CTE_BUFFER_CAPACITY + } else { + DEFAULT_BUFFER_CAPACITY + } + } + } + } + + /// Render a statement and return the SQL string + pub fn render_stmt(&mut self, stmt: &Stmt) -> &str { + match stmt { + Stmt::Select(s) => self.render_select(s), + Stmt::Insert(s) => self.render_insert(s), + Stmt::Update(s) => self.render_update(s), + Stmt::Delete(s) => self.render_delete(s), + } + &self.output + } + + /// Take ownership of the rendered SQL string + pub fn into_sql(self) -> String { + self.output + } + + // ========================================================================= + // Statement rendering + // ========================================================================= + + fn render_select(&mut self, stmt: &SelectStmt) { + self.render_ctes(&stmt.ctes); + self.write("select "); + self.render_select_columns(&stmt.columns); + + if let Some(from) = &stmt.from { + self.newline(); + self.write("from "); + self.render_from(from); + } + + if let Some(where_clause) = &stmt.where_clause { + self.newline(); + self.write("where "); + self.render_expr(where_clause); + } + + if !stmt.group_by.is_empty() { + self.newline(); + self.write("group by "); + self.render_expr_list(&stmt.group_by); + } + + if let Some(having) = &stmt.having { + self.newline(); + self.write("having "); + self.render_expr(having); + } + + if !stmt.order_by.is_empty() { + self.newline(); + self.write("order by "); + self.render_order_by(&stmt.order_by); + } + + if let Some(limit) = stmt.limit { + self.newline(); + write!(self.output, "limit {}", limit).unwrap(); + } + + if let Some(offset) = stmt.offset { + self.newline(); + write!(self.output, "offset {}", offset).unwrap(); + } + } + + fn render_insert(&mut self, stmt: &InsertStmt) { + self.render_ctes(&stmt.ctes); + self.write("insert into "); + self.render_table_name(stmt.schema.as_ref(), &stmt.table); + + if !stmt.columns.is_empty() { + self.write("("); + for (i, col) in stmt.columns.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write_ident(col); + } + self.write(")"); + } + + self.newline(); + match &stmt.values { + InsertValues::Values(rows) => { + self.write("values "); + for (i, row) in rows.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write("("); + self.render_expr_list(row); + self.write(")"); + } + } + InsertValues::Query(query) => { + self.render_select(query); + } + InsertValues::DefaultValues => { + self.write("default values"); + } + } + + if let Some(on_conflict) = &stmt.on_conflict { + self.newline(); + self.render_on_conflict(on_conflict); + } + + if !stmt.returning.is_empty() { + self.newline(); + self.write("returning "); + self.render_select_columns(&stmt.returning); + } + } + + fn render_update(&mut self, stmt: &UpdateStmt) { + self.render_ctes(&stmt.ctes); + self.write("update "); + self.render_table_name(stmt.schema.as_ref(), &stmt.table); + + if let Some(alias) = &stmt.alias { + self.write(" as "); + self.write_ident(alias); + } + + self.newline(); + self.write("set "); + for (i, (col, expr)) in stmt.set.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write_ident(col); + self.write(" = "); + self.render_expr(expr); + } + + if let Some(from) = &stmt.from { + self.newline(); + self.write("from "); + self.render_from(from); + } + + if let Some(where_clause) = &stmt.where_clause { + self.newline(); + self.write("where "); + self.render_expr(where_clause); + } + + if !stmt.returning.is_empty() { + self.newline(); + self.write("returning "); + self.render_select_columns(&stmt.returning); + } + } + + fn render_delete(&mut self, stmt: &DeleteStmt) { + self.render_ctes(&stmt.ctes); + self.write("delete from "); + self.render_table_name(stmt.schema.as_ref(), &stmt.table); + + if let Some(alias) = &stmt.alias { + self.write(" as "); + self.write_ident(alias); + } + + if let Some(using) = &stmt.using { + self.newline(); + self.write("using "); + self.render_from(using); + } + + if let Some(where_clause) = &stmt.where_clause { + self.newline(); + self.write("where "); + self.render_expr(where_clause); + } + + if !stmt.returning.is_empty() { + self.newline(); + self.write("returning "); + self.render_select_columns(&stmt.returning); + } + } + + // ========================================================================= + // CTE rendering + // ========================================================================= + + fn render_ctes(&mut self, ctes: &[Cte]) { + if ctes.is_empty() { + return; + } + + self.write("with "); + for (i, cte) in ctes.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.render_cte(cte); + } + self.newline(); + } + + fn render_cte(&mut self, cte: &Cte) { + self.write_ident(&cte.name); + + if let Some(columns) = &cte.columns { + self.write("("); + for (i, col) in columns.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write_ident(col); + } + self.write(")"); + } + + self.write(" as "); + + if let Some(materialized) = cte.materialized { + if materialized { + self.write("materialized "); + } else { + self.write("not materialized "); + } + } + + self.write("("); + self.indent(); + self.newline(); + + match &cte.query { + CteQuery::Select(s) => self.render_select(s), + CteQuery::Insert(s) => self.render_insert(s), + CteQuery::Update(s) => self.render_update(s), + CteQuery::Delete(s) => self.render_delete(s), + } + + self.dedent(); + self.newline(); + self.write(")"); + } + + // ========================================================================= + // FROM clause rendering + // ========================================================================= + + fn render_from(&mut self, from: &FromClause) { + match from { + FromClause::Table { + schema, + name, + alias, + } => { + self.render_table_name(schema.as_ref(), name); + if let Some(alias) = alias { + self.write(" "); + self.write_ident(alias); + } + } + FromClause::Subquery { query, alias } => { + self.write("("); + self.render_select(query); + self.write(") "); + self.write_ident(alias); + } + FromClause::Function { call, alias } => { + self.render_function_call(call); + self.write(" "); + self.write_ident(alias); + } + FromClause::Join { + left, + join_type, + right, + on, + } => { + self.render_from(left); + self.newline(); + self.write(join_type.as_sql()); + self.write(" "); + self.render_from(right); + if let Some(on_expr) = on { + self.write(" on "); + self.render_expr(on_expr); + } + } + FromClause::CrossJoin { left, right } => { + self.render_from(left); + self.write(" cross join "); + self.render_from(right); + } + FromClause::Lateral { subquery, alias } => { + self.write("lateral ("); + self.render_select(subquery); + self.write(") "); + self.write_ident(alias); + } + } + } + + // ========================================================================= + // Expression rendering + // ========================================================================= + + fn render_expr(&mut self, expr: &Expr) { + match expr { + Expr::Column(col) => { + if let Some(table) = &col.table_alias { + self.write_ident(table); + self.write("."); + } + self.write_ident(&col.column); + } + + Expr::Literal(lit) => self.render_literal(lit), + + Expr::Param(p) => { + write!(self.output, "(${}", p.index).unwrap(); + self.write("::"); + self.render_type(&p.type_cast); + self.write(")"); + } + + Expr::BinaryOp { left, op, right } => { + self.render_expr(left); + self.write(" "); + self.write(op.as_sql()); + // The ANY operator needs parentheses around the right operand: col = any(arr) + if *op == BinaryOperator::Any { + self.write("("); + self.render_expr(right); + self.write(")"); + } else { + self.write(" "); + self.render_expr(right); + } + } + + Expr::UnaryOp { op, expr } => { + self.write(op.as_sql()); + self.write("("); + self.render_expr(expr); + self.write(")"); + } + + Expr::FunctionCall(call) => { + self.render_function_call(call); + } + + Expr::Aggregate(agg) => { + self.render_aggregate(agg); + } + + Expr::Case(case) => { + self.render_case(case); + } + + Expr::Subquery(query) => { + self.write("("); + self.render_select(query); + self.write(")"); + } + + Expr::Array(exprs) => { + self.write("array["); + self.render_expr_list(exprs); + self.write("]"); + } + + Expr::Cast { expr, target_type } => { + self.render_expr(expr); + self.write("::"); + self.render_type(target_type); + } + + Expr::IsNull { expr, negated } => { + self.render_expr(expr); + if *negated { + self.write(" is not null"); + } else { + self.write(" is null"); + } + } + + Expr::InList { + expr, + list, + negated, + } => { + self.render_expr(expr); + if *negated { + self.write(" not in ("); + } else { + self.write(" in ("); + } + self.render_expr_list(list); + self.write(")"); + } + + Expr::Between { + expr, + low, + high, + negated, + } => { + self.render_expr(expr); + if *negated { + self.write(" not between "); + } else { + self.write(" between "); + } + self.render_expr(low); + self.write(" and "); + self.render_expr(high); + } + + Expr::Exists { subquery, negated } => { + if *negated { + self.write("not "); + } + self.write("exists ("); + self.render_select(subquery); + self.write(")"); + } + + Expr::JsonBuild(json) => { + self.render_json_build(json); + } + + Expr::Coalesce(exprs) => { + self.write("coalesce("); + self.render_expr_list(exprs); + self.write(")"); + } + + Expr::Nested(inner) => { + self.write("("); + self.render_expr(inner); + self.write(")"); + } + + // Raw SQL - SECURITY SENSITIVE + // This variant is only available in test code (#[cfg(test)]). + // It outputs the string directly without any escaping. + // NEVER expose this outside of tests - use proper AST nodes instead. + #[cfg(test)] + Expr::Raw(sql) => { + self.write(sql); + } + + Expr::ArrayIndex { array, index } => { + self.write("("); + self.render_expr(array); + self.write(")["); + self.render_expr(index); + self.write("]"); + } + + Expr::FunctionCallWithOrderBy { + name, + args, + order_by, + } => { + self.write(name); + self.write("("); + for (i, arg) in args.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.render_expr(arg); + } + if !order_by.is_empty() { + self.write(" order by "); + self.render_order_by(order_by); + } + self.write(")"); + } + + Expr::Row(exprs) => { + // ROW constructor: ROW(expr1, expr2, ...) or just (expr1, expr2, ...) + // We use the explicit ROW keyword for clarity + self.write("row("); + self.render_expr_list(exprs); + self.write(")"); + } + } + } + + fn render_literal(&mut self, lit: &Literal) { + match lit { + Literal::Null => self.write("null"), + Literal::Bool(b) => self.write(if *b { "true" } else { "false" }), + Literal::Integer(n) => write!(self.output, "{}", n).unwrap(), + Literal::Float(f) => write!(self.output, "{}", f).unwrap(), + Literal::String(s) => self.write_literal(s), + Literal::Default => self.write("default"), + } + } + + fn render_function_call(&mut self, call: &FunctionCall) { + if let Some(schema) = &call.schema { + self.write_ident(schema); + self.write("."); + } + self.write_ident(&call.name); + self.write("("); + + for (i, arg) in call.args.iter().enumerate() { + if i > 0 { + self.write(", "); + } + match arg { + FunctionArg::Unnamed(expr) => self.render_expr(expr), + FunctionArg::Named { name, value } => { + self.write_ident(name); + self.write(" => "); + self.render_expr(value); + } + } + } + + if let Some(order_by) = &call.order_by { + if !order_by.is_empty() { + self.write(" order by "); + self.render_order_by(order_by); + } + } + + self.write(")"); + + if let Some(filter) = &call.filter { + self.write(" filter (where "); + self.render_expr(filter); + self.write(")"); + } + } + + fn render_aggregate(&mut self, agg: &AggregateExpr) { + self.write(agg.function.as_sql()); + self.write("("); + + if agg.distinct { + self.write("distinct "); + } + + // Special case for count(*) + if agg.args.is_empty() && matches!(agg.function, AggregateFunction::Count) { + self.write("*"); + } else { + self.render_expr_list(&agg.args); + } + + if let Some(order_by) = &agg.order_by { + if !order_by.is_empty() { + self.write(" order by "); + self.render_order_by(order_by); + } + } + + self.write(")"); + + if let Some(filter) = &agg.filter { + self.write(" filter (where "); + self.render_expr(filter); + self.write(")"); + } + } + + fn render_case(&mut self, case: &CaseExpr) { + self.write("case"); + + if let Some(operand) = &case.operand { + self.write(" "); + self.render_expr(operand); + } + + for (when_expr, then_expr) in &case.when_clauses { + self.write(" when "); + self.render_expr(when_expr); + self.write(" then "); + self.render_expr(then_expr); + } + + if let Some(else_clause) = &case.else_clause { + self.write(" else "); + self.render_expr(else_clause); + } + + self.write(" end"); + } + + fn render_json_build(&mut self, json: &JsonBuildExpr) { + match json { + JsonBuildExpr::Object(pairs) => { + self.write("jsonb_build_object("); + for (i, (key, value)) in pairs.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.render_expr(key); + self.write(", "); + self.render_expr(value); + } + self.write(")"); + } + JsonBuildExpr::Array(exprs) => { + self.write("jsonb_build_array("); + self.render_expr_list(exprs); + self.write(")"); + } + } + } + + fn render_on_conflict(&mut self, on_conflict: &OnConflict) { + self.write("on conflict "); + + match &on_conflict.target { + OnConflictTarget::Columns(cols) => { + self.write("("); + for (i, col) in cols.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write_ident(col); + } + self.write(") "); + } + OnConflictTarget::Constraint(name) => { + self.write("on constraint "); + self.write_ident(name); + self.write(" "); + } + } + + match &on_conflict.action { + OnConflictAction::DoNothing => { + self.write("do nothing"); + } + OnConflictAction::DoUpdate { set, where_clause } => { + self.write("do update set "); + for (i, (col, expr)) in set.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.write_ident(col); + self.write(" = "); + self.render_expr(expr); + } + if let Some(where_clause) = where_clause { + self.write(" where "); + self.render_expr(where_clause); + } + } + } + } + + // ========================================================================= + // Helper methods + // ========================================================================= + + fn render_select_columns(&mut self, columns: &[SelectColumn]) { + for (i, col) in columns.iter().enumerate() { + if i > 0 { + self.write(", "); + } + match col { + SelectColumn::Expr { expr, alias } => { + self.render_expr(expr); + if let Some(alias) = alias { + self.write(" as "); + self.write_ident(alias); + } + } + SelectColumn::Star => self.write("*"), + SelectColumn::QualifiedStar { table } => { + self.write_ident(table); + self.write(".*"); + } + } + } + } + + fn render_expr_list(&mut self, exprs: &[Expr]) { + for (i, expr) in exprs.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.render_expr(expr); + } + } + + fn render_order_by(&mut self, order_by: &[OrderByExpr]) { + for (i, ob) in order_by.iter().enumerate() { + if i > 0 { + self.write(", "); + } + self.render_expr(&ob.expr); + if let Some(dir) = &ob.direction { + self.write(" "); + self.write(dir.as_sql()); + } + if let Some(nulls) = &ob.nulls { + self.write(" "); + self.write(nulls.as_sql()); + } + } + } + + fn render_table_name(&mut self, schema: Option<&Ident>, name: &Ident) { + if let Some(schema) = schema { + self.write_ident(schema); + self.write("."); + } + self.write_ident(name); + } + + fn render_type(&mut self, sql_type: &SqlType) { + // Type names come from trusted sources (database schema metadata loaded via + // SQL queries), so we output them directly without quoting. + if let Some(schema) = &sql_type.schema { + self.output.push_str(schema); + self.output.push('.'); + } + self.output.push_str(&sql_type.name); + if sql_type.is_array { + self.output.push_str("[]"); + } + } + + // ========================================================================= + // Low-level output methods + // ========================================================================= + + fn write(&mut self, s: &str) { + self.output.push_str(s); + } + + fn write_ident(&mut self, ident: &Ident) { + // Use PostgreSQL's native quote_ident() for proper escaping + self.output.push_str("e_ident(ident.as_str())); + } + + fn write_literal(&mut self, s: &str) { + // Use PostgreSQL's native quote_literal() for proper escaping + self.output.push_str("e_literal(s)); + } + + fn newline(&mut self) { + if self.pretty { + self.output.push('\n'); + for _ in 0..self.indent_level { + self.output.push_str(" "); + } + } else { + self.output.push(' '); + } + } + + fn indent(&mut self) { + self.indent_level += 1; + } + + fn dedent(&mut self) { + if self.indent_level > 0 { + self.indent_level -= 1; + } + } +} + +impl Default for SqlRenderer { + fn default() -> Self { + Self::new() + } +} + +// ========================================================================= +// Convenience functions +// ========================================================================= + +/// Render a statement to a compact SQL string +/// +/// Uses estimated buffer capacity based on statement complexity to reduce allocations. +pub fn render(stmt: &Stmt) -> String { + let capacity = SqlRenderer::estimate_capacity(stmt); + let mut renderer = SqlRenderer::with_capacity(capacity); + renderer.render_stmt(stmt); + renderer.into_sql() +} + +/// Render a statement to a pretty-printed SQL string +pub fn render_pretty(stmt: &Stmt) -> String { + let capacity = SqlRenderer::estimate_capacity(stmt); + let mut renderer = SqlRenderer { + output: String::with_capacity(capacity), + indent_level: 0, + pretty: true, + }; + renderer.render_stmt(stmt); + renderer.into_sql() +} + +/// Render just a SELECT statement +pub fn render_select(stmt: &SelectStmt) -> String { + render(&Stmt::Select(stmt.clone())) +} + +/// Render just an expression +pub fn render_expr(expr: &Expr) -> String { + let mut renderer = SqlRenderer::new(); + renderer.render_expr(expr); + renderer.into_sql() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::*; + + #[test] + fn test_render_simple_select() { + let stmt = + SelectStmt::columns(vec![SelectColumn::star()]).with_from(FromClause::table("users")); + + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("select *")); + assert!(sql.contains("from \"users\"")); + } + + #[test] + fn test_render_select_with_where() { + let stmt = SelectStmt::columns(vec![ + SelectColumn::expr(Expr::qualified_column("t", "id")), + SelectColumn::expr(Expr::qualified_column("t", "name")), + ]) + .with_from(FromClause::table("users").with_alias("t")) + .with_where(Expr::qualified_column("t", "active").eq(Expr::bool(true))); + + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("\"t\".\"id\"")); + assert!(sql.contains("\"t\".\"name\"")); + assert!(sql.contains("where")); + assert!(sql.contains("\"t\".\"active\" = true")); + } + + #[test] + fn test_render_insert() { + let stmt = InsertStmt::new( + "users", + vec![Ident::new("name"), Ident::new("email")], + InsertValues::Values(vec![vec![ + Expr::string("Alice"), + Expr::string("alice@example.com"), + ]]), + ) + .with_schema("public") + .with_returning(vec![SelectColumn::expr(Expr::column("id"))]); + + let sql = render(&Stmt::Insert(stmt)); + assert!(sql.contains("insert into \"public\".\"users\"")); + assert!(sql.contains("(\"name\", \"email\")")); + assert!(sql.contains("values")); + assert!(sql.contains("returning")); + } + + #[test] + fn test_render_update() { + let stmt = UpdateStmt::new("users", vec![(Ident::new("name"), Expr::string("Bob"))]) + .with_where(Expr::column("id").eq(Expr::int(1))); + + let sql = render(&Stmt::Update(stmt)); + assert!(sql.contains("update \"users\"")); + assert!(sql.contains("set \"name\" =")); + assert!(sql.contains("where")); + } + + #[test] + fn test_render_delete() { + let stmt = DeleteStmt::new("users").with_where(Expr::column("id").eq(Expr::int(1))); + + let sql = render(&Stmt::Delete(stmt)); + assert!(sql.contains("delete from \"users\"")); + assert!(sql.contains("where")); + } + + #[test] + fn test_render_cte() { + let inner_select = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_where(Expr::column("active").eq(Expr::bool(true))); + + let outer_select = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("active_users")) + .with_ctes(vec![Cte::select("active_users", inner_select)]); + + let sql = render(&Stmt::Select(outer_select)); + assert!(sql.contains("with \"active_users\" as")); + assert!(sql.contains("from \"active_users\"")); + } + + #[test] + fn test_render_param() { + let param_ref = ParamRef { + index: 1, + type_cast: SqlType::text(), + }; + let expr = Expr::Param(param_ref); + let sql = render_expr(&expr); + assert_eq!(sql, "($1::text)"); + } + + #[test] + fn test_render_jsonb_build_object() { + let expr = Expr::jsonb_build_object(vec![ + (Expr::string("name"), Expr::string("Alice")), + (Expr::string("age"), Expr::int(30)), + ]); + let sql = render_expr(&expr); + assert!(sql.contains("jsonb_build_object")); + assert!(sql.contains("'name'")); + assert!(sql.contains("'Alice'")); + } + + #[test] + fn test_render_aggregate() { + let agg = AggregateExpr::new(AggregateFunction::Count, vec![]); + let expr = Expr::Aggregate(agg); + let sql = render_expr(&expr); + assert_eq!(sql, "count(*)"); + } + + #[test] + fn test_render_case() { + let case = CaseExpr::searched( + vec![ + ( + Expr::column("status").eq(Expr::string("active")), + Expr::int(1), + ), + ( + Expr::column("status").eq(Expr::string("inactive")), + Expr::int(0), + ), + ], + Some(Expr::int(-1)), + ); + let expr = Expr::Case(case); + let sql = render_expr(&expr); + assert!(sql.contains("case")); + assert!(sql.contains("when")); + assert!(sql.contains("then")); + assert!(sql.contains("else")); + assert!(sql.contains("end")); + } + + #[test] + fn test_ident_quoting() { + let ident = Ident::new("user\"name"); + let mut renderer = SqlRenderer::new(); + renderer.write_ident(&ident); + let sql = renderer.into_sql(); + assert_eq!(sql, "\"user\"\"name\""); + } + + #[test] + fn test_string_literal_with_quotes() { + let mut renderer = SqlRenderer::new(); + renderer.write_literal("it's a test"); + let sql = renderer.into_sql(); + assert!(sql.contains("$__$")); + } +} diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs new file mode 100644 index 00000000..7294621f --- /dev/null +++ b/src/ast/stmt.rs @@ -0,0 +1,541 @@ +//! SQL statement types +//! +//! This module defines the top-level SQL statement types: SELECT, INSERT, UPDATE, DELETE. + +use super::cte::Cte; +use super::expr::{Expr, FunctionCall, Ident, OrderByExpr}; + +/// Top-level SQL statement +#[derive(Debug, Clone, PartialEq)] +pub enum Stmt { + Select(SelectStmt), + Insert(InsertStmt), + Update(UpdateStmt), + Delete(DeleteStmt), +} + +impl Stmt { + pub fn select(stmt: SelectStmt) -> Self { + Self::Select(stmt) + } + + pub fn insert(stmt: InsertStmt) -> Self { + Self::Insert(stmt) + } + + pub fn update(stmt: UpdateStmt) -> Self { + Self::Update(stmt) + } + + pub fn delete(stmt: DeleteStmt) -> Self { + Self::Delete(stmt) + } +} + +/// SELECT statement +#[derive(Debug, Clone, PartialEq, Default)] +pub struct SelectStmt { + /// WITH clause (CTEs) + pub ctes: Vec, + /// SELECT columns + pub columns: Vec, + /// FROM clause + pub from: Option, + /// WHERE clause + pub where_clause: Option, + /// GROUP BY clause + pub group_by: Vec, + /// HAVING clause + pub having: Option, + /// ORDER BY clause + pub order_by: Vec, + /// LIMIT + pub limit: Option, + /// OFFSET + pub offset: Option, +} + +impl SelectStmt { + pub fn new() -> Self { + Self::default() + } + + /// Create a simple SELECT with columns + pub fn columns(columns: Vec) -> Self { + Self { + columns, + ..Default::default() + } + } + + pub fn with_from(mut self, from: FromClause) -> Self { + self.from = Some(from); + self + } + + pub fn with_where(mut self, expr: Expr) -> Self { + self.where_clause = Some(expr); + self + } + + pub fn with_ctes(mut self, ctes: Vec) -> Self { + self.ctes = ctes; + self + } + + pub fn with_order_by(mut self, order_by: Vec) -> Self { + self.order_by = order_by; + self + } + + pub fn with_limit(mut self, limit: u64) -> Self { + self.limit = Some(limit); + self + } + + pub fn with_offset(mut self, offset: u64) -> Self { + self.offset = Some(offset); + self + } + + pub fn with_group_by(mut self, group_by: Vec) -> Self { + self.group_by = group_by; + self + } +} + +/// A column in a SELECT clause +#[derive(Debug, Clone, PartialEq)] +pub enum SelectColumn { + /// An expression with optional alias: expr AS alias + Expr { expr: Expr, alias: Option }, + /// All columns: * + Star, + /// All columns from a table: table.* + QualifiedStar { table: Ident }, +} + +impl SelectColumn { + /// Create an expression column without alias + pub fn expr(expr: Expr) -> Self { + Self::Expr { expr, alias: None } + } + + /// Create an expression column with alias + pub fn expr_as(expr: Expr, alias: impl Into) -> Self { + Self::Expr { + expr, + alias: Some(alias.into()), + } + } + + /// Create a star (SELECT *) + pub fn star() -> Self { + Self::Star + } + + /// Create a qualified star (SELECT table.*) + pub fn qualified_star(table: impl Into) -> Self { + Self::QualifiedStar { + table: table.into(), + } + } +} + +/// FROM clause +#[derive(Debug, Clone, PartialEq)] +pub enum FromClause { + /// Simple table reference + Table { + schema: Option, + name: Ident, + alias: Option, + }, + /// Subquery + Subquery { + query: Box, + alias: Ident, + }, + /// Function call as table source + Function { call: FunctionCall, alias: Ident }, + /// JOIN clause + Join { + left: Box, + join_type: JoinType, + right: Box, + on: Option, + }, + /// CROSS JOIN + CrossJoin { + left: Box, + right: Box, + }, + /// LATERAL subquery + Lateral { + subquery: Box, + alias: Ident, + }, +} + +impl FromClause { + /// Create a simple table reference + pub fn table(name: impl Into) -> Self { + Self::Table { + schema: None, + name: name.into(), + alias: None, + } + } + + /// Create a schema-qualified table reference + pub fn qualified_table(schema: impl Into, name: impl Into) -> Self { + Self::Table { + schema: Some(schema.into()), + name: name.into(), + alias: None, + } + } + + /// Create a table reference with alias + pub fn table_alias( + schema: Option>, + name: impl Into, + alias: impl Into, + ) -> Self { + Self::Table { + schema: schema.map(|s| s.into()), + name: name.into(), + alias: Some(alias.into()), + } + } + + /// Create a subquery source + pub fn subquery(query: SelectStmt, alias: impl Into) -> Self { + Self::Subquery { + query: Box::new(query), + alias: alias.into(), + } + } + + /// Create a function call source + pub fn function(call: FunctionCall, alias: impl Into) -> Self { + Self::Function { + call, + alias: alias.into(), + } + } + + /// Add an alias to this FROM clause + pub fn with_alias(self, alias: impl Into) -> Self { + match self { + Self::Table { schema, name, .. } => Self::Table { + schema, + name, + alias: Some(alias.into()), + }, + _ => self, + } + } + + /// Create a LEFT JOIN + pub fn left_join(self, right: FromClause, on: Expr) -> Self { + Self::Join { + left: Box::new(self), + join_type: JoinType::Left, + right: Box::new(right), + on: Some(on), + } + } + + /// Create an INNER JOIN + pub fn inner_join(self, right: FromClause, on: Expr) -> Self { + Self::Join { + left: Box::new(self), + join_type: JoinType::Inner, + right: Box::new(right), + on: Some(on), + } + } + + /// Create a CROSS JOIN + pub fn cross_join(self, right: FromClause) -> Self { + Self::CrossJoin { + left: Box::new(self), + right: Box::new(right), + } + } +} + +/// JOIN type +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, +} + +impl JoinType { + pub fn as_sql(&self) -> &'static str { + match self { + Self::Inner => "inner join", + Self::Left => "left join", + Self::Right => "right join", + Self::Full => "full join", + } + } +} + +/// INSERT statement +#[derive(Debug, Clone, PartialEq)] +pub struct InsertStmt { + /// WITH clause (CTEs) + pub ctes: Vec, + /// Target schema + pub schema: Option, + /// Target table + pub table: Ident, + /// Target columns + pub columns: Vec, + /// Values to insert + pub values: InsertValues, + /// RETURNING clause + pub returning: Vec, + /// ON CONFLICT clause (for upserts) + pub on_conflict: Option, +} + +impl InsertStmt { + pub fn new(table: impl Into, columns: Vec, values: InsertValues) -> Self { + Self { + ctes: vec![], + schema: None, + table: table.into(), + columns, + values, + returning: vec![], + on_conflict: None, + } + } + + pub fn with_schema(mut self, schema: impl Into) -> Self { + self.schema = Some(schema.into()); + self + } + + pub fn with_returning(mut self, returning: Vec) -> Self { + self.returning = returning; + self + } + + pub fn with_ctes(mut self, ctes: Vec) -> Self { + self.ctes = ctes; + self + } +} + +/// Values for INSERT statement +#[derive(Debug, Clone, PartialEq)] +pub enum InsertValues { + /// VALUES (row1), (row2), ... + Values(Vec>), + /// INSERT ... SELECT ... + Query(Box), + /// DEFAULT VALUES + DefaultValues, +} + +/// ON CONFLICT clause for upserts +#[derive(Debug, Clone, PartialEq)] +pub struct OnConflict { + pub target: OnConflictTarget, + pub action: OnConflictAction, +} + +/// Target for ON CONFLICT +#[derive(Debug, Clone, PartialEq)] +pub enum OnConflictTarget { + /// ON CONFLICT (column1, column2) + Columns(Vec), + /// ON CONFLICT ON CONSTRAINT constraint_name + Constraint(Ident), +} + +/// Action for ON CONFLICT +#[derive(Debug, Clone, PartialEq)] +pub enum OnConflictAction { + /// DO NOTHING + DoNothing, + /// DO UPDATE SET ... + DoUpdate { + set: Vec<(Ident, Expr)>, + where_clause: Option, + }, +} + +/// UPDATE statement +#[derive(Debug, Clone, PartialEq)] +pub struct UpdateStmt { + /// WITH clause (CTEs) + pub ctes: Vec, + /// Target schema + pub schema: Option, + /// Target table + pub table: Ident, + /// Table alias + pub alias: Option, + /// SET clause: column = value pairs + pub set: Vec<(Ident, Expr)>, + /// FROM clause (for UPDATE ... FROM ...) + pub from: Option, + /// WHERE clause + pub where_clause: Option, + /// RETURNING clause + pub returning: Vec, +} + +impl UpdateStmt { + pub fn new(table: impl Into, set: Vec<(Ident, Expr)>) -> Self { + Self { + ctes: vec![], + schema: None, + table: table.into(), + alias: None, + set, + from: None, + where_clause: None, + returning: vec![], + } + } + + pub fn with_schema(mut self, schema: impl Into) -> Self { + self.schema = Some(schema.into()); + self + } + + pub fn with_alias(mut self, alias: impl Into) -> Self { + self.alias = Some(alias.into()); + self + } + + pub fn with_where(mut self, expr: Expr) -> Self { + self.where_clause = Some(expr); + self + } + + pub fn with_returning(mut self, returning: Vec) -> Self { + self.returning = returning; + self + } + + pub fn with_ctes(mut self, ctes: Vec) -> Self { + self.ctes = ctes; + self + } +} + +/// DELETE statement +#[derive(Debug, Clone, PartialEq)] +pub struct DeleteStmt { + /// WITH clause (CTEs) + pub ctes: Vec, + /// Target schema + pub schema: Option, + /// Target table + pub table: Ident, + /// Table alias + pub alias: Option, + /// USING clause (for DELETE ... USING ...) + pub using: Option, + /// WHERE clause + pub where_clause: Option, + /// RETURNING clause + pub returning: Vec, +} + +impl DeleteStmt { + pub fn new(table: impl Into) -> Self { + Self { + ctes: vec![], + schema: None, + table: table.into(), + alias: None, + using: None, + where_clause: None, + returning: vec![], + } + } + + pub fn with_schema(mut self, schema: impl Into) -> Self { + self.schema = Some(schema.into()); + self + } + + pub fn with_alias(mut self, alias: impl Into) -> Self { + self.alias = Some(alias.into()); + self + } + + pub fn with_where(mut self, expr: Expr) -> Self { + self.where_clause = Some(expr); + self + } + + pub fn with_returning(mut self, returning: Vec) -> Self { + self.returning = returning; + self + } + + pub fn with_ctes(mut self, ctes: Vec) -> Self { + self.ctes = ctes; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_select_builder() { + let stmt = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_limit(10); + + assert!(matches!(stmt.from, Some(FromClause::Table { .. }))); + assert_eq!(stmt.limit, Some(10)); + } + + #[test] + fn test_from_clause_join() { + let from = FromClause::table("users").with_alias("u").left_join( + FromClause::table("orders").with_alias("o"), + Expr::qualified_column("u", "id").eq(Expr::qualified_column("o", "user_id")), + ); + + assert!(matches!( + from, + FromClause::Join { + join_type: JoinType::Left, + .. + } + )); + } + + #[test] + fn test_insert_stmt() { + let stmt = InsertStmt::new( + "users", + vec![Ident::new("name"), Ident::new("email")], + InsertValues::Values(vec![vec![ + Expr::string("Alice"), + Expr::string("alice@example.com"), + ]]), + ) + .with_schema("public") + .with_returning(vec![SelectColumn::star()]); + + assert_eq!(stmt.table.0, "users"); + assert_eq!(stmt.columns.len(), 2); + } +} diff --git a/src/ast/tests.rs b/src/ast/tests.rs new file mode 100644 index 00000000..d625de9f --- /dev/null +++ b/src/ast/tests.rs @@ -0,0 +1,622 @@ +//! Comprehensive tests for the AST module +//! +//! These tests verify that the AST correctly represents SQL constructs +//! and renders them properly. + +use super::*; + +mod expr_tests { + use super::*; + + #[test] + fn test_column_expressions() { + // Simple column + let col = Expr::column("id"); + let sql = render_expr(&col); + assert_eq!(sql, "\"id\""); + + // Qualified column + let col = Expr::qualified_column("users", "email"); + let sql = render_expr(&col); + assert_eq!(sql, "\"users\".\"email\""); + } + + #[test] + fn test_literal_expressions() { + assert_eq!(render_expr(&Expr::null()), "null"); + assert_eq!(render_expr(&Expr::bool(true)), "true"); + assert_eq!(render_expr(&Expr::bool(false)), "false"); + assert_eq!(render_expr(&Expr::int(42)), "42"); + assert_eq!(render_expr(&Expr::int(-100)), "-100"); + assert_eq!(render_expr(&Expr::string("hello")), "'hello'"); + } + + #[test] + fn test_binary_operations() { + // Equality + let expr = Expr::column("id").eq(Expr::int(1)); + assert!(render_expr(&expr).contains("= 1")); + + // Comparison + let expr = Expr::binary(Expr::column("age"), BinaryOperator::GtEq, Expr::int(18)); + assert!(render_expr(&expr).contains(">= 18")); + + // Logical AND + let expr = Expr::column("a") + .eq(Expr::int(1)) + .and(Expr::column("b").eq(Expr::int(2))); + let sql = render_expr(&expr); + assert!(sql.contains("and")); + + // Logical OR + let expr = Expr::column("a") + .eq(Expr::int(1)) + .or(Expr::column("b").eq(Expr::int(2))); + let sql = render_expr(&expr); + assert!(sql.contains("or")); + } + + #[test] + fn test_is_null() { + let expr = Expr::is_null(Expr::column("deleted_at")); + assert!(render_expr(&expr).contains("is null")); + + let expr = Expr::is_not_null(Expr::column("deleted_at")); + assert!(render_expr(&expr).contains("is not null")); + } + + #[test] + fn test_in_list() { + let expr = Expr::InList { + expr: Box::new(Expr::column("status")), + list: vec![Expr::string("active"), Expr::string("pending")], + negated: false, + }; + let sql = render_expr(&expr); + assert!(sql.contains("in (")); + assert!(sql.contains("'active'")); + assert!(sql.contains("'pending'")); + + let expr = Expr::InList { + expr: Box::new(Expr::column("status")), + list: vec![Expr::string("deleted")], + negated: true, + }; + let sql = render_expr(&expr); + assert!(sql.contains("not in (")); + } + + #[test] + fn test_between() { + let expr = Expr::Between { + expr: Box::new(Expr::column("age")), + low: Box::new(Expr::int(18)), + high: Box::new(Expr::int(65)), + negated: false, + }; + let sql = render_expr(&expr); + assert!(sql.contains("between")); + assert!(sql.contains("18")); + assert!(sql.contains("65")); + } + + #[test] + fn test_type_cast() { + let expr = Expr::cast(Expr::column("id"), SqlType::text()); + let sql = render_expr(&expr); + assert!(sql.contains("::text")); + } + + #[test] + fn test_function_call() { + let expr = Expr::function("lower", vec![Expr::column("name")]); + let sql = render_expr(&expr); + assert!(sql.contains("\"lower\"(")); + } + + #[test] + fn test_coalesce() { + let expr = Expr::coalesce(vec![ + Expr::column("nickname"), + Expr::column("name"), + Expr::string("Unknown"), + ]); + let sql = render_expr(&expr); + assert!(sql.contains("coalesce(")); + } + + #[test] + fn test_case_expression() { + let case = CaseExpr::searched( + vec![ + (Expr::column("x").eq(Expr::int(1)), Expr::string("one")), + (Expr::column("x").eq(Expr::int(2)), Expr::string("two")), + ], + Some(Expr::string("other")), + ); + let expr = Expr::Case(case); + let sql = render_expr(&expr); + assert!(sql.contains("case")); + assert!(sql.contains("when")); + assert!(sql.contains("then")); + assert!(sql.contains("else")); + assert!(sql.contains("end")); + } + + #[test] + fn test_jsonb_build() { + // Object + let expr = Expr::jsonb_build_object(vec![(Expr::string("key"), Expr::string("value"))]); + let sql = render_expr(&expr); + assert!(sql.contains("jsonb_build_object(")); + + // Array + let expr = Expr::jsonb_build_array(vec![Expr::int(1), Expr::int(2)]); + let sql = render_expr(&expr); + assert!(sql.contains("jsonb_build_array(")); + } + + #[test] + fn test_aggregate_expressions() { + // COUNT(*) + let agg = AggregateExpr::count_all(); + let expr = Expr::Aggregate(agg); + assert_eq!(render_expr(&expr), "count(*)"); + + // SUM with column + let agg = AggregateExpr::new(AggregateFunction::Sum, vec![Expr::column("amount")]); + let expr = Expr::Aggregate(agg); + assert!(render_expr(&expr).contains("sum(")); + + // COUNT with DISTINCT + let agg = AggregateExpr::new(AggregateFunction::Count, vec![Expr::column("user_id")]) + .with_distinct(); + let expr = Expr::Aggregate(agg); + let sql = render_expr(&expr); + assert!(sql.contains("distinct")); + + // Aggregate with FILTER + let agg = AggregateExpr::new(AggregateFunction::Count, vec![]) + .with_filter(Expr::column("active").eq(Expr::bool(true))); + let expr = Expr::Aggregate(agg); + let sql = render_expr(&expr); + assert!(sql.contains("filter (where")); + } + + #[test] + fn test_array_operators() { + // Contains + let expr = Expr::binary( + Expr::column("tags"), + BinaryOperator::Contains, + Expr::Array(vec![Expr::string("rust")]), + ); + let sql = render_expr(&expr); + assert!(sql.contains("@>")); + + // Overlap + let expr = Expr::binary( + Expr::column("tags"), + BinaryOperator::Overlap, + Expr::column("other_tags"), + ); + let sql = render_expr(&expr); + assert!(sql.contains("&&")); + } +} + +mod stmt_tests { + use super::*; + + #[test] + fn test_simple_select() { + let stmt = + SelectStmt::columns(vec![SelectColumn::star()]).with_from(FromClause::table("users")); + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("select *")); + assert!(sql.contains("from \"users\"")); + } + + #[test] + fn test_select_with_join() { + let stmt = SelectStmt::columns(vec![ + SelectColumn::expr(Expr::qualified_column("u", "name")), + SelectColumn::expr(Expr::qualified_column("o", "total")), + ]) + .with_from(FromClause::table("users").with_alias("u").left_join( + FromClause::table("orders").with_alias("o"), + Expr::qualified_column("u", "id").eq(Expr::qualified_column("o", "user_id")), + )); + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("left join")); + assert!(sql.contains("on")); + } + + #[test] + fn test_select_with_subquery() { + let inner = SelectStmt::columns(vec![SelectColumn::expr(Expr::column("user_id"))]) + .with_from(FromClause::table("orders")); + let outer = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_where(Expr::InList { + expr: Box::new(Expr::column("id")), + list: vec![Expr::Subquery(Box::new(inner))], + negated: false, + }); + let sql = render(&Stmt::Select(outer)); + assert!(sql.contains("in (")); + assert!(sql.contains("select")); + } + + #[test] + fn test_select_with_order_limit_offset() { + let stmt = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_order_by(vec![ + OrderByExpr::desc(Expr::column("created_at")).with_nulls(NullsOrder::Last) + ]) + .with_limit(10) + .with_offset(20); + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("order by")); + assert!(sql.contains("desc")); + assert!(sql.contains("nulls last")); + assert!(sql.contains("limit 10")); + assert!(sql.contains("offset 20")); + } + + #[test] + fn test_insert_single_row() { + let stmt = InsertStmt::new( + "users", + vec![Ident::new("name"), Ident::new("email")], + InsertValues::Values(vec![vec![ + Expr::string("Alice"), + Expr::string("alice@example.com"), + ]]), + ); + let sql = render(&Stmt::Insert(stmt)); + assert!(sql.contains("insert into")); + assert!(sql.contains("values")); + } + + #[test] + fn test_insert_multiple_rows() { + let stmt = InsertStmt::new( + "users", + vec![Ident::new("name")], + InsertValues::Values(vec![vec![Expr::string("Alice")], vec![Expr::string("Bob")]]), + ); + let sql = render(&Stmt::Insert(stmt)); + assert!(sql.contains("'Alice'")); + assert!(sql.contains("'Bob'")); + } + + #[test] + fn test_insert_with_returning() { + let stmt = InsertStmt::new( + "users", + vec![Ident::new("name")], + InsertValues::Values(vec![vec![Expr::string("Alice")]]), + ) + .with_returning(vec![SelectColumn::star()]); + let sql = render(&Stmt::Insert(stmt)); + assert!(sql.contains("returning *")); + } + + #[test] + fn test_insert_default_values() { + let stmt = InsertStmt::new("users", vec![], InsertValues::DefaultValues); + let sql = render(&Stmt::Insert(stmt)); + assert!(sql.contains("default values")); + } + + #[test] + fn test_update_basic() { + let stmt = UpdateStmt::new( + "users", + vec![ + (Ident::new("name"), Expr::string("Bob")), + (Ident::new("updated_at"), Expr::raw("now()")), + ], + ) + .with_where(Expr::column("id").eq(Expr::int(1))); + let sql = render(&Stmt::Update(stmt)); + assert!(sql.contains("update")); + assert!(sql.contains("set")); + assert!(sql.contains("where")); + } + + #[test] + fn test_delete_basic() { + let stmt = + DeleteStmt::new("users").with_where(Expr::column("active").eq(Expr::bool(false))); + let sql = render(&Stmt::Delete(stmt)); + assert!(sql.contains("delete from")); + assert!(sql.contains("where")); + } +} + +mod cte_tests { + use super::*; + + #[test] + fn test_simple_cte() { + let inner = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_where(Expr::column("active").eq(Expr::bool(true))); + + let outer = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("active_users")) + .with_ctes(vec![Cte::select("active_users", inner)]); + + let sql = render(&Stmt::Select(outer)); + assert!(sql.contains("with")); + assert!(sql.contains("\"active_users\"")); + assert!(sql.contains("as (")); + } + + #[test] + fn test_cte_with_columns() { + let inner = SelectStmt::columns(vec![ + SelectColumn::expr(Expr::column("id")), + SelectColumn::expr(Expr::column("name")), + ]) + .with_from(FromClause::table("users")); + + let cte = Cte::select("user_info", inner).with_columns(vec!["user_id", "user_name"]); + + let outer = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("user_info")) + .with_ctes(vec![cte]); + + let sql = render(&Stmt::Select(outer)); + assert!(sql.contains("\"user_info\"(\"user_id\", \"user_name\")")); + } + + #[test] + fn test_data_modifying_cte() { + let insert = InsertStmt::new( + "users", + vec![Ident::new("name")], + InsertValues::Values(vec![vec![Expr::string("New User")]]), + ) + .with_returning(vec![SelectColumn::star()]); + + let outer = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("new_user")) + .with_ctes(vec![Cte::insert("new_user", insert)]); + + let sql = render(&Stmt::Select(outer)); + assert!(sql.contains("with")); + assert!(sql.contains("insert into")); + assert!(sql.contains("returning")); + } + + #[test] + fn test_multiple_ctes() { + let cte1 = Cte::select( + "a", + SelectStmt::columns(vec![SelectColumn::expr(Expr::int(1))]), + ); + let cte2 = Cte::select( + "b", + SelectStmt::columns(vec![SelectColumn::expr(Expr::int(2))]), + ); + + let outer = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("a").cross_join(FromClause::table("b"))) + .with_ctes(vec![cte1, cte2]); + + let sql = render(&Stmt::Select(outer)); + assert!(sql.contains("with")); + assert!(sql.contains("\"a\"")); + assert!(sql.contains("\"b\"")); + } +} + +mod params_tests { + use super::*; + + #[test] + fn test_param_collection() { + let mut collector = ParamCollector::new(); + + let _e1 = collector.add(ParamValue::String("test".into()), SqlType::text()); + let _e2 = collector.add(ParamValue::Integer(42), SqlType::integer()); + let _e3 = collector.add(ParamValue::Null, SqlType::text()); + + assert_eq!(collector.len(), 3); + assert!(!collector.is_empty()); + + let params = collector.into_params(); + assert_eq!(params[0].index, 1); + assert_eq!(params[1].index, 2); + assert_eq!(params[2].index, 3); + } + + #[test] + fn test_param_in_query() { + let mut collector = ParamCollector::new(); + let param = collector.add( + ParamValue::String("alice@example.com".into()), + SqlType::text(), + ); + + let stmt = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_where(Expr::column("email").eq(param)); + + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("($1::text)")); + } + + #[test] + fn test_array_param() { + let mut collector = ParamCollector::new(); + let param = collector.add( + ParamValue::Array(vec![ + ParamValue::Integer(1), + ParamValue::Integer(2), + ParamValue::Integer(3), + ]), + SqlType::integer().into_array(), + ); + + let stmt = SelectStmt::columns(vec![SelectColumn::star()]) + .with_from(FromClause::table("users")) + .with_where(Expr::binary(Expr::column("id"), BinaryOperator::Any, param)); + + let sql = render(&Stmt::Select(stmt)); + assert!(sql.contains("($1::integer[])")); + } +} + +mod type_tests { + use super::*; + + #[test] + fn test_builtin_types() { + assert_eq!(SqlType::text().to_sql_string(), "text"); + assert_eq!(SqlType::integer().to_sql_string(), "integer"); + assert_eq!(SqlType::bigint().to_sql_string(), "bigint"); + assert_eq!(SqlType::boolean().to_sql_string(), "boolean"); + assert_eq!(SqlType::jsonb().to_sql_string(), "jsonb"); + assert_eq!(SqlType::uuid().to_sql_string(), "uuid"); + } + + #[test] + fn test_array_types() { + assert_eq!(SqlType::text().into_array().to_sql_string(), "text[]"); + assert_eq!(SqlType::integer().as_array().to_sql_string(), "integer[]"); + } + + #[test] + fn test_custom_types() { + let t = SqlType::with_schema("public", "my_enum"); + assert_eq!(t.to_sql_string(), "public.my_enum"); + + let t = SqlType::from_name("public.my_type[]"); + assert_eq!(t.schema, Some("public".to_string())); + assert_eq!(t.name, "my_type"); + assert!(t.is_array); + } +} + +mod integration_tests { + use super::*; + + /// Test a realistic mutation pattern used by pg_graphql + #[test] + fn test_insert_mutation_pattern() { + // This mirrors the INSERT pattern in transpile.rs + let mut params = ParamCollector::new(); + + // Simulate inserting a user + let email_param = params.add( + ParamValue::String("user@example.com".into()), + SqlType::text(), + ); + let name_param = params.add(ParamValue::String("Test User".into()), SqlType::text()); + + // Build the INSERT CTE + let insert = InsertStmt::new( + "account", + vec![Ident::new("email"), Ident::new("name")], + InsertValues::Values(vec![vec![email_param, name_param]]), + ) + .with_schema("public") + .with_returning(vec![ + SelectColumn::expr(Expr::column("id")), + SelectColumn::expr(Expr::column("email")), + SelectColumn::expr(Expr::column("name")), + ]); + + // Build the outer SELECT with jsonb_build_object + let select = SelectStmt::columns(vec![SelectColumn::expr(Expr::jsonb_build_object(vec![ + ( + Expr::string("affectedCount"), + Expr::Aggregate(AggregateExpr::count_all()), + ), + ( + Expr::string("records"), + Expr::coalesce(vec![ + Expr::Aggregate(AggregateExpr::new( + AggregateFunction::JsonbAgg, + vec![Expr::jsonb_build_object(vec![ + (Expr::string("id"), Expr::column("id")), + (Expr::string("email"), Expr::column("email")), + ])], + )), + Expr::raw("jsonb_build_array()"), + ]), + ), + ]))]) + .with_from(FromClause::table("affected").with_alias("affected")) + .with_ctes(vec![Cte::insert("affected", insert)]); + + let sql = render(&Stmt::Select(select)); + + // Verify the structure + assert!(sql.contains("with")); + assert!(sql.contains("insert into")); + assert!(sql.contains("returning")); + assert!(sql.contains("jsonb_build_object")); + assert!(sql.contains("jsonb_agg")); + assert!(sql.contains("($1::text)")); + assert!(sql.contains("($2::text)")); + } + + /// Test the update mutation pattern with at_most check + #[test] + fn test_update_mutation_pattern() { + let mut params = ParamCollector::new(); + + let new_name = params.add(ParamValue::String("Updated".into()), SqlType::text()); + let filter_id = params.add(ParamValue::Integer(1), SqlType::integer()); + + // Build UPDATE CTE + let update = UpdateStmt::new("users", vec![(Ident::new("name"), new_name)]) + .with_schema("public") + .with_alias("t") + .with_where(Expr::qualified_column("t", "id").eq(filter_id)) + .with_returning(vec![SelectColumn::star()]); + + // Total count CTE + let total_count = SelectStmt::columns(vec![SelectColumn::expr_as( + Expr::Aggregate(AggregateExpr::count_all()), + "total_count", + )]) + .with_from(FromClause::table("impacted")); + + // Main select with safety check + let safety_check = Expr::Case(CaseExpr::searched( + vec![( + Expr::binary( + Expr::column("total_count"), + BinaryOperator::Gt, + Expr::int(1), // at_most = 1 + ), + Expr::raw("graphql.exception($a$update impacts too many records$a$)::jsonb"), + )], + Some(Expr::jsonb_build_object(vec![( + Expr::string("affectedCount"), + Expr::column("total_count"), + )])), + )); + + let main_select = SelectStmt::columns(vec![SelectColumn::expr(safety_check)]) + .with_from(FromClause::table("total")) + .with_ctes(vec![ + Cte::update("impacted", update), + Cte::select("total", total_count).with_columns(vec!["total_count"]), + ]); + + let sql = render(&Stmt::Select(main_select)); + + assert!(sql.contains("with")); + assert!(sql.contains("update")); + assert!(sql.contains("case")); + assert!(sql.contains("graphql.exception")); + } +} diff --git a/src/ast/transpile_connection.rs b/src/ast/transpile_connection.rs new file mode 100644 index 00000000..9cfcba7c --- /dev/null +++ b/src/ast/transpile_connection.rs @@ -0,0 +1,2118 @@ +//! AST-based transpilation for ConnectionBuilder +//! +//! This module implements the ToAst trait for ConnectionBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. ConnectionBuilder generates +//! complex pagination queries with multiple CTEs for: +//! - Record fetching with pagination +//! - Total count +//! - Has next/previous page detection +//! - Aggregate computations + +use super::{ + add_param_from_json, build_filter_expr, build_node_object_expr, coalesce, column_ref, + count_star, empty_jsonb_object, func_call, jsonb_agg_with_order_and_filter, jsonb_build_object, + string_literal, AstBuildContext, ToAst, +}; +use crate::ast::{ + BinaryOperator, ColumnRef, Cte, CteQuery, Expr, FromClause, Ident, Literal, NullsOrder, + OrderByExpr, OrderDirection, ParamCollector, SelectColumn, SelectStmt, Stmt, +}; +use crate::builder::{ + AggregateBuilder, AggregateSelection, ConnectionBuilder, ConnectionSelection, Cursor, + EdgeBuilder, EdgeSelection, OrderByBuilder, PageInfoBuilder, PageInfoSelection, +}; +use crate::error::{GraphQLError, GraphQLResult}; +use crate::sql_types::Table; + +/// The result of transpiling a ConnectionBuilder to AST +pub struct ConnectionAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +impl ToAst for ConnectionBuilder { + type Ast = ConnectionAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the __records CTE - main data fetch with pagination + let records_cte = build_records_cte(self, &block_name, params)?; + + // Build the __total_count CTE + let total_count_cte = build_total_count_cte(self, &block_name, params)?; + + // Build __has_next_page and __has_previous_page CTEs + let (has_next_cte, has_prev_cte) = build_pagination_ctes(self, &block_name, params)?; + + // Build the __has_records CTE + let has_records_cte = build_has_records_cte(); + + // Check if aggregates are requested and build the aggregate CTE + let aggregate_builder = self.selections.iter().find_map(|sel| match sel { + ConnectionSelection::Aggregate(builder) => Some(builder), + _ => None, + }); + let aggregates_cte = build_aggregates_cte(self, &block_name, aggregate_builder, params)?; + + // Build the main selection object (excluding aggregates - they're handled separately) + let object_columns = build_connection_object(self, &block_name, params)?; + + // Build the __base_object CTE that combines everything + let base_object_cte = build_base_object_cte(&object_columns, &block_name); + + // Build the final SELECT expression with aggregate merge if needed + let final_expr = if let Some(agg_builder) = aggregate_builder { + // Merge: coalesce(__base_object.obj, '{}'::jsonb) || jsonb_build_object(agg_alias, coalesce(__aggregates.agg_result, '{}'::jsonb)) + Expr::BinaryOp { + left: Box::new(coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ])), + op: BinaryOperator::JsonConcat, + right: Box::new(jsonb_build_object(vec![( + agg_builder.alias.clone(), + coalesce(vec![ + column_ref("__aggregates", "agg_result"), + empty_jsonb_object(), + ]), + )])), + } + } else { + coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ]) + }; + + // Combine all CTEs + let stmt = Stmt::Select(SelectStmt { + ctes: vec![ + records_cte, + total_count_cte, + has_next_cte, + has_prev_cte, + has_records_cte, + aggregates_cte, + base_object_cte, + ], + columns: vec![SelectColumn::expr(final_expr)], + from: Some(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Subquery { + query: Box::new(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + alias: Ident::new("__dummy_for_left_join"), + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__base_object"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__aggregates"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }); + + Ok(ConnectionAst { stmt }) + } +} + +/// Build the __records CTE that fetches the actual data +fn build_records_cte( + conn: &ConnectionBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + // Build filter expression + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + // Get cursor (before or after) + let cursor = conn.before.as_ref().or(conn.after.as_ref()); + + // Determine if this is reverse pagination (using 'before' or 'last') + let is_reverse = conn.before.is_some() || (conn.last.is_some() && conn.first.is_none()); + + // Get order by (reversed if using reverse pagination) + let order_by_builder = if is_reverse { + conn.order_by.reverse() + } else { + conn.order_by.clone() + }; + + // Build ORDER BY expressions + let order_by = build_order_by_exprs(&order_by_builder, block_name); + + // Build cursor pagination clause if cursor exists + let pagination_clause = if let Some(cursor) = cursor { + build_cursor_pagination_clause( + &conn.source.table, + &order_by_builder, + cursor, + block_name, + params, + false, // Don't include cursor's own record + )? + } else { + None + }; + + // Combine filter and pagination clauses + let where_clause = match (filter_clause, pagination_clause) { + (Some(f), Some(p)) => Some(Expr::BinaryOp { + left: Box::new(f), + op: BinaryOperator::And, + right: Box::new(p), + }), + (Some(f), None) => Some(f), + (None, Some(p)) => Some(p), + (None, None) => None, + }; + + // Calculate limit + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + // Select all selectable columns + let columns: Vec = conn + .source + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(ColumnRef::new(c.name.clone())))) + .collect(); + + let select = SelectStmt { + ctes: vec![], + columns, + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by, + limit: Some(limit), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + Ok(Cte { + name: Ident::new("__records"), + columns: None, + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build cursor pagination clause +/// +/// Generates a WHERE clause expression for cursor-based pagination using the +/// after or before cursor. The clause filters records based on the cursor's +/// position in the ordered result set. +fn build_cursor_pagination_clause( + table: &Table, + order_by: &OrderByBuilder, + cursor: &Cursor, + block_name: &str, + params: &mut ParamCollector, + allow_equality: bool, +) -> GraphQLResult> { + if cursor.elems.is_empty() { + return Ok(if allow_equality { + Some(Expr::Literal(Literal::Bool(true))) + } else { + None + }); + } + + // Build the recursive pagination clause + build_cursor_pagination_clause_recursive( + table, + order_by, + cursor, + block_name, + params, + allow_equality, + 0, + ) +} + +/// Recursive helper for building cursor pagination clause +fn build_cursor_pagination_clause_recursive( + table: &Table, + order_by: &OrderByBuilder, + cursor: &Cursor, + block_name: &str, + params: &mut ParamCollector, + allow_equality: bool, + depth: usize, +) -> GraphQLResult> { + // Check if cursor has more elements than order_by columns + if depth < cursor.elems.len() && depth >= order_by.elems.len() { + return Err(GraphQLError::validation( + "orderBy clause incompatible with pagination cursor", + )); + } + + if depth >= cursor.elems.len() { + return Ok(if allow_equality { + Some(Expr::Literal(Literal::Bool(true))) + } else { + None + }); + } + + let cursor_elem = &cursor.elems[depth]; + let order_elem = &order_by.elems[depth]; + let column = &order_elem.column; + + // Find the column type for parameter casting + let col_expr = column_ref(block_name, &column.name); + let val_expr = add_param_from_json(params, &cursor_elem.value, &column.type_name)?; + + // Determine comparison operator based on sort direction + let op = if order_elem.direction.is_asc() { + BinaryOperator::Gt + } else { + BinaryOperator::Lt + }; + + let nulls_first = order_elem.direction.nulls_first(); + + // Build the cursor comparison expression for proper null handling. + // + // For ASC with NULLS FIRST, the sort order is: NULL, 1, 2, 3, ... + // For ASC with NULLS LAST, the sort order is: 1, 2, 3, ..., NULL + // For DESC with NULLS FIRST, the sort order is: NULL, 3, 2, 1, ... + // For DESC with NULLS LAST, the sort order is: 3, 2, 1, ..., NULL + // + // The ">" comparison (for ASC) or "<" comparison (for DESC) needs to account for: + // 1. Standard comparison when both are non-null: col > val (or col < val) + // 2. When cursor val IS NULL and col IS NOT NULL: + // - With NULLS FIRST: non-null values come AFTER null, so col > cursor (include) + // - With NULLS LAST: non-null values come BEFORE null, so col < cursor (exclude) + // 3. When cursor val IS NOT NULL and col IS NULL: + // - With NULLS FIRST: null values come BEFORE non-null, so col < cursor (exclude for >) + // - With NULLS LAST: null values come AFTER non-null, so col > cursor (include for >) + // + // The correct expression is: + // (col > val) + // OR (col IS NOT NULL AND val IS NULL AND nulls_first) -- case 2: non-null > null when nulls_first + // OR (col IS NULL AND val IS NOT NULL AND NOT nulls_first) -- case 3: null > non-null when nulls_last + let main_compare = Expr::BinaryOp { + left: Box::new(col_expr.clone()), + op, + right: Box::new(val_expr.clone()), + }; + + // Case 2: (col IS NOT NULL AND val IS NULL AND nulls_first) + // When nulls come first, non-null values are "greater than" null values + let null_case_2 = Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::IsNull { + expr: Box::new(col_expr.clone()), + negated: true, // IS NOT NULL + }), + op: BinaryOperator::And, + right: Box::new(Expr::IsNull { + expr: Box::new(val_expr.clone()), + negated: false, // IS NULL + }), + }), + op: BinaryOperator::And, + right: Box::new(Expr::Literal(Literal::Bool(nulls_first))), + }; + + // Case 3: (col IS NULL AND val IS NOT NULL AND NOT nulls_first) + // When nulls come last, null values are "greater than" non-null values + let null_case_3 = Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(Expr::IsNull { + expr: Box::new(col_expr.clone()), + negated: false, // IS NULL + }), + op: BinaryOperator::And, + right: Box::new(Expr::IsNull { + expr: Box::new(val_expr.clone()), + negated: true, // IS NOT NULL + }), + }), + op: BinaryOperator::And, + right: Box::new(Expr::Literal(Literal::Bool(!nulls_first))), + }; + + // Combine: main_compare OR null_case_2 OR null_case_3 + let first_condition = Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(main_compare), + op: BinaryOperator::Or, + right: Box::new(null_case_2), + }), + op: BinaryOperator::Or, + right: Box::new(null_case_3), + })); + + // Build equality check for recursion: (col = val OR (col IS NULL AND val IS NULL)) + let equality_check = Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(Expr::BinaryOp { + left: Box::new(col_expr.clone()), + op: BinaryOperator::Eq, + right: Box::new(val_expr.clone()), + }), + op: BinaryOperator::Or, + right: Box::new(Expr::BinaryOp { + left: Box::new(Expr::IsNull { + expr: Box::new(col_expr), + negated: false, + }), + op: BinaryOperator::And, + right: Box::new(Expr::IsNull { + expr: Box::new(val_expr), + negated: false, + }), + }), + })); + + // Recurse to next level + let recurse = build_cursor_pagination_clause_recursive( + table, + order_by, + cursor, + block_name, + params, + allow_equality, + depth + 1, + )?; + + // Build: first_condition OR (equality_check AND recurse) + // Wrap in Nested for proper grouping + let result = match recurse { + Some(recurse_expr) => Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(first_condition), + op: BinaryOperator::Or, + right: Box::new(Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(equality_check), + op: BinaryOperator::And, + right: Box::new(recurse_expr), + }))), + })), + None => first_condition, + }; + + Ok(Some(result)) +} + +/// Build the __total_count CTE +fn build_total_count_cte( + conn: &ConnectionBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let where_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + let select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(count_star())], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Cte { + name: Ident::new("__total_count"), + columns: Some(vec![Ident::new("___total_count")]), + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build pagination detection CTEs +fn build_pagination_ctes( + conn: &ConnectionBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult<(Cte, Cte)> { + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + // Determine if this is reverse pagination + let is_reverse = conn.before.is_some() || (conn.last.is_some() && conn.first.is_none()); + + // Get order by (reversed if using reverse pagination) + let order_by_builder = if is_reverse { + conn.order_by.reverse() + } else { + conn.order_by.clone() + }; + + let order_by = build_order_by_exprs(&order_by_builder, block_name); + + // Get cursor (before or after) + let cursor = conn.before.as_ref().or(conn.after.as_ref()); + + // Build filter expression + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + // Build cursor pagination clause if cursor exists + let pagination_clause = if let Some(cursor) = cursor { + build_cursor_pagination_clause( + &conn.source.table, + &order_by_builder, + cursor, + block_name, + params, + false, + )? + } else { + None + }; + + // Combine filter and pagination clauses + let where_clause = match (filter_clause.clone(), pagination_clause) { + (Some(f), Some(p)) => Some(Expr::BinaryOp { + left: Box::new(f), + op: BinaryOperator::And, + right: Box::new(p), + }), + (Some(f), None) => Some(f), + (None, Some(p)) => Some(p), + (None, None) => None, + }; + + // __has_next_page: select count(*) > limit from (limited subquery) + let page_plus_1_select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause: where_clause.clone(), + group_by: vec![], + having: None, + order_by: order_by.clone(), + limit: Some(limit + 1), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + // __has_previous_page: check if there's a record that would have come before __records + // We check if the first record (by original order) is NOT in __records + // If offset > 0, then by definition there's a previous page + let has_prev_select = if offset > 0 { + // Simple case: offset > 0 means there's definitely a previous page + SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Bool(true)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + } + } else { + // Complex case: check if first record (by order) is NOT in __records + // This handles cursor-based pagination correctly + // + // Query structure: + // with page_minus_1 as ( + // select not (pk_tuple = any(__records.seen)) is_pkey_in_records + // from table + // left join (select array_agg(pk_tuple) from __records) __records(seen) on true + // where filter_clause + // order by order_by_clause + // limit 1 + // ) + // select coalesce(bool_and(is_pkey_in_records), false) from page_minus_1 + + // Build pk tuple expression: (pk_col1, pk_col2, ...) + let pk_columns = conn.source.table.primary_key_columns(); + let pk_tuple_from_table: Expr = if pk_columns.len() == 1 { + column_ref(block_name, &pk_columns[0].name) + } else { + // For multi-column pk, create a ROW expression + Expr::Row( + pk_columns + .iter() + .map(|c| column_ref(block_name, &c.name)) + .collect(), + ) + }; + + let pk_tuple_from_records: Expr = if pk_columns.len() == 1 { + column_ref("__records", &pk_columns[0].name) + } else { + Expr::Row( + pk_columns + .iter() + .map(|c| column_ref("__records", &c.name)) + .collect(), + ) + }; + + // Build: select array_agg(pk_tuple) as seen from __records + let seen_subquery = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::Expr { + expr: Expr::Aggregate(super::AggregateExpr::new( + super::AggregateFunction::ArrayAgg, + vec![pk_tuple_from_records], + )), + alias: Some(Ident::new("seen")), + }], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("__records"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + // Build: not (pk_tuple = any(__seen_records.seen)) + let is_pkey_in_records = Expr::UnaryOp { + op: super::UnaryOperator::Not, + expr: Box::new(Expr::BinaryOp { + left: Box::new(pk_tuple_from_table), + op: BinaryOperator::Any, + right: Box::new(column_ref("__seen_records", "seen")), + }), + }; + + // Build page_minus_1 CTE + let page_minus_1_select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::Expr { + expr: is_pkey_in_records, + alias: Some(Ident::new("is_pkey_in_records")), + }], + from: Some(FromClause::Join { + left: Box::new(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Subquery { + query: Box::new(seen_subquery), + alias: Ident::new("__seen_records"), + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: filter_clause.clone(), + group_by: vec![], + having: None, + order_by: order_by.clone(), + limit: Some(1), + offset: None, + }; + + // Build final select: coalesce(bool_and(is_pkey_in_records), false) + SelectStmt { + ctes: vec![Cte { + name: Ident::new("page_minus_1"), + columns: None, + query: CteQuery::Select(page_minus_1_select), + materialized: None, + }], + columns: vec![SelectColumn::expr(coalesce(vec![ + Expr::Aggregate(super::AggregateExpr::new( + super::AggregateFunction::BoolAnd, + vec![column_ref("page_minus_1", "is_pkey_in_records")], + )), + Expr::Literal(Literal::Bool(false)), + ]))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("page_minus_1"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + } + }; + + // For reverse pagination (before/last), swap hasNextPage and hasPreviousPage queries + // This is because when paginating backwards, "next" in the query direction + // is actually "previous" in the logical ordering + let (next_query, prev_query) = if is_reverse { + ( + CteQuery::Select(has_prev_select), + CteQuery::Select(SelectStmt { + ctes: vec![Cte { + name: Ident::new("page_plus_1"), + columns: None, + query: CteQuery::Select(page_plus_1_select), + materialized: None, + }], + columns: vec![SelectColumn::expr(Expr::BinaryOp { + left: Box::new(count_star()), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(limit as i64))), + })], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("page_plus_1"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + ) + } else { + ( + CteQuery::Select(SelectStmt { + ctes: vec![Cte { + name: Ident::new("page_plus_1"), + columns: None, + query: CteQuery::Select(page_plus_1_select), + materialized: None, + }], + columns: vec![SelectColumn::expr(Expr::BinaryOp { + left: Box::new(count_star()), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(limit as i64))), + })], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("page_plus_1"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + CteQuery::Select(has_prev_select), + ) + }; + + let has_next_cte = Cte { + name: Ident::new("__has_next_page"), + columns: Some(vec![Ident::new("___has_next_page")]), + query: next_query, + materialized: None, + }; + + let has_prev_cte = Cte { + name: Ident::new("__has_previous_page"), + columns: Some(vec![Ident::new("___has_previous_page")]), + query: prev_query, + materialized: None, + }; + + Ok((has_next_cte, has_prev_cte)) +} + +/// Build the __has_records CTE +fn build_has_records_cte() -> Cte { + Cte { + name: Ident::new("__has_records"), + columns: Some(vec![Ident::new("has_records")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Exists { + subquery: Box::new(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("__records"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + negated: false, + })], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + } +} + +/// Build the __aggregates CTE for aggregate computations +fn build_aggregates_cte( + conn: &ConnectionBuilder, + block_name: &str, + agg_builder: Option<&AggregateBuilder>, + params: &mut ParamCollector, +) -> GraphQLResult { + let Some(agg_builder) = agg_builder else { + // No aggregates requested - return a dummy CTE with null + return Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Cast { + expr: Box::new(Expr::Literal(Literal::Null)), + target_type: super::type_name_to_sql_type("jsonb"), + })], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }); + }; + + // Build the aggregate select list + let agg_pairs = build_aggregate_select_list(agg_builder, block_name); + + // Build WHERE clause (same filter as records) + let where_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(jsonb_build_object(agg_pairs))], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }) +} + +/// Build the aggregate select list (count, sum, avg, min, max) +fn build_aggregate_select_list( + agg_builder: &AggregateBuilder, + block_name: &str, +) -> Vec<(String, Expr)> { + let mut pairs = Vec::new(); + + for selection in &agg_builder.selections { + match selection { + AggregateSelection::Count { alias } => { + pairs.push((alias.clone(), count_star())); + } + AggregateSelection::Sum { + alias, + column_builders, + } => { + let field_pairs = + build_aggregate_field_pairs(column_builders, block_name, "sum", false); + pairs.push((alias.clone(), jsonb_build_object(field_pairs))); + } + AggregateSelection::Avg { + alias, + column_builders, + } => { + // AVG needs numeric cast for precision + let field_pairs = + build_aggregate_field_pairs(column_builders, block_name, "avg", true); + pairs.push((alias.clone(), jsonb_build_object(field_pairs))); + } + AggregateSelection::Min { + alias, + column_builders, + } => { + let field_pairs = + build_aggregate_field_pairs(column_builders, block_name, "min", false); + pairs.push((alias.clone(), jsonb_build_object(field_pairs))); + } + AggregateSelection::Max { + alias, + column_builders, + } => { + let field_pairs = + build_aggregate_field_pairs(column_builders, block_name, "max", false); + pairs.push((alias.clone(), jsonb_build_object(field_pairs))); + } + AggregateSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + } + } + + pairs +} + +/// Build field pairs for aggregate functions (sum/avg/min/max) +fn build_aggregate_field_pairs( + column_builders: &[crate::builder::ColumnBuilder], + block_name: &str, + func_name: &str, + cast_to_numeric: bool, +) -> Vec<(String, Expr)> { + column_builders + .iter() + .map(|col_builder| { + let col_expr = column_ref(block_name, &col_builder.column.name); + let arg_expr = if cast_to_numeric { + Expr::Cast { + expr: Box::new(col_expr), + target_type: super::type_name_to_sql_type("numeric"), + } + } else { + col_expr + }; + let agg_expr = func_call(func_name, vec![arg_expr]); + (col_builder.alias.clone(), agg_expr) + }) + .collect() +} + +/// Build the connection object columns from selections +fn build_connection_object( + conn: &ConnectionBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut pairs = Vec::new(); + + for selection in &conn.selections { + match selection { + ConnectionSelection::TotalCount { alias } => { + pairs.push((alias.clone(), column_ref("__total_count", "___total_count"))); + } + ConnectionSelection::Edge(edges_builder) => { + let edges_expr = build_edges_expr( + edges_builder, + block_name, + &conn.order_by, + &conn.source.table, + params, + )?; + pairs.push((edges_builder.alias.clone(), edges_expr)); + } + ConnectionSelection::PageInfo(page_info_builder) => { + let page_info_expr = + build_page_info_expr(page_info_builder, block_name, &conn.order_by)?; + pairs.push((page_info_builder.alias.clone(), page_info_expr)); + } + ConnectionSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + ConnectionSelection::Aggregate(_) => { + // Aggregate handling would be added here + // For now, skip + } + } + } + + Ok(pairs) +} + +/// Build edges expression +fn build_edges_expr( + edges: &EdgeBuilder, + block_name: &str, + order_by: &crate::builder::OrderByBuilder, + table: &Table, + params: &mut ParamCollector, +) -> GraphQLResult { + let mut edge_pairs = Vec::new(); + + for selection in &edges.selections { + match selection { + EdgeSelection::Cursor { alias } => { + // Build cursor: translate(encode(convert_to(jsonb_build_array(to_jsonb(col1), to_jsonb(col2), ...)::text, 'utf-8'), 'base64'), E'\n', '') + let cursor_expr = build_cursor_expr(block_name, order_by); + edge_pairs.push((alias.clone(), cursor_expr)); + } + EdgeSelection::Node(node_builder) => { + let node_expr = + build_node_object_expr(&node_builder.selections, block_name, params)?; + edge_pairs.push((node_builder.alias.clone(), node_expr)); + } + EdgeSelection::Typename { alias, typename } => { + edge_pairs.push((alias.clone(), string_literal(typename))); + } + } + } + + // Build: coalesce(jsonb_agg(jsonb_build_object(...) order by ... ) filter (where pk is not null), '[]') + // The filter clause excludes null rows from LEFT JOIN when no matching records exist + // The ORDER BY re-reverses results that were fetched in reverse order for backward pagination + let edge_object = jsonb_build_object(edge_pairs); + + // Build ORDER BY expressions for jsonb_agg using NORMAL order (not reversed) + // This re-sorts results that were fetched in reverse order for backward pagination + let order_by_exprs: Vec = order_by + .elems + .iter() + .map(|elem| { + // Convert the combined OrderDirection enum to separate direction and nulls + let (direction, nulls) = match elem.direction { + crate::builder::OrderDirection::AscNullsFirst => { + (OrderDirection::Asc, NullsOrder::First) + } + crate::builder::OrderDirection::AscNullsLast => { + (OrderDirection::Asc, NullsOrder::Last) + } + crate::builder::OrderDirection::DescNullsFirst => { + (OrderDirection::Desc, NullsOrder::First) + } + crate::builder::OrderDirection::DescNullsLast => { + (OrderDirection::Desc, NullsOrder::Last) + } + }; + OrderByExpr { + expr: column_ref(block_name, &elem.column.name), + direction: Some(direction), + nulls: Some(nulls), + } + }) + .collect(); + + // Get the first primary key column to use in the filter + let pk_columns = table.primary_key_columns(); + let filter_expr = pk_columns.first().map(|pk_col| { + // Build: block_name.pk_col is not null + Expr::IsNull { + expr: Box::new(column_ref(block_name, &pk_col.name)), + negated: true, // IS NOT NULL + } + }); + + // Build jsonb_agg with order by and optional filter + let agg_expr = jsonb_agg_with_order_and_filter(edge_object, order_by_exprs, filter_expr); + + // Wrap with CASE WHEN for safety: return empty [] when no records exist + // This handles edge cases where LEFT JOIN produces NULL rows + let has_records_ref = column_ref("__has_records", "has_records"); + let case_expr = Expr::Case(super::CaseExpr { + operand: None, + when_clauses: vec![(has_records_ref, coalesce(vec![agg_expr, super::empty_jsonb_array()]))], + else_clause: Some(Box::new(super::empty_jsonb_array())), + }); + + Ok(case_expr) +} + +/// Build cursor expression from order_by columns +/// Format: translate(encode(convert_to(jsonb_build_array(to_jsonb(col1), ...)::text, 'utf-8'), 'base64'), E'\n', '') +fn build_cursor_expr(block_name: &str, order_by: &crate::builder::OrderByBuilder) -> Expr { + // Build to_jsonb(block_name.column) for each order_by column + let jsonb_cols: Vec = order_by + .elems + .iter() + .map(|elem| func_call("to_jsonb", vec![column_ref(block_name, &elem.column.name)])) + .collect(); + + // Build jsonb_build_array(...) + let jsonb_array = func_call("jsonb_build_array", jsonb_cols); + + // Cast to text + let as_text = Expr::Cast { + expr: Box::new(jsonb_array), + target_type: super::type_name_to_sql_type("text"), + }; + + // convert_to(..., 'utf-8') + let converted = func_call("convert_to", vec![as_text, string_literal("utf-8")]); + + // encode(..., 'base64') + let encoded = func_call("encode", vec![converted, string_literal("base64")]); + + // translate(..., E'\n', '') + func_call( + "translate", + vec![encoded, string_literal("\n"), string_literal("")], + ) +} + +/// Build page info expression +fn build_page_info_expr( + page_info: &PageInfoBuilder, + block_name: &str, + order_by: &crate::builder::OrderByBuilder, +) -> GraphQLResult { + let mut pairs = Vec::new(); + + // Build cursor expression for start/end cursor + let cursor_expr = build_cursor_expr(block_name, order_by); + + // Build forward order by expressions (for startCursor) + let forward_order_by = build_order_by_exprs(order_by, block_name); + + // Build reversed order by expressions (for endCursor) + let reversed_order_by: Vec = order_by + .elems + .iter() + .map(|elem| { + let (direction, nulls) = match elem.direction { + crate::builder::OrderDirection::AscNullsFirst => { + (Some(OrderDirection::Desc), Some(NullsOrder::Last)) + } + crate::builder::OrderDirection::AscNullsLast => { + (Some(OrderDirection::Desc), Some(NullsOrder::First)) + } + crate::builder::OrderDirection::DescNullsFirst => { + (Some(OrderDirection::Asc), Some(NullsOrder::Last)) + } + crate::builder::OrderDirection::DescNullsLast => { + (Some(OrderDirection::Asc), Some(NullsOrder::First)) + } + }; + OrderByExpr { + expr: column_ref(block_name, &elem.column.name), + direction, + nulls, + } + }) + .collect(); + + for selection in &page_info.selections { + match selection { + PageInfoSelection::HasNextPage { alias } => { + pairs.push(( + alias.clone(), + coalesce(vec![ + func_call( + "bool_and", + vec![column_ref("__has_next_page", "___has_next_page")], + ), + Expr::Literal(Literal::Bool(false)), + ]), + )); + } + PageInfoSelection::HasPreviousPage { alias } => { + pairs.push(( + alias.clone(), + coalesce(vec![ + func_call( + "bool_and", + vec![column_ref("__has_previous_page", "___has_previous_page")], + ), + Expr::Literal(Literal::Bool(false)), + ]), + )); + } + PageInfoSelection::StartCursor { alias } => { + // case when __has_records.has_records then (array_agg(cursor order by order_by))[1] else null end + let array_agg_expr = Expr::ArrayIndex { + array: Box::new(Expr::FunctionCallWithOrderBy { + name: "array_agg".to_string(), + args: vec![cursor_expr.clone()], + order_by: forward_order_by.clone(), + }), + index: Box::new(Expr::Literal(Literal::Integer(1))), + }; + let case_expr = Expr::Case(crate::ast::CaseExpr::searched( + vec![(column_ref("__has_records", "has_records"), array_agg_expr)], + Some(Expr::Literal(Literal::Null)), + )); + pairs.push((alias.clone(), case_expr)); + } + PageInfoSelection::EndCursor { alias } => { + // case when __has_records.has_records then (array_agg(cursor order by order_by_reversed))[1] else null end + let array_agg_expr = Expr::ArrayIndex { + array: Box::new(Expr::FunctionCallWithOrderBy { + name: "array_agg".to_string(), + args: vec![cursor_expr.clone()], + order_by: reversed_order_by.clone(), + }), + index: Box::new(Expr::Literal(Literal::Integer(1))), + }; + let case_expr = Expr::Case(crate::ast::CaseExpr::searched( + vec![(column_ref("__has_records", "has_records"), array_agg_expr)], + Some(Expr::Literal(Literal::Null)), + )); + pairs.push((alias.clone(), case_expr)); + } + PageInfoSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + } + } + + Ok(jsonb_build_object(pairs)) +} + +/// Build the __base_object CTE +fn build_base_object_cte(object_columns: &[(String, Expr)], block_name: &str) -> Cte { + let object_expr = jsonb_build_object(object_columns.to_vec()); + + // The key insight: __records is LEFT JOINed and aliased as block_name + // This allows object_clause expressions (which use block_name) to work correctly + Cte { + name: Ident::new("__base_object"), + columns: Some(vec![Ident::new("obj")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::Expr { + expr: object_expr, + alias: None, + }], + from: Some(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__total_count"), + alias: None, + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__has_next_page"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__has_previous_page"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__has_records"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + // LEFT JOIN __records aliased as block_name + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__records"), + alias: Some(Ident::new(block_name)), + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![ + column_ref("__total_count", "___total_count"), + column_ref("__has_next_page", "___has_next_page"), + column_ref("__has_previous_page", "___has_previous_page"), + column_ref("__has_records", "has_records"), + ], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + } +} + +/// Build ORDER BY expressions from the order_by builder +fn build_order_by_exprs( + order_by: &crate::builder::OrderByBuilder, + block_name: &str, +) -> Vec { + order_by + .elems + .iter() + .map(|elem| { + let (direction, nulls) = match elem.direction { + crate::builder::OrderDirection::AscNullsFirst => { + (Some(OrderDirection::Asc), Some(NullsOrder::First)) + } + crate::builder::OrderDirection::AscNullsLast => { + (Some(OrderDirection::Asc), Some(NullsOrder::Last)) + } + crate::builder::OrderDirection::DescNullsFirst => { + (Some(OrderDirection::Desc), Some(NullsOrder::First)) + } + crate::builder::OrderDirection::DescNullsLast => { + (Some(OrderDirection::Desc), Some(NullsOrder::Last)) + } + }; + OrderByExpr { + expr: column_ref(block_name, &elem.column.name), + direction, + nulls, + } + }) + .collect() +} + +/// Build a connection query as a subquery expression (for nested connections inside nodes) +/// +/// This is the public interface for building a connection when it's nested inside a node selection. +/// It builds the same complex CTE-based query but returns it as a subquery expression. +pub fn build_connection_subquery( + conn: &ConnectionBuilder, + parent_block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the join clause for the foreign key relationship + let join_clause = build_fkey_join_clause(conn, &block_name, parent_block_name); + + // Build the __records CTE - main data fetch with pagination and FK join + let records_cte = build_records_cte_with_join(conn, &block_name, &join_clause, params)?; + + // Build the __total_count CTE with FK join + let total_count_cte = build_total_count_cte_with_join(conn, &block_name, &join_clause, params)?; + + // Build __has_next_page and __has_previous_page CTEs with FK join + let (has_next_cte, has_prev_cte) = + build_pagination_ctes_with_join(conn, &block_name, &join_clause, params)?; + + // Build the __has_records CTE + let has_records_cte = build_has_records_cte(); + + // Check if aggregates are requested and build the aggregate CTE + let aggregate_builder = conn.selections.iter().find_map(|sel| match sel { + ConnectionSelection::Aggregate(builder) => Some(builder), + _ => None, + }); + let aggregates_cte = + build_aggregates_cte_with_join(conn, &block_name, &join_clause, aggregate_builder, params)?; + + // Build the main selection object (excluding aggregates - they're handled separately) + let object_columns = build_connection_object(conn, &block_name, params)?; + + // Build the __base_object CTE that combines everything + let base_object_cte = build_base_object_cte(&object_columns, &block_name); + + // Build the final SELECT expression with aggregate merge if needed + let final_expr = if let Some(agg_builder) = aggregate_builder { + Expr::BinaryOp { + left: Box::new(coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ])), + op: BinaryOperator::JsonConcat, + right: Box::new(jsonb_build_object(vec![( + agg_builder.alias.clone(), + coalesce(vec![ + column_ref("__aggregates", "agg_result"), + empty_jsonb_object(), + ]), + )])), + } + } else { + coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ]) + }; + + // Build the full select statement as a subquery + let select = SelectStmt { + ctes: vec![ + records_cte, + total_count_cte, + has_next_cte, + has_prev_cte, + has_records_cte, + aggregates_cte, + base_object_cte, + ], + columns: vec![SelectColumn::expr(final_expr)], + from: Some(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Subquery { + query: Box::new(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + alias: Ident::new("__dummy_for_left_join"), + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__base_object"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__aggregates"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Expr::Subquery(Box::new(select))) +} + +/// Build the join clause for a foreign key relationship +fn build_fkey_join_clause( + conn: &ConnectionBuilder, + block_name: &str, + parent_block_name: &str, +) -> Option { + let fkey_reversible = conn.source.fkey.as_ref()?; + let fkey = &fkey_reversible.fkey; + let reverse = fkey_reversible.reverse_reference; + + let (local_cols, parent_cols) = if reverse { + ( + &fkey.local_table_meta.column_names, + &fkey.referenced_table_meta.column_names, + ) + } else { + ( + &fkey.referenced_table_meta.column_names, + &fkey.local_table_meta.column_names, + ) + }; + + let mut conditions: Vec = Vec::new(); + for (local_col, parent_col) in local_cols.iter().zip(parent_cols.iter()) { + conditions.push(Expr::BinaryOp { + left: Box::new(column_ref(block_name, local_col)), + op: BinaryOperator::Eq, + right: Box::new(column_ref(parent_block_name, parent_col)), + }); + } + + if conditions.is_empty() { + None + } else { + Some(super::combine_with_and(conditions)) + } +} + +/// Build __records CTE with join clause for nested connections +fn build_records_cte_with_join( + conn: &ConnectionBuilder, + block_name: &str, + join_clause: &Option, + params: &mut ParamCollector, +) -> GraphQLResult { + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + let order_by = build_order_by_exprs(&conn.order_by, block_name); + + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + // Combine join clause with filter clause + let where_clause = match (join_clause.clone(), filter_clause) { + (Some(join), Some(filter)) => Some(Expr::BinaryOp { + left: Box::new(join), + op: BinaryOperator::And, + right: Box::new(filter), + }), + (Some(join), None) => Some(join), + (None, Some(filter)) => Some(filter), + (None, None) => None, + }; + + let columns: Vec = conn + .source + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(ColumnRef::new(c.name.clone())))) + .collect(); + + let select = SelectStmt { + ctes: vec![], + columns, + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by, + limit: Some(limit), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + Ok(Cte { + name: Ident::new("__records"), + columns: None, + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build __total_count CTE with join clause for nested connections +fn build_total_count_cte_with_join( + conn: &ConnectionBuilder, + block_name: &str, + join_clause: &Option, + params: &mut ParamCollector, +) -> GraphQLResult { + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + let where_clause = match (join_clause.clone(), filter_clause) { + (Some(join), Some(filter)) => Some(Expr::BinaryOp { + left: Box::new(join), + op: BinaryOperator::And, + right: Box::new(filter), + }), + (Some(join), None) => Some(join), + (None, Some(filter)) => Some(filter), + (None, None) => None, + }; + + let select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(count_star())], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Cte { + name: Ident::new("__total_count"), + columns: Some(vec![Ident::new("___total_count")]), + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build pagination CTEs with join clause for nested connections +fn build_pagination_ctes_with_join( + conn: &ConnectionBuilder, + block_name: &str, + join_clause: &Option, + params: &mut ParamCollector, +) -> GraphQLResult<(Cte, Cte)> { + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + let order_by = build_order_by_exprs(&conn.order_by, block_name); + + let where_clause = match (join_clause.clone(), filter_clause) { + (Some(join), Some(filter)) => Some(Expr::BinaryOp { + left: Box::new(join), + op: BinaryOperator::And, + right: Box::new(filter), + }), + (Some(join), None) => Some(join), + (None, Some(filter)) => Some(filter), + (None, None) => None, + }; + + let page_plus_1_select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause: where_clause.clone(), + group_by: vec![], + having: None, + order_by: order_by.clone(), + limit: Some(limit + 1), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + let has_next_cte = Cte { + name: Ident::new("__has_next_page"), + columns: Some(vec![Ident::new("___has_next_page")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![Cte { + name: Ident::new("page_plus_1"), + columns: None, + query: CteQuery::Select(page_plus_1_select), + materialized: None, + }], + columns: vec![SelectColumn::expr(Expr::BinaryOp { + left: Box::new(count_star()), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(limit as i64))), + })], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("page_plus_1"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }; + + let has_prev_cte = Cte { + name: Ident::new("__has_previous_page"), + columns: Some(vec![Ident::new("___has_previous_page")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Bool(offset > 0)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }; + + Ok((has_next_cte, has_prev_cte)) +} + +/// Build aggregates CTE with join clause for nested connections +fn build_aggregates_cte_with_join( + conn: &ConnectionBuilder, + block_name: &str, + join_clause: &Option, + agg_builder: Option<&AggregateBuilder>, + params: &mut ParamCollector, +) -> GraphQLResult { + let Some(agg_builder) = agg_builder else { + return Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Cast { + expr: Box::new(Expr::Literal(Literal::Null)), + target_type: super::type_name_to_sql_type("jsonb"), + })], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }); + }; + + let agg_pairs = build_aggregate_select_list(agg_builder, block_name); + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + let where_clause = match (join_clause.clone(), filter_clause) { + (Some(join), Some(filter)) => Some(Expr::BinaryOp { + left: Box::new(join), + op: BinaryOperator::And, + right: Box::new(filter), + }), + (Some(join), None) => Some(join), + (None, Some(filter)) => Some(filter), + (None, None) => None, + }; + + Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(jsonb_build_object(agg_pairs))], + from: Some(FromClause::Table { + schema: Some(Ident::new(conn.source.table.schema.clone())), + name: Ident::new(conn.source.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }) +} + +/// Build a connection query using a function call as the FROM source +/// +/// This is used when a function returns `setof ` and we want to expose +/// that as a connection. The function call replaces the table in the FROM clause. +pub fn build_function_connection_subquery_full( + conn: &ConnectionBuilder, + func_call: super::FunctionCall, + params: &mut ParamCollector, +) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the from clause using the function call + let from_clause = FromClause::Function { + call: func_call.clone(), + alias: Ident::new(&block_name), + }; + + // Build the __records CTE - main data fetch with pagination + let records_cte = build_records_cte_from_function(conn, &block_name, &from_clause, params)?; + + // Build the __total_count CTE + let total_count_cte = + build_total_count_cte_from_function(conn, &block_name, &from_clause, params)?; + + // Build __has_next_page and __has_previous_page CTEs + let (has_next_cte, has_prev_cte) = + build_pagination_ctes_from_function(conn, &block_name, &from_clause, params)?; + + // Build the __has_records CTE + let has_records_cte = build_has_records_cte(); + + // Check if aggregates are requested and build the aggregate CTE + let aggregate_builder = conn.selections.iter().find_map(|sel| match sel { + ConnectionSelection::Aggregate(builder) => Some(builder), + _ => None, + }); + let aggregates_cte = build_aggregates_cte_from_function( + conn, + &block_name, + &from_clause, + aggregate_builder, + params, + )?; + + // Build the main selection object (excluding aggregates - they're handled separately) + let object_columns = build_connection_object(conn, &block_name, params)?; + + // Build the __base_object CTE that combines everything + let base_object_cte = build_base_object_cte(&object_columns, &block_name); + + // Build the final SELECT expression with aggregate merge if needed + let final_expr = if let Some(agg_builder) = aggregate_builder { + Expr::BinaryOp { + left: Box::new(coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ])), + op: BinaryOperator::JsonConcat, + right: Box::new(jsonb_build_object(vec![( + agg_builder.alias.clone(), + coalesce(vec![ + column_ref("__aggregates", "agg_result"), + empty_jsonb_object(), + ]), + )])), + } + } else { + coalesce(vec![ + column_ref("__base_object", "obj"), + empty_jsonb_object(), + ]) + }; + + // Build the full select statement as a subquery + let select = SelectStmt { + ctes: vec![ + records_cte, + total_count_cte, + has_next_cte, + has_prev_cte, + has_records_cte, + aggregates_cte, + base_object_cte, + ], + columns: vec![SelectColumn::expr(final_expr)], + from: Some(FromClause::Join { + left: Box::new(FromClause::Join { + left: Box::new(FromClause::Subquery { + query: Box::new(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + alias: Ident::new("__dummy_for_left_join"), + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__base_object"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + join_type: super::JoinType::Inner, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("__aggregates"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Expr::Subquery(Box::new(select))) +} + +/// Build __records CTE using a function as the FROM source +fn build_records_cte_from_function( + conn: &ConnectionBuilder, + block_name: &str, + from_clause: &FromClause, + params: &mut ParamCollector, +) -> GraphQLResult { + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + // Determine if this is reverse pagination (using 'before' or 'last') + let is_reverse = conn.before.is_some() || (conn.last.is_some() && conn.first.is_none()); + + // Get order by (reversed if using reverse pagination) + let order_by_builder = if is_reverse { + conn.order_by.reverse() + } else { + conn.order_by.clone() + }; + + let order_by = build_order_by_exprs(&order_by_builder, block_name); + + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + let columns: Vec = conn + .source + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(ColumnRef::new(c.name.clone())))) + .collect(); + + let select = SelectStmt { + ctes: vec![], + columns, + from: Some(from_clause.clone()), + where_clause: filter_clause, + group_by: vec![], + having: None, + order_by, + limit: Some(limit), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + Ok(Cte { + name: Ident::new("__records"), + columns: None, + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build __total_count CTE using a function as the FROM source +fn build_total_count_cte_from_function( + conn: &ConnectionBuilder, + block_name: &str, + from_clause: &FromClause, + params: &mut ParamCollector, +) -> GraphQLResult { + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + let select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(count_star())], + from: Some(from_clause.clone()), + where_clause: filter_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Cte { + name: Ident::new("__total_count"), + columns: Some(vec![Ident::new("___total_count")]), + query: CteQuery::Select(select), + materialized: None, + }) +} + +/// Build pagination CTEs using a function as the FROM source +fn build_pagination_ctes_from_function( + conn: &ConnectionBuilder, + block_name: &str, + from_clause: &FromClause, + params: &mut ParamCollector, +) -> GraphQLResult<(Cte, Cte)> { + let limit = conn + .first + .or(conn.last) + .map(|l| std::cmp::min(l, conn.max_rows)) + .unwrap_or(conn.max_rows); + + let offset = conn.offset.unwrap_or(0); + + // Determine if this is reverse pagination + let is_reverse = conn.before.is_some() || (conn.last.is_some() && conn.first.is_none()); + + // Get order by (reversed if using reverse pagination) + let order_by_builder = if is_reverse { + conn.order_by.reverse() + } else { + conn.order_by.clone() + }; + + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + let order_by = build_order_by_exprs(&order_by_builder, block_name); + + let page_plus_1_select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Integer(1)))], + from: Some(from_clause.clone()), + where_clause: filter_clause.clone(), + group_by: vec![], + having: None, + order_by: order_by.clone(), + limit: Some(limit + 1), + offset: if offset > 0 { Some(offset) } else { None }, + }; + + let has_more_query = CteQuery::Select(SelectStmt { + ctes: vec![Cte { + name: Ident::new("page_plus_1"), + columns: None, + query: CteQuery::Select(page_plus_1_select), + materialized: None, + }], + columns: vec![SelectColumn::expr(Expr::BinaryOp { + left: Box::new(count_star()), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(limit as i64))), + })], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("page_plus_1"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }); + + let offset_check_query = CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Literal(Literal::Bool(offset > 0)))], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }); + + // For reverse pagination (before/last), swap hasNextPage and hasPreviousPage + let (next_query, prev_query) = if is_reverse { + (offset_check_query, has_more_query) + } else { + (has_more_query, offset_check_query) + }; + + let has_next_cte = Cte { + name: Ident::new("__has_next_page"), + columns: Some(vec![Ident::new("___has_next_page")]), + query: next_query, + materialized: None, + }; + + let has_prev_cte = Cte { + name: Ident::new("__has_previous_page"), + columns: Some(vec![Ident::new("___has_previous_page")]), + query: prev_query, + materialized: None, + }; + + Ok((has_next_cte, has_prev_cte)) +} + +/// Build aggregates CTE using a function as the FROM source +fn build_aggregates_cte_from_function( + conn: &ConnectionBuilder, + block_name: &str, + from_clause: &FromClause, + agg_builder: Option<&AggregateBuilder>, + params: &mut ParamCollector, +) -> GraphQLResult { + let Some(agg_builder) = agg_builder else { + return Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::Cast { + expr: Box::new(Expr::Literal(Literal::Null)), + target_type: super::type_name_to_sql_type("jsonb"), + })], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }); + }; + + let agg_pairs = build_aggregate_select_list(agg_builder, block_name); + let filter_clause = build_filter_expr(&conn.filter, &conn.source.table, block_name, params)?; + + Ok(Cte { + name: Ident::new("__aggregates"), + columns: Some(vec![Ident::new("agg_result")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(jsonb_build_object(agg_pairs))], + from: Some(from_clause.clone()), + where_clause: filter_clause, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_records_cte() { + let cte = build_has_records_cte(); + assert_eq!(cte.name.0, "__has_records"); + } +} diff --git a/src/ast/transpile_delete.rs b/src/ast/transpile_delete.rs new file mode 100644 index 00000000..9bb72dd3 --- /dev/null +++ b/src/ast/transpile_delete.rs @@ -0,0 +1,338 @@ +//! AST-based transpilation for DeleteBuilder +//! +//! This module implements the ToAst trait for DeleteBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. + +use super::{ + build_connection_subquery, build_filter_expr, build_relation_subquery_expr, coalesce, + column_ref, count_star, empty_jsonb_array, func_call, func_call_schema, jsonb_agg, + jsonb_build_object, string_literal, AstBuildContext, ToAst, +}; +use crate::ast::{ + BinaryOperator, CaseExpr, ColumnRef, Cte, CteQuery, DeleteStmt, Expr, FromClause, Ident, + Literal, ParamCollector, SelectColumn, SelectStmt, Stmt, +}; +use crate::builder::{ + DeleteBuilder, DeleteSelection, FunctionBuilder, FunctionSelection, NodeBuilder, NodeSelection, +}; +use crate::error::GraphQLResult; + +/// The result of transpiling a DeleteBuilder to AST +pub struct DeleteAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +impl ToAst for DeleteBuilder { + type Ast = DeleteAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build WHERE clause from filter + let where_clause = build_filter_expr(&self.filter, &self.table, &block_name, params)?; + + // Build RETURNING clause - all selectable columns + let returning: Vec = self + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(ColumnRef::new(c.name.clone())))) + .collect(); + + // Build the DELETE statement for the 'impacted' CTE + let delete_stmt = DeleteStmt { + ctes: vec![], + schema: Some(Ident::new(self.table.schema.clone())), + table: Ident::new(self.table.name.clone()), + alias: Some(Ident::new(block_name.clone())), + using: None, + where_clause, + returning, + }; + + // Build the select columns for jsonb_build_object + let select_columns = build_select_columns(&self.selections, &block_name, params)?; + + // Build the complex CTE structure + let stmt = + build_delete_with_at_most(delete_stmt, select_columns, &block_name, self.at_most); + + Ok(DeleteAst { stmt }) + } +} + +/// Build the select columns for jsonb_build_object from DeleteSelection +fn build_select_columns( + selections: &[DeleteSelection], + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut pairs = Vec::new(); + + for selection in selections { + match selection { + DeleteSelection::AffectedCount { alias } => { + pairs.push((string_literal(alias), count_star())); + } + DeleteSelection::Records(node_builder) => { + let node_expr = build_node_builder_expr(node_builder, block_name, params)?; + pairs.push(( + string_literal(&node_builder.alias), + coalesce(vec![jsonb_agg(node_expr), empty_jsonb_array()]), + )); + } + DeleteSelection::Typename { alias, typename } => { + pairs.push((string_literal(alias), string_literal(typename))); + } + } + } + + Ok(pairs) +} + +/// Build expression for a NodeBuilder (simplified) +fn build_node_builder_expr( + node_builder: &NodeBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let mut pairs = Vec::new(); + + for selection in &node_builder.selections { + match selection { + NodeSelection::Column(col_builder) => { + // Use build_column_expr which handles enum mappings + let col_expr = super::build_column_expr(col_builder, block_name); + pairs.push((col_builder.alias.clone(), col_expr)); + } + NodeSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + NodeSelection::NodeId(node_id_builder) => { + let pk_exprs: Vec = node_id_builder + .columns + .iter() + .map(|c| func_call("to_jsonb", vec![column_ref(block_name, &c.name)])) + .collect(); + + let mut array_args = vec![ + string_literal(&node_id_builder.schema_name), + string_literal(&node_id_builder.table_name), + ]; + array_args.extend(pk_exprs); + + let jsonb_array = func_call("jsonb_build_array", array_args); + let as_text = Expr::Cast { + expr: Box::new(jsonb_array), + target_type: super::type_name_to_sql_type("text"), + }; + let converted = func_call("convert_to", vec![as_text, string_literal("utf-8")]); + let encoded = func_call("encode", vec![converted, string_literal("base64")]); + let translated = func_call( + "translate", + vec![encoded, string_literal("\n"), string_literal("")], + ); + + pairs.push((node_id_builder.alias.clone(), translated)); + } + NodeSelection::Function(func_builder) => { + // Build function expression: schema.function(row::schema.table) + let func_expr = build_function_expr(func_builder, block_name)?; + pairs.push((func_builder.alias.clone(), func_expr)); + } + NodeSelection::Connection(conn_builder) => { + // Connection selections - build subquery + let conn_expr = build_connection_subquery(conn_builder, block_name, params)?; + pairs.push((conn_builder.alias.clone(), conn_expr)); + } + NodeSelection::Node(nested_node) => { + // Nested node selections - build subquery + let node_expr = build_relation_subquery_expr(nested_node, block_name, params)?; + pairs.push((nested_node.alias.clone(), node_expr)); + } + } + } + + Ok(jsonb_build_object(pairs)) +} + +/// Build the full DELETE with at_most check using CTEs +fn build_delete_with_at_most( + delete_stmt: DeleteStmt, + select_columns: Vec<(Expr, Expr)>, + block_name: &str, + at_most: i64, +) -> Stmt { + // CTE 1: impacted AS (DELETE ...) + let impacted_cte = Cte { + name: Ident::new("impacted"), + columns: None, + query: CteQuery::Delete(delete_stmt), + materialized: None, + }; + + // CTE 2: total(total_count) AS (SELECT count(*) FROM impacted) + let total_cte = Cte { + name: Ident::new("total"), + columns: Some(vec![Ident::new("total_count")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(count_star())], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("impacted"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }; + + // CTE 3: req(res) AS (SELECT jsonb_build_object(...) FROM impacted LIMIT 1) + let req_cte = Cte { + name: Ident::new("req"), + columns: Some(vec![Ident::new("res")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::JsonBuild( + super::JsonBuildExpr::Object(select_columns), + ))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("impacted"), + alias: Some(Ident::new(block_name)), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: Some(1), + offset: None, + }), + materialized: None, + }; + + // CTE 4: wrapper(res) AS (SELECT CASE WHEN total.total_count > at_most THEN exception ELSE req.res END ...) + let case_expr = Expr::Case(CaseExpr::searched( + vec![( + // WHEN total.total_count > at_most + Expr::BinaryOp { + left: Box::new(column_ref("total", "total_count")), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(at_most))), + }, + // THEN graphql.exception(...)::jsonb + Expr::Cast { + expr: Box::new(func_call_schema( + "graphql", + "exception", + vec![string_literal("delete impacts too many records")], + )), + target_type: super::type_name_to_sql_type("jsonb"), + }, + )], + // ELSE req.res + Some(column_ref("req", "res")), + )); + + let wrapper_cte = Cte { + name: Ident::new("wrapper"), + columns: Some(vec![Ident::new("res")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(case_expr)], + from: Some(FromClause::Join { + left: Box::new(FromClause::Table { + schema: None, + name: Ident::new("total"), + alias: None, + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("req"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: Some(1), + offset: None, + }), + materialized: None, + }; + + // Final SELECT from wrapper + Stmt::Select(SelectStmt { + ctes: vec![impacted_cte, total_cte, req_cte, wrapper_cte], + columns: vec![SelectColumn::expr(column_ref("wrapper", "res"))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("wrapper"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }) +} + +/// Build a function expression for scalar function calls +fn build_function_expr(func_builder: &FunctionBuilder, block_name: &str) -> GraphQLResult { + use super::apply_type_cast; + use crate::ast::{FunctionArg, FunctionCall, SqlType}; + + // Build the row argument with type cast: block_name::schema.table + let row_arg = Expr::Cast { + expr: Box::new(Expr::Column(ColumnRef::new(block_name))), + target_type: SqlType::with_schema( + func_builder.table.schema.clone(), + func_builder.table.name.clone(), + ), + }; + + let args = vec![FunctionArg::unnamed(row_arg)]; + + let func_expr = Expr::FunctionCall(FunctionCall::with_schema( + func_builder.function.schema_name.clone(), + func_builder.function.name.clone(), + args, + )); + + // Apply type cast for special types + match &func_builder.selection { + FunctionSelection::ScalarSelf | FunctionSelection::Array => { + Ok(apply_type_cast(func_expr, func_builder.function.type_oid)) + } + FunctionSelection::Connection(_) | FunctionSelection::Node(_) => { + // These are complex selections - return raw function call + Ok(func_expr) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_delete_ast_structure() { + // Basic structure test - full testing via pg_regress + let ident = Ident::new("test_table"); + assert_eq!(ident.as_str(), "test_table"); + } +} diff --git a/src/ast/transpile_filter.rs b/src/ast/transpile_filter.rs new file mode 100644 index 00000000..f83e7388 --- /dev/null +++ b/src/ast/transpile_filter.rs @@ -0,0 +1,298 @@ +//! AST-based transpilation for FilterBuilder +//! +//! This module implements filter expression building using the AST. +//! Filters are used in WHERE clauses for queries and mutations. + +use super::{add_param_from_json, column_ref}; +use crate::ast::{BinaryOperator, Expr, Literal, ParamCollector, UnaryOperator}; +use crate::builder::{CompoundFilterBuilder, FilterBuilder, FilterBuilderElem}; +use crate::error::{GraphQLError, GraphQLResult}; +use crate::graphql::FilterOp; +use crate::sql_types::Table; + +/// Build a WHERE clause expression from a FilterBuilder +pub fn build_filter_expr( + filter: &FilterBuilder, + table: &Table, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult> { + if filter.elems.is_empty() { + return Ok(None); + } + + let mut conditions = Vec::new(); + + for elem in &filter.elems { + let expr = build_filter_elem_expr(elem, table, block_name, params)?; + conditions.push(expr); + } + + // Combine all conditions with AND + Ok(Some(combine_with_and(conditions))) +} + +/// Combine multiple expressions with AND +pub fn combine_with_and(mut conditions: Vec) -> Expr { + if conditions.len() == 1 { + conditions.remove(0) + } else { + let mut combined = conditions.remove(0); + for cond in conditions { + combined = Expr::BinaryOp { + left: Box::new(combined), + op: BinaryOperator::And, + right: Box::new(cond), + }; + } + combined + } +} + +/// Combine multiple expressions with OR +fn combine_with_or(mut conditions: Vec) -> Expr { + if conditions.len() == 1 { + conditions.remove(0) + } else { + let mut combined = conditions.remove(0); + for cond in conditions { + combined = Expr::BinaryOp { + left: Box::new(combined), + op: BinaryOperator::Or, + right: Box::new(cond), + }; + } + combined + } +} + +/// Build expression for a single filter element +fn build_filter_elem_expr( + elem: &FilterBuilderElem, + table: &Table, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + match elem { + FilterBuilderElem::Column { column, op, value } => { + let col_expr = column_ref(block_name, &column.name); + + match op { + FilterOp::Is => { + // IS NULL / IS NOT NULL + let is_null = match value { + serde_json::Value::String(s) => match s.as_str() { + "NULL" => true, + "NOT_NULL" => false, + _ => { + return Err(GraphQLError::sql_generation( + "Error transpiling Is filter value", + )) + } + }, + _ => { + return Err(GraphQLError::sql_generation( + "Error transpiling Is filter value type", + )) + } + }; + + Ok(Expr::IsNull { + expr: Box::new(col_expr), + negated: !is_null, + }) + } + FilterOp::In | FilterOp::Contains | FilterOp::ContainedBy | FilterOp::Overlap => { + // Array operations use array cast + let cast_type_name = match op { + FilterOp::In | FilterOp::Contains => format!("{}[]", column.type_name), + _ => column.type_name.clone(), + }; + + let param_expr = add_param_from_json(params, value, &cast_type_name)?; + + let binary_op = match op { + FilterOp::In => BinaryOperator::Any, + FilterOp::Contains => BinaryOperator::Contains, + FilterOp::ContainedBy => BinaryOperator::ContainedBy, + FilterOp::Overlap => BinaryOperator::Overlap, + _ => unreachable!(), + }; + + Ok(Expr::BinaryOp { + left: Box::new(col_expr), + op: binary_op, + right: Box::new(param_expr), + }) + } + _ => { + // If the value is null for comparison operators, comparing with NULL + // always produces NULL (unknown), so we should return FALSE for + // consistent filtering semantics + if value.is_null() { + return Ok(Expr::Literal(Literal::Bool(false))); + } + + // Standard comparison operators + let param_expr = add_param_from_json(params, value, &column.type_name)?; + + let binary_op = match op { + FilterOp::Equal => BinaryOperator::Eq, + FilterOp::NotEqual => BinaryOperator::NotEq, + FilterOp::LessThan => BinaryOperator::Lt, + FilterOp::LessThanEqualTo => BinaryOperator::LtEq, + FilterOp::GreaterThan => BinaryOperator::Gt, + FilterOp::GreaterThanEqualTo => BinaryOperator::GtEq, + FilterOp::Like => BinaryOperator::Like, + FilterOp::ILike => BinaryOperator::ILike, + FilterOp::RegEx => BinaryOperator::RegEx, + FilterOp::IRegEx => BinaryOperator::IRegEx, + FilterOp::StartsWith => BinaryOperator::StartsWith, + _ => unreachable!(), + }; + + Ok(Expr::BinaryOp { + left: Box::new(col_expr), + op: binary_op, + right: Box::new(param_expr), + }) + } + } + } + FilterBuilderElem::NodeId(node_id_instance) => { + // Validate that nodeId belongs to this table + if (&node_id_instance.schema_name, &node_id_instance.table_name) + != (&table.schema, &table.name) + { + return Err(GraphQLError::validation( + "nodeId belongs to a different collection", + )); + } + + // Get primary key columns + let pk_columns = table.primary_key_columns(); + + if pk_columns.len() != node_id_instance.values.len() { + return Err(GraphQLError::validation( + "NodeId value count doesn't match primary key columns", + )); + } + + // Build conditions for each primary key column + let mut conditions = Vec::new(); + for (col, value) in pk_columns.iter().zip(&node_id_instance.values) { + let col_expr = column_ref(block_name, &col.name); + let param_expr = add_param_from_json(params, value, &col.type_name)?; + + conditions.push(Expr::BinaryOp { + left: Box::new(col_expr), + op: BinaryOperator::Eq, + right: Box::new(param_expr), + }); + } + + Ok(combine_with_and(conditions)) + } + FilterBuilderElem::Compound(compound) => { + build_compound_filter_expr(compound, table, block_name, params) + } + } +} + +/// Build expression for a compound filter (AND/OR/NOT) +fn build_compound_filter_expr( + compound: &CompoundFilterBuilder, + table: &Table, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + match compound { + CompoundFilterBuilder::And(elems) => { + let mut conditions = Vec::new(); + for elem in elems { + let expr = build_filter_elem_expr(elem, table, block_name, params)?; + conditions.push(expr); + } + + if conditions.is_empty() { + // Empty AND is true + Ok(Expr::Literal(Literal::Bool(true))) + } else { + Ok(combine_with_and(conditions)) + } + } + CompoundFilterBuilder::Or(elems) => { + let mut conditions = Vec::new(); + for elem in elems { + let expr = build_filter_elem_expr(elem, table, block_name, params)?; + conditions.push(expr); + } + + if conditions.is_empty() { + // Empty OR is false + Ok(Expr::Literal(Literal::Bool(false))) + } else { + Ok(combine_with_or(conditions)) + } + } + CompoundFilterBuilder::Not(elem) => { + let expr = build_filter_elem_expr(elem, table, block_name, params)?; + Ok(Expr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(expr), + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: Full testing requires setting up Table, Column, FilterBuilder etc. + // Integration tests via pg_regress are more appropriate. + + #[test] + fn test_filter_ops_mapping() { + // Just verify the operator mappings compile correctly + let ops = [ + (FilterOp::Equal, BinaryOperator::Eq), + (FilterOp::NotEqual, BinaryOperator::NotEq), + (FilterOp::LessThan, BinaryOperator::Lt), + (FilterOp::GreaterThan, BinaryOperator::Gt), + ]; + + for (filter_op, binary_op) in ops { + // Just check the mapping exists + let _ = format!("{:?} -> {:?}", filter_op, binary_op); + } + } + + #[test] + fn test_combine_with_and() { + let exprs = vec![ + Expr::Literal(Literal::Bool(true)), + Expr::Literal(Literal::Bool(false)), + ]; + + let combined = combine_with_and(exprs); + match combined { + Expr::BinaryOp { op, .. } => assert_eq!(op, BinaryOperator::And), + _ => panic!("Expected BinaryOp"), + } + } + + #[test] + fn test_combine_with_or() { + let exprs = vec![ + Expr::Literal(Literal::Bool(true)), + Expr::Literal(Literal::Bool(false)), + ]; + + let combined = combine_with_or(exprs); + match combined { + Expr::BinaryOp { op, .. } => assert_eq!(op, BinaryOperator::Or), + _ => panic!("Expected BinaryOp"), + } + } +} diff --git a/src/ast/transpile_function_call.rs b/src/ast/transpile_function_call.rs new file mode 100644 index 00000000..d22a237a --- /dev/null +++ b/src/ast/transpile_function_call.rs @@ -0,0 +1,156 @@ +//! AST-based transpilation for FunctionCallBuilder +//! +//! This module implements the ToAst trait for FunctionCallBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. FunctionCallBuilder handles +//! top-level function calls in queries and mutations. + +use super::{ + add_param_from_json, apply_type_cast, build_function_connection_subquery_full, + build_node_object_expr, coalesce, AstBuildContext, ToAst, +}; +use crate::ast::{ + Expr, FromClause, FunctionArg, FunctionCall, Ident, Literal, ParamCollector, SelectColumn, + SelectStmt, Stmt, +}; +use crate::builder::{FuncCallReturnTypeBuilder, FunctionCallBuilder}; +use crate::error::GraphQLResult; + +/// The result of transpiling a FunctionCallBuilder to AST +pub struct FunctionCallAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +impl ToAst for FunctionCallBuilder { + type Ast = FunctionCallAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the function call arguments + let args = build_function_args(&self.args_builder, params)?; + + // Build the function call expression + let func_call = FunctionCall::with_schema( + self.function.schema_name.clone(), + self.function.name.clone(), + args, + ); + + // Build the query based on return type + let select_expr = match &self.return_type_builder { + FuncCallReturnTypeBuilder::Scalar | FuncCallReturnTypeBuilder::List => { + // SELECT to_jsonb(schema.func(args)::type_adjustment) + let func_expr = Expr::FunctionCall(func_call); + let adjusted_expr = apply_type_cast(func_expr, self.function.type_oid); + func_call_expr("to_jsonb", vec![adjusted_expr]) + } + FuncCallReturnTypeBuilder::Node(node_builder) => { + // SELECT coalesce((SELECT node_object FROM schema.func(args) AS block WHERE NOT (block IS NULL)), null::jsonb) + let object_expr = + build_node_object_expr(&node_builder.selections, &block_name, params)?; + + // Handle empty selections + let object_expr = if node_builder.selections.is_empty() { + func_call_expr("jsonb_build_object", vec![]) + } else { + object_expr + }; + + // Build: NOT (block_name IS NULL) + let not_null_check = Expr::UnaryOp { + op: super::UnaryOperator::Not, + expr: Box::new(Expr::IsNull { + expr: Box::new(Expr::Column(super::ColumnRef::new(block_name.as_str()))), + negated: false, + }), + }; + + // Build the inner subquery + let inner_subquery = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(object_expr)], + from: Some(FromClause::Function { + call: func_call, + alias: Ident::new(&block_name), + }), + where_clause: Some(not_null_check), + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + // Wrap in coalesce with null::jsonb fallback + coalesce(vec![ + Expr::Subquery(Box::new(inner_subquery)), + Expr::Cast { + expr: Box::new(Expr::Literal(Literal::Null)), + target_type: super::type_name_to_sql_type("jsonb"), + }, + ]) + } + FuncCallReturnTypeBuilder::Connection(conn_builder) => { + // Build a connection query using the function as the FROM source + // This returns an Expr::Subquery containing the full connection query + build_function_connection_subquery_full(conn_builder, func_call, params)? + } + }; + + // For Connection type, the select_expr is already a subquery that returns the result + // For other types, we need to wrap in SELECT + + // Build the final SELECT statement + let select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(select_expr)], + from: None, + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(FunctionCallAst { + stmt: Stmt::Select(select), + }) + } +} + +/// Build function call arguments from FuncCallArgsBuilder +fn build_function_args( + args_builder: &crate::builder::FuncCallArgsBuilder, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut args = Vec::new(); + + for (arg_meta, arg_value) in &args_builder.args { + if let Some(arg) = arg_meta { + // Build the parameter expression with type cast + let param_expr = add_param_from_json(params, arg_value, &arg.type_name)?; + + // Create a named argument: name => value + args.push(FunctionArg::named(arg.name.as_str(), param_expr)); + } + } + + Ok(args) +} + +/// Helper to create a function call expression +fn func_call_expr(name: &str, args: Vec) -> Expr { + Expr::FunctionCall(FunctionCall::new( + name, + args.into_iter().map(FunctionArg::unnamed).collect(), + )) +} + +#[cfg(test)] +mod tests { + // Integration tests via pg_regress are more appropriate for this module + // since it requires a full GraphQL schema and function setup. +} diff --git a/src/ast/transpile_insert.rs b/src/ast/transpile_insert.rs new file mode 100644 index 00000000..686c7089 --- /dev/null +++ b/src/ast/transpile_insert.rs @@ -0,0 +1,298 @@ +//! AST-based transpilation for InsertBuilder +//! +//! This module implements the ToAst trait for InsertBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. + +use super::{ + add_param_from_json, apply_type_cast, build_connection_subquery, build_relation_subquery_expr, + coalesce, column_ref, count_star, default_expr, empty_jsonb_array, jsonb_agg, + jsonb_build_object, string_literal, type_name_to_sql_type, AstBuildContext, ToAst, +}; +use crate::ast::{ + Cte, CteQuery, Expr, FromClause, Ident, InsertStmt, InsertValues, ParamCollector, SelectColumn, + SelectStmt, Stmt, +}; +use crate::builder::{ + FunctionBuilder, FunctionSelection, InsertBuilder, InsertElemValue, InsertRowBuilder, + InsertSelection, +}; +use crate::error::GraphQLResult; +use crate::sql_types::Column; +use std::collections::HashSet; +use std::sync::Arc; + +/// The result of transpiling an InsertBuilder to AST +pub struct InsertAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +impl ToAst for InsertBuilder { + type Ast = InsertAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = &ctx.block_name; + + // Build the select columns for the outer query + let select_columns = build_select_columns(&self.selections, block_name, params)?; + + // Identify all columns provided in any of `object` rows + let referenced_column_names: HashSet<&String> = + self.objects.iter().flat_map(|x| x.row.keys()).collect(); + + let referenced_columns: Vec<&Arc> = self + .table + .columns + .iter() + .filter(|c| referenced_column_names.contains(&c.name)) + .collect(); + + // Build column names for the INSERT + let column_names: Vec = referenced_columns + .iter() + .map(|c| Ident::new(c.name.clone())) + .collect(); + + // Build VALUES rows + let mut values_rows = Vec::with_capacity(self.objects.len()); + for row_builder in &self.objects { + let row = build_insert_row(row_builder, &referenced_columns, params)?; + values_rows.push(row); + } + + // Build the RETURNING clause - all selectable columns as SelectColumn + let returning: Vec = self + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(super::ColumnRef::new(c.name.clone())))) + .collect(); + + // Build the INSERT statement + let insert_stmt = InsertStmt { + ctes: vec![], + schema: Some(Ident::new(self.table.schema.clone())), + table: Ident::new(self.table.name.clone()), + columns: column_names, + values: InsertValues::Values(values_rows), + on_conflict: None, + returning, + }; + + // Wrap in CTE: WITH affected AS (INSERT ...) SELECT jsonb_build_object(...) FROM affected + let cte = Cte { + name: Ident::new("affected"), + columns: None, + query: CteQuery::Insert(insert_stmt), + materialized: None, + }; + + // Build the outer SELECT + let outer_select = SelectStmt { + ctes: vec![cte], + columns: vec![SelectColumn::Expr { + expr: Expr::JsonBuild(super::JsonBuildExpr::Object(select_columns)), + alias: None, + }], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("affected"), + alias: Some(Ident::new(block_name.clone())), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(InsertAst { + stmt: Stmt::Select(outer_select), + }) + } +} + +/// Build the select columns for jsonb_build_object from InsertSelection +fn build_select_columns( + selections: &[InsertSelection], + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut pairs = Vec::new(); + + for selection in selections { + match selection { + InsertSelection::AffectedCount { alias } => { + pairs.push((string_literal(alias), count_star())); + } + InsertSelection::Records(node_builder) => { + // For now, use a simplified approach - we'll fully implement NodeBuilder later + // This creates: coalesce(jsonb_agg(jsonb_build_object(...)), '[]') + let node_expr = build_node_builder_expr(node_builder, block_name, params)?; + pairs.push(( + string_literal(&node_builder.alias), + coalesce(vec![jsonb_agg(node_expr), empty_jsonb_array()]), + )); + } + InsertSelection::Typename { alias, typename } => { + pairs.push((string_literal(alias), string_literal(typename))); + } + } + } + + Ok(pairs) +} + +/// Build expression for a NodeBuilder (simplified for now) +fn build_node_builder_expr( + node_builder: &crate::builder::NodeBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + use crate::builder::NodeSelection; + + let mut pairs = Vec::new(); + + for selection in &node_builder.selections { + match selection { + NodeSelection::Column(col_builder) => { + // Use build_column_expr which handles enum mappings + let col_expr = super::build_column_expr(col_builder, block_name); + pairs.push((col_builder.alias.clone(), col_expr)); + } + NodeSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + NodeSelection::NodeId(node_id_builder) => { + // NodeId is base64 encoded JSON array of [schema, table, pk_values...] + let pk_exprs: Vec = node_id_builder + .columns + .iter() + .map(|c| super::func_call("to_jsonb", vec![column_ref(block_name, &c.name)])) + .collect(); + + // Build: translate(encode(convert_to(jsonb_build_array(schema, table, pk_vals...)::text, 'utf-8'), 'base64'), E'\n', '') + let mut array_args = vec![ + string_literal(&node_id_builder.schema_name), + string_literal(&node_id_builder.table_name), + ]; + array_args.extend(pk_exprs); + + let jsonb_array = super::func_call("jsonb_build_array", array_args); + let as_text = Expr::Cast { + expr: Box::new(jsonb_array), + target_type: type_name_to_sql_type("text"), + }; + let converted = + super::func_call("convert_to", vec![as_text, string_literal("utf-8")]); + let encoded = super::func_call("encode", vec![converted, string_literal("base64")]); + let translated = super::func_call( + "translate", + vec![encoded, string_literal("\n"), string_literal("")], + ); + + pairs.push((node_id_builder.alias.clone(), translated)); + } + NodeSelection::Function(func_builder) => { + // Build function expression: schema.function(row::schema.table) + let func_expr = build_function_expr(func_builder, block_name)?; + pairs.push((func_builder.alias.clone(), func_expr)); + } + NodeSelection::Connection(conn_builder) => { + // Connection selections - build subquery + let conn_expr = build_connection_subquery(conn_builder, block_name, params)?; + pairs.push((conn_builder.alias.clone(), conn_expr)); + } + NodeSelection::Node(nested_node) => { + // Nested node selections - build subquery + let node_expr = build_relation_subquery_expr(nested_node, block_name, params)?; + pairs.push((nested_node.alias.clone(), node_expr)); + } + } + } + + Ok(jsonb_build_object(pairs)) +} + +/// Build a single row of values for INSERT +fn build_insert_row( + row_builder: &InsertRowBuilder, + columns: &[&Arc], + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut row = Vec::with_capacity(columns.len()); + + for column in columns { + let expr = match row_builder.row.get(&column.name) { + None => default_expr(), + Some(elem) => match elem { + InsertElemValue::Default => default_expr(), + InsertElemValue::Value(val) => add_param_from_json(params, val, &column.type_name)?, + }, + }; + row.push(expr); + } + + Ok(row) +} + +/// Build a function expression for scalar function calls +fn build_function_expr(func_builder: &FunctionBuilder, block_name: &str) -> GraphQLResult { + use crate::ast::{ColumnRef, FunctionArg, FunctionCall, SqlType}; + + // Build the row argument with type cast: block_name::schema.table + let row_arg = Expr::Cast { + expr: Box::new(Expr::Column(ColumnRef::new(block_name))), + target_type: SqlType::with_schema( + func_builder.table.schema.clone(), + func_builder.table.name.clone(), + ), + }; + + let args = vec![FunctionArg::unnamed(row_arg)]; + + let func_expr = Expr::FunctionCall(FunctionCall::with_schema( + func_builder.function.schema_name.clone(), + func_builder.function.name.clone(), + args, + )); + + // Apply type cast for special types + match &func_builder.selection { + FunctionSelection::ScalarSelf | FunctionSelection::Array => { + Ok(apply_type_cast(func_expr, func_builder.function.type_oid)) + } + FunctionSelection::Connection(_) | FunctionSelection::Node(_) => { + // These are complex selections - return raw function call + Ok(func_expr) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: Full testing requires setting up Table, Column, etc. which is complex. + // Integration tests via pg_regress are more appropriate for full coverage. + + #[test] + fn test_build_insert_row_default() { + // Test that default values are handled correctly + let row_builder = InsertRowBuilder { + row: std::collections::HashMap::new(), + }; + + // We can't easily test without Column instances, but we can verify + // the function compiles and handles empty cases + let columns: Vec<&Arc> = vec![]; + let mut params = ParamCollector::new(); + + let result = build_insert_row(&row_builder, &columns, &mut params); + assert!(result.is_ok()); + assert!(result.unwrap().is_empty()); + } +} diff --git a/src/ast/transpile_node.rs b/src/ast/transpile_node.rs new file mode 100644 index 00000000..41574b17 --- /dev/null +++ b/src/ast/transpile_node.rs @@ -0,0 +1,551 @@ +//! AST-based transpilation for NodeBuilder +//! +//! This module implements the ToAst trait for NodeBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. + +use super::{ + add_param_from_json, column_ref, func_call, jsonb_build_object, string_literal, + AstBuildContext, ToAst, +}; +use crate::ast::{ + BinaryOperator, Expr, FromClause, FunctionArg, Ident, ParamCollector, SelectColumn, SelectStmt, + Stmt, +}; +use crate::builder::{ + ColumnBuilder, FunctionSelection, NodeBuilder, NodeIdBuilder, NodeIdInstance, NodeSelection, +}; +use crate::error::{GraphQLError, GraphQLResult}; +use crate::sql_types::{Table, TypeDetails}; + +/// The result of transpiling a NodeBuilder to AST for entrypoint queries +pub struct NodeAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +/// The result of transpiling a NodeBuilder to an expression (for nested selections) +pub struct NodeExprAst { + /// The expression representing this node + pub expr: Expr, +} + +impl ToAst for NodeBuilder { + type Ast = NodeAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the object clause from selections + let object_expr = build_node_object_expr(&self.selections, &block_name, params)?; + + // Build WHERE clause from node_id + let node_id = self + .node_id + .as_ref() + .ok_or("Expected nodeId argument missing")?; + + let where_clause = build_node_id_filter(node_id, &self.table, &block_name, params)?; + + // Build the SELECT statement + let select = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(object_expr)], + from: Some(FromClause::Table { + schema: Some(Ident::new(self.table.schema.clone())), + name: Ident::new(self.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause: Some(where_clause), + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + // Wrap in a subquery expression + Ok(NodeAst { + stmt: Stmt::Select(select), + }) + } +} + +/// Build the node_id filter as a WHERE clause expression +fn build_node_id_filter( + node_id: &NodeIdInstance, + table: &Table, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + // Validate that nodeId belongs to this table + if (&node_id.schema_name, &node_id.table_name) != (&table.schema, &table.name) { + return Err(GraphQLError::validation( + "nodeId belongs to a different collection", + )); + } + + let pk_columns = table.primary_key_columns(); + let mut conditions = Vec::new(); + + for (col, val) in pk_columns.iter().zip(node_id.values.iter()) { + let col_expr = column_ref(block_name, &col.name); + let val_expr = add_param_from_json(params, val, &col.type_name)?; + + conditions.push(Expr::BinaryOp { + left: Box::new(col_expr), + op: BinaryOperator::Eq, + right: Box::new(val_expr), + }); + } + + // Combine with AND + if conditions.len() == 1 { + Ok(conditions.remove(0)) + } else { + let mut combined = conditions.remove(0); + for cond in conditions { + combined = Expr::BinaryOp { + left: Box::new(combined), + op: BinaryOperator::And, + right: Box::new(cond), + }; + } + Ok(combined) + } +} + +/// Build expression for a NodeBuilder's selections (jsonb_build_object) +pub fn build_node_object_expr( + selections: &[NodeSelection], + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let mut all_pairs: Vec<(String, Expr)> = Vec::new(); + + for selection in selections { + match selection { + NodeSelection::Column(col_builder) => { + let col_expr = build_column_expr(col_builder, block_name); + all_pairs.push((col_builder.alias.clone(), col_expr)); + } + NodeSelection::Typename { alias, typename } => { + all_pairs.push((alias.clone(), string_literal(typename))); + } + NodeSelection::NodeId(node_id_builder) => { + let node_id_expr = build_node_id_expr(node_id_builder, block_name); + all_pairs.push((node_id_builder.alias.clone(), node_id_expr)); + } + NodeSelection::Function(func_builder) => { + let func_expr = build_function_expr(func_builder, block_name, params)?; + all_pairs.push((func_builder.alias.clone(), func_expr)); + } + NodeSelection::Connection(conn_builder) => { + // Connection needs a subquery - will be implemented in transpile_connection + let conn_expr = build_connection_subquery_expr(conn_builder, block_name, params)?; + all_pairs.push((conn_builder.alias.clone(), conn_expr)); + } + NodeSelection::Node(nested_node) => { + // Nested node relation - build as subquery + let node_expr = build_relation_subquery_expr(nested_node, block_name, params)?; + all_pairs.push((nested_node.alias.clone(), node_expr)); + } + } + } + + // jsonb_build_object has a limit of 100 arguments (50 pairs) + // If we have more, we need to chunk and concatenate with || + const MAX_PAIRS_PER_CALL: usize = 50; + + if all_pairs.len() <= MAX_PAIRS_PER_CALL { + Ok(jsonb_build_object(all_pairs)) + } else { + // Chunk into multiple jsonb_build_object calls and concatenate + let mut chunks: Vec = all_pairs + .chunks(MAX_PAIRS_PER_CALL) + .map(|chunk| jsonb_build_object(chunk.to_vec())) + .collect(); + + // Concatenate with || operator (JsonConcat for JSONB) + let mut result = chunks.remove(0); + for chunk in chunks { + result = Expr::BinaryOp { + left: Box::new(result), + op: BinaryOperator::JsonConcat, + right: Box::new(chunk), + }; + } + Ok(result) + } +} + +/// Build expression for a column selection +/// +/// This handles enum mappings by generating CASE expressions when the column +/// has an enum type with custom mappings defined. +pub fn build_column_expr(col_builder: &ColumnBuilder, block_name: &str) -> Expr { + let col_ref = column_ref(block_name, &col_builder.column.name); + + // Check if this is an enum column with mappings + let maybe_enum = col_builder + .column + .type_ + .as_ref() + .and_then(|t| match &t.details { + Some(TypeDetails::Enum(enum_)) => Some(enum_), + _ => None, + }); + + if let Some(enum_) = maybe_enum { + if let Some(ref mappings) = enum_.directives.mappings { + // Build CASE expression for enum mappings + // case when col = 'pg_value1' then 'graphql_value1' when col = 'pg_value2' then 'graphql_value2' else col::text end + let when_clauses: Vec<(Expr, Expr)> = mappings + .iter() + .map(|(pg_val, graphql_val)| { + ( + Expr::BinaryOp { + left: Box::new(col_ref.clone()), + op: BinaryOperator::Eq, + right: Box::new(string_literal(pg_val)), + }, + string_literal(graphql_val), + ) + }) + .collect(); + + let else_clause = Expr::Cast { + expr: Box::new(col_ref), + target_type: super::type_name_to_sql_type("text"), + }; + + return Expr::Case(super::CaseExpr::searched(when_clauses, Some(else_clause))); + } + } + + // Apply type adjustment for special OIDs + apply_type_cast(col_ref, col_builder.column.type_oid) +} + +/// Apply suffix casts for types that need special handling +/// +/// This handles types that need to be converted for GraphQL output: +/// - bigint (20) -> text (prevents precision loss in JSON) +/// - json/jsonb (114/3802) -> text via #>> '{}' +/// - numeric (1700) -> text (prevents precision loss) +/// - bigint[] (1016) -> text[] +/// - json[]/jsonb[] (199/3807) -> text[] +/// - numeric[] (1231) -> text[] +pub fn apply_type_cast(expr: Expr, type_oid: u32) -> Expr { + match type_oid { + 20 => Expr::Cast { + // bigints as text + expr: Box::new(expr), + target_type: super::type_name_to_sql_type("text"), + }, + 114 | 3802 => Expr::BinaryOp { + // json/b as stringified using #>> '{}' (empty path extracts root as text) + // Use string literal '{}' which PostgreSQL interprets as text[] for the path + left: Box::new(expr), + op: BinaryOperator::JsonPathText, + right: Box::new(string_literal("{}")), + }, + 1700 => Expr::Cast { + // numeric as text + expr: Box::new(expr), + target_type: super::type_name_to_sql_type("text"), + }, + 1016 => Expr::Cast { + // bigint arrays as array of text + expr: Box::new(expr), + target_type: super::type_name_to_sql_type("text[]"), + }, + 199 | 3807 => Expr::Cast { + // json/b array as array of text + expr: Box::new(expr), + target_type: super::type_name_to_sql_type("text[]"), + }, + 1231 => Expr::Cast { + // numeric array as array of text + expr: Box::new(expr), + target_type: super::type_name_to_sql_type("text[]"), + }, + _ => expr, + } +} + +/// Build expression for nodeId (base64 encoded JSON array) +fn build_node_id_expr(node_id_builder: &NodeIdBuilder, block_name: &str) -> Expr { + // Build: translate(encode(convert_to(jsonb_build_array(schema, table, pk_vals...)::text, 'utf-8'), 'base64'), E'\n', '') + let pk_exprs: Vec = node_id_builder + .columns + .iter() + .map(|c| func_call("to_jsonb", vec![column_ref(block_name, &c.name)])) + .collect(); + + let mut array_args = vec![ + string_literal(&node_id_builder.schema_name), + string_literal(&node_id_builder.table_name), + ]; + array_args.extend(pk_exprs); + + let jsonb_array = func_call("jsonb_build_array", array_args); + let as_text = Expr::Cast { + expr: Box::new(jsonb_array), + target_type: super::type_name_to_sql_type("text"), + }; + let converted = func_call("convert_to", vec![as_text, string_literal("utf-8")]); + let encoded = func_call("encode", vec![converted, string_literal("base64")]); + + func_call( + "translate", + vec![encoded, string_literal("\n"), string_literal("")], + ) +} + +/// Build expression for a function call +fn build_function_expr( + func_builder: &crate::builder::FunctionBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + // The function takes the row as a typed argument: schema.function(block_name::schema.table) + // Build the row argument with type cast + let row_arg = Expr::Cast { + expr: Box::new(Expr::Column(super::ColumnRef::new(block_name))), + target_type: super::SqlType::with_schema( + func_builder.table.schema.clone(), + func_builder.table.name.clone(), + ), + }; + + let args = vec![FunctionArg::unnamed(row_arg)]; + + let func_call = super::FunctionCall::with_schema( + func_builder.function.schema_name.clone(), + func_builder.function.name.clone(), + args, + ); + + match &func_builder.selection { + FunctionSelection::ScalarSelf | FunctionSelection::Array => { + // For scalar/array selections, the result is the function call with type cast + let func_expr = Expr::FunctionCall(func_call); + Ok(apply_type_cast(func_expr, func_builder.function.type_oid)) + } + FunctionSelection::Node(node_builder) => { + // For node selection (function returning a single row), wrap in a subquery: + // (SELECT node_object FROM schema.func(block_name::schema.table) AS func_block WHERE NOT (func_block IS NULL)) + let func_block_name = AstBuildContext::new().block_name; + + // Build the node object expression + let object_expr = + build_node_object_expr(&node_builder.selections, &func_block_name, params)?; + + // Build: NOT (func_block IS NULL) + let not_null_check = Expr::UnaryOp { + op: super::UnaryOperator::Not, + expr: Box::new(Expr::IsNull { + expr: Box::new(Expr::Column(super::ColumnRef::new(func_block_name.as_str()))), + negated: false, + }), + }; + + // Build the subquery + let subquery = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(object_expr)], + from: Some(FromClause::Function { + call: func_call, + alias: Ident::new(&func_block_name), + }), + where_clause: Some(not_null_check), + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Expr::Subquery(Box::new(subquery))) + } + FunctionSelection::Connection(conn_builder) => { + // For connection selection (function returning setof), build a connection subquery + // that uses the function as its FROM clause instead of a table + build_function_connection_subquery(func_builder, conn_builder, block_name, params) + } + } +} + +/// Build a connection subquery for a function that returns setof +/// +/// This is used when a function returns a set of rows (setof) and we want to +/// expose that as a connection. The key difference from a regular connection +/// is that the FROM clause uses the function call instead of a table. +fn build_function_connection_subquery( + func_builder: &crate::builder::FunctionBuilder, + conn_builder: &crate::builder::ConnectionBuilder, + parent_block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + // Build the function call argument: parent_block_name::schema.table + let row_arg = Expr::Cast { + expr: Box::new(Expr::Column(super::ColumnRef::new(parent_block_name))), + target_type: super::SqlType::with_schema( + func_builder.table.schema.clone(), + func_builder.table.name.clone(), + ), + }; + + let func_call = super::FunctionCall::with_schema( + func_builder.function.schema_name.clone(), + func_builder.function.name.clone(), + vec![FunctionArg::unnamed(row_arg)], + ); + + // Build the connection subquery using the function as the FROM source + super::build_function_connection_subquery_full(conn_builder, func_call, params) +} + +/// Build a subquery expression for a connection selection +fn build_connection_subquery_expr( + conn_builder: &crate::builder::ConnectionBuilder, + parent_block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + // Use the full connection subquery implementation from transpile_connection + super::build_connection_subquery(conn_builder, parent_block_name, params) +} + +/// Build a subquery expression for a nested node relation +pub fn build_relation_subquery_expr( + nested_node: &NodeBuilder, + parent_block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build the object clause from nested node's selections + let object_expr = build_node_object_expr(&nested_node.selections, &block_name, params)?; + + // Get the foreign key and direction + let fkey = nested_node + .fkey + .as_ref() + .ok_or("Internal Error: relation key")?; + let reverse_reference = nested_node + .reverse_reference + .ok_or("Internal Error: relation reverse reference")?; + + // Build the join condition + let join_condition = build_join_condition( + fkey, + reverse_reference, + &block_name, + parent_block_name, + &nested_node.table, + )?; + + // Build the subquery + let subquery = SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(object_expr)], + from: Some(FromClause::Table { + schema: Some(Ident::new(nested_node.table.schema.clone())), + name: Ident::new(nested_node.table.name.clone()), + alias: Some(Ident::new(block_name)), + }), + where_clause: Some(join_condition), + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }; + + Ok(Expr::Subquery(Box::new(subquery))) +} + +/// Build a join condition for a foreign key relationship +fn build_join_condition( + fkey: &crate::sql_types::ForeignKey, + reverse_reference: bool, + child_block_name: &str, + parent_block_name: &str, + _table: &Table, +) -> GraphQLResult { + let mut conditions = Vec::new(); + + // ForeignKey has local_table_meta and referenced_table_meta, each with column_names + // Depending on direction, pair up the columns + let pairs: Vec<(&String, &String)> = if reverse_reference { + // Parent has the referenced columns, child has the local columns + fkey.local_table_meta + .column_names + .iter() + .zip(fkey.referenced_table_meta.column_names.iter()) + .collect() + } else { + // Parent has the local columns, child has the referenced columns + fkey.referenced_table_meta + .column_names + .iter() + .zip(fkey.local_table_meta.column_names.iter()) + .collect() + }; + + for (child_col, parent_col) in pairs { + let child_expr = column_ref(child_block_name, child_col); + let parent_expr = column_ref(parent_block_name, parent_col); + + conditions.push(Expr::BinaryOp { + left: Box::new(child_expr), + op: BinaryOperator::Eq, + right: Box::new(parent_expr), + }); + } + + // Combine with AND + if conditions.len() == 1 { + Ok(conditions.remove(0)) + } else { + let mut combined = conditions.remove(0); + for cond in conditions { + combined = Expr::BinaryOp { + left: Box::new(combined), + op: BinaryOperator::And, + right: Box::new(cond), + }; + } + Ok(combined) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::ColumnRef; + + #[test] + fn test_type_cast_bigint() { + let expr = Expr::Column(ColumnRef::new("test")); + let casted = apply_type_cast(expr, 20); + match casted { + Expr::Cast { target_type, .. } => { + assert_eq!(target_type.name, "text"); + } + _ => panic!("Expected Cast expression"), + } + } + + #[test] + fn test_type_cast_no_change() { + let expr = Expr::Column(ColumnRef::new("test")); + let result = apply_type_cast(expr.clone(), 25); // text OID, no cast needed + match result { + Expr::Column(_) => {} + _ => panic!("Expected no cast for text type"), + } + } +} diff --git a/src/ast/transpile_update.rs b/src/ast/transpile_update.rs new file mode 100644 index 00000000..05235548 --- /dev/null +++ b/src/ast/transpile_update.rs @@ -0,0 +1,370 @@ +//! AST-based transpilation for UpdateBuilder +//! +//! This module implements the ToAst trait for UpdateBuilder, converting it +//! to a type-safe AST that can be rendered to SQL. + +use super::{ + add_param_from_json, build_connection_subquery, build_filter_expr, + build_relation_subquery_expr, coalesce, column_ref, count_star, empty_jsonb_array, func_call, + func_call_schema, jsonb_agg, jsonb_build_object, string_literal, AstBuildContext, ToAst, +}; +use crate::ast::{ + BinaryOperator, CaseExpr, ColumnRef, Cte, CteQuery, Expr, FromClause, Ident, Literal, + ParamCollector, SelectColumn, SelectStmt, Stmt, UpdateStmt, +}; +use crate::builder::{ + FunctionBuilder, FunctionSelection, NodeBuilder, NodeSelection, UpdateBuilder, UpdateSelection, +}; +use crate::error::GraphQLResult; + +/// The result of transpiling an UpdateBuilder to AST +pub struct UpdateAst { + /// The complete SQL statement + pub stmt: Stmt, +} + +impl ToAst for UpdateBuilder { + type Ast = UpdateAst; + + fn to_ast(&self, params: &mut ParamCollector) -> GraphQLResult { + let ctx = AstBuildContext::new(); + let block_name = ctx.block_name.clone(); + + // Build SET clause + let set_clauses = build_set_clauses(&self.set.set, &self.table, params)?; + + // Build WHERE clause from filter + let where_clause = build_filter_expr(&self.filter, &self.table, &block_name, params)?; + + // Build RETURNING clause - all selectable columns + let returning: Vec = self + .table + .columns + .iter() + .filter(|c| c.permissions.is_selectable) + .map(|c| SelectColumn::expr(Expr::Column(ColumnRef::new(c.name.clone())))) + .collect(); + + // Build the UPDATE statement for the 'impacted' CTE + let update_stmt = UpdateStmt { + ctes: vec![], + schema: Some(Ident::new(self.table.schema.clone())), + table: Ident::new(self.table.name.clone()), + alias: Some(Ident::new(block_name.clone())), + set: set_clauses, + from: None, + where_clause, + returning, + }; + + // Build the select columns for jsonb_build_object + let select_columns = build_select_columns(&self.selections, &block_name, params)?; + + // Build the complex CTE structure: + // WITH impacted AS (UPDATE ...), total(total_count) AS (...), req(res) AS (...), wrapper(res) AS (...) + let stmt = + build_update_with_at_most(update_stmt, select_columns, &block_name, self.at_most); + + Ok(UpdateAst { stmt }) + } +} + +/// Build SET clauses from the SetBuilder +fn build_set_clauses( + set_map: &std::collections::HashMap, + table: &crate::sql_types::Table, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut clauses = Vec::new(); + + for (column_name, value) in set_map { + let column = table + .columns + .iter() + .find(|c| &c.name == column_name) + .expect("Failed to find column in update builder"); + + let value_expr = add_param_from_json(params, value, &column.type_name)?; + + clauses.push((Ident::new(column_name.clone()), value_expr)); + } + + Ok(clauses) +} + +/// Build the select columns for jsonb_build_object from UpdateSelection +fn build_select_columns( + selections: &[UpdateSelection], + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult> { + let mut pairs = Vec::new(); + + for selection in selections { + match selection { + UpdateSelection::AffectedCount { alias } => { + pairs.push((string_literal(alias), count_star())); + } + UpdateSelection::Records(node_builder) => { + let node_expr = build_node_builder_expr(node_builder, block_name, params)?; + pairs.push(( + string_literal(&node_builder.alias), + coalesce(vec![jsonb_agg(node_expr), empty_jsonb_array()]), + )); + } + UpdateSelection::Typename { alias, typename } => { + pairs.push((string_literal(alias), string_literal(typename))); + } + } + } + + Ok(pairs) +} + +/// Build expression for a NodeBuilder (simplified) +fn build_node_builder_expr( + node_builder: &NodeBuilder, + block_name: &str, + params: &mut ParamCollector, +) -> GraphQLResult { + let mut pairs = Vec::new(); + + for selection in &node_builder.selections { + match selection { + NodeSelection::Column(col_builder) => { + // Use build_column_expr which handles enum mappings + let col_expr = super::build_column_expr(col_builder, block_name); + pairs.push((col_builder.alias.clone(), col_expr)); + } + NodeSelection::Typename { alias, typename } => { + pairs.push((alias.clone(), string_literal(typename))); + } + NodeSelection::NodeId(node_id_builder) => { + let pk_exprs: Vec = node_id_builder + .columns + .iter() + .map(|c| func_call("to_jsonb", vec![column_ref(block_name, &c.name)])) + .collect(); + + let mut array_args = vec![ + string_literal(&node_id_builder.schema_name), + string_literal(&node_id_builder.table_name), + ]; + array_args.extend(pk_exprs); + + let jsonb_array = func_call("jsonb_build_array", array_args); + let as_text = Expr::Cast { + expr: Box::new(jsonb_array), + target_type: super::type_name_to_sql_type("text"), + }; + let converted = func_call("convert_to", vec![as_text, string_literal("utf-8")]); + let encoded = func_call("encode", vec![converted, string_literal("base64")]); + let translated = func_call( + "translate", + vec![encoded, string_literal("\n"), string_literal("")], + ); + + pairs.push((node_id_builder.alias.clone(), translated)); + } + NodeSelection::Function(func_builder) => { + // Build function expression: schema.function(row::schema.table) + let func_expr = build_function_expr(func_builder, block_name)?; + pairs.push((func_builder.alias.clone(), func_expr)); + } + NodeSelection::Connection(conn_builder) => { + // Connection selections - build subquery + let conn_expr = build_connection_subquery(conn_builder, block_name, params)?; + pairs.push((conn_builder.alias.clone(), conn_expr)); + } + NodeSelection::Node(nested_node) => { + // Nested node selections - build subquery + let node_expr = build_relation_subquery_expr(nested_node, block_name, params)?; + pairs.push((nested_node.alias.clone(), node_expr)); + } + } + } + + Ok(jsonb_build_object(pairs)) +} + +/// Build the full UPDATE with at_most check using CTEs +fn build_update_with_at_most( + update_stmt: UpdateStmt, + select_columns: Vec<(Expr, Expr)>, + block_name: &str, + at_most: i64, +) -> Stmt { + // CTE 1: impacted AS (UPDATE ...) + let impacted_cte = Cte { + name: Ident::new("impacted"), + columns: None, + query: CteQuery::Update(update_stmt), + materialized: None, + }; + + // CTE 2: total(total_count) AS (SELECT count(*) FROM impacted) + let total_cte = Cte { + name: Ident::new("total"), + columns: Some(vec![Ident::new("total_count")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(count_star())], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("impacted"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }), + materialized: None, + }; + + // CTE 3: req(res) AS (SELECT jsonb_build_object(...) FROM impacted LIMIT 1) + let req_cte = Cte { + name: Ident::new("req"), + columns: Some(vec![Ident::new("res")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(Expr::JsonBuild( + super::JsonBuildExpr::Object(select_columns), + ))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("impacted"), + alias: Some(Ident::new(block_name)), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: Some(1), + offset: None, + }), + materialized: None, + }; + + // CTE 4: wrapper(res) AS (SELECT CASE WHEN total.total_count > at_most THEN exception ELSE req.res END ...) + let case_expr = Expr::Case(CaseExpr::searched( + vec![( + // WHEN total.total_count > at_most + Expr::BinaryOp { + left: Box::new(column_ref("total", "total_count")), + op: BinaryOperator::Gt, + right: Box::new(Expr::Literal(Literal::Integer(at_most))), + }, + // THEN graphql.exception(...)::jsonb + Expr::Cast { + expr: Box::new(func_call_schema( + "graphql", + "exception", + vec![string_literal("update impacts too many records")], + )), + target_type: super::type_name_to_sql_type("jsonb"), + }, + )], + // ELSE req.res + Some(column_ref("req", "res")), + )); + + let wrapper_cte = Cte { + name: Ident::new("wrapper"), + columns: Some(vec![Ident::new("res")]), + query: CteQuery::Select(SelectStmt { + ctes: vec![], + columns: vec![SelectColumn::expr(case_expr)], + from: Some(FromClause::Join { + left: Box::new(FromClause::Table { + schema: None, + name: Ident::new("total"), + alias: None, + }), + join_type: super::JoinType::Left, + right: Box::new(FromClause::Table { + schema: None, + name: Ident::new("req"), + alias: None, + }), + on: Some(Expr::Literal(Literal::Bool(true))), + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: Some(1), + offset: None, + }), + materialized: None, + }; + + // Final SELECT from wrapper + Stmt::Select(SelectStmt { + ctes: vec![impacted_cte, total_cte, req_cte, wrapper_cte], + columns: vec![SelectColumn::expr(column_ref("wrapper", "res"))], + from: Some(FromClause::Table { + schema: None, + name: Ident::new("wrapper"), + alias: None, + }), + where_clause: None, + group_by: vec![], + having: None, + order_by: vec![], + limit: None, + offset: None, + }) +} + +/// Build a function expression for scalar function calls +fn build_function_expr(func_builder: &FunctionBuilder, block_name: &str) -> GraphQLResult { + use super::apply_type_cast; + use crate::ast::{FunctionArg, FunctionCall, SqlType}; + + // Build the row argument with type cast: block_name::schema.table + let row_arg = Expr::Cast { + expr: Box::new(Expr::Column(ColumnRef::new(block_name))), + target_type: SqlType::with_schema( + func_builder.table.schema.clone(), + func_builder.table.name.clone(), + ), + }; + + let args = vec![FunctionArg::unnamed(row_arg)]; + + let func_expr = Expr::FunctionCall(FunctionCall::with_schema( + func_builder.function.schema_name.clone(), + func_builder.function.name.clone(), + args, + )); + + // Apply type cast for special types + match &func_builder.selection { + FunctionSelection::ScalarSelf | FunctionSelection::Array => { + Ok(apply_type_cast(func_expr, func_builder.function.type_oid)) + } + FunctionSelection::Connection(_) | FunctionSelection::Node(_) => { + // These are complex selections - return raw function call + Ok(func_expr) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_set_clause_structure() { + // Basic structure test - full testing via pg_regress + // SET clauses are represented as tuples (Ident, Expr) + let clause: (Ident, Expr) = ( + Ident::new("name"), + Expr::Literal(Literal::String("test".to_string())), + ); + assert_eq!(clause.0 .0, "name"); + } +} diff --git a/src/ast/types.rs b/src/ast/types.rs new file mode 100644 index 00000000..afea3e0e --- /dev/null +++ b/src/ast/types.rs @@ -0,0 +1,337 @@ +//! SQL type representations +//! +//! This module defines SQL types used for parameter casting and type safety. + +/// A SQL type with optional schema qualification and array support +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SqlType { + /// Schema name (e.g., "public", "pg_catalog") + pub schema: Option, + /// Type name (e.g., "text", "integer", "my_enum") + pub name: String, + /// PostgreSQL OID when known (for optimization) + pub oid: Option, + /// Whether this is an array type + pub is_array: bool, +} + +impl SqlType { + /// Create a new SQL type + pub fn new(name: impl Into) -> Self { + Self { + schema: None, + name: name.into(), + oid: None, + is_array: false, + } + } + + /// Create a type with schema qualification + pub fn with_schema(schema: impl Into, name: impl Into) -> Self { + Self { + schema: Some(schema.into()), + name: name.into(), + oid: None, + is_array: false, + } + } + + /// Create a custom type (user-defined) + pub fn custom(schema: Option, name: String) -> Self { + Self { + schema, + name, + oid: None, + is_array: false, + } + } + + /// Create an array version of this type + pub fn into_array(self) -> Self { + Self { + is_array: true, + ..self + } + } + + /// Create an array version of this type (non-consuming) + pub fn as_array(&self) -> Self { + Self { + is_array: true, + ..self.clone() + } + } + + // Common PostgreSQL types with known OIDs + + /// PostgreSQL `text` type (OID 25) + pub fn text() -> Self { + Self { + schema: None, + name: "text".into(), + oid: Some(25), + is_array: false, + } + } + + /// PostgreSQL `integer` type (OID 23) + pub fn integer() -> Self { + Self { + schema: None, + name: "integer".into(), + oid: Some(23), + is_array: false, + } + } + + /// PostgreSQL `bigint` type (OID 20) + pub fn bigint() -> Self { + Self { + schema: None, + name: "bigint".into(), + oid: Some(20), + is_array: false, + } + } + + /// PostgreSQL `smallint` type (OID 21) + pub fn smallint() -> Self { + Self { + schema: None, + name: "smallint".into(), + oid: Some(21), + is_array: false, + } + } + + /// PostgreSQL `boolean` type (OID 16) + pub fn boolean() -> Self { + Self { + schema: None, + name: "boolean".into(), + oid: Some(16), + is_array: false, + } + } + + /// PostgreSQL `real` type (OID 700) + pub fn real() -> Self { + Self { + schema: None, + name: "real".into(), + oid: Some(700), + is_array: false, + } + } + + /// PostgreSQL `double precision` type (OID 701) + pub fn double_precision() -> Self { + Self { + schema: None, + name: "double precision".into(), + oid: Some(701), + is_array: false, + } + } + + /// PostgreSQL `numeric` type (OID 1700) + pub fn numeric() -> Self { + Self { + schema: None, + name: "numeric".into(), + oid: Some(1700), + is_array: false, + } + } + + /// PostgreSQL `json` type (OID 114) + pub fn json() -> Self { + Self { + schema: None, + name: "json".into(), + oid: Some(114), + is_array: false, + } + } + + /// PostgreSQL `jsonb` type (OID 3802) + pub fn jsonb() -> Self { + Self { + schema: None, + name: "jsonb".into(), + oid: Some(3802), + is_array: false, + } + } + + /// PostgreSQL `uuid` type (OID 2950) + pub fn uuid() -> Self { + Self { + schema: None, + name: "uuid".into(), + oid: Some(2950), + is_array: false, + } + } + + /// PostgreSQL `timestamp` type (OID 1114) + pub fn timestamp() -> Self { + Self { + schema: None, + name: "timestamp".into(), + oid: Some(1114), + is_array: false, + } + } + + /// PostgreSQL `timestamptz` type (OID 1184) + pub fn timestamptz() -> Self { + Self { + schema: None, + name: "timestamptz".into(), + oid: Some(1184), + is_array: false, + } + } + + /// PostgreSQL `date` type (OID 1082) + pub fn date() -> Self { + Self { + schema: None, + name: "date".into(), + oid: Some(1082), + is_array: false, + } + } + + /// PostgreSQL `time` type (OID 1083) + pub fn time() -> Self { + Self { + schema: None, + name: "time".into(), + oid: Some(1083), + is_array: false, + } + } + + /// PostgreSQL `bytea` type (OID 17) + pub fn bytea() -> Self { + Self { + schema: None, + name: "bytea".into(), + oid: Some(17), + is_array: false, + } + } + + /// Create from a type name string (e.g., "text", "integer[]", "public.my_type") + pub fn from_name(type_name: &str) -> Self { + let is_array = type_name.ends_with("[]"); + let base_name = if is_array { + &type_name[..type_name.len() - 2] + } else { + type_name + }; + + // Check for schema qualification + let (schema, name) = if let Some(dot_pos) = base_name.find('.') { + ( + Some(base_name[..dot_pos].to_string()), + base_name[dot_pos + 1..].to_string(), + ) + } else { + (None, base_name.to_string()) + }; + + Self { + schema, + name, + oid: None, + is_array, + } + } + + /// Get the full type name for SQL rendering + pub fn to_sql_string(&self) -> String { + let mut result = String::new(); + + if let Some(schema) = &self.schema { + result.push_str(schema); + result.push('.'); + } + + result.push_str(&self.name); + + if self.is_array { + result.push_str("[]"); + } + + result + } + + /// Check if this type requires special handling for JSON serialization + /// (e.g., bigint needs to be converted to text to avoid precision loss) + pub fn needs_text_cast_for_json(&self) -> bool { + matches!( + self.oid, + Some(20) // bigint + | Some(1700) // numeric + ) + } + + /// Check if this is a JSON/JSONB type + pub fn is_json(&self) -> bool { + matches!(self.oid, Some(114) | Some(3802)) || self.name == "json" || self.name == "jsonb" + } +} + +impl Default for SqlType { + fn default() -> Self { + Self::text() + } +} + +impl std::fmt::Display for SqlType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.to_sql_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basic_types() { + assert_eq!(SqlType::text().to_sql_string(), "text"); + assert_eq!(SqlType::integer().to_sql_string(), "integer"); + assert_eq!(SqlType::jsonb().to_sql_string(), "jsonb"); + } + + #[test] + fn test_array_types() { + assert_eq!(SqlType::text().into_array().to_sql_string(), "text[]"); + assert_eq!(SqlType::integer().as_array().to_sql_string(), "integer[]"); + } + + #[test] + fn test_from_name() { + let t = SqlType::from_name("text"); + assert_eq!(t.name, "text"); + assert!(!t.is_array); + + let t = SqlType::from_name("integer[]"); + assert_eq!(t.name, "integer"); + assert!(t.is_array); + + let t = SqlType::from_name("public.my_type"); + assert_eq!(t.schema, Some("public".to_string())); + assert_eq!(t.name, "my_type"); + } + + #[test] + fn test_schema_qualified() { + let t = SqlType::with_schema("myschema", "mytype"); + assert_eq!(t.to_sql_string(), "myschema.mytype"); + } +} diff --git a/src/builder.rs b/src/builder.rs index 6baf1526..fe8f42ce 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -5,7 +5,7 @@ use crate::gson; use crate::parser_util::*; use crate::sql_types::*; use graphql_parser::query::*; -use serde::Serialize; +use serde::ser::{Serialize, SerializeMap, Serializer}; use std::collections::HashMap; use std::hash::Hash; use std::ops::Deref; @@ -2159,7 +2159,7 @@ where // Introspection #[allow(clippy::large_enum_variant)] -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub enum __FieldField { Name, Description, @@ -2170,7 +2170,7 @@ pub enum __FieldField { Typename { alias: String, typename: String }, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct __FieldSelection { pub alias: String, pub selection: __FieldField, @@ -2182,7 +2182,7 @@ pub struct __FieldBuilder { pub selections: Vec<__FieldSelection>, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub enum __EnumValueField { Name, Description, @@ -2191,7 +2191,7 @@ pub enum __EnumValueField { Typename { alias: String, typename: String }, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct __EnumValueSelection { pub alias: String, pub selection: __EnumValueField, @@ -2204,7 +2204,7 @@ pub struct __EnumValueBuilder { } #[allow(clippy::large_enum_variant)] -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub enum __InputValueField { Name, Description, @@ -2215,7 +2215,7 @@ pub enum __InputValueField { Typename { alias: String, typename: String }, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct __InputValueSelection { pub alias: String, pub selection: __InputValueField, @@ -2281,9 +2281,8 @@ pub struct __DirectiveBuilder { pub selections: Vec<__DirectiveSelection>, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] #[allow(dead_code)] -#[serde(untagged)] pub enum __SchemaField { Description, Types(Vec<__TypeBuilder>), @@ -2294,7 +2293,7 @@ pub enum __SchemaField { Typename { alias: String, typename: String }, } -#[derive(Serialize, Clone, Debug)] +#[derive(Clone, Debug)] pub struct __SchemaSelection { pub alias: String, pub selection: __SchemaField, @@ -2974,3 +2973,222 @@ impl __Schema { } } } + +// Custom Serialize implementations for introspection builders +// These serialize based on the selected fields, not all fields + +impl Serialize for __FieldBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + + for selection in &self.selections { + match &selection.selection { + __FieldField::Name => { + map.serialize_entry(&selection.alias, &self.field.name())?; + } + __FieldField::Description => { + map.serialize_entry(&selection.alias, &self.field.description())?; + } + + __FieldField::IsDeprecated => { + map.serialize_entry(&selection.alias, &self.field.is_deprecated())?; + } + __FieldField::DeprecationReason => { + map.serialize_entry(&selection.alias, &self.field.deprecation_reason())?; + } + __FieldField::Arguments(input_value_builders) => { + map.serialize_entry(&selection.alias, input_value_builders)?; + } + __FieldField::Type(t) => { + map.serialize_entry(&selection.alias, t)?; + } + __FieldField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} + +impl Serialize for __TypeBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + + for selection in &self.selections { + match &selection.selection { + __TypeField::Kind => { + map.serialize_entry(&selection.alias, &format!("{:?}", self.type_.kind()))?; + } + __TypeField::Name => { + map.serialize_entry(&selection.alias, &self.type_.name())?; + } + __TypeField::Description => { + map.serialize_entry(&selection.alias, &self.type_.description())?; + } + __TypeField::Fields(fields) => { + map.serialize_entry(&selection.alias, fields)?; + } + __TypeField::InputFields(input_field_builders) => { + map.serialize_entry(&selection.alias, input_field_builders)?; + } + __TypeField::Interfaces(interfaces) => { + map.serialize_entry(&selection.alias, &interfaces)?; + } + __TypeField::EnumValues(enum_values) => { + map.serialize_entry(&selection.alias, enum_values)?; + } + __TypeField::PossibleTypes(possible_types) => { + map.serialize_entry(&selection.alias, &possible_types)?; + } + __TypeField::OfType(t_builder) => { + map.serialize_entry(&selection.alias, t_builder)?; + } + __TypeField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} + +impl Serialize for __DirectiveBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + for selection in &self.selections { + match &selection.selection { + __DirectiveField::Name => { + map.serialize_entry(&selection.alias, &self.directive.name())?; + } + __DirectiveField::Description => { + map.serialize_entry(&selection.alias, &self.directive.description())?; + } + __DirectiveField::Locations => { + map.serialize_entry(&selection.alias, &self.directive.locations())?; + } + __DirectiveField::Args(args) => { + map.serialize_entry(&selection.alias, args)?; + } + __DirectiveField::IsRepeatable => { + map.serialize_entry(&selection.alias, &self.directive.is_repeatable())?; + } + __DirectiveField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} + +impl Serialize for __SchemaBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + + for selection in &self.selections { + match &selection.selection { + __SchemaField::Description => { + map.serialize_entry(&selection.alias, &self.description)?; + } + __SchemaField::Types(type_builders) => { + map.serialize_entry(&selection.alias, &type_builders)?; + } + __SchemaField::QueryType(type_builder) => { + map.serialize_entry(&selection.alias, &type_builder)?; + } + __SchemaField::MutationType(type_builder) => { + map.serialize_entry(&selection.alias, &type_builder)?; + } + __SchemaField::SubscriptionType(type_builder) => { + map.serialize_entry(&selection.alias, &type_builder)?; + } + __SchemaField::Directives(directives) => { + map.serialize_entry(&selection.alias, directives)?; + } + __SchemaField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} + +impl Serialize for __InputValueBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + + for selection in &self.selections { + match &selection.selection { + __InputValueField::Name => { + map.serialize_entry(&selection.alias, &self.input_value.name())?; + } + __InputValueField::Description => { + map.serialize_entry(&selection.alias, &self.input_value.description())?; + } + __InputValueField::Type(type_builder) => { + map.serialize_entry(&selection.alias, &type_builder)?; + } + __InputValueField::DefaultValue => { + map.serialize_entry(&selection.alias, &self.input_value.default_value())?; + } + __InputValueField::IsDeprecated => { + map.serialize_entry(&selection.alias, &self.input_value.is_deprecated())?; + } + __InputValueField::DeprecationReason => { + map.serialize_entry(&selection.alias, &self.input_value.deprecation_reason())?; + } + __InputValueField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} + +impl Serialize for __EnumValueBuilder { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(self.selections.len()))?; + + for selection in &self.selections { + match &selection.selection { + __EnumValueField::Name => { + map.serialize_entry(&selection.alias, &self.enum_value.name())?; + } + __EnumValueField::Description => { + map.serialize_entry(&selection.alias, &self.enum_value.description())?; + } + __EnumValueField::IsDeprecated => { + map.serialize_entry(&selection.alias, &self.enum_value.is_deprecated())?; + } + __EnumValueField::DeprecationReason => { + map.serialize_entry(&selection.alias, &self.enum_value.deprecation_reason())?; + } + __EnumValueField::Typename { alias, typename } => { + map.serialize_entry(alias, typename)?; + } + } + } + map.end() + } +} diff --git a/src/executor/mod.rs b/src/executor/mod.rs new file mode 100644 index 00000000..c5777158 --- /dev/null +++ b/src/executor/mod.rs @@ -0,0 +1,51 @@ +//! Execution module for pg_graphql +//! +//! This module handles the execution of SQL statements generated from the AST. +//! It provides: +//! +//! - Execution plans that can contain one or more SQL statements +//! - Telemetry and logging for debugging +//! - pgrx-specific execution backend +//! +//! # Architecture +//! +//! The executor is designed to support future features like nested inserts +//! that require multiple SQL statements to be executed in a single transaction. +//! +//! ```text +//! ExecutionPlan +//! └── ExecutionStep[] +//! ├── stmt: Stmt (AST) +//! ├── params: Vec +//! └── depends_on: Vec +//! ``` +//! +//! For now, all operations use single-step plans, but the infrastructure +//! is in place for multi-step execution. + +mod plan; +mod telemetry; + +#[cfg(feature = "pg18")] +mod pgrx_backend; +#[cfg(feature = "pg17")] +mod pgrx_backend; +#[cfg(feature = "pg16")] +mod pgrx_backend; +#[cfg(feature = "pg15")] +mod pgrx_backend; +#[cfg(feature = "pg14")] +mod pgrx_backend; + +pub use plan::*; +pub use telemetry::*; + +// Re-export pgrx_backend when building with pgrx +#[cfg(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +))] +pub use pgrx_backend::*; diff --git a/src/executor/pgrx_backend.rs b/src/executor/pgrx_backend.rs new file mode 100644 index 00000000..06b48948 --- /dev/null +++ b/src/executor/pgrx_backend.rs @@ -0,0 +1,233 @@ +//! pgrx-specific execution backend +//! +//! This module handles the actual execution of SQL statements using pgrx's +//! SPI (Server Programming Interface). It converts AST parameters to pgrx +//! Datums and executes queries. + +use crate::ast::{render, Param, ParamValue}; +use crate::error::{GraphQLError, GraphQLResult}; +use crate::executor::{log_result, log_sql, ExecutionPlan}; +use pgrx::datum::DatumWithOid; +use pgrx::prelude::*; +use pgrx::spi::SpiClient; +use std::time::Instant; + +/// Convert AST parameters to pgrx Datums +/// +/// All parameters are converted to text and cast at the SQL level. +/// This matches the existing behavior in transpile.rs. +pub fn params_to_datums(params: &[Param]) -> GraphQLResult>> { + params.iter().map(param_to_datum).collect() +} + +fn param_to_datum(param: &Param) -> GraphQLResult> { + let datum = match ¶m.value { + ParamValue::Null => { + let null: Option = None; + null.into_datum() + } + ParamValue::Bool(b) => b.to_string().into_datum(), + ParamValue::String(s) => s.clone().into_datum(), + ParamValue::Integer(i) => i.to_string().into_datum(), + ParamValue::Float(f) => f.to_string().into_datum(), + ParamValue::Array(arr) => { + let strings: Vec> = arr + .iter() + .map(|v| match v { + ParamValue::Null => None, + ParamValue::String(s) => Some(s.clone()), + ParamValue::Integer(i) => Some(i.to_string()), + ParamValue::Float(f) => Some(f.to_string()), + ParamValue::Bool(b) => Some(b.to_string()), + ParamValue::Array(_) => None, // Nested arrays not supported + ParamValue::Json(j) => Some(j.to_string()), + }) + .collect(); + strings.into_datum() + } + ParamValue::Json(v) => v.to_string().into_datum(), + }; + + // Use text OID for all parameters - PostgreSQL will cast as needed + let oid = if param.sql_type.is_array { + pgrx::pg_sys::TEXTARRAYOID + } else { + pgrx::pg_sys::TEXTOID + }; + + Ok(unsafe { DatumWithOid::new(datum, oid) }) +} + +/// Execute a query plan and return JSON result +/// +/// This is used for SELECT queries (including those with CTEs). +pub fn execute_query(plan: &ExecutionPlan) -> GraphQLResult { + if !plan.is_single_step() { + return Err(GraphQLError::internal( + "Query execution currently only supports single-step plans", + )); + } + + let step = plan.main_step().ok_or_else(|| { + GraphQLError::internal("No main step in execution plan") + })?; + + let sql = render(&step.stmt); + let datums = params_to_datums(&step.params)?; + + log_sql(&sql, &step.params); + let start = Instant::now(); + + let result = Spi::connect(|client| { + let result = client.select(&sql, Some(1), &datums)?; + if result.is_empty() { + Ok(serde_json::Value::Null) + } else { + let jsonb: pgrx::JsonB = result + .first() + .get(1)? + .ok_or_else(|| spi::Error::InvalidPosition)?; + Ok(jsonb.0) + } + }); + + match &result { + Ok(_) => log_result(start, true), + Err(e) => { + log_result(start, false); + return Err(GraphQLError::sql_execution(format!("SPI error: {:?}", e))); + } + } + + result.map_err(|e: spi::Error| GraphQLError::sql_execution(format!("SPI error: {:?}", e))) +} + +/// Execute a mutation plan and return JSON result +/// +/// This is used for INSERT, UPDATE, DELETE operations. +/// It takes a mutable SPI client to participate in the current transaction. +pub fn execute_mutation<'conn>( + plan: &ExecutionPlan, + client: &'conn mut SpiClient<'conn>, +) -> GraphQLResult<(serde_json::Value, &'conn mut SpiClient<'conn>)> { + if !plan.is_single_step() { + // TODO: For nested inserts, implement multi-step execution + return Err(GraphQLError::internal( + "Mutation execution currently only supports single-step plans", + )); + } + + let step = plan.main_step().ok_or_else(|| { + GraphQLError::internal("No main step in execution plan") + })?; + + let sql = render(&step.stmt); + let datums = params_to_datums(&step.params)?; + + log_sql(&sql, &step.params); + let start = Instant::now(); + + let result = client.update(&sql, None, &datums); + + match &result { + Ok(_) => {} + Err(_) => { + log_result(start, false); + return Err(GraphQLError::sql_execution( + "Failed to execute mutation", + )); + } + } + + let res_q = result.map_err(|_| { + GraphQLError::sql_execution("Internal Error: Failed to execute transpiled query") + })?; + + let jsonb: pgrx::JsonB = match res_q.first().get::(1) { + Ok(Some(dat)) => dat, + Ok(None) => pgrx::JsonB(serde_json::Value::Null), + Err(e) => { + log_result(start, false); + return Err(GraphQLError::sql_generation(format!( + "Internal Error: Failed to load result from transpiled query: {e}" + ))); + } + }; + + log_result(start, true); + Ok((jsonb.0, client)) +} + +/// Execute a multi-step mutation plan (for future nested inserts) +/// +/// This executes all steps in dependency order within the same transaction. +#[allow(dead_code)] +pub fn execute_multi_step_mutation<'conn>( + plan: &ExecutionPlan, + client: &'conn mut SpiClient<'conn>, +) -> GraphQLResult<(Vec, &'conn mut SpiClient<'conn>)> { + let mut results = Vec::with_capacity(plan.steps.len()); + + for step in plan.steps_in_order() { + let sql = render(&step.stmt); + let datums = params_to_datums(&step.params)?; + + log_sql(&sql, &step.params); + let start = Instant::now(); + + let result = client.update(&sql, None, &datums).map_err(|_| { + log_result(start, false); + GraphQLError::sql_execution(format!( + "Failed to execute step '{}': {}", + step.id, step.description + )) + })?; + + let jsonb: pgrx::JsonB = match result.first().get::(1) { + Ok(Some(dat)) => dat, + Ok(None) => pgrx::JsonB(serde_json::Value::Null), + Err(e) => { + log_result(start, false); + return Err(GraphQLError::sql_generation(format!( + "Failed to get result from step '{}': {e}", + step.id + ))); + } + }; + + log_result(start, true); + results.push(jsonb.0); + } + + Ok((results, client)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_param_value_conversion() { + // Test that ParamValue conversions produce the expected string representations + let string_val = ParamValue::String("hello".into()); + assert!(!string_val.is_null()); + + let int_val = ParamValue::Integer(42); + assert_eq!(int_val.to_sql_literal(), Some("42".to_string())); + + let null_val = ParamValue::Null; + assert!(null_val.is_null()); + assert_eq!(null_val.to_sql_literal(), None); + + let array_val = ParamValue::Array(vec![ + ParamValue::Integer(1), + ParamValue::Integer(2), + ]); + let literal = array_val.to_sql_literal(); + assert!(literal.is_some()); + assert!(literal.unwrap().contains("1")); + } + + // Note: pgrx-specific tests (params_to_datums, execute_*) require + // a PostgreSQL connection and are tested via pg_regress instead. +} diff --git a/src/executor/plan.rs b/src/executor/plan.rs new file mode 100644 index 00000000..0fd4d20f --- /dev/null +++ b/src/executor/plan.rs @@ -0,0 +1,299 @@ +//! Execution plans for SQL statements +//! +//! An execution plan represents one or more SQL statements that need to be +//! executed together. Currently, all GraphQL operations result in single-step +//! plans, but the infrastructure supports multi-step plans for future features +//! like nested inserts. + +use crate::ast::{Param, Stmt}; +use std::time::Instant; + +/// An execution plan containing one or more SQL statements +#[derive(Debug)] +pub struct ExecutionPlan { + /// The steps to execute + pub steps: Vec, + /// Telemetry information for debugging + pub telemetry: PlanTelemetry, +} + +/// A single step in an execution plan +#[derive(Debug)] +pub struct ExecutionStep { + /// Unique identifier for this step + pub id: String, + /// The SQL statement (as AST) + pub stmt: Stmt, + /// Parameters for this statement + pub params: Vec, + /// Human-readable description of what this step does + pub description: String, + /// IDs of steps that must complete before this one + pub depends_on: Vec, +} + +/// Telemetry information attached to an execution plan +#[derive(Debug)] +pub struct PlanTelemetry { + /// The original GraphQL query (if available) + pub graphql_query: Option, + /// The operation name (if specified) + pub operation_name: Option, + /// When the plan was created + pub created_at: Instant, + /// Custom tags for categorization + pub tags: Vec<(String, String)>, +} + +impl Default for PlanTelemetry { + fn default() -> Self { + Self { + graphql_query: None, + operation_name: None, + created_at: Instant::now(), + tags: Vec::new(), + } + } +} + +impl PlanTelemetry { + pub fn new() -> Self { + Self::default() + } + + /// Set the GraphQL query + pub fn with_query(mut self, query: impl Into) -> Self { + self.graphql_query = Some(query.into()); + self + } + + /// Set the operation name + pub fn with_operation_name(mut self, name: impl Into) -> Self { + self.operation_name = Some(name.into()); + self + } + + /// Add a tag + pub fn with_tag(mut self, key: impl Into, value: impl Into) -> Self { + self.tags.push((key.into(), value.into())); + self + } + + /// Get elapsed time since plan creation + pub fn elapsed_ms(&self) -> u128 { + self.created_at.elapsed().as_millis() + } +} + +impl ExecutionPlan { + /// Create a single-step execution plan + /// + /// This is the most common case: one GraphQL operation = one SQL statement. + pub fn single(stmt: Stmt, params: Vec, description: impl Into) -> Self { + Self { + steps: vec![ExecutionStep { + id: "main".to_string(), + stmt, + params, + description: description.into(), + depends_on: vec![], + }], + telemetry: PlanTelemetry::default(), + } + } + + /// Create a multi-step execution plan + /// + /// Use this for operations that require multiple SQL statements, + /// such as nested inserts. + pub fn multi(steps: Vec) -> Self { + Self { + steps, + telemetry: PlanTelemetry::default(), + } + } + + /// Attach GraphQL context to the plan + pub fn with_graphql_context( + mut self, + query: impl Into, + operation_name: Option>, + ) -> Self { + self.telemetry.graphql_query = Some(query.into()); + self.telemetry.operation_name = operation_name.map(|n| n.into()); + self + } + + /// Attach telemetry to the plan + pub fn with_telemetry(mut self, telemetry: PlanTelemetry) -> Self { + self.telemetry = telemetry; + self + } + + /// Add a tag to the plan's telemetry + pub fn with_tag(mut self, key: impl Into, value: impl Into) -> Self { + self.telemetry.tags.push((key.into(), value.into())); + self + } + + /// Check if this is a single-step plan + pub fn is_single_step(&self) -> bool { + self.steps.len() == 1 + } + + /// Get the number of steps in the plan + pub fn step_count(&self) -> usize { + self.steps.len() + } + + /// Get the main step (for single-step plans) + pub fn main_step(&self) -> Option<&ExecutionStep> { + if self.steps.len() == 1 { + self.steps.first() + } else { + self.steps.iter().find(|s| s.id == "main") + } + } + + /// Get a step by ID + pub fn get_step(&self, id: &str) -> Option<&ExecutionStep> { + self.steps.iter().find(|s| s.id == id) + } + + /// Get steps in execution order (respecting dependencies) + /// + /// Returns steps sorted so that dependencies come before dependents. + /// Panics if there are circular dependencies. + pub fn steps_in_order(&self) -> Vec<&ExecutionStep> { + // Simple topological sort + let mut result = Vec::new(); + let mut completed: std::collections::HashSet<&str> = std::collections::HashSet::new(); + + while result.len() < self.steps.len() { + let mut made_progress = false; + + for step in &self.steps { + if completed.contains(step.id.as_str()) { + continue; + } + + let deps_satisfied = step + .depends_on + .iter() + .all(|dep| completed.contains(dep.as_str())); + + if deps_satisfied { + result.push(step); + completed.insert(&step.id); + made_progress = true; + } + } + + if !made_progress && result.len() < self.steps.len() { + panic!("Circular dependency detected in execution plan"); + } + } + + result + } +} + +impl ExecutionStep { + /// Create a new execution step + pub fn new( + id: impl Into, + stmt: Stmt, + params: Vec, + description: impl Into, + ) -> Self { + Self { + id: id.into(), + stmt, + params, + description: description.into(), + depends_on: vec![], + } + } + + /// Add a dependency on another step + pub fn depends_on(mut self, step_id: impl Into) -> Self { + self.depends_on.push(step_id.into()); + self + } + + /// Add multiple dependencies + pub fn depends_on_all(mut self, step_ids: Vec>) -> Self { + for id in step_ids { + self.depends_on.push(id.into()); + } + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::*; + + fn dummy_stmt() -> Stmt { + Stmt::Select(SelectStmt::columns(vec![SelectColumn::star()])) + } + + #[test] + fn test_single_step_plan() { + let plan = ExecutionPlan::single(dummy_stmt(), vec![], "test query"); + + assert!(plan.is_single_step()); + assert_eq!(plan.step_count(), 1); + assert!(plan.main_step().is_some()); + } + + #[test] + fn test_multi_step_plan() { + let steps = vec![ + ExecutionStep::new("step1", dummy_stmt(), vec![], "first"), + ExecutionStep::new("step2", dummy_stmt(), vec![], "second").depends_on("step1"), + ]; + + let plan = ExecutionPlan::multi(steps); + + assert!(!plan.is_single_step()); + assert_eq!(plan.step_count(), 2); + } + + #[test] + fn test_steps_in_order() { + let steps = vec![ + ExecutionStep::new("c", dummy_stmt(), vec![], "third") + .depends_on("a") + .depends_on("b"), + ExecutionStep::new("a", dummy_stmt(), vec![], "first"), + ExecutionStep::new("b", dummy_stmt(), vec![], "second").depends_on("a"), + ]; + + let plan = ExecutionPlan::multi(steps); + let ordered: Vec<&str> = plan.steps_in_order().iter().map(|s| s.id.as_str()).collect(); + + // a must come before b, b must come before c + let a_idx = ordered.iter().position(|&id| id == "a").unwrap(); + let b_idx = ordered.iter().position(|&id| id == "b").unwrap(); + let c_idx = ordered.iter().position(|&id| id == "c").unwrap(); + + assert!(a_idx < b_idx); + assert!(b_idx < c_idx); + } + + #[test] + fn test_telemetry() { + let plan = ExecutionPlan::single(dummy_stmt(), vec![], "test") + .with_graphql_context("query { users { id } }", Some("GetUsers")) + .with_tag("type", "query"); + + assert!(plan.telemetry.graphql_query.is_some()); + assert_eq!( + plan.telemetry.operation_name, + Some("GetUsers".to_string()) + ); + assert!(!plan.telemetry.tags.is_empty()); + } +} diff --git a/src/executor/telemetry.rs b/src/executor/telemetry.rs new file mode 100644 index 00000000..10f3075b --- /dev/null +++ b/src/executor/telemetry.rs @@ -0,0 +1,289 @@ +//! Telemetry and logging for SQL execution +//! +//! This module provides configurable logging for debugging SQL generation +//! and execution. It can be controlled via environment variables or +//! PostgreSQL GUC settings. +//! +//! # Configuration +//! +//! Set the `PG_GRAPHQL_LOG_LEVEL` environment variable to one of: +//! - `off` - No logging (default) +//! - `basic` - Log SQL and timing only +//! - `detailed` - Log SQL, parameters, and GraphQL context +//! - `debug` - Log everything including AST dumps +//! +//! # Example +//! +//! ```bash +//! export PG_GRAPHQL_LOG_LEVEL=detailed +//! ``` + +use crate::ast::Param; +use std::time::Instant; + +/// Log level for SQL telemetry +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum LogLevel { + /// No logging + Off = 0, + /// Basic info: SQL and timing + Basic = 1, + /// Detailed: SQL, parameters, GraphQL context + Detailed = 2, + /// Debug: Everything including AST + Debug = 3, +} + +impl LogLevel { + /// Parse from string + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "basic" => Self::Basic, + "detailed" => Self::Detailed, + "debug" => Self::Debug, + _ => Self::Off, + } + } +} + +impl Default for LogLevel { + fn default() -> Self { + Self::Off + } +} + +/// Get current log level from environment +/// +/// Checks `PG_GRAPHQL_LOG_LEVEL` environment variable. +pub fn get_log_level() -> LogLevel { + std::env::var("PG_GRAPHQL_LOG_LEVEL") + .map(|s| LogLevel::from_str(&s)) + .unwrap_or(LogLevel::Off) +} + +/// Log an execution plan +#[cfg(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +))] +pub fn log_plan(plan: &super::ExecutionPlan) { + let level = get_log_level(); + if level < LogLevel::Basic { + return; + } + + pgrx::info!( + "pg_graphql: Execution plan with {} step(s), elapsed: {}ms", + plan.steps.len(), + plan.telemetry.elapsed_ms() + ); + + if level >= LogLevel::Detailed { + if let Some(query) = &plan.telemetry.graphql_query { + let truncated = if query.len() > 500 { + format!("{}...", &query[..500]) + } else { + query.clone() + }; + pgrx::info!("pg_graphql: GraphQL: {}", truncated); + } + + if let Some(op_name) = &plan.telemetry.operation_name { + pgrx::info!("pg_graphql: Operation: {}", op_name); + } + + for (key, value) in &plan.telemetry.tags { + pgrx::info!("pg_graphql: Tag {}: {}", key, value); + } + } +} + +/// Log SQL execution +#[cfg(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +))] +pub fn log_sql(sql: &str, params: &[Param]) { + let level = get_log_level(); + if level < LogLevel::Basic { + return; + } + + // Truncate very long SQL for basic logging + let sql_display = if level >= LogLevel::Detailed { + sql.to_string() + } else if sql.len() > 1000 { + format!("{}...", &sql[..1000]) + } else { + sql.to_string() + }; + + pgrx::info!("pg_graphql: SQL:\n{}", sql_display); + + if level >= LogLevel::Detailed && !params.is_empty() { + for param in params { + pgrx::info!( + "pg_graphql: Param ${}: {:?} ({})", + param.index, + param.value, + param.sql_type + ); + } + } +} + +/// Log execution result +#[cfg(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +))] +pub fn log_result(start: Instant, success: bool) { + let level = get_log_level(); + if level < LogLevel::Basic { + return; + } + + let duration_ms = start.elapsed().as_millis(); + let status = if success { "completed" } else { "failed" }; + + pgrx::info!("pg_graphql: Execution {} in {}ms", status, duration_ms); +} + +/// Log an error +#[cfg(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +))] +pub fn log_error(context: &str, error: &str) { + let level = get_log_level(); + if level < LogLevel::Basic { + return; + } + + pgrx::warning!("pg_graphql: Error in {}: {}", context, error); +} + +// Non-pgrx versions for testing + +#[cfg(not(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +)))] +pub fn log_plan(_plan: &super::ExecutionPlan) {} + +#[cfg(not(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +)))] +pub fn log_sql(_sql: &str, _params: &[Param]) {} + +#[cfg(not(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +)))] +pub fn log_result(_start: Instant, _success: bool) {} + +#[cfg(not(any( + feature = "pg14", + feature = "pg15", + feature = "pg16", + feature = "pg17", + feature = "pg18" +)))] +pub fn log_error(_context: &str, _error: &str) {} + +/// A guard that logs execution timing on drop +pub struct ExecutionTimer { + start: Instant, + context: String, + #[allow(dead_code)] + logged: bool, +} + +impl ExecutionTimer { + /// Start a new execution timer + pub fn new(context: impl Into) -> Self { + Self { + start: Instant::now(), + context: context.into(), + logged: false, + } + } + + /// Get elapsed time in milliseconds + pub fn elapsed_ms(&self) -> u128 { + self.start.elapsed().as_millis() + } + + /// Mark as successful and log + pub fn success(mut self) { + self.logged = true; + log_result(self.start, true); + } + + /// Mark as failed and log + pub fn failure(mut self, error: &str) { + self.logged = true; + log_error(&self.context, error); + log_result(self.start, false); + } +} + +impl Drop for ExecutionTimer { + fn drop(&mut self) { + // Log if not already logged (implicit failure) + if !self.logged { + log_result(self.start, false); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_log_level_parsing() { + assert_eq!(LogLevel::from_str("off"), LogLevel::Off); + assert_eq!(LogLevel::from_str("basic"), LogLevel::Basic); + assert_eq!(LogLevel::from_str("detailed"), LogLevel::Detailed); + assert_eq!(LogLevel::from_str("debug"), LogLevel::Debug); + assert_eq!(LogLevel::from_str("BASIC"), LogLevel::Basic); + assert_eq!(LogLevel::from_str("invalid"), LogLevel::Off); + } + + #[test] + fn test_log_level_ordering() { + assert!(LogLevel::Off < LogLevel::Basic); + assert!(LogLevel::Basic < LogLevel::Detailed); + assert!(LogLevel::Detailed < LogLevel::Debug); + } + + #[test] + fn test_execution_timer() { + let timer = ExecutionTimer::new("test"); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert!(timer.elapsed_ms() >= 10); + } +} diff --git a/src/lib.rs b/src/lib.rs index 88ba19d3..8fd23161 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,9 +5,11 @@ use pgrx::*; use resolve::resolve_inner; use serde_json::json; +pub mod ast; mod builder; mod constants; mod error; +pub mod executor; mod graphql; mod gson; mod merge; @@ -15,7 +17,6 @@ mod omit; mod parser_util; mod resolve; mod sql_types; -mod transpile; pg_module_magic!(); diff --git a/src/resolve.rs b/src/resolve.rs index 2f734bba..76940b9c 100644 --- a/src/resolve.rs +++ b/src/resolve.rs @@ -8,7 +8,6 @@ use crate::graphql::*; use crate::omit::*; use crate::parser_util::*; use crate::sql_types::get_one_readonly; -use crate::transpile::{MutationEntrypoint, QueryEntrypoint}; use graphql_parser::query::Selection; use graphql_parser::query::{ Definition, Document, FragmentDefinition, Mutation, OperationDefinition, Query, SelectionSet, @@ -17,6 +16,8 @@ use graphql_parser::query::{ use itertools::Itertools; use serde_json::{json, Value}; +use crate::ast::AstExecutable; + #[allow(non_snake_case)] pub fn resolve_inner<'a, T>( document: Document<'a, T>, @@ -220,14 +221,18 @@ where ); match connection_builder { - Ok(builder) => match builder.execute() { - Ok(d) => { - res_data[alias_or_name(selection)] = d; + Ok(builder) => { + let result = builder.execute_via_ast(); + + match result { + Ok(d) => { + res_data[alias_or_name(selection)] = d; + } + Err(msg) => res_errors.push(ErrorMessage { + message: msg.to_string(), + }), } - Err(msg) => res_errors.push(ErrorMessage { - message: msg.to_string(), - }), - }, + } Err(msg) => res_errors.push(ErrorMessage { message: msg.to_string(), }), @@ -244,14 +249,18 @@ where ); match node_builder { - Ok(builder) => match builder.execute() { - Ok(d) => { - res_data[alias_or_name(selection)] = d; + Ok(builder) => { + let result = builder.execute_via_ast(); + + match result { + Ok(d) => { + res_data[alias_or_name(selection)] = d; + } + Err(msg) => res_errors.push(ErrorMessage { + message: msg.to_string(), + }), } - Err(msg) => res_errors.push(ErrorMessage { - message: msg.to_string(), - }), - }, + } Err(msg) => res_errors.push(ErrorMessage { message: msg.to_string(), }), @@ -318,9 +327,9 @@ where match function_call_builder { Ok(builder) => { - match ::execute( - &builder, - ) { + let result = builder.execute_via_ast(); + + match result { Ok(d) => { res_data[alias_or_name(selection)] = d; } @@ -451,7 +460,7 @@ where } }; - let (d, conn) = builder.execute(conn)?; + let (d, conn) = builder.execute_mutation_via_ast(conn)?; res_data[alias_or_name(selection)] = d; conn @@ -470,7 +479,7 @@ where } }; - let (d, conn) = builder.execute(conn)?; + let (d, conn) = builder.execute_mutation_via_ast(conn)?; res_data[alias_or_name(selection)] = d; conn } @@ -488,7 +497,7 @@ where } }; - let (d, conn) = builder.execute(conn)?; + let (d, conn) = builder.execute_mutation_via_ast(conn)?; res_data[alias_or_name(selection)] = d; conn } @@ -513,9 +522,7 @@ where }; let (d, conn) = - ::execute( - &builder, conn, - )?; + builder.execute_mutation_via_ast(conn)?; res_data[alias_or_name(selection)] = d; conn } diff --git a/src/transpile.rs b/src/transpile.rs deleted file mode 100644 index aa6fbaa2..00000000 --- a/src/transpile.rs +++ /dev/null @@ -1,1954 +0,0 @@ -use crate::builder::*; -use crate::constants::aggregate; -use crate::error::{GraphQLError, GraphQLResult}; -use crate::graphql::*; -use crate::sql_types::{Column, ForeignKey, ForeignKeyTableInfo, Function, Table, TypeDetails}; -use itertools::Itertools; -use pgrx::datum::DatumWithOid; -use pgrx::pg_sys::PgBuiltInOids; -use pgrx::prelude::*; -use pgrx::spi::SpiClient; -use pgrx::{direct_function_call, JsonB}; -use serde::ser::{Serialize, SerializeMap, Serializer}; -use std::cmp; -use std::collections::HashSet; -use std::sync::Arc; - -pub fn quote_ident(ident: &str) -> String { - unsafe { - direct_function_call::(pg_sys::quote_ident, &[ident.into_datum()]) - .expect("failed to quote ident") - } -} - -pub fn quote_literal(ident: &str) -> String { - unsafe { - direct_function_call::(pg_sys::quote_literal, &[ident.into_datum()]) - .expect("failed to quote literal") - } -} - -pub fn rand_block_name() -> String { - use rand::distributions::Alphanumeric; - use rand::{thread_rng, Rng}; - quote_ident( - &thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect::() - .to_lowercase(), - ) -} - -pub trait MutationEntrypoint<'conn> { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult; - - fn execute<'c>( - &self, - conn: &'c mut SpiClient<'conn>, - ) -> GraphQLResult<(serde_json::Value, &'c mut SpiClient<'conn>)> { - let mut param_context = ParamContext { params: vec![] }; - let sql = &self.to_sql_entrypoint(&mut param_context); - let sql = match sql { - Ok(sql) => sql, - Err(err) => { - return Err(err.clone()); - } - }; - - let res_q = conn.update(sql, None, ¶m_context.params).map_err(|_| { - GraphQLError::sql_execution("Internal Error: Failed to execute transpiled query") - })?; - - let res: pgrx::JsonB = match res_q.first().get::(1) { - Ok(Some(dat)) => dat, - Ok(None) => JsonB(serde_json::Value::Null), - Err(e) => { - return Err(GraphQLError::sql_generation(format!( - "Internal Error: Failed to load result from transpiled query: {e}" - ))); - } - }; - - Ok((res.0, conn)) - } -} - -pub trait QueryEntrypoint { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult; - - fn execute(&self) -> GraphQLResult { - let mut param_context = ParamContext { params: vec![] }; - let sql = &self.to_sql_entrypoint(&mut param_context); - let sql = match sql { - Ok(sql) => sql, - Err(err) => { - return Err(err.clone()); - } - }; - - let spi_result: Result, spi::Error> = Spi::connect(|c| { - let val = c.select(sql, Some(1), ¶m_context.params)?; - // Get a value from the query - if val.is_empty() { - Ok(None) - } else { - val.first().get::(1) - } - }); - - match spi_result { - Ok(Some(jsonb)) => Ok(jsonb.0), - Ok(None) => Ok(serde_json::Value::Null), - _ => Err(GraphQLError::internal( - "Internal Error: Failed to execute transpiled query", - )), - } - } -} - -impl Table { - fn to_selectable_columns_clause(&self) -> String { - self.columns - .iter() - .filter(|x| x.permissions.is_selectable) - .map(|x| quote_ident(&x.name)) - .collect::>() - .join(", ") - } - - /// a priamry key tuple clause selects the columns of the primary key as a composite record - /// that is useful in "has_previous_page" by letting us compare records on a known unique key - fn to_primary_key_tuple_clause(&self, block_name: &str) -> String { - let pkey_cols: Vec<&Arc> = self.primary_key_columns(); - - let pkey_frags: Vec = pkey_cols - .iter() - .map(|x| format!("{block_name}.{}", quote_ident(&x.name))) - .collect(); - - format!("({})", pkey_frags.join(",")) - } - - fn to_cursor_clause(&self, block_name: &str, order_by: &OrderByBuilder) -> String { - let frags: Vec = order_by - .elems - .iter() - .map(|x| { - let quoted_col_name = quote_ident(&x.column.name); - format!("to_jsonb({block_name}.{quoted_col_name})") - }) - .collect(); - - let clause = frags.join(", "); - - format!("translate(encode(convert_to(jsonb_build_array({clause})::text, 'utf-8'), 'base64'), E'\n', '')") - } - - #[allow(clippy::only_used_in_recursion)] - fn to_pagination_clause( - &self, - block_name: &str, - order_by: &OrderByBuilder, - cursor: &Cursor, - param_context: &mut ParamContext, - allow_equality: bool, - ) -> GraphQLResult { - // When paginating, allowe_equality should be false because we don't want to - // include the cursor's record in the page - // - // when checking to see if a previous page exists, allow_equality should be - // true, in combination with a reversed order_by because the existence of the - // cursor's record proves that there is a previous page - - // [id asc, name desc] - /* - "( - ( id > x1 or ( id is not null and x1 is null and )) - or (( id = x1 or ( id is null and x1 is null )) and ) - - )" - */ - if cursor.elems.is_empty() { - return Ok(format!("{allow_equality}")); - } - let mut next_cursor = cursor.clone(); - let cursor_elem = next_cursor.elems.remove(0); - - if order_by.elems.is_empty() { - return Err(GraphQLError::validation( - "orderBy clause incompatible with pagination cursor", - )); - } - let mut next_order_by = order_by.clone(); - let order_elem = next_order_by.elems.remove(0); - - let column = order_elem.column; - let quoted_col = quote_ident(&column.name); - - let val = cursor_elem.value; - - let val_clause = param_context.clause_for(&val, &column.type_name)?; - - let recurse_clause = self.to_pagination_clause( - block_name, - &next_order_by, - &next_cursor, - param_context, - allow_equality, - )?; - - let nulls_first: bool = order_elem.direction.nulls_first(); - - let op = match order_elem.direction.is_asc() { - true => ">", - false => "<", - }; - - Ok(format!("( - ( {block_name}.{quoted_col} {op} {val_clause} or ( {block_name}.{quoted_col} is not null and {val_clause} is null and {nulls_first})) - or (( {block_name}.{quoted_col} = {val_clause} or ( {block_name}.{quoted_col} is null and {val_clause} is null)) and {recurse_clause}) - - )")) - } - - fn to_join_clause( - &self, - fkey: &ForeignKey, - reverse_reference: bool, - quoted_block_name: &str, - quoted_parent_block_name: &str, - ) -> GraphQLResult { - let mut equality_clauses = vec!["true".to_string()]; - - let table_ref: &ForeignKeyTableInfo; - let foreign_ref: &ForeignKeyTableInfo; - - match reverse_reference { - true => { - table_ref = &fkey.local_table_meta; - foreign_ref = &fkey.referenced_table_meta; - } - false => { - table_ref = &fkey.referenced_table_meta; - foreign_ref = &fkey.local_table_meta; - } - }; - - for (local_col_name, parent_col_name) in table_ref - .column_names - .iter() - .zip(foreign_ref.column_names.iter()) - { - let quoted_parent_literal_col = format!( - "{}.{}", - quoted_parent_block_name, - quote_ident(parent_col_name) - ); - let quoted_local_literal_col = - format!("{}.{}", quoted_block_name, quote_ident(local_col_name)); - - let equality_clause = format!( - "{} = {}", - quoted_local_literal_col, quoted_parent_literal_col - ); - - equality_clauses.push(equality_clause); - } - Ok(equality_clauses.join(" and ")) - } -} - -impl MutationEntrypoint<'_> for InsertBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - let quoted_schema = quote_ident(&self.table.schema); - let quoted_table = quote_ident(&self.table.name); - - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql("ed_block_name, param_context)) - .collect::, _>>()?; - - let selectable_columns_clause = self.table.to_selectable_columns_clause(); - - let select_clause = frags.join(", "); - - // Identify all columns provided in any of `object` rows - let referenced_column_names: HashSet<&String> = - self.objects.iter().flat_map(|x| x.row.keys()).collect(); - - let referenced_columns: Vec<&Arc> = self - .table - .columns - .iter() - .filter(|c| referenced_column_names.contains(&c.name)) - .collect(); - - // Order matters. This must be in the same order as `referenced_columns` - let referenced_columns_clause: String = referenced_columns - .iter() - .map(|c| quote_ident(&c.name)) - .collect::>() - .join(", "); - - let mut values_rows_clause: Vec = vec![]; - - for row_map in &self.objects { - let mut working_row = vec![]; - for column in referenced_columns.iter() { - let elem_clause = match row_map.row.get(&column.name) { - None => "default".to_string(), - Some(elem) => match elem { - InsertElemValue::Default => "default".to_string(), - InsertElemValue::Value(val) => { - param_context.clause_for(val, &column.type_name)? - } - }, - }; - working_row.push(elem_clause); - } - // (1, 'hello', 5) - let insert_row_clause = format!("({})", working_row.join(", ")); - values_rows_clause.push(insert_row_clause); - } - - let values_clause = values_rows_clause.join(", "); - - Ok(format!( - " - with affected as ( - insert into {quoted_schema}.{quoted_table}({referenced_columns_clause}) - values {values_clause} - returning {selectable_columns_clause} - ) - select - jsonb_build_object({select_clause}) - from - affected as {quoted_block_name}; - " - )) - } -} - -impl InsertSelection { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let r = match self { - Self::AffectedCount { alias } => { - format!("{}, count(*)", quote_literal(alias)) - } - Self::Records(x) => { - format!( - "{}, coalesce(jsonb_agg({}), jsonb_build_array())", - quote_literal(&x.alias), - x.to_sql(block_name, param_context)? - ) - } - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }; - Ok(r) - } -} - -impl UpdateSelection { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let r = match self { - Self::AffectedCount { alias } => { - format!("{}, count(*)", quote_literal(alias)) - } - Self::Records(x) => { - format!( - "{}, coalesce(jsonb_agg({}), jsonb_build_array())", - quote_literal(&x.alias), - x.to_sql(block_name, param_context)? - ) - } - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }; - Ok(r) - } -} - -impl DeleteSelection { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let r = match self { - Self::AffectedCount { alias } => { - format!("{}, count(*)", quote_literal(alias)) - } - Self::Records(x) => { - format!( - "{}, coalesce(jsonb_agg({}), jsonb_build_array())", - quote_literal(&x.alias), - x.to_sql(block_name, param_context)? - ) - } - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }; - - Ok(r) - } -} - -impl MutationEntrypoint<'_> for UpdateBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - let quoted_schema = quote_ident(&self.table.schema); - let quoted_table = quote_ident(&self.table.name); - - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql("ed_block_name, param_context)) - .collect::, _>>()?; - - let select_clause = frags.join(", "); - - let set_clause: String = { - let mut set_clause_frags = vec![]; - for (column_name, val) in &self.set.set { - let quoted_column = quote_ident(column_name); - - let column: &Column = self - .table - .columns - .iter() - .find(|x| &x.name == column_name) - .expect("Failed to find field in update builder"); - - let value_clause = param_context.clause_for(val, &column.type_name)?; - - let set_clause_frag = format!("{quoted_column} = {value_clause}"); - set_clause_frags.push(set_clause_frag); - } - set_clause_frags.join(", ") - }; - - let selectable_columns_clause = self.table.to_selectable_columns_clause(); - - let where_clause = - self.filter - .to_where_clause("ed_block_name, &self.table, param_context)?; - - let at_most = self.at_most; - - Ok(format!( - " - with impacted as ( - update {quoted_schema}.{quoted_table} as {quoted_block_name} - set {set_clause} - where {where_clause} - returning {selectable_columns_clause} - ), - total(total_count) as ( - select - count(*) - from - impacted - ), - req(res) as ( - select - jsonb_build_object({select_clause}) - from - impacted {quoted_block_name} - limit 1 - ), - wrapper(res) as ( - select - case - when total.total_count > {at_most} then graphql.exception($a$update impacts too many records$a$)::jsonb - else req.res - end - from - total - left join req - on true - limit 1 - ) - select - res - from - wrapper; - " - )) - } -} - -impl MutationEntrypoint<'_> for DeleteBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - let quoted_schema = quote_ident(&self.table.schema); - let quoted_table = quote_ident(&self.table.name); - - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql("ed_block_name, param_context)) - .collect::, _>>()?; - - let select_clause = frags.join(", "); - let where_clause = - self.filter - .to_where_clause("ed_block_name, &self.table, param_context)?; - - let selectable_columns_clause = self.table.to_selectable_columns_clause(); - - let at_most = self.at_most; - - Ok(format!( - " - with impacted as ( - delete from {quoted_schema}.{quoted_table} as {quoted_block_name} - where {where_clause} - returning {selectable_columns_clause} - ), - total(total_count) as ( - select - count(*) - from - impacted - ), - req(res) as ( - select - jsonb_build_object({select_clause}) - from - impacted {quoted_block_name} - limit 1 - ), - wrapper(res) as ( - select - case - when total.total_count > {at_most} then graphql.exception($a$delete impacts too many records$a$)::jsonb - else req.res - end - from - total - left join req - on true - limit 1 - ) - select - res - from - wrapper; - " - )) - } -} - -impl FunctionCallBuilder { - fn to_sql(&self, param_context: &mut ParamContext) -> GraphQLResult { - let mut arg_clauses = vec![]; - for (arg, arg_value) in &self.args_builder.args { - if let Some(arg) = arg { - let arg_clause = param_context.clause_for(arg_value, &arg.type_name)?; - let named_arg_clause = format!("{} => {}", quote_ident(&arg.name), arg_clause); - arg_clauses.push(named_arg_clause); - } - } - - let args_clause = format!("({})", arg_clauses.join(", ")); - - let block_name = &rand_block_name(); - let func_schema = quote_ident(&self.function.schema_name); - let func_name = quote_ident(&self.function.name); - - let query = match &self.return_type_builder { - FuncCallReturnTypeBuilder::Scalar | FuncCallReturnTypeBuilder::List => { - let type_adjustment_clause = apply_suffix_casts(self.function.type_oid); - format!("select to_jsonb({func_schema}.{func_name}{args_clause}{type_adjustment_clause}) {block_name};") - } - FuncCallReturnTypeBuilder::Node(node_builder) => { - let select_clause = node_builder.to_sql(block_name, param_context)?; - let select_clause = if select_clause.is_empty() { - "jsonb_build_object()".to_string() - } else { - select_clause - }; - format!("select coalesce((select {select_clause} from {func_schema}.{func_name}{args_clause} {block_name} where not ({block_name} is null)), null::jsonb);") - } - FuncCallReturnTypeBuilder::Connection(connection_builder) => { - let from_clause = format!("{func_schema}.{func_name}{args_clause}"); - let select_clause = connection_builder.to_sql( - Some(block_name), - param_context, - None, - Some(from_clause), - )?; - select_clause.to_string() - } - }; - - Ok(query) - } -} - -impl MutationEntrypoint<'_> for FunctionCallBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - self.to_sql(param_context) - } -} - -impl QueryEntrypoint for FunctionCallBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - self.to_sql(param_context) - } -} - -impl OrderByBuilder { - fn to_order_by_clause(&self, block_name: &str) -> String { - let mut frags = vec![]; - - for elem in &self.elems { - let quoted_column_name = quote_ident(&elem.column.name); - let direction_clause = match elem.direction { - OrderDirection::AscNullsFirst => "asc nulls first", - OrderDirection::AscNullsLast => "asc nulls last", - OrderDirection::DescNullsFirst => "desc nulls first", - OrderDirection::DescNullsLast => "desc nulls last", - }; - let elem_clause = format!("{block_name}.{quoted_column_name} {direction_clause}"); - frags.push(elem_clause) - } - frags.join(", ") - } -} - -pub fn json_to_text_datum(val: &serde_json::Value) -> GraphQLResult> { - use serde_json::Value; - let null: Option = None; - match val { - Value::Null => Ok(null.into_datum()), - Value::Bool(x) => Ok(x.to_string().into_datum()), - Value::String(x) => Ok(x.into_datum()), - Value::Number(x) => Ok(x.to_string().into_datum()), - Value::Array(xarr) => { - let mut inner_vals: Vec> = vec![]; - for elem in xarr { - let str_elem = match elem { - Value::Null => None, - Value::Bool(x) => Some(x.to_string()), - Value::String(x) => Some(x.to_string()), - Value::Number(x) => Some(x.to_string()), - Value::Array(_) => { - return Err(GraphQLError::type_error( - "Unexpected array in input value array", - )); - } - Value::Object(_) => { - return Err(GraphQLError::validation( - "Unexpected object in input value array", - )); - } - }; - inner_vals.push(str_elem); - } - Ok(inner_vals.into_datum()) - } - // Should this ever happen? json input is escaped so it would be a string. - Value::Object(_) => Err(GraphQLError::validation("Unexpected object in input value")), - } -} - -pub struct ParamContext<'src> { - pub params: Vec>, -} - -impl<'src> ParamContext<'src> { - // Pushes a parameter into the context and returns a SQL clause to reference it - //fn clause_for(&mut self, param: (PgOid, Option)) -> String { - fn clause_for(&mut self, value: &serde_json::Value, type_name: &str) -> GraphQLResult { - let type_oid = match type_name.ends_with("[]") { - true => PgOid::BuiltIn(PgBuiltInOids::TEXTARRAYOID), - false => PgOid::BuiltIn(PgBuiltInOids::TEXTOID), - }; - - let val_datum = json_to_text_datum(value)?; - let datum_with_oid = unsafe { DatumWithOid::new(val_datum, type_oid.value()) }; - self.params.push(datum_with_oid); - Ok(format!("(${}::{})", self.params.len(), type_name)) - } -} - -impl FilterBuilderElem { - fn to_sql( - &self, - block_name: &str, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - match self { - Self::Column { column, op, value } => { - let frag = match op { - FilterOp::Is => { - format!( - "{block_name}.{} {}", - quote_ident(&column.name), - match value { - serde_json::Value::String(x) => { - match x.as_str() { - "NULL" => "is null", - "NOT_NULL" => "is not null", - _ => { - return Err(GraphQLError::sql_generation( - "Error transpiling Is filter value", - )) - } - } - } - _ => { - return Err(GraphQLError::sql_generation( - "Error transpiling Is filter value type", - )); - } - } - ) - } - _ => { - let cast_type_name = match op { - FilterOp::In => format!("{}[]", column.type_name), - FilterOp::Contains => format!("{}[]", column.type_name), - FilterOp::ContainedBy => format!("{}[]", column.type_name), - FilterOp::Overlap => format!("{}[]", column.type_name), - _ => column.type_name.clone(), - }; - - let val_clause = param_context.clause_for(value, &cast_type_name)?; - - format!( - "{block_name}.{} {} {}", - quote_ident(&column.name), - match op { - FilterOp::Equal => "=", - FilterOp::NotEqual => "<>", - FilterOp::LessThan => "<", - FilterOp::LessThanEqualTo => "<=", - FilterOp::GreaterThan => ">", - FilterOp::GreaterThanEqualTo => ">=", - FilterOp::In => "= any", - FilterOp::StartsWith => "^@", - FilterOp::Like => "like", - FilterOp::ILike => "ilike", - FilterOp::RegEx => "~", - FilterOp::IRegEx => "~*", - FilterOp::Contains => "@>", - FilterOp::ContainedBy => "<@", - FilterOp::Overlap => "&&", - FilterOp::Is => { - return Err(GraphQLError::sql_generation( - "Error transpiling Is filter", - )); - } - }, - val_clause - ) - } - }; - Ok(frag) - } - Self::NodeId(node_id) => node_id.to_sql(block_name, table, param_context), - FilterBuilderElem::Compound(compound_builder) => { - compound_builder.to_sql(block_name, table, param_context) - } - } - } -} - -impl CompoundFilterBuilder { - fn to_sql( - &self, - block_name: &str, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - Ok(match self { - CompoundFilterBuilder::And(elements) => { - let bool_expressions = elements - .iter() - .map(|e| e.to_sql(block_name, table, param_context)) - .collect::, _>>()?; - format!("({})", bool_expressions.join(" and ")) - } - CompoundFilterBuilder::Or(elements) => { - let bool_expressions = elements - .iter() - .map(|e| e.to_sql(block_name, table, param_context)) - .collect::, _>>()?; - format!("({})", bool_expressions.join(" or ")) - } - CompoundFilterBuilder::Not(elem) => { - format!("not({})", elem.to_sql(block_name, table, param_context)?) - } - }) - } -} - -impl FilterBuilder { - fn to_where_clause( - &self, - block_name: &str, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let mut frags = vec!["true".to_string()]; - - for elem in &self.elems { - let frag = elem.to_sql(block_name, table, param_context)?; - frags.push(frag); - } - Ok(frags.join(" and ")) - } -} - -pub struct FromFunction { - function: Arc, - input_table: Arc, - // The block name for the functions argument - input_block_name: String, -} - -impl ConnectionBuilder { - fn page_selections(&self) -> Vec { - self.selections - .iter() - .flat_map(|x| match x { - ConnectionSelection::PageInfo(page_info_builder) => { - page_info_builder.selections.clone() - } - _ => vec![], - }) - .collect() - } - - fn requested_next_page(&self) -> bool { - self.page_selections() - .iter() - .any(|x| matches!(&x, PageInfoSelection::HasNextPage { alias: _ })) - } - - fn requested_previous_page(&self) -> bool { - self.page_selections() - .iter() - .any(|x| matches!(&x, PageInfoSelection::HasPreviousPage { alias: _ })) - } - - fn is_reverse_pagination(&self) -> bool { - self.last.is_some() || self.before.is_some() - } - - fn to_join_clause( - &self, - quoted_block_name: &str, - quoted_parent_block_name: &Option<&str>, - ) -> GraphQLResult { - match &self.source.fkey { - Some(fkey) => { - let quoted_parent_block_name = quoted_parent_block_name - .ok_or("Internal Error: Parent block name is required when fkey_ix is set")?; - self.source.table.to_join_clause( - &fkey.fkey, - fkey.reverse_reference, - quoted_block_name, - quoted_parent_block_name, - ) - } - None => Ok("true".to_string()), - } - } - - fn object_clause( - &self, - quoted_block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let frags: Vec = self - .selections - .iter() - .filter_map(|x| { - x.to_sql( - quoted_block_name, - &self.order_by, - &self.source.table, - param_context, - ) - .transpose() - }) - .collect::, _>>()?; - - Ok(frags.join(", ")) - } - - fn limit_clause(&self) -> u64 { - cmp::min( - self.first - .unwrap_or_else(|| self.last.unwrap_or(self.max_rows)), - self.max_rows, - ) - } - - //TODO:Revisit if from_clause is the best name - #[allow(clippy::wrong_self_convention)] - fn from_clause(&self, quoted_block_name: &str, function: &Option) -> String { - let quoted_schema = quote_ident(&self.source.table.schema); - let quoted_table = quote_ident(&self.source.table.name); - - match function { - Some(from_function) => { - let quoted_func_schema = quote_ident(&from_function.function.schema_name); - let quoted_func = quote_ident(&from_function.function.name); - let input_block_name = &from_function.input_block_name; - let quoted_input_schema = quote_ident(&from_function.input_table.schema); - let quoted_input_table = quote_ident(&from_function.input_table.name); - format!("{quoted_func_schema}.{quoted_func}({input_block_name}::{quoted_input_schema}.{quoted_input_table}) {quoted_block_name}") - } - None => { - format!("{quoted_schema}.{quoted_table} {quoted_block_name}") - } - } - } - - // Generates the *contents* of the aggregate jsonb_build_object - fn aggregate_select_list(&self, quoted_block_name: &str) -> GraphQLResult> { - let Some(agg_builder) = self.selections.iter().find_map(|sel| match sel { - ConnectionSelection::Aggregate(builder) => Some(builder), - _ => None, - }) else { - return Ok(None); - }; - - let mut agg_selections = vec![]; - - for selection in &agg_builder.selections { - match selection { - AggregateSelection::Count { alias } => { - // Produces: 'count_alias', count(*) - agg_selections.push(format!("{}, count(*)", quote_literal(alias))); - } - AggregateSelection::Sum { - alias, - column_builders: selections, - } - | AggregateSelection::Avg { - alias, - column_builders: selections, - } - | AggregateSelection::Min { - alias, - column_builders: selections, - } - | AggregateSelection::Max { - alias, - column_builders: selections, - } => { - let pg_func = match selection { - AggregateSelection::Sum { .. } => aggregate::SUM, - AggregateSelection::Avg { .. } => aggregate::AVG, - AggregateSelection::Min { .. } => aggregate::MIN, - AggregateSelection::Max { .. } => aggregate::MAX, - AggregateSelection::Count { .. } => { - unreachable!("Count should be handled by its own arm") - } - AggregateSelection::Typename { .. } => { - unreachable!("Typename should be handled by its own arm") - } - }; - - let mut field_selections = vec![]; - for col_builder in selections { - let col_sql = col_builder.to_sql(quoted_block_name)?; - let col_alias = &col_builder.alias; - - // Always cast avg input to numeric for precision - let col_sql_casted = if pg_func == "avg" { - format!("{}::numeric", col_sql) - } else { - col_sql - }; - // Produces: 'col_alias', agg_func(col) - field_selections.push(format!( - "{}, {}({})", - quote_literal(col_alias), - pg_func, - col_sql_casted - )); - } - // Produces: 'agg_alias', jsonb_build_object('col_alias', agg_func(col), ...) - agg_selections.push(format!( - "{}, jsonb_build_object({})", - quote_literal(alias), - field_selections.join(", ") - )); - } - AggregateSelection::Typename { alias, typename } => { - // Produces: '__typename', 'AggregateTypeName' - agg_selections.push(format!( - "{}, {}", - quote_literal(alias), - quote_literal(typename) - )); - } - } - } - - if agg_selections.is_empty() { - Ok(None) - } else { - Ok(Some(agg_selections.join(", "))) - } - } - - pub fn to_sql( - &self, - quoted_parent_block_name: Option<&str>, - param_context: &mut ParamContext, - from_func: Option, - from_clause: Option, - ) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - - let from_clause = match from_clause { - Some(from_clause) => format!("{from_clause} {quoted_block_name}"), - None => self.from_clause("ed_block_name, &from_func), - }; - - let where_clause = - self.filter - .to_where_clause("ed_block_name, &self.source.table, param_context)?; - - let order_by_clause = self.order_by.to_order_by_clause("ed_block_name); - let order_by_clause_reversed = self - .order_by - .reverse() - .to_order_by_clause("ed_block_name); - - let order_by_clause_records = match self.is_reverse_pagination() { - true => &order_by_clause_reversed, - false => &order_by_clause, - }; - - let requested_next_page = self.requested_next_page(); - let requested_previous_page = self.requested_previous_page(); - - let join_clause = self.to_join_clause("ed_block_name, "ed_parent_block_name)?; - - let cursor = &self.before.clone().or_else(|| self.after.clone()); - - let object_clause = self.object_clause("ed_block_name, param_context)?; - let aggregate_select_list = self.aggregate_select_list("ed_block_name)?; - - let selectable_columns_clause = self.source.table.to_selectable_columns_clause(); - - let pkey_tuple_clause_from_block = self - .source - .table - .to_primary_key_tuple_clause("ed_block_name); - let pkey_tuple_clause_from_records = - self.source.table.to_primary_key_tuple_clause("__records"); - - let pagination_clause = { - let order_by = match self.is_reverse_pagination() { - true => self.order_by.reverse(), - false => self.order_by.clone(), - }; - match cursor { - Some(cursor) => self.source.table.to_pagination_clause( - "ed_block_name, - &order_by, - cursor, - param_context, - false, - )?, - None => "true".to_string(), - } - }; - - let limit = self.limit_clause(); - let offset = self.offset.unwrap_or(0); - - // Determine if aggregates are requested based on if we generated a select list - let requested_aggregates = aggregate_select_list.is_some(); - - // initialized assuming forwards pagination - let mut has_next_page_query = format!( - " - with page_plus_1 as ( - select - 1 - from - {from_clause} - where - {join_clause} - and {where_clause} - and {pagination_clause} - order by - {order_by_clause} - limit ({limit} + 1) - offset ({offset}) - ) - select count(*) > {limit} from page_plus_1 - " - ); - - let mut has_prev_page_query = format!(" - with page_minus_1 as ( - select - not ({pkey_tuple_clause_from_block} = any( __records.seen )) is_pkey_in_records - from - {from_clause} - left join (select array_agg({pkey_tuple_clause_from_records}) from __records ) __records(seen) - on true - where - {join_clause} - and {where_clause} - order by - {order_by_clause_records} - limit 1 - ) - select coalesce(bool_and(is_pkey_in_records), false) from page_minus_1 - "); - - if self.is_reverse_pagination() { - // Reverse has_next_page and has_previous_page - std::mem::swap(&mut has_next_page_query, &mut has_prev_page_query); - } - if !requested_next_page { - has_next_page_query = "select null".to_string() - } - if !requested_previous_page { - has_prev_page_query = "select null".to_string() - } - - // Build aggregate CTE if requested - let aggregate_cte = if requested_aggregates { - let select_list_str = aggregate_select_list.unwrap_or_default(); - format!( - r#" - ,__aggregates(agg_result) as ( - select - jsonb_build_object({select_list_str}) - from - {from_clause} - where - {join_clause} - and {where_clause} - ) - "# - ) - } else { - r#" - ,__aggregates(agg_result) as (select null::jsonb) - "# - .to_string() - }; - - // Add helper cte to set page info correctly for empty collections - let has_records_cte = r#" - ,__has_records(has_records) as (select exists(select 1 from __records)) - "#; - - // Clause containing selections *not* including the aggregate - let base_object_clause = object_clause; // Renamed original object_clause - - // Clause to merge the aggregate result if requested - let aggregate_merge_clause = if requested_aggregates { - let agg_alias = self - .selections - .iter() - .find_map(|sel| match sel { - ConnectionSelection::Aggregate(builder) => Some(builder.alias.clone()), - _ => None, - }) - .ok_or( - "Internal Error: Aggregate builder not found when requested_aggregates is true", - )?; - format!( - "|| jsonb_build_object({}, coalesce(__aggregates.agg_result, '{{}}'::jsonb))", - quote_literal(&agg_alias) - ) - } else { - "".to_string() - }; - - Ok(format!( - r#" - ( - with __records as ( - select - {selectable_columns_clause} - from - {from_clause} - where - true - and {join_clause} - and {where_clause} - and {pagination_clause} - order by - {order_by_clause_records} - limit - {limit} - offset - {offset} - ), - __total_count(___total_count) as ( - select - count(*) - from - {from_clause} - where - {join_clause} - and {where_clause} - ), - __has_next_page(___has_next_page) as ( - {has_next_page_query} - ), - __has_previous_page(___has_previous_page) as ( - {has_prev_page_query} - ) - {has_records_cte} - {aggregate_cte}, - __base_object as ( - select jsonb_build_object({base_object_clause}) as obj - from - __total_count - cross join __has_next_page - cross join __has_previous_page - cross join __has_records - left join __records {quoted_block_name} on true - group by - __total_count.___total_count, - __has_next_page.___has_next_page, - __has_previous_page.___has_previous_page, - __has_records.has_records - ) - select - coalesce(__base_object.obj, '{{}}'::jsonb) {aggregate_merge_clause} - from - (select 1) as __dummy_for_left_join - left join __base_object on true - cross join __aggregates - ) - "# - )) - } -} - -impl QueryEntrypoint for ConnectionBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - self.to_sql(None, param_context, None, None) - } -} - -impl PageInfoBuilder { - pub fn to_sql( - &self, - _block_name: &str, - order_by: &OrderByBuilder, - table: &Table, - ) -> GraphQLResult { - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql(_block_name, order_by, table)) - .collect::, _>>()?; - - let x = frags.join(", "); - - Ok(format!("jsonb_build_object({})", x,)) - } -} - -impl PageInfoSelection { - pub fn to_sql( - &self, - block_name: &str, - order_by: &OrderByBuilder, - table: &Table, - ) -> GraphQLResult { - let order_by_clause = order_by.to_order_by_clause(block_name); - let order_by_clause_reversed = order_by.reverse().to_order_by_clause(block_name); - - let cursor_clause = table.to_cursor_clause(block_name, order_by); - - Ok(match self { - Self::StartCursor { alias } => { - format!( - "{}, case when __has_records.has_records then (array_agg({cursor_clause} order by {order_by_clause}))[1] else null end", - quote_literal(alias) - ) - } - Self::EndCursor { alias } => { - format!( - "{}, case when __has_records.has_records then (array_agg({cursor_clause} order by {order_by_clause_reversed}))[1] else null end", - quote_literal(alias) - ) - } - Self::HasNextPage { alias } => { - format!( - "{}, coalesce(bool_and(__has_next_page.___has_next_page), false)", - quote_literal(alias) - ) - } - Self::HasPreviousPage { alias } => { - format!( - "{}, coalesce(bool_and(__has_previous_page.___has_previous_page), false)", - quote_literal(alias) - ) - } - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }) - } -} - -impl ConnectionSelection { - pub fn to_sql( - &self, - block_name: &str, - order_by: &OrderByBuilder, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult> { - Ok(match self { - Self::Edge(x) => Some(format!( - "{}, {}", - quote_literal(&x.alias), - x.to_sql(block_name, order_by, table, param_context)? - )), - Self::PageInfo(x) => Some(format!( - "{}, {}", - quote_literal(&x.alias), - x.to_sql(block_name, order_by, table)? - )), - Self::TotalCount { alias } => Some(format!( - "{}, coalesce(__total_count.___total_count, 0)", - quote_literal(alias), - )), - Self::Typename { alias, typename } => Some(format!( - "{}, {}", - quote_literal(alias), - quote_literal(typename) - )), - // SQL generation is handled by ConnectionBuilder::aggregate_select_list - // and the results are merged in later in the process - Self::Aggregate(_) => None, - }) - } -} - -impl EdgeBuilder { - pub fn to_sql( - &self, - block_name: &str, - order_by: &OrderByBuilder, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql(block_name, order_by, table, param_context)) - .collect::, _>>()?; - - let x = frags.join(", "); - let order_by_clause = order_by.to_order_by_clause(block_name); - - // Get the first primary key column name to use in the filter - let first_pk_col = table.primary_key_columns().first().map(|col| &col.name); - - // Create a filter clause that checks if any primary key column is not NULL - let filter_clause = if let Some(pk_col) = first_pk_col { - format!( - "filter (where {}.{} is not null)", - block_name, - quote_ident(pk_col) - ) - } else { - "".to_string() // Fallback if no primary key columns (should be rare) - }; - - Ok(format!( - "coalesce( - jsonb_agg( - jsonb_build_object({x}) - order by {order_by_clause} - ) {filter_clause}, - jsonb_build_array() - )" - )) - } -} - -impl EdgeSelection { - pub fn to_sql( - &self, - block_name: &str, - order_by: &OrderByBuilder, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - Ok(match self { - Self::Cursor { alias } => { - let cursor_clause = table.to_cursor_clause(block_name, order_by); - format!("{}, {cursor_clause}", quote_literal(alias)) - } - Self::Node(builder) => format!( - "{}, {}", - quote_literal(&builder.alias), - builder.to_sql(block_name, param_context)? - ), - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }) - } -} - -impl NodeBuilder { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql(block_name, param_context)) - .collect::, _>>()?; - - const MAX_ARGS_IN_JSONB_BUILD_OBJECT: usize = 100; //jsonb_build_object has a limit of 100 arguments - const ARGS_PER_FRAG: usize = 2; // each x.to_sql(...) function above return a pair of args - const CHUNK_SIZE: usize = MAX_ARGS_IN_JSONB_BUILD_OBJECT / ARGS_PER_FRAG; - - let frags: Vec = frags - .chunks(CHUNK_SIZE) - .map(|chunks| format!("jsonb_build_object({})", chunks.join(", "))) - .collect(); - - Ok(frags.join(" || ").to_string()) - } - - pub fn to_relation_sql( - &self, - parent_block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - let quoted_schema = quote_ident(&self.table.schema); - let quoted_table = quote_ident(&self.table.name); - - let fkey = self.fkey.as_ref().ok_or("Internal Error: relation key")?; - let reverse_reference = self - .reverse_reference - .ok_or("Internal Error: relation reverse reference")?; - - let frags: Vec = self - .selections - .iter() - .map(|x| x.to_sql("ed_block_name, param_context)) - .collect::, _>>()?; - - let object_clause: Vec = frags - .chunks(50) - .map(|chunks| format!("jsonb_build_object({})", chunks.join(", "))) - .collect(); - - let object_clause_string = object_clause.join(" || ").to_string(); - - let join_clause = self.table.to_join_clause( - fkey, - reverse_reference, - "ed_block_name, - parent_block_name, - )?; - - Ok(format!( - " - ( - select - {object_clause_string} - from - {quoted_schema}.{quoted_table} as {quoted_block_name} - where - {join_clause} - )" - )) - } -} - -impl QueryEntrypoint for NodeBuilder { - fn to_sql_entrypoint(&self, param_context: &mut ParamContext) -> GraphQLResult { - let quoted_block_name = rand_block_name(); - let quoted_schema = quote_ident(&self.table.schema); - let quoted_table = quote_ident(&self.table.name); - let object_clause = self.to_sql("ed_block_name, param_context)?; - - let node_id = self - .node_id - .as_ref() - .ok_or("Expected nodeId argument missing")?; - - let node_id_clause = node_id.to_sql("ed_block_name, &self.table, param_context)?; - - Ok(format!( - " - ( - select - {object_clause} - from - {quoted_schema}.{quoted_table} as {quoted_block_name} - where - {node_id_clause} - ) - " - )) - } -} - -impl NodeIdInstance { - pub fn to_sql( - &self, - block_name: &str, - table: &Table, - param_context: &mut ParamContext, - ) -> GraphQLResult { - // TODO: abstract this logical check into builder. It is not related to - // transpiling and should not be in this module - if (&self.schema_name, &self.table_name) != (&table.schema, &table.name) { - return Err(GraphQLError::validation( - "nodeId belongs to a different collection", - )); - } - - let mut col_val_pairs: Vec = vec![]; - for (col, val) in table.primary_key_columns().iter().zip(self.values.iter()) { - let column_name = &col.name; - let val_clause = param_context.clause_for(val, &col.type_name)?; - col_val_pairs.push(format!("{block_name}.{column_name} = {val_clause}")) - } - Ok(col_val_pairs.join(" and ")) - } -} - -// Returns a :: casts suffix that can be appended to a type for oids that need special -// handling -fn apply_suffix_casts(type_oid: u32) -> String { - match type_oid { - 20 => "::text", // bigints as text - 114 | 3802 => "#>> '{}'", // json/b as stringified - 1700 => "::text", // numeric as text - 1016 => "::text[]", // bigint arrays as array of text - 199 | 3807 => "::text[]", // json/b array as array of text - 1231 => "::text[]", // numeric array as array of text - _ => "", - } - .to_string() -} - -impl NodeSelection { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - Ok(match self { - // TODO need to provide alias when called from node builder. - Self::Connection(builder) => format!( - "{}, {}", - quote_literal(&builder.alias), - builder.to_sql(Some(block_name), param_context, None, None)? - ), - Self::Node(builder) => format!( - "{}, {}", - quote_literal(&builder.alias), - builder.to_relation_sql(block_name, param_context)? - ), - Self::Column(builder) => { - let type_adjustment_clause = apply_suffix_casts(builder.column.type_oid); - - format!( - "{}, {}{}", - quote_literal(&builder.alias), - builder.to_sql(block_name)?, - type_adjustment_clause - ) - } - Self::Function(builder) => { - let type_adjustment_clause = apply_suffix_casts(builder.function.type_oid); - format!( - "{}, {}{}", - quote_literal(&builder.alias), - builder.to_sql(block_name, param_context)?, - type_adjustment_clause - ) - } - Self::NodeId(builder) => format!( - "{}, {}", - quote_literal(&builder.alias), - builder.to_sql(block_name)? - ), - Self::Typename { alias, typename } => { - format!("{}, {}", quote_literal(alias), quote_literal(typename)) - } - }) - } -} - -impl ColumnBuilder { - pub fn to_sql(&self, block_name: &str) -> GraphQLResult { - let col = format!("{}.{}", &block_name, quote_ident(&self.column.name)); - let maybe_enum = self.column.type_.as_ref().and_then(|t| match t.details { - Some(TypeDetails::Enum(ref enum_)) => Some(enum_), - _ => None, - }); - if let Some(enum_) = maybe_enum { - match enum_.directives.mappings { - Some(ref mappings) => { - let cases = mappings - .iter() - .map(|(k, v)| { - format!( - "when {col} = {} then {}", - quote_literal(k), - quote_literal(v) - ) - }) - .join(" "); - Ok(format!("case {cases} else {col}::text end")) - } - _ => Ok(col), - } - } else { - Ok(col) - } - } -} - -impl NodeIdBuilder { - pub fn to_sql(&self, block_name: &str) -> GraphQLResult { - let column_selects: Vec = self - .columns - .iter() - .map(|col| format!("{}.{}", block_name, col.name)) - .collect(); - let column_clause = column_selects.join(", "); - let schema_name = quote_literal(&self.schema_name); - let table_name = quote_literal(&self.table_name); - Ok(format!( - "translate(encode(convert_to(jsonb_build_array({schema_name}, {table_name}, {column_clause})::text, 'utf-8'), 'base64'), E'\n', '')" - )) - } -} - -impl FunctionBuilder { - pub fn to_sql( - &self, - block_name: &str, - param_context: &mut ParamContext, - ) -> GraphQLResult { - let schema_name = quote_ident(&self.function.schema_name); - let function_name = quote_ident(&self.function.name); - - let sql_frag = match &self.selection { - FunctionSelection::ScalarSelf => format!( - "{schema_name}.{function_name}({block_name}::{}.{})", - quote_ident(&self.table.schema), - quote_ident(&self.table.name) - ), - FunctionSelection::Array => format!( - // Current implementation will not support enums or record types correctly - // however, those functions are filtered out upstream - "{schema_name}.{function_name}({block_name}::{}.{})", - quote_ident(&self.table.schema), - quote_ident(&self.table.name) - ), - FunctionSelection::Node(node_builder) => { - let func_block_name = rand_block_name(); - let object_clause = node_builder.to_sql(&func_block_name, param_context)?; - - let from_clause = format!( - "{schema_name}.{function_name}({block_name}::{}.{})", - quote_ident(&self.table.schema), - quote_ident(&self.table.name) - ); - format!( - " - ( - select - {object_clause} - from - {from_clause} as {func_block_name} - where - not ({func_block_name} is null) - ) - " - ) - } - FunctionSelection::Connection(connection_builder) => connection_builder.to_sql( - None, - param_context, - Some(FromFunction { - function: Arc::clone(&self.function), - input_table: Arc::clone(&self.table), - input_block_name: block_name.to_string(), - }), - None, - )?, - }; - Ok(sql_frag) - } -} - -impl Serialize for __FieldBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - - for selection in &self.selections { - match &selection.selection { - __FieldField::Name => { - map.serialize_entry(&selection.alias, &self.field.name())?; - } - __FieldField::Description => { - map.serialize_entry(&selection.alias, &self.field.description())?; - } - - __FieldField::IsDeprecated => { - map.serialize_entry(&selection.alias, &self.field.is_deprecated())?; - } - __FieldField::DeprecationReason => { - map.serialize_entry(&selection.alias, &self.field.deprecation_reason())?; - } - __FieldField::Arguments(input_value_builders) => { - map.serialize_entry(&selection.alias, input_value_builders)?; - } - __FieldField::Type(t) => { - // TODO - map.serialize_entry(&selection.alias, t)?; - } - __FieldField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -impl Serialize for __TypeBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - - for selection in &self.selections { - match &selection.selection { - __TypeField::Kind => { - map.serialize_entry(&selection.alias, &format!("{:?}", self.type_.kind()))?; - } - __TypeField::Name => { - map.serialize_entry(&selection.alias, &self.type_.name())?; - } - __TypeField::Description => { - map.serialize_entry(&selection.alias, &self.type_.description())?; - } - __TypeField::Fields(fields) => { - map.serialize_entry(&selection.alias, fields)?; - } - __TypeField::InputFields(input_field_builders) => { - map.serialize_entry(&selection.alias, input_field_builders)?; - } - __TypeField::Interfaces(interfaces) => { - map.serialize_entry(&selection.alias, &interfaces)?; - } - __TypeField::EnumValues(enum_values) => { - map.serialize_entry(&selection.alias, enum_values)?; - } - __TypeField::PossibleTypes(possible_types) => { - map.serialize_entry(&selection.alias, &possible_types)?; - } - __TypeField::OfType(t_builder) => { - map.serialize_entry(&selection.alias, t_builder)?; - } - __TypeField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -impl Serialize for __DirectiveBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - for selection in &self.selections { - match &selection.selection { - __DirectiveField::Name => { - map.serialize_entry(&selection.alias, &self.directive.name())?; - } - __DirectiveField::Description => { - map.serialize_entry(&selection.alias, &self.directive.description())?; - } - __DirectiveField::Locations => { - map.serialize_entry(&selection.alias, &self.directive.locations())?; - } - __DirectiveField::Args(args) => { - map.serialize_entry(&selection.alias, args)?; - } - __DirectiveField::IsRepeatable => { - map.serialize_entry(&selection.alias, &self.directive.is_repeatable())?; - } - __DirectiveField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -impl Serialize for __SchemaBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - - for selection in &self.selections { - match &selection.selection { - __SchemaField::Description => { - map.serialize_entry(&selection.alias, &self.description)?; - } - __SchemaField::Types(type_builders) => { - map.serialize_entry(&selection.alias, &type_builders)?; - } - __SchemaField::QueryType(type_builder) => { - map.serialize_entry(&selection.alias, &type_builder)?; - } - __SchemaField::MutationType(type_builder) => { - map.serialize_entry(&selection.alias, &type_builder)?; - } - __SchemaField::SubscriptionType(type_builder) => { - map.serialize_entry(&selection.alias, &type_builder)?; - } - __SchemaField::Directives(directives) => { - map.serialize_entry(&selection.alias, directives)?; - } - __SchemaField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -impl Serialize for __InputValueBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - - for selection in &self.selections { - match &selection.selection { - __InputValueField::Name => { - map.serialize_entry(&selection.alias, &self.input_value.name())?; - } - __InputValueField::Description => { - map.serialize_entry(&selection.alias, &self.input_value.description())?; - } - __InputValueField::Type(type_builder) => { - map.serialize_entry(&selection.alias, &type_builder)?; - } - __InputValueField::DefaultValue => { - map.serialize_entry(&selection.alias, &self.input_value.default_value())?; - } - __InputValueField::IsDeprecated => { - map.serialize_entry(&selection.alias, &self.input_value.is_deprecated())?; - } - __InputValueField::DeprecationReason => { - map.serialize_entry(&selection.alias, &self.input_value.deprecation_reason())?; - } - __InputValueField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -impl Serialize for __EnumValueBuilder { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - let mut map = serializer.serialize_map(Some(self.selections.len()))?; - - for selection in &self.selections { - match &selection.selection { - __EnumValueField::Name => { - map.serialize_entry(&selection.alias, &self.enum_value.name())?; - } - __EnumValueField::Description => { - map.serialize_entry(&selection.alias, &self.enum_value.description())?; - } - __EnumValueField::IsDeprecated => { - map.serialize_entry(&selection.alias, &self.enum_value.is_deprecated())?; - } - __EnumValueField::DeprecationReason => { - map.serialize_entry(&selection.alias, &self.enum_value.deprecation_reason())?; - } - __EnumValueField::Typename { alias, typename } => { - map.serialize_entry(&alias, typename)?; - } - } - } - map.end() - } -} - -#[cfg(any(test, feature = "pg_test"))] -#[pgrx::pg_schema] -mod tests { - use crate::transpile::*; - - #[pg_test] - fn test_quote_ident() { - let res = quote_ident("hello world"); - assert_eq!(res, r#""hello world""#); - } - - #[pg_test] - fn test_quote_literal() { - let res = quote_ident("hel'lo world"); - assert_eq!(res, r#""hel'lo world""#); - } -} diff --git a/test/expected/cursor_pagination_nulls.out b/test/expected/cursor_pagination_nulls.out new file mode 100644 index 00000000..448e7745 --- /dev/null +++ b/test/expected/cursor_pagination_nulls.out @@ -0,0 +1,1019 @@ +-- Test cursor pagination with nulls first/last ordering +-- This test verifies correct behavior when paginating through data with NULL values +-- using different null ordering strategies (NULLS FIRST vs NULLS LAST) +begin; + comment on schema public is '@graphql({"inflect_names": false})'; + create table items( + id int primary key, + priority int, -- nullable column for testing null ordering + name text + ); + -- Insert test data with strategic NULL placement + -- IDs 1-5: have priority values + -- IDs 6-8: NULL priority (to test null ordering) + insert into items(id, priority, name) values + (1, 10, 'low'), + (2, 20, 'medium'), + (3, 30, 'high'), + (4, 20, 'medium-alt'), -- duplicate priority + (5, 10, 'low-alt'), -- duplicate priority + (6, null, 'unset-a'), + (7, null, 'unset-b'), + (8, null, 'unset-c'); + -- Show the data for reference + select * from items order by id; + id | priority | name +----+----------+------------ + 1 | 10 | low + 2 | 20 | medium + 3 | 30 | high + 4 | 20 | medium-alt + 5 | 10 | low-alt + 6 | | unset-a + 7 | | unset-b + 8 | | unset-c +(8 rows) + + -- ========================================================================== + -- Test 1: NULLS LAST with ASC (default behavior) + -- Order should be: 1,5 (priority=10), 2,4 (priority=20), 3 (priority=30), 6,7,8 (null) + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$) + ); + jsonb_pretty +----------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "name": "low", + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "name": "low-alt", + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 2, + + "name": "medium", + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 4, + + "name": "medium-alt",+ + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 3, + + "name": "high", + + "priority": 30 + + } + + }, + + { + + "node": { + + "id": 6, + + "name": "unset-a", + + "priority": null + + } + + }, + + { + + "node": { + + "id": 7, + + "name": "unset-b", + + "priority": null + + } + + }, + + { + + "node": { + + "id": 8, + + "name": "unset-c", + + "priority": null + + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 2: NULLS FIRST with ASC + -- Order should be: 6,7,8 (null), 1,5 (priority=10), 2,4 (priority=20), 3 (priority=30) + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$) + ); + jsonb_pretty +----------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 6, + + "name": "unset-a", + + "priority": null + + } + + }, + + { + + "node": { + + "id": 7, + + "name": "unset-b", + + "priority": null + + } + + }, + + { + + "node": { + + "id": 8, + + "name": "unset-c", + + "priority": null + + } + + }, + + { + + "node": { + + "id": 1, + + "name": "low", + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "name": "low-alt", + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 2, + + "name": "medium", + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 4, + + "name": "medium-alt",+ + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 3, + + "name": "high", + + "priority": 30 + + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 3: Cursor pagination with NULLS LAST - first 3, then next page + -- Should get: first page [1,5,2], then from cursor after id=2 get [4,3,6] + -- ========================================================================== + -- First page + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + first: 3 + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +-------------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "priority": 10 + + }, + + "cursor": "WzEwLCAxLCAxXQ=="+ + }, + + { + + "node": { + + "id": 5, + + "priority": 10 + + }, + + "cursor": "WzEwLCA1LCA1XQ=="+ + }, + + { + + "node": { + + "id": 2, + + "priority": 20 + + }, + + "cursor": "WzIwLCAyLCAyXQ=="+ + } + + ] + + } + + } + + } +(1 row) + + -- Next page using cursor [20, 2] (priority=20, id=2) + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[20, 2]'::jsonb)) + ) + ); + jsonb_pretty +------------------------------------------ + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 4, + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 3, + + "priority": 30 + + } + + }, + + { + + "node": { + + "id": 6, + + "priority": null+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 4: Cursor pagination with NULLS FIRST - first 3, then next page + -- Should get: first page [6,7,8], then from cursor after id=8 get [1,5,2] + -- ========================================================================== + -- First page (nulls come first) + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + first: 3 + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +-------------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 6, + + "priority": null + + }, + + "cursor": "W251bGwsIDYsIDZd"+ + }, + + { + + "node": { + + "id": 7, + + "priority": null + + }, + + "cursor": "W251bGwsIDcsIDdd"+ + }, + + { + + "node": { + + "id": 8, + + "priority": null + + }, + + "cursor": "W251bGwsIDgsIDhd"+ + } + + ] + + } + + } + + } +(1 row) + + -- Next page using cursor [null, 8] (priority=null, id=8) + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + jsonb_pretty +---------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "priority": 10+ + } + + }, + + { + + "node": { + + "id": 5, + + "priority": 10+ + } + + }, + + { + + "node": { + + "id": 2, + + "priority": 20+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 5: Reverse pagination (last/before) with NULLS LAST + -- Get last 3, then previous page + -- ========================================================================== + -- Last 3 items (should be 3, 6, 7, 8 area - reversed from end) + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + last: 3 + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +-------------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 6, + + "priority": null + + }, + + "cursor": "W251bGwsIDYsIDZd"+ + }, + + { + + "node": { + + "id": 7, + + "priority": null + + }, + + "cursor": "W251bGwsIDcsIDdd"+ + }, + + { + + "node": { + + "id": 8, + + "priority": null + + }, + + "cursor": "W251bGwsIDgsIDhd"+ + } + + ] + + } + + } + + } +(1 row) + + -- Previous page using before cursor + select jsonb_pretty( + graphql.resolve($$ + query PrevPage($beforeCursor: Cursor) { + itemsCollection( + last: 3 + before: $beforeCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('beforeCursor', graphql.encode('[null, 6]'::jsonb)) + ) + ); + jsonb_pretty +---------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 2, + + "priority": 20+ + } + + }, + + { + + "node": { + + "id": 4, + + "priority": 20+ + } + + }, + + { + + "node": { + + "id": 3, + + "priority": 30+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 6: Reverse pagination (last/before) with NULLS FIRST + -- ========================================================================== + -- Last 3 items with nulls first ordering + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + last: 3 + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +-------------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 2, + + "priority": 20 + + }, + + "cursor": "WzIwLCAyLCAyXQ=="+ + }, + + { + + "node": { + + "id": 4, + + "priority": 20 + + }, + + "cursor": "WzIwLCA0LCA0XQ=="+ + }, + + { + + "node": { + + "id": 3, + + "priority": 30 + + }, + + "cursor": "WzMwLCAzLCAzXQ=="+ + } + + ] + + } + + } + + } +(1 row) + + -- Previous page + select jsonb_pretty( + graphql.resolve($$ + query PrevPage($beforeCursor: Cursor) { + itemsCollection( + last: 3 + before: $beforeCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('beforeCursor', graphql.encode('[20, 4]'::jsonb)) + ) + ); + jsonb_pretty +---------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "priority": 10+ + } + + }, + + { + + "node": { + + "id": 5, + + "priority": 10+ + } + + }, + + { + + "node": { + + "id": 2, + + "priority": 20+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 7: DESC with NULLS FIRST (nulls at start of descending order) + -- Order: 6,7,8 (null), 3 (30), 2,4 (20), 1,5 (10) + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: DescNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +------------------------------------------ + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 6, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 7, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 8, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 3, + + "priority": 30 + + } + + }, + + { + + "node": { + + "id": 2, + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 4, + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 1, + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "priority": 10 + + } + + } + + ] + + } + + } + + } +(1 row) + + -- Paginate through with cursor + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: DescNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + jsonb_pretty +---------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 3, + + "priority": 30+ + } + + }, + + { + + "node": { + + "id": 2, + + "priority": 20+ + } + + }, + + { + + "node": { + + "id": 4, + + "priority": 20+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 8: DESC with NULLS LAST (nulls at end of descending order) + -- Order: 3 (30), 2,4 (20), 1,5 (10), 6,7,8 (null) + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: DescNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$) + ); + jsonb_pretty +------------------------------------------ + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 3, + + "priority": 30 + + } + + }, + + { + + "node": { + + "id": 2, + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 4, + + "priority": 20 + + } + + }, + + { + + "node": { + + "id": 1, + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 6, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 7, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 8, + + "priority": null+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- Paginate - after priority=20, id=4, should get 1,5 then nulls + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 4 + after: $afterCursor + orderBy: [{priority: DescNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[20, 4]'::jsonb)) + ) + ); + jsonb_pretty +------------------------------------------ + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 6, + + "priority": null+ + } + + }, + + { + + "node": { + + "id": 7, + + "priority": null+ + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 9: Edge case - cursor at NULL value boundary (transitioning from null to non-null) + -- With NULLS FIRST, cursor at last null should give first non-null items + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + query AfterLastNull($afterCursor: Cursor) { + itemsCollection( + first: 2 + after: $afterCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + jsonb_pretty +-------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 1, + + "name": "low", + + "priority": 10 + + } + + }, + + { + + "node": { + + "id": 5, + + "name": "low-alt",+ + "priority": 10 + + } + + } + + ] + + } + + } + + } +(1 row) + + -- ========================================================================== + -- Test 10: Edge case - cursor at non-NULL value boundary (transitioning to nulls) + -- With NULLS LAST, cursor at last non-null should give null items + -- ========================================================================== + select jsonb_pretty( + graphql.resolve($$ + query AfterLastNonNull($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[30, 3]'::jsonb)) + ) + ); + jsonb_pretty +-------------------------------------------- + { + + "data": { + + "itemsCollection": { + + "edges": [ + + { + + "node": { + + "id": 6, + + "name": "unset-a",+ + "priority": null + + } + + }, + + { + + "node": { + + "id": 7, + + "name": "unset-b",+ + "priority": null + + } + + }, + + { + + "node": { + + "id": 8, + + "name": "unset-c",+ + "priority": null + + } + + } + + ] + + } + + } + + } +(1 row) + +rollback; diff --git a/test/expected/resolve_connection_pagination_args.out b/test/expected/resolve_connection_pagination_args.out index c5fd87af..bd62c69a 100644 --- a/test/expected/resolve_connection_pagination_args.out +++ b/test/expected/resolve_connection_pagination_args.out @@ -582,13 +582,6 @@ begin; "data": { + "blogCollection": { + "edges": [ + - { + - "node": { + - "id": 14, + - "title": "b", + - "reversed": 1 + - } + - }, + { + "node": { + "id": 3, + @@ -616,6 +609,13 @@ begin; "title": "b", + "reversed": 2 + } + + }, + + { + + "node": { + + "id": 12, + + "title": null,+ + "reversed": 3 + + } + } + ] + } + diff --git a/test/sql/cursor_pagination_nulls.sql b/test/sql/cursor_pagination_nulls.sql new file mode 100644 index 00000000..a173544e --- /dev/null +++ b/test/sql/cursor_pagination_nulls.sql @@ -0,0 +1,405 @@ +-- Test cursor pagination with nulls first/last ordering +-- This test verifies correct behavior when paginating through data with NULL values +-- using different null ordering strategies (NULLS FIRST vs NULLS LAST) + +begin; + comment on schema public is '@graphql({"inflect_names": false})'; + + create table items( + id int primary key, + priority int, -- nullable column for testing null ordering + name text + ); + + -- Insert test data with strategic NULL placement + -- IDs 1-5: have priority values + -- IDs 6-8: NULL priority (to test null ordering) + insert into items(id, priority, name) values + (1, 10, 'low'), + (2, 20, 'medium'), + (3, 30, 'high'), + (4, 20, 'medium-alt'), -- duplicate priority + (5, 10, 'low-alt'), -- duplicate priority + (6, null, 'unset-a'), + (7, null, 'unset-b'), + (8, null, 'unset-c'); + + -- Show the data for reference + select * from items order by id; + + -- ========================================================================== + -- Test 1: NULLS LAST with ASC (default behavior) + -- Order should be: 1,5 (priority=10), 2,4 (priority=20), 3 (priority=30), 6,7,8 (null) + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$) + ); + + -- ========================================================================== + -- Test 2: NULLS FIRST with ASC + -- Order should be: 6,7,8 (null), 1,5 (priority=10), 2,4 (priority=20), 3 (priority=30) + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$) + ); + + -- ========================================================================== + -- Test 3: Cursor pagination with NULLS LAST - first 3, then next page + -- Should get: first page [1,5,2], then from cursor after id=2 get [4,3,6] + -- ========================================================================== + + -- First page + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + first: 3 + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + + -- Next page using cursor [20, 2] (priority=20, id=2) + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[20, 2]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 4: Cursor pagination with NULLS FIRST - first 3, then next page + -- Should get: first page [6,7,8], then from cursor after id=8 get [1,5,2] + -- ========================================================================== + + -- First page (nulls come first) + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + first: 3 + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + + -- Next page using cursor [null, 8] (priority=null, id=8) + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 5: Reverse pagination (last/before) with NULLS LAST + -- Get last 3, then previous page + -- ========================================================================== + + -- Last 3 items (should be 3, 6, 7, 8 area - reversed from end) + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + last: 3 + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + + -- Previous page using before cursor + select jsonb_pretty( + graphql.resolve($$ + query PrevPage($beforeCursor: Cursor) { + itemsCollection( + last: 3 + before: $beforeCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('beforeCursor', graphql.encode('[null, 6]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 6: Reverse pagination (last/before) with NULLS FIRST + -- ========================================================================== + + -- Last 3 items with nulls first ordering + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + last: 3 + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + cursor + node { + id + priority + } + } + } + } + $$) + ); + + -- Previous page + select jsonb_pretty( + graphql.resolve($$ + query PrevPage($beforeCursor: Cursor) { + itemsCollection( + last: 3 + before: $beforeCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('beforeCursor', graphql.encode('[20, 4]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 7: DESC with NULLS FIRST (nulls at start of descending order) + -- Order: 6,7,8 (null), 3 (30), 2,4 (20), 1,5 (10) + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: DescNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$) + ); + + -- Paginate through with cursor + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: DescNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 8: DESC with NULLS LAST (nulls at end of descending order) + -- Order: 3 (30), 2,4 (20), 1,5 (10), 6,7,8 (null) + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + { + itemsCollection( + orderBy: [{priority: DescNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$) + ); + + -- Paginate - after priority=20, id=4, should get 1,5 then nulls + select jsonb_pretty( + graphql.resolve($$ + query NextPage($afterCursor: Cursor) { + itemsCollection( + first: 4 + after: $afterCursor + orderBy: [{priority: DescNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[20, 4]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 9: Edge case - cursor at NULL value boundary (transitioning from null to non-null) + -- With NULLS FIRST, cursor at last null should give first non-null items + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + query AfterLastNull($afterCursor: Cursor) { + itemsCollection( + first: 2 + after: $afterCursor + orderBy: [{priority: AscNullsFirst}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[null, 8]'::jsonb)) + ) + ); + + -- ========================================================================== + -- Test 10: Edge case - cursor at non-NULL value boundary (transitioning to nulls) + -- With NULLS LAST, cursor at last non-null should give null items + -- ========================================================================== + + select jsonb_pretty( + graphql.resolve($$ + query AfterLastNonNull($afterCursor: Cursor) { + itemsCollection( + first: 3 + after: $afterCursor + orderBy: [{priority: AscNullsLast}, {id: AscNullsLast}] + ) { + edges { + node { + id + priority + name + } + } + } + } + $$, + jsonb_build_object('afterCursor', graphql.encode('[30, 3]'::jsonb)) + ) + ); + +rollback;