Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def annotate_scope(self, scope: Scope) -> None:
elif isinstance(source.expression, exp.Unnest):
self._set_type(col, source.expression.type)

if col.type and col.type.args.get("nullable") is False:
col.meta["nonnull"] = True

if isinstance(self.schema, MappingSchema):
for table_column in scope.table_columns:
source = scope.sources.get(table_column.name)
Expand Down Expand Up @@ -446,6 +449,11 @@ def _annotate_binary(self, expression: B) -> B:
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))

if isinstance(expression, exp.Is) or (
left.meta.get("nonnull") is True and right.meta.get("nonnull") is True
):
expression.meta["nonnull"] = True

return expression

def _annotate_unary(self, expression: E) -> E:
Expand All @@ -456,6 +464,9 @@ def _annotate_unary(self, expression: E) -> E:
else:
self._set_type(expression, expression.this.type)

if expression.this.meta.get("nonnull") is True:
expression.meta["nonnull"] = True

return expression

def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
Expand All @@ -466,6 +477,8 @@ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
else:
self._set_type(expression, exp.DataType.Type.DOUBLE)

expression.meta["nonnull"] = True

return expression

def _annotate_with_type(
Expand Down
7 changes: 4 additions & 3 deletions sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,14 +395,15 @@ def remove_complements(expression, root=True):
"""
Removing complements.

A AND NOT A -> FALSE
A OR NOT A -> TRUE
A AND NOT A -> FALSE (only for non-NULL A)
A OR NOT A -> TRUE (only for non-NULL A)
"""
if isinstance(expression, AND_OR) and (root or not expression.same_parent):
ops = set(expression.flatten())
for op in ops:
if isinstance(op, exp.Not) and op.this in ops:
return exp.false() if isinstance(expression, exp.And) else exp.true()
if expression.meta.get("nonnull") is True:
return exp.false() if isinstance(expression, exp.And) else exp.true()

return expression

Expand Down
35 changes: 31 additions & 4 deletions tests/fixtures/optimizer/simplify.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ y OR y;
y;

x AND NOT x;
FALSE;
NOT x AND x;

x OR NOT x;
TRUE;
NOT x OR x;

1 AND TRUE;
TRUE;
Expand Down Expand Up @@ -299,7 +299,7 @@ A XOR D XOR B XOR E XOR F XOR G XOR C;
A XOR B XOR C XOR D XOR E XOR F XOR G;

A AND NOT B AND C AND B;
FALSE;
A AND B AND C AND NOT B;

(a AND b AND c AND d) AND (d AND c AND b AND a);
a AND b AND c AND d;
Expand Down Expand Up @@ -892,7 +892,7 @@ COALESCE(x, 1) = 1;
x = 1 OR x IS NULL;

COALESCE(x, 1) IS NULL;
FALSE;
NOT x IS NULL AND x IS NULL;

COALESCE(ROW() OVER (), 1) = 1;
ROW() OVER () = 1 OR ROW() OVER () IS NULL;
Expand Down Expand Up @@ -1344,3 +1344,30 @@ WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0
# dialect: sqlite
WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT (NOT(CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END)) AS res FROM t0;
WITH t0 AS (SELECT 1 AS a, 'foo' AS p) SELECT NOT NOT CASE WHEN t0.a > 1 THEN t0.a ELSE t0.p END AS res FROM t0;

--------------------------------------
-- Simplify complements
--------------------------------------
TRUE OR NOT TRUE;
TRUE;

TRUE AND NOT TRUE;
FALSE;

'a' OR NOT 'a';
TRUE;

'a' AND NOT 'a';
FALSE;

100 OR NOT 100;
TRUE;

100 AND NOT 100;
FALSE;

NULL OR NOT NULL;
NULL;

NULL AND NOT NULL;
NULL;
55 changes: 55 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1699,3 +1699,58 @@ def test_annotate_object_construct(self):
self.assertEqual(
annotated.selects[0].type.sql("snowflake"), 'OBJECT("foo" VARCHAR, "a b" VARCHAR)'
)

def test_nonnull_annotation(self):
for literal_sql in ("1", "'foo'", "2.5"):
with self.subTest(f"Test NULL annotation for literal: {literal_sql}"):
sql = f"SELECT {literal_sql}"
query = parse_one(sql)
annotated = annotate_types(query)
assert annotated.selects[0].meta.get("nonnull") is True

schema = {"foo": {"id": "INT"}}

operand_pairs = (
("1", "1", True),
("foo.id", "foo.id", None),
("1", "foo.id", None),
("foo.id", "1", None),
)

for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"):
for operand1, operand2, nonnull in operand_pairs:
sql_predicate = f"{operand1} {predicate} {operand2}"
with self.subTest(f"Test NULL propagation for predicate: {predicate}"):
sql = f"SELECT {sql_predicate} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nonnull") is nonnull

for predicate in ("IS NULL", "IS NOT NULL"):
sql_predicate = f"foo.id {predicate}"
with self.subTest(f"Test NULL propagation for predicate: {predicate}"):
sql = f"SELECT {sql_predicate} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nonnull") is True

for connector in ("AND", "OR"):
for predicate in (">", "<", ">=", "<=", "=", "!=", "<>", "LIKE", "NOT LIKE"):
for operand1, operand2, nonnull in operand_pairs:
sql_predicate = f"({operand1} {predicate} {operand2})"
sql_connector = f"{sql_predicate} {connector} {sql_predicate}"
with self.subTest(
f"Test NULL propagation for connector: {connector} with predicates: {predicate}"
):
sql = f"SELECT {sql_connector} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nonnull") is nonnull

for unary in ("NOT", "-"):
for value, nonnull in (("1", True), ("foo.id", None)):
with self.subTest(f"Test NULL propagation for unary: {unary} with value: {value}"):
sql = f"SELECT {unary} {value} FROM foo"
query = parse_one(sql)
annotated = annotate_types(query, schema=schema)
assert annotated.selects[0].meta.get("nonnull") is nonnull