1414
1515from __future__ import annotations
1616
17+ import abc
1718import dataclasses
1819import datetime
1920import functools
3839 to_wkt = dumps
3940
4041
42+ class SelectableFragment (abc .ABC ):
43+ """
44+ Represent a grammar fragment that can be converted to a SELECT or FROM item.
45+ """
46+
47+ def as_select_all (self ) -> sge .Select :
48+ ...
49+
50+ def select (self , * items : sge .Expression ) -> sge .Select :
51+ ...
52+
53+ def as_from_item (self ) -> sge .FromItem :
54+ ...
55+
56+
57+ class SelectFragment (SelectableFragment ):
58+ def __init__ (self , select_expr : sge .Select ):
59+ self .select_expr = select_expr
60+
61+ def as_select_all (self ) -> sge .Select :
62+ return self .select_expr
63+
64+ def select (self , * items : sge .Expression ) -> sge .Select :
65+ return sge .Select ().select (* items ).from_ (self .select_expr .subquery ())
66+
67+ def as_from_item (self ) -> sge .FromItem :
68+ return self .select_expr .subquery ()
69+
70+
71+ class TableFragment (SelectableFragment ):
72+ def __init__ (self , table : sge .Table | sge .Unnest ):
73+ self .table = table
74+
75+ def as_select_all (self ) -> sge .Select :
76+ return sge .Select ().select (sge .Star ()).from_ (self .table )
77+
78+ def select (self , * items : sge .Expression ) -> sge .Select :
79+ return sge .Select ().select (* items ).from_ (self .table )
80+
81+ def as_from_item (self ) -> sge .FromItem :
82+ return self .table
83+
84+
85+ class DeferredSelectFragment (SelectableFragment ):
86+ def __init__ (self , select_supplier : typing .Callable [[sge .Select ], sge .Select ]):
87+ self .select_supplier = select_supplier
88+
89+ def as_select_all (self ) -> sge .Select :
90+ return self .select_supplier (sge .Select ().select (sge .Star ()))
91+
92+ def select (self , * items : sge .Expression ) -> sge .Select :
93+ return self .select_supplier (sge .Select ().select (* items ))
94+
95+ def as_from_item (self ) -> sge .FromItem :
96+ return self .select_supplier (sge .Select ().select (sge .Star ())).subquery ()
97+
98+
4199@dataclasses .dataclass (frozen = True )
42100class SQLGlotIR :
43101 """Helper class to build SQLGlot Query and generate SQL string."""
44102
45- expr : typing . Union [ sge . Select , sge . Table ] = sg . select ()
103+ expr : SelectableFragment
46104 """The SQLGlot expression representing the query."""
47105
48106 uid_gen : guid .SequentialUIDGenerator = guid .SequentialUIDGenerator ()
49107 """Generator for unique identifiers."""
50108
109+ def __post_init__ (self ):
110+ assert isinstance (self .expr , SelectableFragment )
111+
51112 @property
52113 def sql (self ) -> str :
53114 """Generate SQL string from the given expression."""
54- return sql .to_sql (self .expr )
115+ return sql .to_sql (self .expr .as_select_all ())
116+
117+ @classmethod
118+ def empty (
119+ cls , uid_gen : guid .SequentialUIDGenerator = guid .SequentialUIDGenerator ()
120+ ) -> SQLGlotIR :
121+ return cls (expr = SelectFragment (sge .select ()), uid_gen = uid_gen )
122+
123+ @classmethod
124+ def from_expr (
125+ cls ,
126+ expr : sge .Expression ,
127+ uid_gen : guid .SequentialUIDGenerator = guid .SequentialUIDGenerator (),
128+ ) -> SQLGlotIR :
129+ if isinstance (expr , sge .Select ):
130+ return cls (expr = SelectFragment (expr ), uid_gen = uid_gen )
131+ elif isinstance (expr , (sge .Table , sge .Unnest )):
132+ return cls (expr = TableFragment (expr ), uid_gen = uid_gen )
133+ else :
134+ raise ValueError (f"Unsupported expression type: { type (expr )} " )
135+
136+ @classmethod
137+ def from_func (
138+ cls ,
139+ select_handler : typing .Callable [[sge .Select ], sge .Select ],
140+ uid_gen : guid .SequentialUIDGenerator = guid .SequentialUIDGenerator (),
141+ ):
142+ return cls (expr = DeferredSelectFragment (select_handler ), uid_gen = uid_gen )
55143
56144 @classmethod
57145 def from_pyarrow (
@@ -97,7 +185,7 @@ def from_pyarrow(
97185 ),
98186 ],
99187 )
100- return cls (expr = sg . select ( sge . Star ()). from_ ( expr ) , uid_gen = uid_gen )
188+ return cls . from_expr (expr = expr , uid_gen = uid_gen )
101189
102190 @classmethod
103191 def from_table (
@@ -143,9 +231,9 @@ def from_table(
143231 select_expr = select_expr .where (
144232 sg .parse_one (sql_predicate , dialect = sql .base .DIALECT ), append = False
145233 )
146- return cls (expr = select_expr , uid_gen = uid_gen )
234+ return cls . from_expr (expr = select_expr , uid_gen = uid_gen )
147235
148- return cls (expr = table_expr , uid_gen = uid_gen )
236+ return cls . from_expr (expr = table_expr , uid_gen = uid_gen )
149237
150238 @classmethod
151239 def from_cte_ref (
@@ -156,7 +244,7 @@ def from_cte_ref(
156244 table_expr = sge .Table (
157245 this = sql .identifier (cte_ref ),
158246 )
159- return cls (expr = table_expr , uid_gen = uid_gen )
247+ return cls . from_expr (expr = table_expr , uid_gen = uid_gen )
160248
161249 def select (
162250 self ,
@@ -191,7 +279,7 @@ def select(
191279 if limit is not None :
192280 new_expr = new_expr .limit (limit )
193281
194- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
282+ return SQLGlotIR . from_expr (expr = new_expr , uid_gen = self .uid_gen )
195283
196284 @classmethod
197285 def from_unparsed_query (
@@ -209,7 +297,7 @@ def from_unparsed_query(
209297 )
210298 select_expr = sge .Select ().select (sge .Star ()).from_ (sge .Table (this = cte_name ))
211299 select_expr = _set_query_ctes (select_expr , [cte ])
212- return cls (expr = select_expr , uid_gen = uid_gen )
300+ return cls . from_expr (expr = select_expr , uid_gen = uid_gen )
213301
214302 @classmethod
215303 def from_union (
@@ -241,7 +329,7 @@ def from_union(
241329 final_select_expr = (
242330 sge .Select ().select (* selections ).from_ (union_expr .subquery ())
243331 )
244- return cls (expr = final_select_expr , uid_gen = uid_gen )
332+ return cls . from_expr (expr = final_select_expr , uid_gen = uid_gen )
245333
246334 def join (
247335 self ,
@@ -262,15 +350,13 @@ def join(
262350 )
263351
264352 join_type_str = join_type if join_type != "outer" else "full outer"
265- new_expr = (
266- sge . Select ()
267- . select ( sge . Star ())
268- . from_ ( left_from )
269- . join ( right_from , on = join_on , join_type = join_type_str )
353+ return SQLGlotIR . from_func (
354+ lambda select : select . from_ ( left_from ). join (
355+ right_from , on = join_on , join_type = join_type_str
356+ ),
357+ uid_gen = self . uid_gen ,
270358 )
271359
272- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
273-
274360 def isin_join (
275361 self ,
276362 right : SQLGlotIR ,
@@ -280,7 +366,6 @@ def isin_join(
280366 ) -> SQLGlotIR :
281367 """Joins the current query with another SQLGlotIR instance."""
282368 left_from = self ._as_from_item ()
283- right_select = right ._as_select ()
284369
285370 new_column : sge .Expression
286371 if joins_nulls :
@@ -294,7 +379,7 @@ def isin_join(
294379 ]
295380 )
296381 right_expr1 , right_expr2 = _value_to_non_null_identity (conditions [1 ])
297- right_select = right_select .select (
382+ right_select = right . expr .select (
298383 * [
299384 sge .Struct (
300385 expressions = [
@@ -303,7 +388,6 @@ def isin_join(
303388 ]
304389 )
305390 ],
306- append = False ,
307391 )
308392
309393 new_column = sge .In (
@@ -313,7 +397,7 @@ def isin_join(
313397 else :
314398 new_column = sge .In (
315399 this = conditions [0 ].expr ,
316- expressions = [right_select . subquery ()],
400+ expressions = [right . _as_subquery ()],
317401 )
318402
319403 new_column = sge .Alias (
@@ -322,7 +406,7 @@ def isin_join(
322406 )
323407
324408 new_expr = sge .Select ().select (sge .Star (), new_column ).from_ (left_from )
325- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
409+ return SQLGlotIR . from_expr (expr = new_expr , uid_gen = self .uid_gen )
326410
327411 def explode (
328412 self ,
@@ -344,8 +428,8 @@ def sample(self, fraction: float) -> SQLGlotIR:
344428 expression = sql .literal (fraction , dtypes .FLOAT_DTYPE ),
345429 )
346430
347- new_expr = self ._as_select ().where (condition , append = False )
348- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
431+ new_expr = self .expr . as_select_all ().where (condition , append = False )
432+ return SQLGlotIR . from_expr (expr = new_expr , uid_gen = self .uid_gen )
349433
350434 def aggregate (
351435 self ,
@@ -368,10 +452,7 @@ def aggregate(
368452 for id , expr in aggregations
369453 ]
370454
371- new_expr = self ._as_select ()
372- new_expr = new_expr .group_by (* by_cols ).select (
373- * [* by_cols , * aggregations_expr ], append = False
374- )
455+ new_expr = self .expr .select (* [* by_cols , * aggregations_expr ]).group_by (* by_cols )
375456
376457 condition = _and (
377458 tuple (
@@ -381,7 +462,7 @@ def aggregate(
381462 )
382463 if condition is not None :
383464 new_expr = new_expr .where (condition , append = False )
384- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
465+ return SQLGlotIR . from_expr (expr = new_expr , uid_gen = self .uid_gen )
385466
386467 def with_ctes (
387468 self ,
@@ -395,7 +476,7 @@ def with_ctes(
395476 for cte_name , cte in ctes
396477 ]
397478 select_expr = _set_query_ctes (self ._as_select (), sge_ctes )
398- return SQLGlotIR (expr = select_expr , uid_gen = self .uid_gen )
479+ return SQLGlotIR . from_expr (expr = select_expr , uid_gen = self .uid_gen )
399480
400481 def resample (
401482 self ,
@@ -431,7 +512,7 @@ def resample(
431512 .join (unnest_expr , join_type = "cross" )
432513 )
433514
434- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
515+ return SQLGlotIR . from_expr (expr = new_expr , uid_gen = self .uid_gen )
435516
436517 def _explode_single_column (
437518 self , column_name : str , offsets_col : typing .Optional [str ]
@@ -449,12 +530,9 @@ def _explode_single_column(
449530 )
450531 selection = sge .Star (replace = [unnested_column_alias .as_ (column )])
451532
452- new_expr = self ._as_select ()
453533 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
454- new_expr = new_expr .select (selection , append = False ).join (
455- unnest_expr , join_type = "LEFT"
456- )
457- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
534+ new_expr = self .expr .select (selection ).join (unnest_expr , join_type = "LEFT" )
535+ return SQLGlotIR .from_expr (expr = new_expr , uid_gen = self .uid_gen )
458536
459537 def _explode_multiple_columns (
460538 self ,
@@ -492,26 +570,18 @@ def _explode_multiple_columns(
492570 for column in columns
493571 ]
494572 )
495- new_expr = self ._as_select ()
496573 # Use LEFT JOIN to preserve rows when unnesting empty arrays.
497- new_expr = new_expr .select (selection , append = False ).join (
498- unnest_expr , join_type = "LEFT"
499- )
500- return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
574+ new_expr = self .expr .select (selection ).join (unnest_expr , join_type = "LEFT" )
575+ return SQLGlotIR .from_expr (expr = new_expr , uid_gen = self .uid_gen )
501576
502- def _as_from_item (self ) -> typing .Union [sge .Subquery , sge .Table ]:
503- if isinstance (self .expr , sge .Select ):
504- return self .expr .subquery ()
505- else : # table or cte
506- return self .expr
577+ def _as_from_item (self ) -> typing .Union [sge .Subquery , sge .Table , sge .Unnest ]:
578+ return self .expr .as_from_item ()
507579
508580 def _as_select (self ) -> sge .Select :
509- if isinstance (self .expr , sge .Select ):
510- return self .expr
511- else : # table or cte
512- return sge .Select ().select (sge .Star ()).from_ (self .expr )
581+ return self .expr .as_select_all ()
513582
514583 def _as_subquery (self ) -> sge .Subquery :
584+ # Sometimes explicitly need a subquery, e.g. for IN expressions.
515585 return self ._as_select ().subquery ()
516586
517587
0 commit comments