Skip to content

Commit 2d382c6

Browse files
refactor ir internals to reduce select count
1 parent 75e02cc commit 2d382c6

File tree

42 files changed

+524
-363
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+524
-363
lines changed

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def compile_node(
126126
for current_node in list(node.iter_nodes_topo()):
127127
if current_node.child_nodes == ():
128128
# For leaf node, generates a dumpy child to pass the UID generator.
129-
child_results = tuple([sqlglot_ir.SQLGlotIR(uid_gen=uid_gen)])
129+
child_results = tuple([sqlglot_ir.SQLGlotIR.empty(uid_gen=uid_gen)])
130130
else:
131131
# Child nodes should have been compiled in the reverse topological order.
132132
child_results = tuple(

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 119 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import abc
1718
import dataclasses
1819
import datetime
1920
import functools
@@ -38,20 +39,107 @@
3839
to_wkt = dumps
3940

4041

42+
class SelectableFragment(abc.ABC):
43+
"""
44+
Represent a grammar fragment that can be converted to a SELECT or FROM item.
45+
"""
46+
47+
def as_select_all(self) -> sge.Select:
48+
...
49+
50+
def select(self, *items: sge.Expression) -> sge.Select:
51+
...
52+
53+
def as_from_item(self) -> sge.FromItem:
54+
...
55+
56+
57+
class SelectFragment(SelectableFragment):
58+
def __init__(self, select_expr: sge.Select):
59+
self.select_expr = select_expr
60+
61+
def as_select_all(self) -> sge.Select:
62+
return self.select_expr
63+
64+
def select(self, *items: sge.Expression) -> sge.Select:
65+
return sge.Select().select(*items).from_(self.select_expr.subquery())
66+
67+
def as_from_item(self) -> sge.FromItem:
68+
return self.select_expr.subquery()
69+
70+
71+
class TableFragment(SelectableFragment):
72+
def __init__(self, table: sge.Table | sge.Unnest):
73+
self.table = table
74+
75+
def as_select_all(self) -> sge.Select:
76+
return sge.Select().select(sge.Star()).from_(self.table)
77+
78+
def select(self, *items: sge.Expression) -> sge.Select:
79+
return sge.Select().select(*items).from_(self.table)
80+
81+
def as_from_item(self) -> sge.FromItem:
82+
return self.table
83+
84+
85+
class DeferredSelectFragment(SelectableFragment):
86+
def __init__(self, select_supplier: typing.Callable[[sge.Select], sge.Select]):
87+
self.select_supplier = select_supplier
88+
89+
def as_select_all(self) -> sge.Select:
90+
return self.select_supplier(sge.Select().select(sge.Star()))
91+
92+
def select(self, *items: sge.Expression) -> sge.Select:
93+
return self.select_supplier(sge.Select().select(*items))
94+
95+
def as_from_item(self) -> sge.FromItem:
96+
return self.select_supplier(sge.Select().select(sge.Star())).subquery()
97+
98+
4199
@dataclasses.dataclass(frozen=True)
42100
class SQLGlotIR:
43101
"""Helper class to build SQLGlot Query and generate SQL string."""
44102

45-
expr: typing.Union[sge.Select, sge.Table] = sg.select()
103+
expr: SelectableFragment
46104
"""The SQLGlot expression representing the query."""
47105

48106
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
49107
"""Generator for unique identifiers."""
50108

109+
def __post_init__(self):
110+
assert isinstance(self.expr, SelectableFragment)
111+
51112
@property
52113
def sql(self) -> str:
53114
"""Generate SQL string from the given expression."""
54-
return sql.to_sql(self.expr)
115+
return sql.to_sql(self.expr.as_select_all())
116+
117+
@classmethod
118+
def empty(
119+
cls, uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator()
120+
) -> SQLGlotIR:
121+
return cls(expr=SelectFragment(sge.select()), uid_gen=uid_gen)
122+
123+
@classmethod
124+
def from_expr(
125+
cls,
126+
expr: sge.Expression,
127+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator(),
128+
) -> SQLGlotIR:
129+
if isinstance(expr, sge.Select):
130+
return cls(expr=SelectFragment(expr), uid_gen=uid_gen)
131+
elif isinstance(expr, (sge.Table, sge.Unnest)):
132+
return cls(expr=TableFragment(expr), uid_gen=uid_gen)
133+
else:
134+
raise ValueError(f"Unsupported expression type: {type(expr)}")
135+
136+
@classmethod
137+
def from_func(
138+
cls,
139+
select_handler: typing.Callable[[sge.Select], sge.Select],
140+
uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator(),
141+
):
142+
return cls(expr=DeferredSelectFragment(select_handler), uid_gen=uid_gen)
55143

56144
@classmethod
57145
def from_pyarrow(
@@ -97,7 +185,7 @@ def from_pyarrow(
97185
),
98186
],
99187
)
100-
return cls(expr=sg.select(sge.Star()).from_(expr), uid_gen=uid_gen)
188+
return cls.from_expr(expr=expr, uid_gen=uid_gen)
101189

102190
@classmethod
103191
def from_table(
@@ -143,9 +231,9 @@ def from_table(
143231
select_expr = select_expr.where(
144232
sg.parse_one(sql_predicate, dialect=sql.base.DIALECT), append=False
145233
)
146-
return cls(expr=select_expr, uid_gen=uid_gen)
234+
return cls.from_expr(expr=select_expr, uid_gen=uid_gen)
147235

148-
return cls(expr=table_expr, uid_gen=uid_gen)
236+
return cls.from_expr(expr=table_expr, uid_gen=uid_gen)
149237

150238
@classmethod
151239
def from_cte_ref(
@@ -156,7 +244,7 @@ def from_cte_ref(
156244
table_expr = sge.Table(
157245
this=sql.identifier(cte_ref),
158246
)
159-
return cls(expr=table_expr, uid_gen=uid_gen)
247+
return cls.from_expr(expr=table_expr, uid_gen=uid_gen)
160248

161249
def select(
162250
self,
@@ -191,7 +279,7 @@ def select(
191279
if limit is not None:
192280
new_expr = new_expr.limit(limit)
193281

194-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
282+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
195283

196284
@classmethod
197285
def from_unparsed_query(
@@ -209,7 +297,7 @@ def from_unparsed_query(
209297
)
210298
select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name))
211299
select_expr = _set_query_ctes(select_expr, [cte])
212-
return cls(expr=select_expr, uid_gen=uid_gen)
300+
return cls.from_expr(expr=select_expr, uid_gen=uid_gen)
213301

214302
@classmethod
215303
def from_union(
@@ -241,7 +329,7 @@ def from_union(
241329
final_select_expr = (
242330
sge.Select().select(*selections).from_(union_expr.subquery())
243331
)
244-
return cls(expr=final_select_expr, uid_gen=uid_gen)
332+
return cls.from_expr(expr=final_select_expr, uid_gen=uid_gen)
245333

246334
def join(
247335
self,
@@ -262,15 +350,13 @@ def join(
262350
)
263351

264352
join_type_str = join_type if join_type != "outer" else "full outer"
265-
new_expr = (
266-
sge.Select()
267-
.select(sge.Star())
268-
.from_(left_from)
269-
.join(right_from, on=join_on, join_type=join_type_str)
353+
return SQLGlotIR.from_func(
354+
lambda select: select.from_(left_from).join(
355+
right_from, on=join_on, join_type=join_type_str
356+
),
357+
uid_gen=self.uid_gen,
270358
)
271359

272-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
273-
274360
def isin_join(
275361
self,
276362
right: SQLGlotIR,
@@ -280,7 +366,6 @@ def isin_join(
280366
) -> SQLGlotIR:
281367
"""Joins the current query with another SQLGlotIR instance."""
282368
left_from = self._as_from_item()
283-
right_select = right._as_select()
284369

285370
new_column: sge.Expression
286371
if joins_nulls:
@@ -294,7 +379,7 @@ def isin_join(
294379
]
295380
)
296381
right_expr1, right_expr2 = _value_to_non_null_identity(conditions[1])
297-
right_select = right_select.select(
382+
right_select = right.expr.select(
298383
*[
299384
sge.Struct(
300385
expressions=[
@@ -303,7 +388,6 @@ def isin_join(
303388
]
304389
)
305390
],
306-
append=False,
307391
)
308392

309393
new_column = sge.In(
@@ -313,7 +397,7 @@ def isin_join(
313397
else:
314398
new_column = sge.In(
315399
this=conditions[0].expr,
316-
expressions=[right_select.subquery()],
400+
expressions=[right._as_subquery()],
317401
)
318402

319403
new_column = sge.Alias(
@@ -322,7 +406,7 @@ def isin_join(
322406
)
323407

324408
new_expr = sge.Select().select(sge.Star(), new_column).from_(left_from)
325-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
409+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
326410

327411
def explode(
328412
self,
@@ -344,8 +428,8 @@ def sample(self, fraction: float) -> SQLGlotIR:
344428
expression=sql.literal(fraction, dtypes.FLOAT_DTYPE),
345429
)
346430

347-
new_expr = self._as_select().where(condition, append=False)
348-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
431+
new_expr = self.expr.as_select_all().where(condition, append=False)
432+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
349433

350434
def aggregate(
351435
self,
@@ -368,10 +452,7 @@ def aggregate(
368452
for id, expr in aggregations
369453
]
370454

371-
new_expr = self._as_select()
372-
new_expr = new_expr.group_by(*by_cols).select(
373-
*[*by_cols, *aggregations_expr], append=False
374-
)
455+
new_expr = self.expr.select(*[*by_cols, *aggregations_expr]).group_by(*by_cols)
375456

376457
condition = _and(
377458
tuple(
@@ -381,7 +462,7 @@ def aggregate(
381462
)
382463
if condition is not None:
383464
new_expr = new_expr.where(condition, append=False)
384-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
465+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
385466

386467
def with_ctes(
387468
self,
@@ -395,7 +476,7 @@ def with_ctes(
395476
for cte_name, cte in ctes
396477
]
397478
select_expr = _set_query_ctes(self._as_select(), sge_ctes)
398-
return SQLGlotIR(expr=select_expr, uid_gen=self.uid_gen)
479+
return SQLGlotIR.from_expr(expr=select_expr, uid_gen=self.uid_gen)
399480

400481
def resample(
401482
self,
@@ -431,7 +512,7 @@ def resample(
431512
.join(unnest_expr, join_type="cross")
432513
)
433514

434-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
515+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
435516

436517
def _explode_single_column(
437518
self, column_name: str, offsets_col: typing.Optional[str]
@@ -449,12 +530,9 @@ def _explode_single_column(
449530
)
450531
selection = sge.Star(replace=[unnested_column_alias.as_(column)])
451532

452-
new_expr = self._as_select()
453533
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
454-
new_expr = new_expr.select(selection, append=False).join(
455-
unnest_expr, join_type="LEFT"
456-
)
457-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
534+
new_expr = self.expr.select(selection).join(unnest_expr, join_type="LEFT")
535+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
458536

459537
def _explode_multiple_columns(
460538
self,
@@ -492,26 +570,18 @@ def _explode_multiple_columns(
492570
for column in columns
493571
]
494572
)
495-
new_expr = self._as_select()
496573
# Use LEFT JOIN to preserve rows when unnesting empty arrays.
497-
new_expr = new_expr.select(selection, append=False).join(
498-
unnest_expr, join_type="LEFT"
499-
)
500-
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
574+
new_expr = self.expr.select(selection).join(unnest_expr, join_type="LEFT")
575+
return SQLGlotIR.from_expr(expr=new_expr, uid_gen=self.uid_gen)
501576

502-
def _as_from_item(self) -> typing.Union[sge.Subquery, sge.Table]:
503-
if isinstance(self.expr, sge.Select):
504-
return self.expr.subquery()
505-
else: # table or cte
506-
return self.expr
577+
def _as_from_item(self) -> typing.Union[sge.Subquery, sge.Table, sge.Unnest]:
578+
return self.expr.as_from_item()
507579

508580
def _as_select(self) -> sge.Select:
509-
if isinstance(self.expr, sge.Select):
510-
return self.expr
511-
else: # table or cte
512-
return sge.Select().select(sge.Star()).from_(self.expr)
581+
return self.expr.as_select_all()
513582

514583
def _as_subquery(self) -> sge.Subquery:
584+
# Sometimes explicitly need a subquery, e.g. for IN expressions.
515585
return self._as_select().subquery()
516586

517587

bigframes/core/rewrite/pruning.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def prune_selection_child(
6767

6868
# Important to check this first
6969
if list(selection.ids) == list(child.ids):
70-
# Added all() here - a generator object is natively truthy
7170
if all(ref.ref.id == ref.id for ref in selection.input_output_pairs):
7271
# selection is no-op so just remove it entirely
7372
return child

bigframes/session/bq_caching_executor.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -490,15 +490,6 @@ def prepare_plan(
490490

491491
return plan
492492

493-
def simplify_plan(
494-
self, plan: nodes.BigFrameNode, use_cache: bool = True
495-
) -> nodes.BigFrameNode:
496-
if use_cache:
497-
plan = self.replace_cached_subtrees(plan)
498-
plan = rewrite.column_pruning(plan)
499-
plan = plan.top_down(rewrite.fold_row_counts)
500-
return plan
501-
502493
def _cache_with_cluster_cols(
503494
self, array_value: bigframes.core.ArrayValue, cluster_cols: Sequence[str]
504495
):

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_corr/out.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@ SELECT
33
FROM (
44
SELECT
55
CORR(`int64_col`, `float64_col`) AS `bfcol_2`
6-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
6+
FROM (
7+
SELECT
8+
`int64_col`,
9+
`float64_col`
10+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
11+
)
712
)

tests/unit/core/compile/sqlglot/aggregations/snapshots/test_binary_compiler/test_cov/out.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@ SELECT
33
FROM (
44
SELECT
55
COVAR_SAMP(`int64_col`, `float64_col`) AS `bfcol_2`
6-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
6+
FROM (
7+
SELECT
8+
`int64_col`,
9+
`float64_col`
10+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
11+
)
712
)

0 commit comments

Comments
 (0)