Skip to content
32 changes: 32 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,37 @@ def _annotate_reverse(self: TypeAnnotator, expression: exp.Reverse) -> exp.Rever
return expression


def _annotate_decode_case(self: TypeAnnotator, expression: exp.DecodeCase) -> exp.DecodeCase:
"""Annotate DecodeCase with the type inferred from return values only.

DECODE uses the format: DECODE(expr, val1, ret1, val2, ret2, ..., default)
We only look at the return values (ret1, ret2, ..., default) to determine the type,
not the comparison values (val1, val2, ...) or the expression being compared.
"""
self._annotate_args(expression)

expressions = expression.expressions

# Return values are at indices 2, 4, 6, ... and the last element (if even length)
# DECODE(expr, val1, ret1, val2, ret2, ..., default)
return_types = [expressions[i].type for i in range(2, len(expressions), 2)]

# If the total number of expressions is even, the last one is the default
# Example:
# DECODE(x, 1, 'a', 2, 'b') -> len=5 (odd), no default
# DECODE(x, 1, 'a', 2, 'b', 'default') -> len=6 (even), has default
if len(expressions) % 2 == 0:
return_types.append(expressions[-1].type)

# Determine the common type from all return values
last_type = None
for ret_type in return_types:
last_type = self._maybe_coerce(last_type or ret_type, ret_type)

self._set_type(expression, last_type)
return expression


def _annotate_timestamp_from_parts(
self: TypeAnnotator, expression: exp.TimestampFromParts
) -> exp.TimestampFromParts:
Expand Down Expand Up @@ -760,6 +791,7 @@ class Snowflake(Dialect):
exp.TimeAdd: _annotate_date_or_time_add,
exp.GreatestIgnoreNulls: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.LeastIgnoreNulls: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.DecodeCase: _annotate_decode_case,
exp.Reverse: _annotate_reverse,
exp.TimestampFromParts: _annotate_timestamp_from_parts,
}
Expand Down
16 changes: 16 additions & 0 deletions tests/fixtures/optimizer/annotate_functions.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2315,6 +2315,22 @@ DATE;
PREVIOUS_DAY(CAST('2024-05-09 08:50:57' AS TIMESTAMP), 'MONDAY');
DATE;

# dialect: snowflake
DECODE(x, 1, 100, 2, 200, 0);
INT;

# dialect: snowflake
DECODE(status, 'A', 'Active', 'I', 'Inactive', 'Neither');
VARCHAR;

# dialect: snowflake
DECODE(100, 100, 1, 90, 2, 5.5);
DOUBLE;

# dialect: snowflake
DECODE(x, 1, 100, NULL);
INT;

# dialect: snowflake
PI();
DOUBLE;
Expand Down