Skip to content

Commit f47f283

Browse files
authored
Group UNION DISTINCT selects when followed by UNION ALL (#116)
1 parent 2ac2ff1 commit f47f283

File tree

31 files changed

+285
-141
lines changed

31 files changed

+285
-141
lines changed

ast/ast.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,7 @@ func (p *Projection) End() token.Position { return p.Position }
637637
// ProjectionSelectQuery represents the SELECT part of a projection.
638638
type ProjectionSelectQuery struct {
639639
Position token.Position `json:"-"`
640+
With []Expression `json:"with,omitempty"` // WITH clause expressions
640641
Columns []Expression `json:"columns"`
641642
GroupBy []Expression `json:"group_by,omitempty"`
642643
OrderBy []Expression `json:"order_by,omitempty"` // ORDER BY columns
@@ -700,6 +701,7 @@ const (
700701
AlterModifyOrderBy AlterCommandType = "MODIFY_ORDER_BY"
701702
AlterModifySampleBy AlterCommandType = "MODIFY_SAMPLE_BY"
702703
AlterRemoveSampleBy AlterCommandType = "REMOVE_SAMPLE_BY"
704+
AlterApplyDeletedMask AlterCommandType = "APPLY_DELETED_MASK"
703705
)
704706

705707
// TruncateQuery represents a TRUNCATE statement.
@@ -983,6 +985,7 @@ type RenameQuery struct {
983985
To string `json:"to,omitempty"` // Deprecated: for backward compat
984986
OnCluster string `json:"on_cluster,omitempty"`
985987
Settings []*SettingExpr `json:"settings,omitempty"`
988+
IfExists bool `json:"if_exists,omitempty"` // IF EXISTS modifier
986989
}
987990

988991
func (r *RenameQuery) Pos() token.Position { return r.Position }

internal/explain/expressions.go

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,31 @@ func isSimpleLiteralOrNegation(e ast.Expression) bool {
211211
return false
212212
}
213213

214+
// isSimpleLiteralOrNestedLiteral checks if an expression is a literal (including nested tuples/arrays of literals)
215+
// Returns false for complex expressions like subqueries, function calls, identifiers, etc.
216+
func isSimpleLiteralOrNestedLiteral(e ast.Expression) bool {
217+
if lit, ok := e.(*ast.Literal); ok {
218+
// For nested arrays/tuples, recursively check if all elements are also literals
219+
if lit.Type == ast.LiteralArray || lit.Type == ast.LiteralTuple {
220+
if exprs, ok := lit.Value.([]ast.Expression); ok {
221+
for _, elem := range exprs {
222+
if !isSimpleLiteralOrNestedLiteral(elem) {
223+
return false
224+
}
225+
}
226+
}
227+
}
228+
return true
229+
}
230+
// Unary minus of a literal integer/float is also simple (negative number)
231+
if unary, ok := e.(*ast.UnaryExpr); ok && unary.Op == "-" {
232+
if lit, ok := unary.Operand.(*ast.Literal); ok {
233+
return lit.Type == ast.LiteralInteger || lit.Type == ast.LiteralFloat
234+
}
235+
}
236+
return false
237+
}
238+
214239
// containsOnlyArraysOrTuples checks if a slice of expressions contains
215240
// only array or tuple literals (including empty arrays).
216241
// Returns true if the slice is empty or contains only arrays/tuples.
@@ -952,16 +977,39 @@ func explainWithElement(sb *strings.Builder, n *ast.WithElement, indent string,
952977
// When name is empty, don't show the alias part
953978
switch e := n.Query.(type) {
954979
case *ast.Literal:
955-
// Empty tuples should be rendered as Function tuple, not Literal
980+
// Tuples containing complex expressions (subqueries, function calls, etc) should be rendered as Function tuple
981+
// But tuples of simple literals (including nested tuples of literals) stay as Literal
956982
if e.Type == ast.LiteralTuple {
957-
if exprs, ok := e.Value.([]ast.Expression); ok && len(exprs) == 0 {
958-
if n.Name != "" {
959-
fmt.Fprintf(sb, "%sFunction tuple (alias %s) (children %d)\n", indent, n.Name, 1)
983+
if exprs, ok := e.Value.([]ast.Expression); ok {
984+
needsFunctionFormat := false
985+
// Empty tuples always use Function tuple format
986+
if len(exprs) == 0 {
987+
needsFunctionFormat = true
960988
} else {
961-
fmt.Fprintf(sb, "%sFunction tuple (children %d)\n", indent, 1)
989+
for _, expr := range exprs {
990+
// Check if any element is a truly complex expression (not just a literal)
991+
if !isSimpleLiteralOrNestedLiteral(expr) {
992+
needsFunctionFormat = true
993+
break
994+
}
995+
}
996+
}
997+
if needsFunctionFormat {
998+
if n.Name != "" {
999+
fmt.Fprintf(sb, "%sFunction tuple (alias %s) (children %d)\n", indent, n.Name, 1)
1000+
} else {
1001+
fmt.Fprintf(sb, "%sFunction tuple (children %d)\n", indent, 1)
1002+
}
1003+
if len(exprs) > 0 {
1004+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(exprs))
1005+
} else {
1006+
fmt.Fprintf(sb, "%s ExpressionList\n", indent)
1007+
}
1008+
for _, expr := range exprs {
1009+
Node(sb, expr, depth+2)
1010+
}
1011+
return
9621012
}
963-
fmt.Fprintf(sb, "%s ExpressionList\n", indent)
964-
return
9651013
}
9661014
}
9671015
// Arrays containing non-literal expressions should be rendered as Function array
@@ -1064,6 +1112,36 @@ func explainWithElement(sb *strings.Builder, n *ast.WithElement, indent string,
10641112
explainArrayAccessWithAlias(sb, e, n.Name, indent, depth)
10651113
case *ast.BetweenExpr:
10661114
explainBetweenExprWithAlias(sb, e, n.Name, indent, depth)
1115+
case *ast.UnaryExpr:
1116+
// For unary minus with numeric literal, output as negative literal with alias
1117+
if e.Op == "-" {
1118+
if lit, ok := e.Operand.(*ast.Literal); ok && (lit.Type == ast.LiteralInteger || lit.Type == ast.LiteralFloat) {
1119+
// Format as negative literal
1120+
negLit := &ast.Literal{
1121+
Position: lit.Position,
1122+
Type: lit.Type,
1123+
Value: lit.Value,
1124+
}
1125+
if n.Name != "" {
1126+
fmt.Fprintf(sb, "%sLiteral %s (alias %s)\n", indent, formatNegativeLiteral(negLit), n.Name)
1127+
} else {
1128+
fmt.Fprintf(sb, "%sLiteral %s\n", indent, formatNegativeLiteral(negLit))
1129+
}
1130+
return
1131+
}
1132+
}
1133+
// For other unary expressions, output as function
1134+
fnName := "negate"
1135+
if e.Op == "NOT" {
1136+
fnName = "not"
1137+
}
1138+
if n.Name != "" {
1139+
fmt.Fprintf(sb, "%sFunction %s (alias %s) (children %d)\n", indent, fnName, n.Name, 1)
1140+
} else {
1141+
fmt.Fprintf(sb, "%sFunction %s (children %d)\n", indent, fnName, 1)
1142+
}
1143+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, 1)
1144+
Node(sb, e.Operand, depth+2)
10671145
default:
10681146
// For other types, just output the expression (alias may be lost)
10691147
Node(sb, n.Query, depth)

internal/explain/format.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ func FormatLiteral(lit *ast.Literal) string {
144144
}
145145
}
146146

147+
// formatNegativeLiteral formats a numeric literal with a negative sign prepended
148+
func formatNegativeLiteral(lit *ast.Literal) string {
149+
switch lit.Type {
150+
case ast.LiteralInteger:
151+
switch val := lit.Value.(type) {
152+
case int64:
153+
return fmt.Sprintf("Int64_-%d", val)
154+
case uint64:
155+
return fmt.Sprintf("Int64_-%d", val)
156+
default:
157+
return fmt.Sprintf("Int64_-%v", lit.Value)
158+
}
159+
case ast.LiteralFloat:
160+
val := lit.Value.(float64)
161+
return fmt.Sprintf("Float64_-%s", FormatFloat(val))
162+
default:
163+
return fmt.Sprintf("-%v", lit.Value)
164+
}
165+
}
166+
147167
// formatArrayLiteral formats an array literal for EXPLAIN AST output
148168
func formatArrayLiteral(val interface{}) string {
149169
exprs, ok := val.([]ast.Expression)

internal/explain/functions.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,12 @@ func explainInExpr(sb *strings.Builder, n *ast.InExpr, indent string, depth int)
11331133
fmt.Fprintf(sb, "%s Function tuple (children %d)\n", indent, 1)
11341134
if allParenthesizedPrimitives {
11351135
// Expand the elements
1136-
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(elems))
1136+
// For empty tuples, don't include children count
1137+
if len(elems) == 0 {
1138+
fmt.Fprintf(sb, "%s ExpressionList\n", indent)
1139+
} else {
1140+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(elems))
1141+
}
11371142
for _, elem := range elems {
11381143
Node(sb, elem, depth+4)
11391144
}

internal/explain/select.go

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,21 @@ func explainSelectWithUnionQuery(sb *strings.Builder, n *ast.SelectWithUnionQuer
298298
// ClickHouse optimizes UNION ALL when selects have identical expressions but different aliases.
299299
// In that case, only the first SELECT is shown since column names come from the first SELECT anyway.
300300
selects := simplifyUnionSelects(n.Selects)
301+
302+
// Check if we need to group selects due to mode changes
303+
// e.g., A UNION DISTINCT B UNION ALL C -> (A UNION DISTINCT B) UNION ALL C
304+
groupedSelects := groupSelectsByUnionMode(selects, n.UnionModes)
305+
301306
// Wrap selects in ExpressionList
302-
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(selects))
307+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(groupedSelects))
303308

304309
// Check if first operand has a WITH clause to be inherited by subsequent operands
305310
var inheritedWith []ast.Expression
306311
if len(selects) > 0 {
307312
inheritedWith = extractWithClause(selects[0])
308313
}
309314

310-
for i, sel := range selects {
315+
for i, sel := range groupedSelects {
311316
if i > 0 && len(inheritedWith) > 0 {
312317
// Subsequent operands inherit the WITH clause from the first operand
313318
explainSelectQueryWithInheritedWith(sb, sel, inheritedWith, depth+2)
@@ -620,6 +625,62 @@ func simplifyUnionSelects(selects []ast.Statement) []ast.Statement {
620625
return selects
621626
}
622627

628+
// groupSelectsByUnionMode groups selects when union modes change from DISTINCT to ALL.
629+
// For example, A UNION DISTINCT B UNION ALL C becomes (A UNION DISTINCT B) UNION ALL C.
630+
// This matches ClickHouse's EXPLAIN AST output which nests DISTINCT groups before ALL.
631+
// Note: The reverse (ALL followed by DISTINCT) does NOT trigger nesting.
632+
func groupSelectsByUnionMode(selects []ast.Statement, unionModes []string) []ast.Statement {
633+
if len(selects) < 3 || len(unionModes) < 2 {
634+
return selects
635+
}
636+
637+
// Normalize union modes (strip "UNION " prefix if present)
638+
normalizeMode := func(mode string) string {
639+
if len(mode) > 6 && mode[:6] == "UNION " {
640+
return mode[6:]
641+
}
642+
return mode
643+
}
644+
645+
// Only group when DISTINCT transitions to ALL
646+
// Find first DISTINCT mode, then check if it's followed by ALL
647+
firstMode := normalizeMode(unionModes[0])
648+
if firstMode != "DISTINCT" {
649+
return selects
650+
}
651+
652+
// Find where DISTINCT ends and ALL begins
653+
modeChangeIdx := -1
654+
for i := 1; i < len(unionModes); i++ {
655+
if normalizeMode(unionModes[i]) == "ALL" {
656+
modeChangeIdx = i
657+
break
658+
}
659+
}
660+
661+
// If no DISTINCT->ALL transition found, return as-is
662+
if modeChangeIdx == -1 {
663+
return selects
664+
}
665+
666+
// Create a nested SelectWithUnionQuery for selects 0..modeChangeIdx (inclusive)
667+
// modeChangeIdx is the index of the union operator, so we include selects[0] through selects[modeChangeIdx]
668+
nestedSelects := selects[:modeChangeIdx+1]
669+
nestedModes := unionModes[:modeChangeIdx]
670+
671+
nested := &ast.SelectWithUnionQuery{
672+
Selects: nestedSelects,
673+
UnionModes: nestedModes,
674+
}
675+
676+
// Result is [nested, selects[modeChangeIdx+1], ...]
677+
result := make([]ast.Statement, 0, len(selects)-modeChangeIdx)
678+
result = append(result, nested)
679+
result = append(result, selects[modeChangeIdx+1:]...)
680+
681+
return result
682+
}
683+
623684
func countSelectQueryChildren(n *ast.SelectQuery) int {
624685
count := 1 // columns ExpressionList
625686
// WITH clause

internal/explain/statements.go

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,13 @@ func explainCreateQuery(sb *strings.Builder, n *ast.CreateQuery, indent string,
446446
if len(n.OrderBy) > 0 {
447447
if len(n.OrderBy) == 1 {
448448
if ident, ok := n.OrderBy[0].(*ast.Identifier); ok {
449-
fmt.Fprintf(sb, "%s Identifier %s\n", storageIndent, ident.Name())
449+
// When ORDER BY has modifiers (ASC/DESC), wrap in StorageOrderByElement
450+
if n.OrderByHasModifiers {
451+
fmt.Fprintf(sb, "%s StorageOrderByElement (children %d)\n", storageIndent, 1)
452+
fmt.Fprintf(sb, "%s Identifier %s\n", storageIndent, ident.Name())
453+
} else {
454+
fmt.Fprintf(sb, "%s Identifier %s\n", storageIndent, ident.Name())
455+
}
450456
} else if lit, ok := n.OrderBy[0].(*ast.Literal); ok && lit.Type == ast.LiteralTuple {
451457
// Handle tuple literal - for ORDER BY with modifiers (DESC/ASC),
452458
// ClickHouse outputs just "Function tuple" without children
@@ -1620,6 +1626,10 @@ func explainAlterCommand(sb *strings.Builder, cmd *ast.AlterCommand, indent stri
16201626
if cmdType == ast.AlterClearStatistics {
16211627
cmdType = ast.AlterDropStatistics
16221628
}
1629+
// ATTACH PARTITION ... FROM table is shown as REPLACE_PARTITION in EXPLAIN AST
1630+
if cmdType == ast.AlterAttachPartition && cmd.FromTable != "" {
1631+
cmdType = ast.AlterReplacePartition
1632+
}
16231633
// DETACH_PARTITION is shown as DROP_PARTITION in EXPLAIN AST
16241634
if cmdType == ast.AlterDetachPartition {
16251635
cmdType = ast.AlterDropPartition
@@ -1802,7 +1812,7 @@ func explainAlterCommand(sb *strings.Builder, cmd *ast.AlterCommand, indent stri
18021812
case ast.AlterModifySetting:
18031813
fmt.Fprintf(sb, "%s Set\n", indent)
18041814
case ast.AlterDropPartition, ast.AlterDetachPartition, ast.AlterAttachPartition,
1805-
ast.AlterReplacePartition, ast.AlterFetchPartition, ast.AlterMovePartition, ast.AlterFreezePartition, ast.AlterApplyPatches:
1815+
ast.AlterReplacePartition, ast.AlterFetchPartition, ast.AlterMovePartition, ast.AlterFreezePartition, ast.AlterApplyPatches, ast.AlterApplyDeletedMask:
18061816
if cmd.Partition != nil {
18071817
// PARTITION ALL is shown as Partition_ID (empty) in EXPLAIN AST
18081818
if ident, ok := cmd.Partition.(*ast.Identifier); ok && strings.ToUpper(ident.Name()) == "ALL" {
@@ -1910,6 +1920,9 @@ func explainProjection(sb *strings.Builder, p *ast.Projection, indent string, de
19101920

19111921
func explainProjectionSelectQuery(sb *strings.Builder, q *ast.ProjectionSelectQuery, indent string, depth int) {
19121922
children := 0
1923+
if len(q.With) > 0 {
1924+
children++
1925+
}
19131926
if len(q.Columns) > 0 {
19141927
children++
19151928
}
@@ -1920,6 +1933,13 @@ func explainProjectionSelectQuery(sb *strings.Builder, q *ast.ProjectionSelectQu
19201933
children++
19211934
}
19221935
fmt.Fprintf(sb, "%sProjectionSelectQuery (children %d)\n", indent, children)
1936+
// Output WITH clause first
1937+
if len(q.With) > 0 {
1938+
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(q.With))
1939+
for _, w := range q.With {
1940+
Node(sb, w, depth+2)
1941+
}
1942+
}
19231943
if len(q.Columns) > 0 {
19241944
fmt.Fprintf(sb, "%s ExpressionList (children %d)\n", indent, len(q.Columns))
19251945
for _, col := range q.Columns {
@@ -2085,7 +2105,7 @@ func countAlterCommandChildren(cmd *ast.AlterCommand) int {
20852105
case ast.AlterModifySetting:
20862106
children = 1
20872107
case ast.AlterDropPartition, ast.AlterDetachPartition, ast.AlterAttachPartition,
2088-
ast.AlterReplacePartition, ast.AlterFetchPartition, ast.AlterMovePartition, ast.AlterFreezePartition, ast.AlterApplyPatches:
2108+
ast.AlterReplacePartition, ast.AlterFetchPartition, ast.AlterMovePartition, ast.AlterFreezePartition, ast.AlterApplyPatches, ast.AlterApplyDeletedMask:
20892109
if cmd.Partition != nil {
20902110
children++
20912111
}

0 commit comments

Comments
 (0)