diff --git a/superset/sql_parse.py b/superset/sql_parse.py index f721f456d093..11e4279aa276 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -19,12 +19,13 @@ from __future__ import annotations +import enum import logging import re import urllib.parse from collections.abc import Iterable, Iterator from dataclasses import dataclass -from typing import Any, cast +from typing import Any, cast, Generic, TypeVar from unittest.mock import Mock import sqlglot @@ -334,89 +335,175 @@ def is_cte(source: exp.Table, scope: Scope) -> bool: return source.name in ctes_in_scope -class SQLScript: +# To avoid unnecessary parsing/formatting of queries, the statement has the concept of +# an "internal representation", which is the AST of the SQL statement. For most of the +# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special +# case: KustoKQL uses a different syntax and there are no Python parsers for it, so we +# store the AST as a string (the original query), and manipulate it with regular +# expressions. +InternalRepresentation = TypeVar("InternalRepresentation") + +# The base type. This helps type checking the `split_query` method correctly, since each +# derived class has a more specific return type (the class itself). This will no longer +# be needed once Python 3.11 is the lowest version supported. See PEP 673 for more +# information: https://peps.python.org/pep-0673/ +TBaseSQLStatement = TypeVar("TBaseSQLStatement") # pylint: disable=invalid-name + + +class BaseSQLStatement(Generic[InternalRepresentation]): """ - A SQL script, with 0+ statements. + Base class for SQL statements. + + The class can be instantiated with a string representation of the query or, for + efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`, + which will split a query in multiple already parsed statements. + + The `engine` parameters comes from the `engine` attribute in a Superset DB engine + spec. """ def __init__( self, - query: str, - engine: str | None = None, + statement: str | InternalRepresentation, + engine: str, ): - dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + self._parsed: InternalRepresentation = ( + self._parse_statement(statement, engine) + if isinstance(statement, str) + else statement + ) + self.engine = engine + self.tables = self._extract_tables_from_statement(self._parsed, self.engine) - self.statements = [ - SQLStatement(statement, engine=engine) - for statement in parse(query, dialect=dialect) - if statement - ] + @classmethod + def split_query( + cls: type[TBaseSQLStatement], + query: str, + engine: str, + ) -> list[TBaseSQLStatement]: + """ + Split a query into multiple instantiated statements. + + This is a helper function to split a full SQL query into multiple + `BaseSQLStatement` instances. It's used by `SQLScript` when instantiating the + statements within a query. + """ + raise NotImplementedError() + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> InternalRepresentation: + """ + Parse a string containing a single SQL statement, and returns the parsed AST. + + Derived classes should not assume that `statement` contains a single statement, + and MUST explicitly validate that. Since this validation is parser dependent the + responsibility is left to the children classes. + """ + raise NotImplementedError() + + @classmethod + def _extract_tables_from_statement( + cls, + parsed: InternalRepresentation, + engine: str, + ) -> set[Table]: + """ + Extract all table references in a given statement. + """ + raise NotImplementedError() def format(self, comments: bool = True) -> str: """ - Pretty-format the SQL query. + Format the statement, optionally ommitting comments. """ - return ";\n".join(statement.format(comments) for statement in self.statements) + raise NotImplementedError() - def get_settings(self) -> dict[str, str]: + def get_settings(self) -> dict[str, str | bool]: """ - Return the settings for the SQL query. + Return any settings set by the statement. - >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'") - >>> statement.get_settings() - {"foo": "'baz'"} + For example, for this statement: + sql> SET foo = 'bar'; + + The method should return `{"foo": "'bar'"}`. Note the single quotes. """ - settings: dict[str, str] = {} - for statement in self.statements: - settings.update(statement.get_settings()) + raise NotImplementedError() - return settings + def __str__(self) -> str: + return self.format() -class SQLStatement: +class SQLStatement(BaseSQLStatement[exp.Expression]): """ A SQL statement. - This class provides helper methods to manipulate and introspect SQL. + This class is used for all engines with dialects that can be parsed using sqlglot. """ def __init__( self, statement: str | exp.Expression, - engine: str | None = None, + engine: str, ): - dialect = SQLGLOT_DIALECTS.get(engine) if engine else None + self._dialect = SQLGLOT_DIALECTS.get(engine) + super().__init__(statement, engine) - if isinstance(statement, str): - try: - self._parsed = self._parse_statement(statement, dialect) - except ParseError as ex: - raise SupersetParseError(statement, engine) from ex - else: - self._parsed = statement + @classmethod + def split_query( + cls, + query: str, + engine: str, + ) -> list[SQLStatement]: + dialect = SQLGLOT_DIALECTS.get(engine) - self._dialect = dialect - self.tables = extract_tables_from_statement(self._parsed, dialect) + try: + statements = sqlglot.parse(query, dialect=dialect) + except sqlglot.errors.ParseError as ex: + raise SupersetParseError("Unable to split query") from ex - @staticmethod + return [cls(statement, engine) for statement in statements if statement] + + @classmethod def _parse_statement( - sql_statement: str, - dialect: Dialects | None, + cls, + statement: str, + engine: str, ) -> exp.Expression: """ Parse a single SQL statement. """ - statements = [ - statement - for statement in sqlglot.parse(sql_statement, dialect=dialect) - if statement - ] + dialect = SQLGLOT_DIALECTS.get(engine) + + # We could parse with `sqlglot.parse_one` to get a single statement, but we need + # to verify that the string contains exactly one statement. + try: + statements = sqlglot.parse(statement, dialect=dialect) + except sqlglot.errors.ParseError as ex: + raise SupersetParseError("Unable to split query") from ex + + statements = [statement for statement in statements if statement] if len(statements) != 1: - raise ValueError("SQLStatement should have exactly one statement") + raise SupersetParseError("SQLStatement should have exactly one statement") return statements[0] + @classmethod + def _extract_tables_from_statement( + cls, + parsed: exp.Expression, + engine: str, + ) -> set[Table]: + """ + Find all referenced tables. + """ + dialect = SQLGLOT_DIALECTS.get(engine) + return extract_tables_from_statement(parsed, dialect) + def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. @@ -424,7 +511,7 @@ def format(self, comments: bool = True) -> str: write = Dialect.get_or_raise(self._dialect) return write.generate(self._parsed, copy=False, comments=comments, pretty=True) - def get_settings(self) -> dict[str, str]: + def get_settings(self) -> dict[str, str | bool]: """ Return the settings for the SQL statement. @@ -440,6 +527,192 @@ def get_settings(self) -> dict[str, str]: } +class KQLSplitState(enum.Enum): + """ + State machine for splitting a KQL query. + + The state machine keeps track of whether we're inside a string or not, so we + don't split the query in a semi-colon that's part of a string. + """ + + OUTSIDE_STRING = enum.auto() + INSIDE_SINGLE_QUOTED_STRING = enum.auto() + INSIDE_DOUBLE_QUOTED_STRING = enum.auto() + INSIDE_MULTILINE_STRING = enum.auto() + + +def split_kql(kql: str) -> list[str]: + """ + Custom function for splitting KQL statements. + """ + statements = [] + state = KQLSplitState.OUTSIDE_STRING + statement_start = 0 + query = kql if kql.endswith(";") else kql + ";" + for i, character in enumerate(query): + if state == KQLSplitState.OUTSIDE_STRING: + if character == ";": + statements.append(query[statement_start:i]) + statement_start = i + 1 + elif character == "'": + state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + elif character == '"': + state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + elif character == "`" and query[i - 2 : i] == "``": + state = KQLSplitState.INSIDE_MULTILINE_STRING + + elif ( + state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + and character == "'" + and query[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + and character == '"' + and query[i - 1] != "\\" + ): + state = KQLSplitState.OUTSIDE_STRING + + elif ( + state == KQLSplitState.INSIDE_MULTILINE_STRING + and character == "`" + and query[i - 2 : i] == "``" + ): + state = KQLSplitState.OUTSIDE_STRING + + return statements + + +class KustoKQLStatement(BaseSQLStatement[str]): + """ + Special class for Kusto KQL. + + Kusto KQL is a SQL-like language, but it's not supported by sqlglot. Queries look + like this: + + StormEvents + | summarize PropertyDamage = sum(DamageProperty) by State + | join kind=innerunique PopulationData on State + | project State, PropertyDamagePerCapita = PropertyDamage / Population + | sort by PropertyDamagePerCapita + + See https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/ for more + details about it. + """ + + @classmethod + def split_query( + cls, + query: str, + engine: str, + ) -> list[KustoKQLStatement]: + """ + Split a query at semi-colons. + + Since we don't have a parser, we use a simple state machine based function. See + https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string + for more information. + """ + return [cls(statement, engine) for statement in split_kql(query)] + + @classmethod + def _parse_statement( + cls, + statement: str, + engine: str, + ) -> str: + if engine != "kustokql": + raise SupersetParseError(f"Invalid engine: {engine}") + + statements = split_kql(statement) + if len(statements) != 1: + raise SupersetParseError("SQLStatement should have exactly one statement") + + return statements[0].strip() + + @classmethod + def _extract_tables_from_statement(cls, parsed: str, engine: str) -> set[Table]: + """ + Extract all tables referenced in the statement. + + StormEvents + | where InjuriesDirect + InjuriesIndirect > 50 + | join (PopulationData) on State + | project State, Population, TotalInjuries = InjuriesDirect + InjuriesIndirect + + """ + logger.warning( + "Kusto KQL doesn't support table extraction. This means that data access " + "roles will not be enforced by Superset in the database." + ) + return set() + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL statement. + """ + return self._parsed + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL statement. + + >>> statement = KustoKQLStatement("set querytrace;") + >>> statement.get_settings() + {"querytrace": True} + + """ + set_regex = r"^set\s+(?P\w+)(?:\s*=\s*(?P\w+))?$" + if match := re.match(set_regex, self._parsed, re.IGNORECASE): + return {match.group("name"): match.group("value") or True} + + return {} + + +class SQLScript: + """ + A SQL script, with 0+ statements. + """ + + # Special engines that can't be parsed using sqlglot. Supporting non-SQL engines + # adds a lot of complexity to Superset, so we should avoid adding new engines to + # this data structure. + special_engines = { + "kustokql": KustoKQLStatement, + } + + def __init__( + self, + query: str, + engine: str, + ): + statement_class = self.special_engines.get(engine, SQLStatement) + self.statements = statement_class.split_query(query, engine) + + def format(self, comments: bool = True) -> str: + """ + Pretty-format the SQL query. + """ + return ";\n".join(statement.format(comments) for statement in self.statements) + + def get_settings(self) -> dict[str, str | bool]: + """ + Return the settings for the SQL query. + + >>> statement = SQLScript("SET foo = 'bar'; SET foo = 'baz'") + >>> statement.get_settings() + {"foo": "'baz'"} + + """ + settings: dict[str, str | bool] = {} + for statement in self.statements: + settings.update(statement.get_settings()) + + return settings + + class ParsedQuery: def __init__( self, diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index aa4171e763fe..79958b074330 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -37,8 +37,10 @@ has_table_query, insert_rls_as_subquery, insert_rls_in_predicate, + KustoKQLStatement, ParsedQuery, sanitize_clause, + split_kql, SQLScript, SQLStatement, strip_comments_from_sql, @@ -1883,21 +1885,31 @@ def test_sqlquery() -> None: """ Test the `SQLScript` class. """ - script = SQLScript("SELECT 1; SELECT 2;") + script = SQLScript("SELECT 1; SELECT 2;", "sqlite") assert len(script.statements) == 2 assert script.format() == "SELECT\n 1;\nSELECT\n 2" assert script.statements[0].format() == "SELECT\n 1" - script = SQLScript("SET a=1; SET a=2; SELECT 3;") + script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite") assert script.get_settings() == {"a": "2"} + query = SQLScript( + """set querytrace; +Events | take 100""", + "kustokql", + ) + assert query.get_settings() == {"querytrace": True} + def test_sqlstatement() -> None: """ Test the `SQLStatement` class. """ - statement = SQLStatement("SELECT * FROM table1 UNION ALL SELECT * FROM table2") + statement = SQLStatement( + "SELECT * FROM table1 UNION ALL SELECT * FROM table2", + "sqlite", + ) assert statement.tables == { Table(table="table1", schema=None, catalog=None), @@ -1908,7 +1920,7 @@ def test_sqlstatement() -> None: == "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2" ) - statement = SQLStatement("SET a=1") + statement = SQLStatement("SET a=1", "sqlite") assert statement.get_settings() == {"a": "1"} @@ -1950,3 +1962,137 @@ def test_extract_tables_from_jinja_sql( extract_tables_from_jinja_sql(sql.format(engine=engine, macro=macro), engine) == expected ) + + +def test_kustokqlstatement_split_query() -> None: + """ + Test the `KustoKQLStatement` split method. + """ + statements = KustoKQLStatement.split_query( + """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day; +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp); +let cachedResult = materialize(materializedScope); +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """, + "kustokql", + ) + assert len(statements) == 4 + + +def test_kustokqlstatement_with_program() -> None: + """ + Test the `KustoKQLStatement` split method when the KQL has a program. + """ + statements = KustoKQLStatement.split_query( + """ +print program = ``` + public class Program { + public static void Main() { + System.Console.WriteLine("Hello!"); + } + }``` + """, + "kustokql", + ) + assert len(statements) == 1 + + +def test_kustokqlstatement_with_set() -> None: + """ + Test the `KustoKQLStatement` split method when the KQL has a set command. + """ + statements = KustoKQLStatement.split_query( + """ +set querytrace; +Events | take 100 + """, + "kustokql", + ) + assert len(statements) == 2 + assert statements[0].format() == "set querytrace" + assert statements[1].format() == "Events | take 100" + + +@pytest.mark.parametrize( + "kql,statements", + [ + ('print banner=strcat("Hello", ", ", "World!")', 1), + (r"print 'O\'Malley\'s'", 1), + (r"print 'O\'Mal;ley\'s'", 1), + ("print ```foo;\nbar;\nbaz;```\n", 1), + ], +) +def test_kustokql_statement_split_special(kql: str, statements: int) -> None: + assert len(KustoKQLStatement.split_query(kql, "kustokql")) == statements + + +def test_split_kql() -> None: + """ + Test the `split_kql` function. + """ + kql = """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day; +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp); +let cachedResult = materialize(materializedScope); +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """ + assert split_kql(kql) == [ + """ +let totalPagesPerDay = PageViews +| summarize by Page, Day = startofday(Timestamp) +| summarize count() by Day""", + """ +let materializedScope = PageViews +| summarize by Page, Day = startofday(Timestamp)""", + """ +let cachedResult = materialize(materializedScope)""", + """ +cachedResult +| project Page, Day1 = Day +| join kind = inner +( + cachedResult + | project Page, Day2 = Day +) +on Page +| where Day2 > Day1 +| summarize count() by Day1, Day2 +| join kind = inner + totalPagesPerDay +on $left.Day1 == $right.Day +| project Day1, Day2, Percentage = count_*100.0/count_1 + """, + ]