Skip to content

Commit 069dd09

Browse files
fix isin join logic
1 parent 74b2470 commit 069dd09

File tree

3 files changed

+69
-132
lines changed
  • bigframes/core/compile/sqlglot
  • tests/unit/core/compile/sqlglot/snapshots/test_compile_isin

3 files changed

+69
-132
lines changed

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 62 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -280,32 +280,39 @@ def isin_join(
280280
) -> SQLGlotIR:
281281
"""Joins the current query with another SQLGlotIR instance."""
282282
left_from = self._as_from_item()
283-
# Prefer subquery over CTE for the IN clause's right side to improve SQL readability.
284283
right_select = right._as_select()
285284

286-
left_condition = typed_expr.TypedExpr(
287-
sge.Column(this=conditions[0].expr, table=left_from),
288-
conditions[0].dtype,
289-
)
290-
291285
new_column: sge.Expression
292286
if joins_nulls:
293-
right_table_name = sql.identifier(next(self.uid_gen.get_uid_stream("bft_")))
294-
right_condition = typed_expr.TypedExpr(
295-
sge.Column(this=conditions[1].expr, table=right_table_name),
296-
conditions[1].dtype,
287+
part1_id = sql.identifier(next(self.uid_gen.get_uid_stream("bfpart1_")))
288+
part2_id = sql.identifier(next(self.uid_gen.get_uid_stream("bfpart2_")))
289+
left_expr1, left_expr2 = _value_to_non_null_identity(conditions[0])
290+
left_as_struct = sge.Struct(
291+
expressions=[
292+
sge.PropertyEQ(this=part1_id, expression=left_expr1),
293+
sge.PropertyEQ(this=part2_id, expression=left_expr2),
294+
]
297295
)
298-
new_column = sge.Exists(
299-
this=sge.Select()
300-
.select(sge.convert(1))
301-
.from_(sge.Alias(this=right_select.subquery(), alias=right_table_name))
302-
.where(
303-
_join_condition(left_condition, right_condition, joins_nulls=True)
304-
)
296+
right_expr1, right_expr2 = _value_to_non_null_identity(conditions[1])
297+
right_select = right_select.select(
298+
*[
299+
sge.Struct(
300+
expressions=[
301+
sge.PropertyEQ(this=part1_id, expression=right_expr1),
302+
sge.PropertyEQ(this=part2_id, expression=right_expr2),
303+
]
304+
)
305+
],
306+
append=False,
307+
)
308+
309+
new_column = sge.In(
310+
this=left_as_struct,
311+
expressions=[right_select.subquery()],
305312
)
306313
else:
307314
new_column = sge.In(
308-
this=left_condition.expr,
315+
this=conditions[0].expr,
309316
expressions=[right_select.subquery()],
310317
)
311318

@@ -314,12 +321,7 @@ def isin_join(
314321
alias=sql.identifier(indicator_col),
315322
)
316323

317-
new_expr = (
318-
sge.Select()
319-
.select(sge.Column(this=sge.Star(), table=left_from), new_column)
320-
.from_(left_from)
321-
)
322-
324+
new_expr = sge.Select().select(sge.Star(), new_column).from_(left_from)
323325
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)
324326

325327
def explode(
@@ -543,77 +545,48 @@ def _join_condition(
543545
joins_nulls: If True, generates complex logic to handle nulls/NaNs.
544546
Otherwise, uses a simple equality check where appropriate.
545547
"""
546-
is_floating_types = (
547-
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
548-
)
549-
if not is_floating_types and not joins_nulls:
548+
if not joins_nulls:
550549
return sge.EQ(this=left.expr, expression=right.expr)
551-
552-
is_numeric_types = dtypes.is_numeric(
553-
left.dtype, include_bool=False
554-
) and dtypes.is_numeric(right.dtype, include_bool=False)
555-
if is_numeric_types:
556-
return _join_condition_for_numeric(left, right)
557-
else:
558-
return _join_condition_for_others(left, right)
559-
560-
561-
def _join_condition_for_others(
562-
left: typed_expr.TypedExpr,
563-
right: typed_expr.TypedExpr,
564-
) -> sge.And:
565-
"""Generates a join condition for non-numeric types to match pandas's
566-
null-handling logic.
567-
"""
568-
left_str = sql.cast(left.expr, "STRING")
569-
right_str = sql.cast(right.expr, "STRING")
570-
left_0 = sge.func("COALESCE", left_str, sql.literal("0", dtypes.STRING_DTYPE))
571-
left_1 = sge.func("COALESCE", left_str, sql.literal("1", dtypes.STRING_DTYPE))
572-
right_0 = sge.func("COALESCE", right_str, sql.literal("0", dtypes.STRING_DTYPE))
573-
right_1 = sge.func("COALESCE", right_str, sql.literal("1", dtypes.STRING_DTYPE))
550+
left_expr1, left_expr2 = _value_to_non_null_identity(left)
551+
right_expr1, right_expr2 = _value_to_non_null_identity(right)
574552
return sge.And(
575-
this=sge.EQ(this=left_0, expression=right_0),
576-
expression=sge.EQ(this=left_1, expression=right_1),
553+
this=sge.EQ(this=left_expr1, expression=right_expr1),
554+
expression=sge.EQ(this=left_expr2, expression=right_expr2),
577555
)
578556

579557

580-
def _join_condition_for_numeric(
581-
left: typed_expr.TypedExpr,
582-
right: typed_expr.TypedExpr,
583-
) -> sge.And:
584-
"""Generates a join condition for non-numeric types to match pandas's
585-
null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't
586-
equal so need to coalesce as well with different constants.
587-
"""
588-
is_floating_types = (
589-
left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE
590-
)
591-
left_0 = sge.func("COALESCE", left.expr, sql.literal(0, left.dtype))
592-
left_1 = sge.func("COALESCE", left.expr, sql.literal(1, left.dtype))
593-
right_0 = sge.func("COALESCE", right.expr, sql.literal(0, right.dtype))
594-
right_1 = sge.func("COALESCE", right.expr, sql.literal(1, right.dtype))
595-
if not is_floating_types:
596-
return sge.And(
597-
this=sge.EQ(this=left_0, expression=right_0),
598-
expression=sge.EQ(this=left_1, expression=right_1),
558+
def _value_to_non_null_identity(
559+
value: typed_expr.TypedExpr,
560+
) -> tuple[sge.Expression, sge.Expression]:
561+
# normal_value -> (normal_value, normal_value)
562+
# null_value -> (0, 1)
563+
# nan_value -> (2, 3)
564+
if dtypes.is_numeric(value.dtype, include_bool=False):
565+
expr1 = sge.func("COALESCE", value.expr, sql.literal(0, value.dtype))
566+
expr2 = sge.func("COALESCE", value.expr, sql.literal(1, value.dtype))
567+
if value.dtype == dtypes.FLOAT_DTYPE:
568+
expr1 = sge.If(
569+
this=sge.IsNan(this=value.expr),
570+
true=sql.literal(2, value.dtype),
571+
false=expr1,
572+
)
573+
expr2 = sge.If(
574+
this=sge.IsNan(this=value.expr),
575+
true=sql.literal(3, value.dtype),
576+
false=expr2,
577+
)
578+
else: # general case, convert to string and coalesce
579+
expr1 = sge.func(
580+
"COALESCE",
581+
sql.cast(value.expr, "STRING"),
582+
sql.literal("0", dtypes.STRING_DTYPE),
599583
)
600-
601-
left_2 = sge.If(
602-
this=sge.IsNan(this=left.expr), true=sql.literal(2, left.dtype), false=left_0
603-
)
604-
left_3 = sge.If(
605-
this=sge.IsNan(this=left.expr), true=sql.literal(3, left.dtype), false=left_1
606-
)
607-
right_2 = sge.If(
608-
this=sge.IsNan(this=right.expr), true=sql.literal(2, right.dtype), false=right_0
609-
)
610-
right_3 = sge.If(
611-
this=sge.IsNan(this=right.expr), true=sql.literal(3, right.dtype), false=right_1
612-
)
613-
return sge.And(
614-
this=sge.EQ(this=left_2, expression=right_2),
615-
expression=sge.EQ(this=left_3, expression=right_3),
616-
)
584+
expr2 = sge.func(
585+
"COALESCE",
586+
sql.cast(value.expr, "STRING"),
587+
sql.literal("1", dtypes.STRING_DTYPE),
588+
)
589+
return expr1, expr2
617590

618591

619592
def _set_query_ctes(

tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin/out.sql

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,45 +3,19 @@ SELECT
33
`bfcol_5` AS `int64_col`
44
FROM (
55
SELECT
6-
(
7-
SELECT
8-
`rowindex` AS `bfcol_2`,
9-
`int64_col` AS `bfcol_3`
10-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
11-
).*,
12-
EXISTS(
13-
SELECT
14-
1
15-
FROM (
6+
*,
7+
STRUCT(COALESCE(`bfcol_3`, 0) AS `bfpart1_0`, COALESCE(`bfcol_3`, 1) AS `bfpart2_0`) IN (
8+
(
169
SELECT
17-
`int64_too` AS `bfcol_4`
10+
STRUCT(COALESCE(`bfcol_4`, 0) AS `bfpart1_0`, COALESCE(`bfcol_4`, 1) AS `bfpart2_0`)
1811
FROM (
1912
SELECT
2013
`int64_too`
2114
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
2215
GROUP BY
2316
`int64_too`
2417
)
25-
) AS `bft_1`
26-
WHERE
27-
COALESCE(
28-
(
29-
SELECT
30-
`rowindex` AS `bfcol_2`,
31-
`int64_col` AS `bfcol_3`
32-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
33-
).`bfcol_3`,
34-
0
35-
) = COALESCE(`bft_1`.`bfcol_4`, 0)
36-
AND COALESCE(
37-
(
38-
SELECT
39-
`rowindex` AS `bfcol_2`,
40-
`int64_col` AS `bfcol_3`
41-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
42-
).`bfcol_3`,
43-
1
44-
) = COALESCE(`bft_1`.`bfcol_4`, 1)
18+
)
4519
) AS `bfcol_5`
4620
FROM (
4721
SELECT

tests/unit/core/compile/sqlglot/snapshots/test_compile_isin/test_compile_isin_not_nullable/out.sql

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,8 @@ SELECT
33
`bfcol_5` AS `rowindex_2`
44
FROM (
55
SELECT
6-
(
7-
SELECT
8-
`rowindex` AS `bfcol_2`,
9-
`rowindex_2` AS `bfcol_3`
10-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
11-
).*,
12-
(
13-
SELECT
14-
`rowindex` AS `bfcol_2`,
15-
`rowindex_2` AS `bfcol_3`
16-
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` AS `bft_0`
17-
).`bfcol_3` IN (
6+
*,
7+
`bfcol_3` IN (
188
(
199
SELECT
2010
`rowindex_2` AS `bfcol_4`

0 commit comments

Comments
 (0)