Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6a1107e
refactor: Add cte factoring to new compiler
TrevorBergeron Feb 12, 2026
97a5a7f
Merge remote-tracking branch 'github/main' into cte_extract2
TrevorBergeron Feb 27, 2026
2d9dad5
Merge remote-tracking branch 'github/main' into cte_extract2
TrevorBergeron Mar 5, 2026
10a9798
separate logical cte nodes from concrete sql ones
TrevorBergeron Mar 6, 2026
a5283f1
move extract_ctes later in compiler
TrevorBergeron Mar 6, 2026
49ddfdf
fix _as_from_item helper
TrevorBergeron Mar 6, 2026
accf797
fix test issues and update snapshots
TrevorBergeron Mar 6, 2026
1e62fd3
fix cte emitter
TrevorBergeron Mar 6, 2026
e60f3ed
dont create with expression without ctes
TrevorBergeron Mar 6, 2026
ce9fbb9
amend selection compilation
TrevorBergeron Mar 6, 2026
74b2470
fix as_select star logic
TrevorBergeron Mar 6, 2026
069dd09
fix isin join logic
TrevorBergeron Mar 6, 2026
5f2ed0d
redo id remapper
TrevorBergeron Mar 6, 2026
75e02cc
fix id issues preventing cte factoring
TrevorBergeron Mar 7, 2026
2d382c6
refactor ir internals to reduce select count
TrevorBergeron Mar 7, 2026
5717cfc
avoid more select stars
TrevorBergeron Mar 7, 2026
b7d360a
fix identifier tests
TrevorBergeron Mar 7, 2026
10cf398
wrap more nodes in ctes
TrevorBergeron Mar 8, 2026
28865fa
enable experimental compiler for tests
TrevorBergeron Mar 8, 2026
e6043dd
avoid rewrapping nodes, wrapping root
TrevorBergeron Mar 8, 2026
8a5a96e
dont extract union all children as ctes
TrevorBergeron Mar 8, 2026
64d5ce9
fix type, tmp test change
TrevorBergeron Mar 8, 2026
cd5579b
fix isin logic
TrevorBergeron Mar 9, 2026
47a8af3
fix isin struct ics
TrevorBergeron Mar 9, 2026
dd39c51
fix isin w nulls
TrevorBergeron Mar 9, 2026
fecc37a
avoid TypedExpr warpping
TrevorBergeron Mar 9, 2026
1a3d311
reset compiler to stable
TrevorBergeron Mar 9, 2026
e601d1b
appease mypy
TrevorBergeron Mar 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 4 additions & 14 deletions bigframes/core/bigframe_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,12 @@ def top_down(
"""
Perform a top-down transformation of the BigFrameNode tree.
"""
to_process = [self]
results: Dict[BigFrameNode, BigFrameNode] = {}

while to_process:
item = to_process.pop()
if item not in results.keys():
item_result = transform(item)
results[item] = item_result
to_process.extend(item_result.child_nodes)
@functools.cache
def recursive_transform(node: BigFrameNode) -> BigFrameNode:
return transform(node).transform_children(recursive_transform)

to_process = [self]
# for each processed item, replace its children
for item in reversed(list(results.keys())):
results[item] = results[item].transform_children(lambda x: results[x])

return results[self]
return recursive_transform(self)

def bottom_up(
self: BigFrameNode,
Expand Down
31 changes: 28 additions & 3 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
if request.sort_rows:
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
# TODO: Extract CTEs earlier
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
sql = _compile_result_node(result_node)
return configs.CompileResult(
sql,
Expand All @@ -74,6 +76,8 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
result_node = dataclasses.replace(result_node, order_by=None)
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
# TODO: Extract CTEs earlier
result_node = typing.cast(nodes.ResultNode, rewrite.extract_ctes(result_node))
sql = _compile_result_node(result_node)
# Return the ordering iff no extra columns are needed to define the row order
if ordering is not None:
Expand All @@ -94,6 +98,7 @@ def _remap_variables(
result_node, _ = rewrite.remap_variables(
node, map(identifiers.ColumnId, uid_gen.get_uid_stream("bfcol_"))
)
result_node.validate_tree()
return typing.cast(nodes.ResultNode, result_node)


Expand All @@ -102,13 +107,16 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
# of nodes using the same generator.
uid_gen = guid.SequentialUIDGenerator()
root = _remap_variables(root, uid_gen)
# Remap variables creates too mayn new
# root = rewrite.select_pullup(root, prefer_source_names=False)
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))

# Have to bind schema as the final step before compilation.
# Probably, should defer even further
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))

sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root), uid_gen)
# TODO: Bake all IDs in tree, stop passing uid_gen to emitters
sqlglot_ir_obj = compile_node(rewrite.as_sql_nodes(root, uid_gen), uid_gen)
return sqlglot_ir_obj.sql


Expand All @@ -121,7 +129,7 @@ def compile_node(
for current_node in list(node.iter_nodes_topo()):
if current_node.child_nodes == ():
# For leaf node, generates a dumpy child to pass the UID generator.
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
child_results = tuple([sqlglot_ir.SQLGlotIR.empty(uid_gen=uid_gen)])
else:
# Child nodes should have been compiled in the reverse topological order.
child_results = tuple(
Expand Down Expand Up @@ -256,6 +264,23 @@ def compile_isin_join(
)


@_compile_node.register
def compile_cte_ref_node(node: sql_nodes.SqlCteRefNode, child: sqlglot_ir.SQLGlotIR):
return sqlglot_ir.SQLGlotIR.from_cte_ref(
node.cte_name,
uid_gen=child.uid_gen,
)


@_compile_node.register
def compile_with_ctes_node(
node: sql_nodes.SqlWithCtesNode,
child: sqlglot_ir.SQLGlotIR,
*ctes: sqlglot_ir.SQLGlotIR,
):
return child.with_ctes(tuple(zip(node.cte_names, ctes)))


@_compile_node.register
def compile_concat(
node: nodes.ConcatNode, *children: sqlglot_ir.SQLGlotIR
Expand All @@ -271,7 +296,7 @@ def compile_concat(
]

return sqlglot_ir.SQLGlotIR.from_union(
[child._as_select() for child in children],
[child.expr.as_select_all() for child in children],
output_aliases=output_aliases,
uid_gen=uid_gen,
)
Expand Down
20 changes: 16 additions & 4 deletions bigframes/core/compile/sqlglot/expressions/comparison_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,39 @@
@register_unary_op(ops.IsInOp, pass_op=True)
def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
values = []
# bools are not comparable to non-bools in SQL, so we need to cast the expression to INT64 if the values contain non-bools.
must_upcast_bools = dtypes.is_numeric(expr.dtype, include_bool=False) or any(
dtypes.is_numeric(dtypes.bigframes_type(type(value)), include_bool=False)
for value in op.values
if not _is_null(value)
)
for value in op.values:
if _is_null(value):
continue
dtype = dtypes.bigframes_type(type(value))
if dtypes.can_compare(expr.dtype, dtype):
if must_upcast_bools and dtype == dtypes.BOOL_DTYPE:
value = int(value)
values.append(sge.convert(value))

sg_lexpr: sge.Expression = expr.expr
if expr.dtype == dtypes.BOOL_DTYPE and must_upcast_bools:
sg_lexpr = sge.cast(expr.expr, "INT64")

if op.match_nulls:
contains_nulls = any(_is_null(value) for value in op.values)
if contains_nulls:
if len(values) == 0:
return sge.Is(this=expr.expr, expression=sge.Null())
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
this=expr.expr, expressions=values
return sge.Is(this=sg_lexpr, expression=sge.Null())
return sge.Is(this=sg_lexpr, expression=sge.Null()) | sge.In(
this=sg_lexpr, expressions=values
)

if len(values) == 0:
return sge.convert(False)

return sge.func(
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
"COALESCE", sge.In(this=sg_lexpr, expressions=values), sge.convert(False)
)


Expand Down
Loading
Loading