Skip to content

Commit ce9b691

Browse files
committed
refactor: use sqlglot to build literal
1 parent 077cb2e commit ce9b691

File tree

12 files changed

+239
-288
lines changed

12 files changed

+239
-288
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from bigframes import series, session
2929
from bigframes.bigquery._operations import utils as bq_utils
3030
from bigframes.core import convert
31+
from bigframes.core.compile.sqlglot import sql as sg_sql
3132
from bigframes.core.logging import log_adapter
32-
import bigframes.core.sql.literals
3333
from bigframes.ml import core as ml_core
3434
from bigframes.operations import ai_ops, output_schemas
3535

@@ -458,7 +458,7 @@ def generate_embedding(
458458
model_name, session = bq_utils.get_model_name_and_session(model, data)
459459
table_sql = bq_utils.to_sql(data)
460460

461-
struct_fields: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {}
461+
struct_fields: Dict[str, Any] = {}
462462
if output_dimensionality is not None:
463463
struct_fields["OUTPUT_DIMENSIONALITY"] = output_dimensionality
464464
if task_type is not None:
@@ -478,7 +478,7 @@ def generate_embedding(
478478
FROM AI.GENERATE_EMBEDDING(
479479
MODEL `{model_name}`,
480480
({table_sql}),
481-
{bigframes.core.sql.literals.struct_literal(struct_fields)}
481+
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
482482
)
483483
"""
484484

@@ -591,7 +591,7 @@ def generate_text(
591591
FROM AI.GENERATE_TEXT(
592592
MODEL `{model_name}`,
593593
({table_sql}),
594-
{bigframes.core.sql.literals.struct_literal(struct_fields)}
594+
{sg_sql.to_sql(sg_sql.literal(struct_fields))}
595595
)
596596
"""
597597

@@ -677,9 +677,7 @@ def generate_table(
677677
else:
678678
output_schema_str = output_schema
679679

680-
struct_fields_bq: Dict[str, bigframes.core.sql.literals.STRUCT_VALUES] = {
681-
"output_schema": output_schema_str
682-
}
680+
struct_fields_bq: Dict[str, Any] = {"output_schema": output_schema_str}
683681
if temperature is not None:
684682
struct_fields_bq["temperature"] = temperature
685683
if top_p is not None:
@@ -691,7 +689,7 @@ def generate_table(
691689
if request_type is not None:
692690
struct_fields_bq["request_type"] = request_type
693691

694-
struct_sql = bigframes.core.sql.literals.struct_literal(struct_fields_bq)
692+
struct_sql = sg_sql.to_sql(sg_sql.literal(struct_fields_bq))
695693
query = f"""
696694
SELECT *
697695
FROM AI.GENERATE_TABLE(

bigframes/core/compile/sqlglot/sql/base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,11 @@ def identifier(id: str) -> sge.Identifier:
5757
return sge.to_identifier(id, quoted=QUOTED)
5858

5959

60-
def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
60+
def literal(value: typing.Any, dtype: dtypes.Dtype | None = None) -> sge.Expression:
6161
"""Return a string representing column reference in a SQL."""
62+
if dtype is None:
63+
dtype = dtypes.infer_literal_type(value)
64+
6265
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
6366
if sqlglot_type is None:
6467
if not pd.isna(value):
@@ -81,6 +84,14 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
8184
expressions=[literal(value=v, dtype=value_type) for v in value]
8285
)
8386
return values if len(value) > 0 else cast(values, sqlglot_type)
87+
elif dtype == dtypes.FLOAT_DTYPE:
88+
if pd.isna(value):
89+
if isinstance(value, (float, np.floating)) and np.isnan(value):
90+
return constants._NAN
91+
return cast(sge.Null(), sqlglot_type)
92+
if np.isinf(value):
93+
return constants._INF if value > 0 else constants._NEG_INF
94+
return sge.convert(value)
8495
elif pd.isna(value) or (isinstance(value, pa.Scalar) and not value.is_valid):
8596
return cast(sge.Null(), sqlglot_type)
8697
elif dtype == dtypes.JSON_DTYPE:
@@ -100,13 +111,11 @@ def literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
100111
return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt))
101112
elif dtype == dtypes.TIMEDELTA_DTYPE:
102113
return sge.convert(utils.timedelta_to_micros(value))
103-
elif dtype == dtypes.FLOAT_DTYPE:
104-
if np.isinf(value):
105-
return constants._INF if value > 0 else constants._NEG_INF
106-
return sge.convert(value)
107114
else:
108115
if isinstance(value, np.generic):
109116
value = value.item()
117+
if isinstance(value, pa.Scalar):
118+
value = value.as_py()
110119
return sge.convert(value)
111120

112121

bigframes/core/pyformat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _field_to_template_value(
8989
dry_run: bool = False,
9090
) -> str:
9191
"""Convert value to something embeddable in a SQL string."""
92-
import bigframes.core.sql # Avoid circular imports
92+
import bigframes.core.compile.sqlglot.sql as sql # Avoid circular imports
9393
import bigframes.dataframe # Avoid circular imports
9494

9595
_validate_type(name, value)
@@ -107,20 +107,20 @@ def _field_to_template_value(
107107
if isinstance(value, str):
108108
return value
109109

110-
return bigframes.core.sql.simple_literal(value)
110+
return sql.to_sql(sql.literal(value))
111111

112112

113113
def _validate_type(name: str, value: Any):
114114
"""Raises TypeError if value is unsupported."""
115-
import bigframes.core.sql # Avoid circular imports
116115
import bigframes.dataframe # Avoid circular imports
116+
import bigframes.dtypes # Avoid circular imports
117117

118118
if value is None:
119119
return # None can't be used in isinstance, but is a valid literal.
120120

121121
supported_types = (
122122
typing.get_args(_BQ_TABLE_TYPES)
123-
+ typing.get_args(bigframes.core.sql.SIMPLE_LITERAL_TYPES)
123+
+ bigframes.dtypes.SUPPORTED_LITERAL_TYPES
124124
+ (bigframes.dataframe.DataFrame,)
125125
+ (pandas.DataFrame,)
126126
)

bigframes/core/sql/__init__.py

Lines changed: 20 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,19 @@
1717
Utility functions for SQL construction.
1818
"""
1919

20-
import datetime
21-
import decimal
2220
import json
23-
import math
24-
from typing import cast, Collection, Iterable, Mapping, Optional, TYPE_CHECKING, Union
21+
from typing import (
22+
Any,
23+
cast,
24+
Collection,
25+
Iterable,
26+
Mapping,
27+
Optional,
28+
TYPE_CHECKING,
29+
Union,
30+
)
2531

2632
import bigframes_vendored.sqlglot.expressions as sge
27-
import shapely.geometry.base # type: ignore
2833

2934
from bigframes.core.compile.sqlglot import sql
3035

@@ -43,68 +48,8 @@
4348
to_wkt = dumps
4449

4550

46-
SIMPLE_LITERAL_TYPES = Union[
47-
bytes,
48-
str,
49-
int,
50-
bool,
51-
float,
52-
datetime.datetime,
53-
datetime.date,
54-
datetime.time,
55-
decimal.Decimal,
56-
list,
57-
]
58-
59-
60-
### Writing SQL Values (literals, column references, table references, etc.)
61-
def simple_literal(value: Union[SIMPLE_LITERAL_TYPES, None]) -> str:
62-
"""Return quoted input string."""
63-
64-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#literals
65-
if value is None:
66-
return "NULL"
67-
elif isinstance(value, str):
68-
# Single quoting seems to work nicer with ibis than double quoting
69-
return f"'{sql.escape_chars(value)}'"
70-
elif isinstance(value, bytes):
71-
return repr(value)
72-
elif isinstance(value, (bool, int)):
73-
return str(value)
74-
elif isinstance(value, float):
75-
# https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#floating_point_literals
76-
if math.isnan(value):
77-
return 'CAST("nan" as FLOAT)'
78-
if value == math.inf:
79-
return 'CAST("+inf" as FLOAT)'
80-
if value == -math.inf:
81-
return 'CAST("-inf" as FLOAT)'
82-
return str(value)
83-
# Check datetime first as it is a subclass of date
84-
elif isinstance(value, datetime.datetime):
85-
if value.tzinfo is None:
86-
return f"DATETIME('{value.isoformat()}')"
87-
else:
88-
return f"TIMESTAMP('{value.isoformat()}')"
89-
elif isinstance(value, datetime.date):
90-
return f"DATE('{value.isoformat()}')"
91-
elif isinstance(value, datetime.time):
92-
return f"TIME(DATETIME('1970-01-01 {value.isoformat()}'))"
93-
elif isinstance(value, shapely.geometry.base.BaseGeometry):
94-
return f"ST_GEOGFROMTEXT({simple_literal(to_wkt(value))})"
95-
elif isinstance(value, decimal.Decimal):
96-
# TODO: disambiguate BIGNUMERIC based on scale and/or precision
97-
return f"CAST('{str(value)}' AS NUMERIC)"
98-
elif isinstance(value, list):
99-
simple_literals = [simple_literal(i) for i in value]
100-
return f"[{', '.join(simple_literals)}]"
101-
102-
else:
103-
raise ValueError(f"Cannot produce literal for {value}")
104-
105-
106-
def multi_literal(*values: str):
107-
literal_strings = [simple_literal(i) for i in values]
51+
def multi_literal(*values: Any):
52+
literal_strings = [sql.to_sql(sql.literal(i)) for i in values]
10853
return "(" + ", ".join(literal_strings) + ")"
10954

11055

@@ -210,7 +155,7 @@ def create_vector_index_ddl(
210155

211156
rendered_options = ", ".join(
212157
[
213-
f"{option_name} = {simple_literal(option_value)}"
158+
f"{option_name} = {sql.to_sql(sql.literal(option_value))}"
214159
for option_name, option_value in options.items()
215160
]
216161
)
@@ -237,24 +182,26 @@ def create_vector_search_sql(
237182

238183
vector_search_args = [
239184
f"TABLE {sql.to_sql(sql.identifier(cast(str, base_table)))}",
240-
f"{simple_literal(column_to_search)}",
185+
f"{sql.to_sql(sql.literal(column_to_search))}",
241186
f"({sql_string})",
242187
]
243188

244189
if query_column_to_search is not None:
245190
vector_search_args.append(
246-
f"query_column_to_search => {simple_literal(query_column_to_search)}"
191+
f"query_column_to_search => {sql.to_sql(sql.literal(query_column_to_search))}"
247192
)
248193

249194
if top_k is not None:
250-
vector_search_args.append(f"top_k=> {simple_literal(top_k)}")
195+
vector_search_args.append(f"top_k=> {sql.to_sql(sql.literal(top_k))}")
251196

252197
if distance_type is not None:
253-
vector_search_args.append(f"distance_type => {simple_literal(distance_type)}")
198+
vector_search_args.append(
199+
f"distance_type => {sql.to_sql(sql.literal(distance_type))}"
200+
)
254201

255202
if options is not None:
256203
vector_search_args.append(
257-
f"options => {simple_literal(json.dumps(options, indent=None))}"
204+
f"options => {sql.to_sql(sql.literal(json.dumps(options, indent=None)))}"
258205
)
259206

260207
args_str = ",\n".join(vector_search_args)

bigframes/core/sql/literals.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

bigframes/core/sql/ml.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from typing import Any, Dict, List, Mapping, Optional, Union
1818

1919
from bigframes.core.compile.sqlglot import sql as sg_sql
20-
import bigframes.core.sql
21-
import bigframes.core.sql.literals
2220

2321

2422
def create_model_ddl(
@@ -76,9 +74,9 @@ def create_model_ddl(
7674
# Handle list options like model_registry="vertex_ai"
7775
# wait, usually options are key=value.
7876
# if value is list, it is [val1, val2]
79-
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
77+
rendered_val = sg_sql.to_sql(sg_sql.literal(list(option_value)))
8078
else:
81-
rendered_val = bigframes.core.sql.simple_literal(option_value)
79+
rendered_val = sg_sql.to_sql(sg_sql.literal(option_value))
8280

8381
rendered_options.append(f"{option_name} = {rendered_val}")
8482

@@ -108,7 +106,7 @@ def _build_struct_sql(
108106
) -> str:
109107
if not struct_options:
110108
return ""
111-
return f", {bigframes.core.sql.literals.struct_literal(struct_options)}"
109+
return f", {sg_sql.to_sql(sg_sql.literal(struct_options))}"
112110

113111

114112
def evaluate(

0 commit comments

Comments
 (0)