From 19dc1791a269098d3b9a686f142997c17b47fea3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Mar 2026 06:20:55 +0000 Subject: [PATCH 1/2] refactor: use sqlglot to build literal --- bigframes/bigquery/_operations/ai.py | 27 ++- bigframes/core/compile/sqlglot/sql/base.py | 19 ++- bigframes/core/pyformat.py | 8 +- bigframes/core/sql/__init__.py | 93 +++-------- bigframes/core/sql/literals.py | 58 ------- bigframes/core/sql/ml.py | 8 +- bigframes/dtypes.py | 30 +++- bigframes/functions/function.py | 8 +- bigframes/ml/sql.py | 5 +- bigframes/session/_io/bigquery/__init__.py | 6 +- .../core/compile/sqlglot/sql/test_base.py | 156 ++++++++++++++++++ tests/unit/core/test_sql.py | 121 +------------- 12 files changed, 245 insertions(+), 294 deletions(-) delete mode 100644 bigframes/core/sql/literals.py create mode 100644 tests/unit/core/compile/sqlglot/sql/test_base.py diff --git a/bigframes/bigquery/_operations/ai.py b/bigframes/bigquery/_operations/ai.py index dd9c4e236b1..bc28cb2e353 100644 --- a/bigframes/bigquery/_operations/ai.py +++ b/bigframes/bigquery/_operations/ai.py @@ -28,8 +28,9 @@ from bigframes import series, session from bigframes.bigquery._operations import utils as bq_utils from bigframes.core import convert +from bigframes.core.compile.sqlglot import sql as sg_sql from bigframes.core.logging import log_adapter -import bigframes.core.sql.literals +from bigframes.ml import base as ml_base from bigframes.ml import core as ml_core from bigframes.operations import ai_ops, output_schemas @@ -392,7 +393,7 @@ def generate_double( @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_embedding( - model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + model: Union[ml_base.BaseEstimator, str, pd.Series], data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], *, output_dimensionality: Optional[int] = None, @@ -416,7 +417,7 @@ def generate_embedding( ... ) # doctest: +SKIP Args: - model (bigframes.ml.base.BaseEstimator or str): + model (ml_base.BaseEstimator or str): The model to use for text embedding. data (bigframes.pandas.DataFrame or bigframes.pandas.Series): The data to generate embeddings for. If a Series is provided, it is @@ -458,7 +459,7 @@ def generate_embedding( model_name, session = bq_utils.get_model_name_and_session(model, data) table_sql = bq_utils.to_sql(data) - struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {} + struct_fields: Dict[str, Any] = {} if output_dimensionality is not None: struct_fields["OUTPUT_DIMENSIONALITY"] = output_dimensionality if task_type is not None: @@ -478,7 +479,7 @@ def generate_embedding( FROM AI.GENERATE_EMBEDDING( MODEL `{model_name}`, ({table_sql}), - {bigframes.core.sql.literals.struct_literal(struct_fields)} + {sg_sql.to_sql(sg_sql.literal(struct_fields))} ) """ @@ -490,7 +491,7 @@ def generate_embedding( @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_text( - model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + model: Union[ml_base.BaseEstimator, str, pd.Series], data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], *, temperature: Optional[float] = None, @@ -519,7 +520,7 @@ def generate_text( ... ) # doctest: +SKIP Args: - model (bigframes.ml.base.BaseEstimator or str): + model (ml_base.BaseEstimator or str): The model to use for text generation. data (bigframes.pandas.DataFrame or bigframes.pandas.Series): The data to generate text for. If a Series is provided, it is @@ -591,7 +592,7 @@ def generate_text( FROM AI.GENERATE_TEXT( MODEL `{model_name}`, ({table_sql}), - {bigframes.core.sql.literals.struct_literal(struct_fields)} + {sg_sql.to_sql(sg_sql.literal(struct_fields))} ) """ @@ -603,7 +604,7 @@ def generate_text( @log_adapter.method_logger(custom_base_name="bigquery_ai") def generate_table( - model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series], + model: Union[ml_base.BaseEstimator, str, pd.Series], data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series], *, output_schema: Union[str, Mapping[str, str]], @@ -635,7 +636,7 @@ def generate_table( ... ) # doctest: +SKIP Args: - model (bigframes.ml.base.BaseEstimator or str): + model (ml_base.BaseEstimator or str): The model to use for table generation. data (bigframes.pandas.DataFrame or bigframes.pandas.Series): The data to generate table for. If a Series is provided, it is @@ -677,9 +678,7 @@ def generate_table( else: output_schema_str = output_schema - struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = { - "output_schema": output_schema_str - } + struct_fields_bq: Dict[str, Any] = {"output_schema": output_schema_str} if temperature is not None: struct_fields_bq["temperature"] = temperature if top_p is not None: @@ -691,7 +690,7 @@ def generate_table( if request_type is not None: struct_fields_bq["request_type"] = request_type - struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq) + struct_sql = sg_sql.to_sql(sg_sql.literal(struct_fields_bq)) query = f""" SELECT * FROM AI.GENERATE_TABLE( diff --git a/bigframes/core/compile/sqlglot/sql/base.py b/bigframes/core/compile/sqlglot/sql/base.py index 86e2153dd7e..6e888fdf5e8 100644 --- a/bigframes/core/compile/sqlglot/sql/base.py +++ b/bigframes/core/compile/sqlglot/sql/base.py @@ -57,8 +57,11 @@ def identifier(id: str) -> sge.Identifier: return sge.to_identifier(id, quoted=QUOTED) -def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: +def literal(value: typing.Any, dtype: dtypes.Dtype | None = None) -> sge.Expression: """Return a string representing column reference in a SQL.""" + if dtype is None: + dtype = dtypes.infer_literal_type(value) + sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None if sqlglot_type is None: if not pd.isna(value): @@ -81,6 +84,14 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: expressions=[literal(value=v, dtype=value_type) for v in value] ) return values if len(value) > 0 else cast(values, sqlglot_type) + elif dtype == dtypes.FLOAT_DTYPE: + if pd.isna(value): + if isinstance(value, (float, np.floating)) and np.isnan(value): + return constants._NAN + return cast(sge.Null(), sqlglot_type) + if np.isinf(value): + return constants._INF if value > 0 else constants._NEG_INF + return sge.convert(value) elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid): return cast(sge.Null(), sqlglot_type) elif dtype == dtypes.JSON_DTYPE: @@ -100,13 +111,11 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) elif dtype == dtypes.TIMEDELTA_DTYPE: return sge.convert(utils.timedelta_to_micros(value)) - elif dtype == dtypes.FLOAT_DTYPE: - if np.isinf(value): - return constants._INF if value > 0 else constants._NEG_INF - return sge.convert(value) else: if isinstance(value, np.generic): value = value.item() + if isinstance(value, pa.Scalar): + value = value.as_py() return sge.convert(value) diff --git a/bigframes/core/pyformat.py b/bigframes/core/pyformat.py index 8f49556ff4c..1dbb74fbb72 100644 --- a/bigframes/core/pyformat.py +++ b/bigframes/core/pyformat.py @@ -89,7 +89,7 @@ def _field_to_template_value( dry_run: bool = False, ) -> str: """Convert value to something embeddable in a SQL string.""" - import bigframes.core.sql # Avoid circular imports + import bigframes.core.compile.sqlglot.sql as sql # Avoid circular imports import bigframes.dataframe # Avoid circular imports _validate_type(name, value) @@ -107,20 +107,20 @@ def _field_to_template_value( if isinstance(value, str): return value - return bigframes.core.sql.simple_literal(value) + return sql.to_sql(sql.literal(value)) def _validate_type(name: str, value: Any): """Raises TypeError if value is unsupported.""" - import bigframes.core.sql # Avoid circular imports import bigframes.dataframe # Avoid circular imports + import bigframes.dtypes # Avoid circular imports if value is None: return # None can't be used in isinstance, but is a valid literal. supported_types = ( typing.get_args(_BQ_TABLE_TYPES) - + typing.get_args(bigframes.core.sql.SIMPLE_LITERAL_TYPES) + + bigframes.dtypes.SUPPORTED_LITERAL_TYPES + (bigframes.dataframe.DataFrame,) + (pandas.DataFrame,) ) diff --git a/bigframes/core/sql/__init__.py b/bigframes/core/sql/__init__.py index e17830042db..8c9a093802c 100644 --- a/bigframes/core/sql/__init__.py +++ b/bigframes/core/sql/__init__.py @@ -17,14 +17,19 @@ Utility functions for SQL construction. """ -import datetime -import decimal import json -import math -from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union +from typing import ( + Any, + cast, + Collection, + Iterable, + Mapping, + Optional, + TYPE_CHECKING, + Union, +) import bigframes_vendored.sqlglot.expressions as sge -import shapely.geometry.base # type: ignore from bigframes.core.compile.sqlglot import sql @@ -43,68 +48,8 @@ to_wkt = dumps -SIMPLE_LITERAL_TYPES = Union[ - bytes, - str, - int, - bool, - float, - datetime.datetime, - datetime.date, - datetime.time, - decimal.Decimal, - list, -] - - -### Writing SQL Values (literals, column references, table references, etc.) -def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str: - """Return quoted input string.""" - - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals - if value is None: - return "NULL" - elif isinstance(value, str): - # Single quoting seems to work nicer with ibis than double quoting - return f"'{sql.escape_chars(value)}'" - elif isinstance(value, bytes): - return repr(value) - elif isinstance(value, (bool, int)): - return str(value) - elif isinstance(value, float): - # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals - if math.isnan(value): - return 'CAST("nan" as FLOAT)' - if value == math.inf: - return 'CAST("+inf" as FLOAT)' - if value == -math.inf: - return 'CAST("-inf" as FLOAT)' - return str(value) - # Check datetime first as it is a subclass of date - elif isinstance(value, datetime.datetime): - if value.tzinfo is None: - return f"DATETIME('{value.isoformat()}')" - else: - return f"TIMESTAMP('{value.isoformat()}')" - elif isinstance(value, datetime.date): - return f"DATE('{value.isoformat()}')" - elif isinstance(value, datetime.time): - return f"TIME(DATETIME('1970-01-01 {value.isoformat()}'))" - elif isinstance(value, shapely.geometry.base.BaseGeometry): - return f"ST_GEOGFROMTEXT({simple_literal(to_wkt(value))})" - elif isinstance(value, decimal.Decimal): - # TODO: disambiguate BIGNUMERIC based on scale and/or precision - return f"CAST('{str(value)}' AS NUMERIC)" - elif isinstance(value, list): - simple_literals = [simple_literal(i) for i in value] - return f"[{', '.join(simple_literals)}]" - - else: - raise ValueError(f"Cannot produce literal for {value}") - - -def multi_literal(*values: str): - literal_strings = [simple_literal(i) for i in values] +def multi_literal(*values: Any): + literal_strings = [sql.to_sql(sql.literal(i)) for i in values] return "(" + ", ".join(literal_strings) + ")" @@ -210,7 +155,7 @@ def create_vector_index_ddl( rendered_options = ", ".join( [ - f"{option_name} = {simple_literal(option_value)}" + f"{option_name} = {sql.to_sql(sql.literal(option_value))}" for option_name, option_value in options.items() ] ) @@ -237,24 +182,26 @@ def create_vector_search_sql( vector_search_args = [ f"TABLE {sql.to_sql(sql.identifier(cast(str, base_table)))}", - f"{simple_literal(column_to_search)}", + f"{sql.to_sql(sql.literal(column_to_search))}", f"({sql_string})", ] if query_column_to_search is not None: vector_search_args.append( - f"query_column_to_search => {simple_literal(query_column_to_search)}" + f"query_column_to_search => {sql.to_sql(sql.literal(query_column_to_search))}" ) if top_k is not None: - vector_search_args.append(f"top_k=> {simple_literal(top_k)}") + vector_search_args.append(f"top_k=> {sql.to_sql(sql.literal(top_k))}") if distance_type is not None: - vector_search_args.append(f"distance_type => {simple_literal(distance_type)}") + vector_search_args.append( + f"distance_type => {sql.to_sql(sql.literal(distance_type))}" + ) if options is not None: vector_search_args.append( - f"options => {simple_literal(json.dumps(options, indent=None))}" + f"options => {sql.to_sql(sql.literal(json.dumps(options, indent=None)))}" ) args_str = ",\n".join(vector_search_args) diff --git a/bigframes/core/sql/literals.py b/bigframes/core/sql/literals.py deleted file mode 100644 index 59c81977315..00000000000 --- a/bigframes/core/sql/literals.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import collections.abc -import json -from typing import Any, List, Mapping, Union - -import bigframes.core.sql - -STRUCT_VALUES = Union[ - str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any] -] -STRUCT_TYPE = Mapping[str, STRUCT_VALUES] - - -def struct_literal(struct_options: STRUCT_TYPE) -> str: - rendered_options = [] - for option_name, option_value in struct_options.items(): - if option_name == "model_params": - json_str = json.dumps(option_value) - # Escape single quotes for SQL string literal - sql_json_str = json_str.replace("'", "''") - rendered_val = f"JSON'{sql_json_str}'" - elif isinstance(option_value, collections.abc.Mapping): - struct_body = ", ".join( - [ - f"{bigframes.core.sql.simple_literal(v)} AS {k}" - for k, v in option_value.items() - ] - ) - rendered_val = f"STRUCT({struct_body})" - elif isinstance(option_value, list): - rendered_val = ( - "[" - + ", ".join( - [bigframes.core.sql.simple_literal(v) for v in option_value] - ) - + "]" - ) - elif isinstance(option_value, bool): - rendered_val = str(option_value).lower() - else: - rendered_val = bigframes.core.sql.simple_literal(option_value) - rendered_options.append(f"{rendered_val} AS {option_name}") - return f"STRUCT({', '.join(rendered_options)})" diff --git a/bigframes/core/sql/ml.py b/bigframes/core/sql/ml.py index 0edb784c37e..93ccca6aa15 100644 --- a/bigframes/core/sql/ml.py +++ b/bigframes/core/sql/ml.py @@ -17,8 +17,6 @@ from typing import Any, Dict, List, Mapping, Optional, Union from bigframes.core.compile.sqlglot import sql as sg_sql -import bigframes.core.sql -import bigframes.core.sql.literals def create_model_ddl( @@ -76,9 +74,9 @@ def create_model_ddl( # Handle list options like model_registry="vertex_ai" # wait, usually options are key=value. # if value is list, it is [val1, val2] - rendered_val = bigframes.core.sql.simple_literal(list(option_value)) + rendered_val = sg_sql.to_sql(sg_sql.literal(list(option_value))) else: - rendered_val = bigframes.core.sql.simple_literal(option_value) + rendered_val = sg_sql.to_sql(sg_sql.literal(option_value)) rendered_options.append(f"{option_name} = {rendered_val}") @@ -108,7 +106,7 @@ def _build_struct_sql( ) -> str: if not struct_options: return "" - return f", {bigframes.core.sql.literals.struct_literal(struct_options)}" + return f", {sg_sql.to_sql(sg_sql.literal(struct_options))}" def evaluate( diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index a2abe9b817a..6b875a97d22 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -118,6 +118,21 @@ ] LOCAL_SCALAR_TYPES = typing.get_args(LOCAL_SCALAR_TYPE) +SUPPORTED_LITERAL_TYPE = typing.Union[ + bytes, + str, + int, + bool, + float, + datetime.datetime, + datetime.date, + datetime.time, + decimal.Decimal, + list, + shapely.geometry.base.BaseGeometry, +] +SUPPORTED_LITERAL_TYPES = typing.get_args(SUPPORTED_LITERAL_TYPE) + # Will have a few dtype variants: simple(eg. int, string, bool), complex (eg. list, struct), and virtual (eg. micro intervals, categorical) @dataclass(frozen=True) @@ -900,11 +915,16 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]: def lcd_type(*dtypes: Dtype) -> Dtype: if len(dtypes) < 1: raise ValueError("at least one dypes should be provided") - if len(dtypes) == 1: - return dtypes[0] + unique_dtypes = set(dtypes) + if None in unique_dtypes: + unique_dtypes.remove(None) + + if len(unique_dtypes) == 0: + return None if len(unique_dtypes) == 1: - return unique_dtypes.pop() + return next(iter(unique_dtypes)) + # Implicit conversion currently only supported for numeric types hierarchy: list[Dtype] = [ BOOL_DTYPE, @@ -913,9 +933,9 @@ def lcd_type(*dtypes: Dtype) -> Dtype: BIGNUMERIC_DTYPE, FLOAT_DTYPE, ] - if any([dtype not in hierarchy for dtype in dtypes]): + if any([dtype not in hierarchy for dtype in unique_dtypes]): return None - lcd_index = max([hierarchy.index(dtype) for dtype in dtypes]) + lcd_index = max([hierarchy.index(dtype) for dtype in unique_dtypes]) return hierarchy[lcd_index] diff --git a/bigframes/functions/function.py b/bigframes/functions/function.py index 242daf7525d..4e06cb16633 100644 --- a/bigframes/functions/function.py +++ b/bigframes/functions/function.py @@ -214,10 +214,10 @@ def __call__(self, *args, **kwargs): if self._local_fun: return self._local_fun(*args, **kwargs) # avoid circular imports - import bigframes.core.sql as bf_sql + from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.session._io.bigquery as bf_io_bigquery - args_string = ", ".join(map(bf_sql.simple_literal, args)) + args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args]) sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})" iter, job = bf_io_bigquery.start_query_with_client( self._session.bqclient, @@ -298,10 +298,10 @@ def __call__(self, *args, **kwargs): if self._local_fun: return self._local_fun(*args, **kwargs) # avoid circular imports - import bigframes.core.sql as bf_sql + from bigframes.core.compile.sqlglot import sql as sg_sql import bigframes.session._io.bigquery as bf_io_bigquery - args_string = ", ".join(map(bf_sql.simple_literal, args)) + args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args]) sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})" iter, job = bf_io_bigquery.start_query_with_client( self._session.bqclient, diff --git a/bigframes/ml/sql.py b/bigframes/ml/sql.py index be9055e9568..894fc44b1b3 100644 --- a/bigframes/ml/sql.py +++ b/bigframes/ml/sql.py @@ -22,7 +22,6 @@ import google.cloud.bigquery from bigframes.core.compile.sqlglot import sql as sg_sql -import bigframes.core.sql as sql_vals INDENT_STR = " " @@ -35,7 +34,7 @@ class BaseSqlGenerator: def encode_value(self, v: Union[str, int, float, Iterable[str]]) -> str: """Encode a parameter value for SQL""" if isinstance(v, (str, int, float)): - return sql_vals.simple_literal(v) + return sg_sql.to_sql(sg_sql.literal(v)) elif isinstance(v, Iterable): inner = ", ".join([self.encode_value(x) for x in v]) return f"[{inner}]" @@ -62,7 +61,7 @@ def build_structs(self, **kwargs: Union[int, float, str, Mapping]) -> str: v_trans = self.build_schema(**v) if isinstance(v, Mapping) else v param_strs.append( - f"{sql_vals.simple_literal(v_trans)} AS {sg_sql.to_sql(sg_sql.identifier(k))}" + f"{sg_sql.to_sql(sg_sql.literal(v_trans))} AS {sg_sql.to_sql(sg_sql.identifier(k))}" ) return "\n" + INDENT_STR + f",\n{INDENT_STR}".join(param_strs) diff --git a/bigframes/session/_io/bigquery/__init__.py b/bigframes/session/_io/bigquery/__init__.py index a9abf6602d4..61b22d03115 100644 --- a/bigframes/session/_io/bigquery/__init__.py +++ b/bigframes/session/_io/bigquery/__init__.py @@ -534,12 +534,12 @@ def to_query( time_travel_clause = "" if time_travel_timestamp is not None: - time_travel_literal = bigframes.core.sql.simple_literal(time_travel_timestamp) + time_travel_literal = sg_sql.to_sql(sg_sql.literal(time_travel_timestamp)) time_travel_clause = f" FOR SYSTEM_TIME AS OF {time_travel_literal}" limit_clause = "" if max_results is not None: - limit_clause = f" LIMIT {bigframes.core.sql.simple_literal(max_results)}" + limit_clause = f" LIMIT {sg_sql.to_sql(sg_sql.literal(max_results))}" where_clause = f" WHERE {sql_predicate}" if sql_predicate else "" @@ -603,7 +603,7 @@ def compile_filters(filters: third_party_pandas_gbq.FiltersType) -> str: if operator_str in ["IN", "NOT IN"]: value_literal = bigframes.core.sql.multi_literal(*value) else: - value_literal = bigframes.core.sql.simple_literal(value) + value_literal = sg_sql.to_sql(sg_sql.literal(value)) expression = bigframes.core.sql.infix_op( operator_str, column_ref, value_literal ) diff --git a/tests/unit/core/compile/sqlglot/sql/test_base.py b/tests/unit/core/compile/sqlglot/sql/test_base.py new file mode 100644 index 00000000000..d11fddc954c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/sql/test_base.py @@ -0,0 +1,156 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import decimal + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest +import shapely.geometry # type: ignore + +import bigframes.core.compile.sqlglot.sql.base as sql + + +@pytest.mark.parametrize( + ("value", "expected_pattern"), + ( + pytest.param(None, "NULL", id="null"), + pytest.param(True, "TRUE", id="true"), + pytest.param(False, "FALSE", id="false"), + pytest.param(123, "123", id="int"), + pytest.param(123.75, "123.75", id="float"), + pytest.param("abc", "'abc'", id="string"), + pytest.param( + b"\x01\x02\x03ABC", "CAST(b'\\x01\\x02\\x03ABC' AS BYTES)", id="bytes" + ), + pytest.param( + decimal.Decimal("123.75"), "CAST(123.75 AS NUMERIC)", id="decimal" + ), + pytest.param( + datetime.date(2025, 1, 1), "CAST('2025-01-01' AS DATE)", id="date" + ), + pytest.param( + datetime.datetime(2025, 1, 2, 3, 45, 6, 789123), + "CAST('2025-01-02T03:45:06.789123' AS DATETIME)", + id="datetime", + ), + pytest.param( + datetime.time(12, 34, 56, 789123), + "CAST('12:34:56.789123' AS TIME)", + id="time", + ), + pytest.param( + datetime.datetime( + 2025, 1, 2, 3, 45, 6, 789123, tzinfo=datetime.timezone.utc + ), + "CAST('2025-01-02T03:45:06.789123+00:00' AS TIMESTAMP)", + id="timestamp", + ), + pytest.param( + shapely.geometry.Point(0, 1), "ST_GEOGFROMTEXT('POINT (0 1)')", id="geo" + ), + pytest.param(np.int64(123), "123", id="np_int64"), + pytest.param(np.float64(123.75), "123.75", id="np_float64"), + pytest.param(float("inf"), "CAST('Infinity' AS FLOAT64)", id="inf"), + pytest.param(float("-inf"), "CAST('-Infinity' AS FLOAT64)", id="neg_inf"), + pytest.param(float("nan"), "NULL", id="nan"), + pytest.param(pd.NA, "NULL", id="pd_na"), + pytest.param(datetime.timedelta(seconds=1), "1000000", id="timedelta"), + pytest.param("POINT (0 1)", "'POINT (0 1)'", id="string_geo"), + ), +) +def test_literal(value, expected_pattern): + got = sql.to_sql(sql.literal(value)) + assert got == expected_pattern + + +@pytest.mark.parametrize( + ("value", "dtype", "expected"), + ( + pytest.param( + decimal.Decimal("1.23"), + sql.dtypes.BIGNUMERIC_DTYPE, + "CAST(1.23 AS BIGNUMERIC)", + id="bignumeric", + ), + pytest.param( + [], + pd.ArrowDtype(pa.list_(pa.int64())), + "ARRAY[]", + id="empty_array", + ), + pytest.param( + {"a": 1, "b": "hello"}, + pd.ArrowDtype(pa.struct([("a", pa.int64()), ("b", pa.string())])), + "STRUCT(1 AS `a`, 'hello' AS `b`)", + id="struct", + ), + pytest.param( + float("nan"), + sql.dtypes.FLOAT_DTYPE, + "CAST('NaN' AS FLOAT64)", + id="explicit_nan", + ), + pytest.param( + pa.scalar(123, type=pa.int64()), + None, + "123", + id="pa_scalar_int", + ), + pytest.param( + pa.scalar(None, type=pa.int64()), + None, + "CAST(NULL AS INT64)", + id="pa_scalar_null", + ), + pytest.param( + {"a": 10}, + sql.dtypes.JSON_DTYPE, + "PARSE_JSON('{\\'a\\': 10}')", + id="json", + ), + ), +) +def test_literal_explicit_dtype(value, dtype, expected): + got = sql.to_sql(sql.literal(value, dtype=dtype)) + assert got == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + ( + pytest.param([True, False], "[TRUE, FALSE]", id="bool"), + pytest.param([123, 456], "[123, 456]", id="int"), + pytest.param( + [123.75, 456.78, float("nan"), float("inf"), float("-inf")], + "[\n 123.75,\n 456.78,\n CAST('NaN' AS FLOAT64),\n CAST('Infinity' AS FLOAT64),\n CAST('-Infinity' AS FLOAT64)\n]", + id="float", + ), + pytest.param( + [b"\x01\x02\x03ABC", b"\x01\x02\x03ABC"], + "[CAST(b'\\x01\\x02\\x03ABC' AS BYTES), CAST(b'\\x01\\x02\\x03ABC' AS BYTES)]", + id="bytes", + ), + pytest.param( + [datetime.date(2025, 1, 1), datetime.date(2025, 1, 1)], + "[CAST('2025-01-01' AS DATE), CAST('2025-01-01' AS DATE)]", + id="date", + ), + ), +) +def test_literal_for_list(value: list, expected: str): + got = sql.to_sql(sql.literal(value)) + assert got == expected diff --git a/tests/unit/core/test_sql.py b/tests/unit/core/test_sql.py index 17da3008fc4..04ebb28764d 100644 --- a/tests/unit/core/test_sql.py +++ b/tests/unit/core/test_sql.py @@ -12,128 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import datetime -import decimal -import re - -import pytest -import shapely.geometry # type: ignore - from bigframes.core import sql -@pytest.mark.parametrize( - ("value", "expected_pattern"), - ( - # Try to have some literals for each scalar data type: - # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types - (None, "NULL"), - # TODO: support ARRAY type (possibly another method?) - (True, "True"), - (False, "False"), - ( - b"\x01\x02\x03ABC", - re.escape(r"b'\x01\x02\x03ABC'"), - ), - ( - datetime.date(2025, 1, 1), - re.escape("DATE('2025-01-01')"), - ), - ( - datetime.datetime(2025, 1, 2, 3, 45, 6, 789123), - re.escape("DATETIME('2025-01-02T03:45:06.789123')"), - ), - ( - shapely.geometry.Point(0, 1), - r"ST_GEOGFROMTEXT\('POINT \(0[.]?0* 1[.]?0*\)'\)", - ), - # TODO: INTERVAL type (e.g. from dateutil.relativedelta) - # TODO: JSON type (TBD what Python object that would correspond to) - (123, re.escape("123")), - (decimal.Decimal("123.75"), re.escape("CAST('123.75' AS NUMERIC)")), - # TODO: support BIGNUMERIC by looking at precision/scale of the DECIMAL - (123.75, re.escape("123.75")), - # TODO: support RANGE type - ("abc", re.escape("'abc'")), - # TODO: support STRUCT type (possibly another method?) - ( - datetime.time(12, 34, 56, 789123), - re.escape("TIME(DATETIME('1970-01-01 12:34:56.789123'))"), - ), - ( - datetime.datetime( - 2025, 1, 2, 3, 45, 6, 789123, tzinfo=datetime.timezone.utc - ), - re.escape("TIMESTAMP('2025-01-02T03:45:06.789123+00:00')"), - ), - ), -) -def test_simple_literal(value, expected_pattern): - got = sql.simple_literal(value) - assert re.match(expected_pattern, got) is not None - - -@pytest.mark.parametrize( - ("value", "expected_pattern"), - ( - # Try to have some list of literals for each scalar data type: - # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types - ([None, None], re.escape("[NULL, NULL]")), - ([True, False], re.escape("[True, False]")), - ( - [b"\x01\x02\x03ABC", b"\x01\x02\x03ABC"], - re.escape("[b'\\x01\\x02\\x03ABC', b'\\x01\\x02\\x03ABC']"), - ), - ( - [datetime.date(2025, 1, 1), datetime.date(2025, 1, 1)], - re.escape("[DATE('2025-01-01'), DATE('2025-01-01')]"), - ), - ( - [datetime.datetime(2025, 1, 2, 3, 45, 6, 789123)], - re.escape("[DATETIME('2025-01-02T03:45:06.789123')]"), - ), - ( - [shapely.geometry.Point(0, 1), shapely.geometry.Point(0, 2)], - r"\[ST_GEOGFROMTEXT\('POINT \(0[.]?0* 1[.]?0*\)'\), ST_GEOGFROMTEXT\('POINT \(0[.]?0* 2[.]?0*\)'\)\]", - ), - # TODO: INTERVAL type (e.g. from dateutil.relativedelta) - # TODO: JSON type (TBD what Python object that would correspond to) - ([123, 456], re.escape("[123, 456]")), - ( - [decimal.Decimal("123.75"), decimal.Decimal("456.78")], - re.escape("[CAST('123.75' AS NUMERIC), CAST('456.78' AS NUMERIC)]"), - ), - # TODO: support BIGNUMERIC by looking at precision/scale of the DECIMAL - ([123.75, 456.78], re.escape("[123.75, 456.78]")), - # TODO: support RANGE type - (["abc", "def"], re.escape("['abc', 'def']")), - # TODO: support STRUCT type (possibly another method?) - ( - [datetime.time(12, 34, 56, 789123), datetime.time(11, 25, 56, 789123)], - re.escape( - "[TIME(DATETIME('1970-01-01 12:34:56.789123')), TIME(DATETIME('1970-01-01 11:25:56.789123'))]" - ), - ), - ( - [ - datetime.datetime( - 2025, 1, 2, 3, 45, 6, 789123, tzinfo=datetime.timezone.utc - ), - datetime.datetime( - 2025, 2, 1, 4, 45, 6, 789123, tzinfo=datetime.timezone.utc - ), - ], - re.escape( - "[TIMESTAMP('2025-01-02T03:45:06.789123+00:00'), TIMESTAMP('2025-02-01T04:45:06.789123+00:00')]" - ), - ), - ), -) -def test_simple_literal_w_list(value: list, expected_pattern: str): - got = sql.simple_literal(value) - assert re.match(expected_pattern, got) is not None - - def test_create_vector_search_sql_simple(): result_query = sql.create_vector_search_sql( sql_string="SELECT embedding FROM my_embeddings_table WHERE id = 1", @@ -180,6 +61,6 @@ def test_create_vector_search_sql_all_named_parameters(): query_column_to_search => 'another_embedding_column', top_k=> 10, distance_type => 'cosine', -options => '{\\"fraction_lists_to_search\\": 0.1, \\"use_brute_force\\": false}') +options => '{"fraction_lists_to_search": 0.1, "use_brute_force": false}') """ ) From fa428f9dc5064b0525348c862631684c92b518af Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 10 Mar 2026 17:38:03 +0000 Subject: [PATCH 2/2] fix unit tests --- bigframes/dtypes.py | 8 +++---- tests/unit/bigquery/test_ai.py | 12 +++++----- tests/unit/bigquery/test_ml.py | 22 +++++++++---------- .../core/compile/sqlglot/sql/test_base.py | 11 +++++++--- .../evaluate_model_with_options.sql | 2 +- .../explain_predict_model_with_options.sql | 2 +- .../generate_embedding_model_with_options.sql | 6 ++++- .../generate_text_model_with_options.sql | 11 +++++++++- .../global_explain_model_with_options.sql | 2 +- .../predict_model_with_options.sql | 2 +- tests/unit/ml/test_golden_sql.py | 8 +++---- tests/unit/session/test_io_bigquery.py | 4 ++-- 12 files changed, 54 insertions(+), 36 deletions(-) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 6b875a97d22..304428ef2fa 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -724,10 +724,6 @@ def infer_literal_type(literal) -> typing.Optional[Dtype]: # Maybe also normalize literal to canonical python representation to remove this burden from compilers? if isinstance(literal, pa.Scalar): return arrow_dtype_to_bigframes_dtype(literal.type) - if pd.api.types.is_list_like(literal): - element_types = [infer_literal_type(i) for i in literal] - common_type = lcd_type(*element_types) - return list_type(common_type) if pd.api.types.is_dict_like(literal): fields = [] for key in literal.keys(): @@ -738,6 +734,10 @@ def infer_literal_type(literal) -> typing.Optional[Dtype]: pa.field(key, field_type, nullable=(not pa.types.is_list(field_type))) ) return pd.ArrowDtype(pa.struct(fields)) + if pd.api.types.is_list_like(literal): + element_types = [infer_literal_type(i) for i in literal] + common_type = lcd_type(*element_types) + return list_type(common_type) if pd.isna(literal): return None # Null value without a definite type # Make sure to check datetime before date as datetimes are also dates diff --git a/tests/unit/bigquery/test_ai.py b/tests/unit/bigquery/test_ai.py index c73e63b9db1..2cb876d39a5 100644 --- a/tests/unit/bigquery/test_ai.py +++ b/tests/unit/bigquery/test_ai.py @@ -91,7 +91,7 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session): expected_part_1 = "SELECT * FROM AI.GENERATE_EMBEDDING(" expected_part_2 = f"MODEL `{model_name}`," expected_part_3 = "(SELECT * FROM my_table)," - expected_part_4 = "STRUCT(256 AS OUTPUT_DIMENSIONALITY)" + expected_part_4 = "STRUCT(256 AS `OUTPUT_DIMENSIONALITY`)" assert expected_part_1 in query assert expected_part_2 in query @@ -117,7 +117,7 @@ def test_generate_embedding_with_series(mock_embedding_series, mock_session): assert f"MODEL `{model_name}`" in query assert "(SELECT my_col AS content FROM my_table)" in query assert ( - "STRUCT(0.0 AS START_SECOND, 10.0 AS END_SECOND, 5.0 AS INTERVAL_SECONDS)" + "STRUCT(0.0 AS `START_SECOND`, 10.0 AS `END_SECOND`, 5.0 AS `INTERVAL_SECONDS`)" in query ) @@ -180,7 +180,7 @@ def test_generate_text_with_dataframe(mock_dataframe, mock_session): expected_part_1 = "SELECT * FROM AI.GENERATE_TEXT(" expected_part_2 = f"MODEL `{model_name}`," expected_part_3 = "(SELECT * FROM my_table)," - expected_part_4 = "STRUCT(256 AS MAX_OUTPUT_TOKENS)" + expected_part_4 = "STRUCT(256 AS `MAX_OUTPUT_TOKENS`)" assert expected_part_1 in query assert expected_part_2 in query @@ -238,7 +238,7 @@ def test_generate_table_with_dataframe(mock_dataframe, mock_session): expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE(" expected_part_2 = f"MODEL `{model_name}`," expected_part_3 = "(SELECT * FROM my_table)," - expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)" + expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS `output_schema`)" assert expected_part_1 in query assert expected_part_2 in query @@ -264,7 +264,7 @@ def test_generate_table_with_options(mock_dataframe, mock_session): assert f"MODEL `{model_name}`" in query assert "(SELECT * FROM my_table)" in query assert ( - "STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)" + "STRUCT('col1 STRING' AS `output_schema`, 0.5 AS `temperature`, 100 AS `max_output_tokens`)" in query ) @@ -287,7 +287,7 @@ def test_generate_table_with_mapping_schema(mock_dataframe, mock_session): expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE(" expected_part_2 = f"MODEL `{model_name}`," expected_part_3 = "(SELECT * FROM my_table)," - expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)" + expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS `output_schema`)" assert expected_part_1 in query assert expected_part_2 in query diff --git a/tests/unit/bigquery/test_ml.py b/tests/unit/bigquery/test_ml.py index 6d39901a35e..a68133225d4 100644 --- a/tests/unit/bigquery/test_ml.py +++ b/tests/unit/bigquery/test_ml.py @@ -167,14 +167,14 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo assert "ML.GENERATE_TEXT" in generated_sql assert f"MODEL `{MODEL_NAME}`" in generated_sql assert "(SELECT * FROM `pandas_df`)" in generated_sql - assert "STRUCT(0.5 AS temperature" in generated_sql - assert "128 AS max_output_tokens" in generated_sql - assert "20 AS top_k" in generated_sql - assert "0.9 AS top_p" in generated_sql - assert "true AS flatten_json_output" in generated_sql - assert "['a', 'b'] AS stop_sequences" in generated_sql - assert "true AS ground_with_google_search" in generated_sql - assert "'TYPE' AS request_type" in generated_sql + assert "STRUCT(\n 0.5 AS `temperature`" in generated_sql + assert "128 AS `max_output_tokens`" in generated_sql + assert "20 AS `top_k`" in generated_sql + assert "0.9 AS `top_p`" in generated_sql + assert "TRUE AS `flatten_json_output`" in generated_sql + assert "['a', 'b'] AS `stop_sequences`" in generated_sql + assert "TRUE AS `ground_with_google_search`" in generated_sql + assert "'TYPE' AS `request_type`" in generated_sql @mock.patch("bigframes.pandas.read_gbq_query") @@ -210,6 +210,6 @@ def test_generate_embedding_with_pandas_dataframe( assert "ML.GENERATE_EMBEDDING" in generated_sql assert f"MODEL `{MODEL_NAME}`" in generated_sql assert "(SELECT * FROM `pandas_df`)" in generated_sql - assert "true AS flatten_json_output" in generated_sql - assert "'RETRIEVAL_DOCUMENT' AS task_type" in generated_sql - assert "256 AS output_dimensionality" in generated_sql + assert "STRUCT(\n TRUE AS `flatten_json_output`" in generated_sql + assert "'RETRIEVAL_DOCUMENT' AS `task_type`" in generated_sql + assert "256 AS `output_dimensionality`" in generated_sql diff --git a/tests/unit/core/compile/sqlglot/sql/test_base.py b/tests/unit/core/compile/sqlglot/sql/test_base.py index d11fddc954c..5ba77d925d0 100644 --- a/tests/unit/core/compile/sqlglot/sql/test_base.py +++ b/tests/unit/core/compile/sqlglot/sql/test_base.py @@ -14,6 +14,7 @@ import datetime import decimal +import re import numpy as np import pandas as pd @@ -59,9 +60,6 @@ "CAST('2025-01-02T03:45:06.789123+00:00' AS TIMESTAMP)", id="timestamp", ), - pytest.param( - shapely.geometry.Point(0, 1), "ST_GEOGFROMTEXT('POINT (0 1)')", id="geo" - ), pytest.param(np.int64(123), "123", id="np_int64"), pytest.param(np.float64(123.75), "123.75", id="np_float64"), pytest.param(float("inf"), "CAST('Infinity' AS FLOAT64)", id="inf"), @@ -77,6 +75,13 @@ def test_literal(value, expected_pattern): assert got == expected_pattern +def test_literal_for_geo(): + value = shapely.geometry.Point(0, 1) + expected_pattern = r"ST_GEOGFROMTEXT\('POINT \(0[.]?0* 1[.]?0*\)'\)" + got = sql.to_sql(sql.literal(value)) + assert re.match(expected_pattern, got) is not None + + @pytest.mark.parametrize( ("value", "dtype", "expected"), ( diff --git a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql index 848c36907b9..cdb66bbf0e1 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_evaluate_model_with_options/evaluate_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level)) +SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(FALSE AS `perform_aggregation`, 10 AS `horizon`, 0.95 AS `confidence_level`)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_explain_predict_model_with_options/explain_predict_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_explain_predict_model_with_options/explain_predict_model_with_options.sql index 1214bba8706..7569463ea2d 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_explain_predict_model_with_options/explain_predict_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_explain_predict_model_with_options/explain_predict_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS top_k_features)) +SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS `top_k_features`)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_embedding_model_with_options/generate_embedding_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_embedding_model_with_options/generate_embedding_model_with_options.sql index d07e1c1e15e..3be957079cf 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_generate_embedding_model_with_options/generate_embedding_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_embedding_model_with_options/generate_embedding_model_with_options.sql @@ -1 +1,5 @@ -SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(true AS flatten_json_output, 'RETRIEVAL_DOCUMENT' AS task_type, 256 AS output_dimensionality)) +SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT( + TRUE AS `flatten_json_output`, + 'RETRIEVAL_DOCUMENT' AS `task_type`, + 256 AS `output_dimensionality` +)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql index 7839ff3fbdd..0ea26747287 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_generate_text_model_with_options/generate_text_model_with_options.sql @@ -1 +1,10 @@ -SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type)) +SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT( + 0.5 AS `temperature`, + 128 AS `max_output_tokens`, + 20 AS `top_k`, + 0.9 AS `top_p`, + TRUE AS `flatten_json_output`, + ['a', 'b'] AS `stop_sequences`, + TRUE AS `ground_with_google_search`, + 'TYPE' AS `request_type` +)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql index b8d158acfc7..396648aa1db 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_global_explain_model_with_options/global_explain_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain)) +SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(TRUE AS `class_level_explain`)) diff --git a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql index f320d47fcf4..e19f39eebba 100644 --- a/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql +++ b/tests/unit/core/sql/snapshots/test_ml/test_predict_model_with_options/predict_model_with_options.sql @@ -1 +1 @@ -SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns)) +SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(TRUE AS `keep_original_columns`)) diff --git a/tests/unit/ml/test_golden_sql.py b/tests/unit/ml/test_golden_sql.py index 7f6843aacf6..d3d880f87ae 100644 --- a/tests/unit/ml/test_golden_sql.py +++ b/tests/unit/ml/test_golden_sql.py @@ -124,7 +124,7 @@ def test_linear_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=TRUE,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" ) @@ -134,7 +134,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X, model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=FALSE,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" ) @@ -169,7 +169,7 @@ def test_logistic_regression_default_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql", + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=TRUE,\n auto_class_weights=FALSE,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql", ) @@ -191,7 +191,7 @@ def test_logistic_regression_params_fit( model.fit(mock_X, mock_y) mock_session._start_query_ml_ddl.assert_called_once_with( - "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" + "CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=FALSE,\n auto_class_weights=TRUE,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql" ) diff --git a/tests/unit/session/test_io_bigquery.py b/tests/unit/session/test_io_bigquery.py index eb58c6bb52d..3d3832d6786 100644 --- a/tests/unit/session/test_io_bigquery.py +++ b/tests/unit/session/test_io_bigquery.py @@ -345,7 +345,7 @@ def test_bq_schema_to_sql(schema: Iterable[bigquery.SchemaField], expected: str) ), ( "SELECT `row_index`, `string_col` FROM `test_table` " - "FOR SYSTEM_TIME AS OF TIMESTAMP('2024-05-14T12:42:36.125125+00:00') " + "FOR SYSTEM_TIME AS OF CAST('2024-05-14T12:42:36.125125+00:00' AS TIMESTAMP) " "WHERE `rowindex` NOT IN (0, 6) OR `string_col` IN ('Hello, World!', " "'こんにちは') LIMIT 123" ), @@ -374,7 +374,7 @@ def test_bq_schema_to_sql(schema: Iterable[bigquery.SchemaField], expected: str) string_col, FROM `test_table` AS t ) """ - "FOR SYSTEM_TIME AS OF TIMESTAMP('2024-05-14T12:42:36.125125+00:00') " + "FOR SYSTEM_TIME AS OF CAST('2024-05-14T12:42:36.125125+00:00' AS TIMESTAMP) " "WHERE `rowindex` < 4 AND `string_col` = 'Hello, World!' " "LIMIT 123" ),