@@ -280,32 +280,39 @@ def isin_join(
280280 ) -> SQLGlotIR :
281281 """Joins the current query with another SQLGlotIR instance."""
282282 left_from = self ._as_from_item ()
283- # Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
284283 right_select = right ._as_select ()
285284
286- left_condition = typed_expr .TypedExpr (
287- sge .Column (this = conditions [0 ].expr , table = left_from ),
288- conditions [0 ].dtype ,
289- )
290-
291285 new_column : sge .Expression
292286 if joins_nulls :
293- right_table_name = sql .identifier (next (self .uid_gen .get_uid_stream ("bft_" )))
294- right_condition = typed_expr .TypedExpr (
295- sge .Column (this = conditions [1 ].expr , table = right_table_name ),
296- conditions [1 ].dtype ,
287+ part1_id = sql .identifier (next (self .uid_gen .get_uid_stream ("bfpart1_" )))
288+ part2_id = sql .identifier (next (self .uid_gen .get_uid_stream ("bfpart2_" )))
289+ left_expr1 , left_expr2 = _value_to_non_null_identity (conditions [0 ])
290+ left_as_struct = sge .Struct (
291+ expressions = [
292+ sge .PropertyEQ (this = part1_id , expression = left_expr1 ),
293+ sge .PropertyEQ (this = part2_id , expression = left_expr2 ),
294+ ]
297295 )
298- new_column = sge .Exists (
299- this = sge .Select ()
300- .select (sge .convert (1 ))
301- .from_ (sge .Alias (this = right_select .subquery (), alias = right_table_name ))
302- .where (
303- _join_condition (left_condition , right_condition , joins_nulls = True )
304- )
296+ right_expr1 , right_expr2 = _value_to_non_null_identity (conditions [1 ])
297+ right_select = right_select .select (
298+ * [
299+ sge .Struct (
300+ expressions = [
301+ sge .PropertyEQ (this = part1_id , expression = right_expr1 ),
302+ sge .PropertyEQ (this = part2_id , expression = right_expr2 ),
303+ ]
304+ )
305+ ],
306+ append = False ,
307+ )
308+
309+ new_column = sge .In (
310+ this = left_as_struct ,
311+ expressions = [right_select .subquery ()],
305312 )
306313 else :
307314 new_column = sge .In (
308- this = left_condition .expr ,
315+ this = conditions [ 0 ] .expr ,
309316 expressions = [right_select .subquery ()],
310317 )
311318
@@ -314,12 +321,7 @@ def isin_join(
314321 alias = sql .identifier (indicator_col ),
315322 )
316323
317- new_expr = (
318- sge .Select ()
319- .select (sge .Column (this = sge .Star (), table = left_from ), new_column )
320- .from_ (left_from )
321- )
322-
324+ new_expr = sge .Select ().select (sge .Star (), new_column ).from_ (left_from )
323325 return SQLGlotIR (expr = new_expr , uid_gen = self .uid_gen )
324326
325327 def explode (
@@ -543,77 +545,48 @@ def _join_condition(
543545 joins_nulls: If True, generates complex logic to handle nulls/NaNs.
544546 Otherwise, uses a simple equality check where appropriate.
545547 """
546- is_floating_types = (
547- left .dtype == dtypes .FLOAT_DTYPE and right .dtype == dtypes .FLOAT_DTYPE
548- )
549- if not is_floating_types and not joins_nulls :
548+ if not joins_nulls :
550549 return sge .EQ (this = left .expr , expression = right .expr )
551-
552- is_numeric_types = dtypes .is_numeric (
553- left .dtype , include_bool = False
554- ) and dtypes .is_numeric (right .dtype , include_bool = False )
555- if is_numeric_types :
556- return _join_condition_for_numeric (left , right )
557- else :
558- return _join_condition_for_others (left , right )
559-
560-
561- def _join_condition_for_others (
562- left : typed_expr .TypedExpr ,
563- right : typed_expr .TypedExpr ,
564- ) -> sge .And :
565- """Generates a join condition for non-numeric types to match pandas's
566- null-handling logic.
567- """
568- left_str = sql .cast (left .expr , "STRING" )
569- right_str = sql .cast (right .expr , "STRING" )
570- left_0 = sge .func ("COALESCE" , left_str , sql .literal ("0" , dtypes .STRING_DTYPE ))
571- left_1 = sge .func ("COALESCE" , left_str , sql .literal ("1" , dtypes .STRING_DTYPE ))
572- right_0 = sge .func ("COALESCE" , right_str , sql .literal ("0" , dtypes .STRING_DTYPE ))
573- right_1 = sge .func ("COALESCE" , right_str , sql .literal ("1" , dtypes .STRING_DTYPE ))
550+ left_expr1 , left_expr2 = _value_to_non_null_identity (left )
551+ right_expr1 , right_expr2 = _value_to_non_null_identity (right )
574552 return sge .And (
575- this = sge .EQ (this = left_0 , expression = right_0 ),
576- expression = sge .EQ (this = left_1 , expression = right_1 ),
553+ this = sge .EQ (this = left_expr1 , expression = right_expr1 ),
554+ expression = sge .EQ (this = left_expr2 , expression = right_expr2 ),
577555 )
578556
579557
580- def _join_condition_for_numeric (
581- left : typed_expr .TypedExpr ,
582- right : typed_expr .TypedExpr ,
583- ) -> sge .And :
584- """Generates a join condition for non-numeric types to match pandas's
585- null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't
586- equal so need to coalesce as well with different constants.
587- """
588- is_floating_types = (
589- left .dtype == dtypes .FLOAT_DTYPE and right .dtype == dtypes .FLOAT_DTYPE
590- )
591- left_0 = sge .func ("COALESCE" , left .expr , sql .literal (0 , left .dtype ))
592- left_1 = sge .func ("COALESCE" , left .expr , sql .literal (1 , left .dtype ))
593- right_0 = sge .func ("COALESCE" , right .expr , sql .literal (0 , right .dtype ))
594- right_1 = sge .func ("COALESCE" , right .expr , sql .literal (1 , right .dtype ))
595- if not is_floating_types :
596- return sge .And (
597- this = sge .EQ (this = left_0 , expression = right_0 ),
598- expression = sge .EQ (this = left_1 , expression = right_1 ),
558+ def _value_to_non_null_identity (
559+ value : typed_expr .TypedExpr ,
560+ ) -> tuple [sge .Expression , sge .Expression ]:
561+ # normal_value -> (normal_value, normal_value)
562+ # null_value -> (0, 1)
563+ # nan_value -> (2, 3)
564+ if dtypes .is_numeric (value .dtype , include_bool = False ):
565+ expr1 = sge .func ("COALESCE" , value .expr , sql .literal (0 , value .dtype ))
566+ expr2 = sge .func ("COALESCE" , value .expr , sql .literal (1 , value .dtype ))
567+ if value .dtype == dtypes .FLOAT_DTYPE :
568+ expr1 = sge .If (
569+ this = sge .IsNan (this = value .expr ),
570+ true = sql .literal (2 , value .dtype ),
571+ false = expr1 ,
572+ )
573+ expr2 = sge .If (
574+ this = sge .IsNan (this = value .expr ),
575+ true = sql .literal (3 , value .dtype ),
576+ false = expr2 ,
577+ )
578+ else : # general case, convert to string and coalesce
579+ expr1 = sge .func (
580+ "COALESCE" ,
581+ sql .cast (value .expr , "STRING" ),
582+ sql .literal ("0" , dtypes .STRING_DTYPE ),
599583 )
600-
601- left_2 = sge .If (
602- this = sge .IsNan (this = left .expr ), true = sql .literal (2 , left .dtype ), false = left_0
603- )
604- left_3 = sge .If (
605- this = sge .IsNan (this = left .expr ), true = sql .literal (3 , left .dtype ), false = left_1
606- )
607- right_2 = sge .If (
608- this = sge .IsNan (this = right .expr ), true = sql .literal (2 , right .dtype ), false = right_0
609- )
610- right_3 = sge .If (
611- this = sge .IsNan (this = right .expr ), true = sql .literal (3 , right .dtype ), false = right_1
612- )
613- return sge .And (
614- this = sge .EQ (this = left_2 , expression = right_2 ),
615- expression = sge .EQ (this = left_3 , expression = right_3 ),
616- )
584+ expr2 = sge .func (
585+ "COALESCE" ,
586+ sql .cast (value .expr , "STRING" ),
587+ sql .literal ("1" , dtypes .STRING_DTYPE ),
588+ )
589+ return expr1 , expr2
617590
618591
619592def _set_query_ctes (
0 commit comments