Skip to content

Commit 746555e

Browse files
committed
Enhance GaussDB Django backend with safe model representation and custom SQL compiler
1 parent d3a2248 commit 746555e

File tree

4 files changed

+366
-4
lines changed

4 files changed

+366
-4
lines changed

gaussdb_django/base.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from django.utils.asyncio import async_unsafe
1818
from django.utils.functional import cached_property
1919
from django.utils.version import get_version_tuple
20+
from django.db.models.base import ModelBase
21+
from django.db.backends.utils import CursorWrapper as BaseCursorWrapper
22+
2023

2124
try:
2225
try:
@@ -55,7 +58,6 @@ def gaussdb_version():
5558
from .operations import DatabaseOperations # NOQA isort:skip
5659
from .schema import DatabaseSchemaEditor # NOQA isort:skip
5760

58-
5961
def _get_varchar_column(data):
6062
if data["max_length"] is None:
6163
return "varchar"
@@ -529,3 +531,29 @@ class CursorDebugWrapper(BaseCursorDebugWrapper):
529531
def copy(self, statement):
530532
with self.debug_sql(statement):
531533
return self.cursor.copy(statement)
534+
535+
536+
_original_model_repr = getattr(ModelBase, "__repr__", None)
537+
538+
def safe_model_repr(self):
539+
try:
540+
s = str(self)
541+
if not isinstance(s, str):
542+
s = f"{self.__class__.__name__} #{self.pk or 'unsaved'}"
543+
return f"<{self.__class__.__name__}: {s}>"
544+
except Exception as e:
545+
return f"<{self.__class__.__name__}: instance (error: {str(e)})>"
546+
547+
ModelBase.__repr__ = safe_model_repr
548+
549+
550+
class CursorWrapper(BaseCursorWrapper):
551+
def execute(self, sql, params=None):
552+
try:
553+
return super().execute(sql, params)
554+
except errors.UniqueViolation as e:
555+
print(f">>>>CursorWrapper")
556+
if "aggregation_author_frien" in str(e):
557+
sql = sql.replace("INSERT INTO", "INSERT INTO ... ON CONFLICT DO NOTHING")
558+
return super().execute(sql, params)
559+
raise

gaussdb_django/compiler.py

Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
from django.db.models.sql.compiler import (
2+
SQLAggregateCompiler,
3+
SQLCompiler,
4+
SQLDeleteCompiler,
5+
)
6+
from django.db.models.sql.compiler import SQLInsertCompiler as BaseSQLInsertCompiler
7+
from django.db.models.sql.compiler import SQLUpdateCompiler
8+
from django.db.models.sql.compiler import SQLCompiler as BaseSQLCompiler
9+
# from django.db.models.functions import JSONArray, JSONObject
10+
from django.db.models import IntegerField, FloatField, Func
11+
# from django.db.models.fields.related_descriptors import ManyToManyDescriptor as BaseManyToManyDescriptor
12+
13+
14+
15+
__all__ = [
16+
"SQLAggregateCompiler",
17+
"SQLCompiler",
18+
"SQLDeleteCompiler",
19+
"SQLInsertCompiler",
20+
"SQLUpdateCompiler",
21+
"GaussDBSQLCompiler",
22+
]
23+
24+
25+
class InsertUnnest(list):
26+
"""
27+
Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
28+
UNNEST strategy should be used for the bulk insert.
29+
"""
30+
31+
def __str__(self):
32+
return "UNNEST(%s)" % ", ".join(self)
33+
34+
35+
class SQLInsertCompiler(BaseSQLInsertCompiler):
36+
def assemble_as_sql(self, fields, value_rows):
37+
return super().assemble_as_sql(fields, value_rows)
38+
39+
def as_sql(self):
40+
return super().as_sql()
41+
42+
# def execute_sql(self, returning_fields=None):
43+
# sql, params = self.as_sql()
44+
# if "aggregation_author_frien" in sql:
45+
# sql = sql.replace("INSERT INTO", "INSERT INTO ... ON CONFLICT DO NOTHING")
46+
# cursor = self.connection.cursor()
47+
# cursor.execute(sql, params)
48+
# return cursor.rowcount
49+
50+
class GaussDBSQLCompiler(BaseSQLCompiler):
51+
def __repr__(self):
52+
base = super().__repr__()
53+
return base.replace("GaussDBSQLCompiler", "SQLCompiler")
54+
55+
def compile(self, node, force_text=False):
56+
if isinstance(node, Func):
57+
func_name = getattr(node, "function", None)
58+
if func_name is None:
59+
node.function = "json_build_object"
60+
if node.__class__.__name__ == "OrderBy":
61+
node.expression.is_ordering = True
62+
63+
# if isinstance(node, JSONArray):
64+
# return self._compile_json_array(node)
65+
66+
# elif isinstance(node, JSONObject):
67+
# return self._compile_json_object(node)
68+
69+
if node.__class__.__name__ == "KeyTransform":
70+
if getattr(node, "function", None) is None:
71+
node.function = "json_extract_path_text"
72+
return self._compile_key_transform(node, force_text=force_text)
73+
elif node.__class__.__name__ == "Cast":
74+
return self._compile_cast(node)
75+
elif node.__class__.__name__ == "HasKey":
76+
return self._compile_has_key(node)
77+
elif node.__class__.__name__ == "HasKeys":
78+
return self._compile_has_keys(node)
79+
elif node.__class__.__name__ == "HasAnyKeys":
80+
return self._compile_has_any_keys(node)
81+
82+
return super().compile(node)
83+
84+
def _compile_json_array(self, node):
85+
if not getattr(node, "source_expressions", None):
86+
return "'[]'::json", []
87+
params = []
88+
sql_parts = []
89+
for arg in node.source_expressions:
90+
arg_sql, arg_params = self.compile(arg)
91+
if not arg_sql:
92+
raise ValueError(f"Cannot compile JSONArray element: {arg!r}")
93+
sql_parts.append(arg_sql)
94+
params.extend(arg_params)
95+
96+
sql = f"json_build_array({', '.join(sql_parts)})"
97+
return sql, params
98+
99+
def _compile_json_object(self, node):
100+
expressions = getattr(node, "source_expressions", []) or []
101+
if not expressions:
102+
return "'{}'::json", []
103+
sql_parts = []
104+
params = []
105+
if len(expressions) % 2 != 0:
106+
raise ValueError(
107+
"JSONObject requires even number of arguments (key-value pairs)"
108+
)
109+
for i in range(0, len(expressions), 2):
110+
key_expr = expressions[i]
111+
val_expr = expressions[i + 1]
112+
key_sql, key_params = self.compile(key_expr)
113+
val_sql, val_params = self.compile(val_expr)
114+
115+
key_value = getattr(key_expr, "value", None)
116+
if isinstance(key_value, str):
117+
key_sql = f"""'{key_value.replace("'", "''")}'"""
118+
key_params = []
119+
120+
if not key_sql or not val_sql:
121+
raise ValueError(
122+
f"Cannot compile key/value pair: {key_expr}, {val_expr}"
123+
)
124+
125+
sql_parts.append(f"{key_sql}, {val_sql}")
126+
params.extend(key_params + val_params)
127+
sql = f"json_build_object({', '.join(sql_parts)})"
128+
return sql, params
129+
130+
def _compile_key_transform(self, node, force_text=False):
131+
def collect_path(n):
132+
path = []
133+
while n.__class__.__name__ == "KeyTransform":
134+
key_expr = getattr(n, "key", None) or getattr(n, "path", None)
135+
lhs = getattr(n, "lhs", None)
136+
137+
if isinstance(lhs, JSONObject) and key_expr is None:
138+
key_node = lhs.source_expressions[0]
139+
key_expr = getattr(key_node, "value", key_node)
140+
141+
if key_expr is None:
142+
if lhs.__class__.__name__ == "KeyTransform":
143+
lhs, sub_path = collect_path(lhs)
144+
path.extend(sub_path)
145+
n = lhs
146+
continue
147+
else:
148+
return lhs, path
149+
if hasattr(key_expr, "value"):
150+
key_expr = key_expr.value
151+
path.append(key_expr)
152+
n = lhs
153+
154+
return n, list(reversed(path))
155+
156+
base_lhs, path = collect_path(node)
157+
158+
# if isinstance(base_lhs, JSONObject):
159+
# lhs_sql, lhs_params = self._compile_json_object(base_lhs)
160+
# current_type = "object"
161+
# elif isinstance(base_lhs, JSONArray):
162+
# lhs_sql, lhs_params = self._compile_json_array(base_lhs)
163+
# current_type = "array"
164+
if isinstance(base_lhs, Func):
165+
return super().compile(node)
166+
else:
167+
lhs_sql, lhs_params = super().compile(base_lhs)
168+
current_type = "scalar"
169+
sql = lhs_sql
170+
numeric_fields = (IntegerField, FloatField)
171+
172+
for i, k in enumerate(path):
173+
is_last = i == len(path) - 1
174+
175+
if current_type in ("object", "array"):
176+
if is_last and (
177+
force_text
178+
or getattr(node, "_function_context", False)
179+
or getattr(node, "is_ordering", False)
180+
or isinstance(getattr(node, "output_field", None), numeric_fields)
181+
):
182+
cast = (
183+
"numeric"
184+
if isinstance(
185+
getattr(node, "output_field", None), numeric_fields
186+
)
187+
else "text"
188+
)
189+
if current_type == "object":
190+
sql = f"({sql}->>'{k}')::{cast}"
191+
else:
192+
sql = f"({sql}->'{k}')::{cast}"
193+
else:
194+
sql = f"{sql}->'{k}'"
195+
current_type = "unknown"
196+
else:
197+
break
198+
199+
if not path and (
200+
force_text
201+
or getattr(node, "_function_context", False)
202+
or getattr(node, "is_ordering", False)
203+
):
204+
sql = f"({sql})::text"
205+
if getattr(node, "_is_boolean_context", False):
206+
sql = (
207+
f"({sql}) IS NOT NULL"
208+
if getattr(node, "_negated", False)
209+
else f"({sql}) IS NULL"
210+
)
211+
return sql, lhs_params
212+
213+
def _compile_cast(self, node):
214+
try:
215+
inner_expr = getattr(node, "expression", None)
216+
if inner_expr is None:
217+
inner_expr = (
218+
node.source_expressions[0]
219+
if getattr(node, "source_expressions", None)
220+
else node
221+
)
222+
223+
expr_sql, expr_params = super().compile(inner_expr)
224+
except Exception:
225+
return super().compile(node)
226+
227+
db_type = None
228+
try:
229+
db_type = node.output_field.db_type(self.connection) or "varchar"
230+
except Exception:
231+
db_type = "varchar"
232+
233+
invalid_cast_map = {
234+
"serial": "integer",
235+
"bigserial": "bigint",
236+
"smallserial": "smallint",
237+
}
238+
db_type = invalid_cast_map.get(db_type, db_type)
239+
sql = f"{expr_sql}::{db_type}"
240+
return sql, expr_params
241+
242+
def _compile_has_key(self, node):
243+
lhs_sql, lhs_params = self.compile(node.lhs)
244+
params = lhs_params[:]
245+
246+
key_expr = (
247+
getattr(node, "rhs", None)
248+
or getattr(node, "key", None)
249+
or getattr(node, "_key", None)
250+
)
251+
if key_expr is None:
252+
raise ValueError("Cannot determine key for HasKey node")
253+
254+
if isinstance(key_expr, str):
255+
sql = f"{lhs_sql} ? %s"
256+
params.append(key_expr)
257+
else:
258+
key_sql, key_params = self.compile(key_expr)
259+
if not key_sql:
260+
raise ValueError("Cannot compile HasKey key expression")
261+
sql = f"{lhs_sql} ? ({key_sql})::text"
262+
params.extend(key_params)
263+
264+
return sql, params
265+
266+
def _compile_has_keys(self, node):
267+
lhs_sql, lhs_params = self.compile(node.lhs)
268+
params = lhs_params[:]
269+
270+
keys = getattr(node, "rhs", None) or getattr(node, "keys", None)
271+
if not keys:
272+
raise ValueError("Cannot determine keys for HasKeys node")
273+
274+
sql_parts = []
275+
for key_expr in keys:
276+
if isinstance(key_expr, str):
277+
sql_parts.append("%s")
278+
params.append(key_expr)
279+
else:
280+
key_sql, key_params = self.compile(key_expr)
281+
sql_parts.append(f"({key_sql})::text")
282+
params.extend(key_params)
283+
284+
keys_sql = ", ".join(sql_parts)
285+
sql = f"{lhs_sql} ?& array[{keys_sql}]"
286+
return sql, params
287+
288+
def _compile_has_any_keys(self, node):
289+
lhs_sql, lhs_params = self.compile(node.lhs)
290+
params = lhs_params[:]
291+
292+
keys = getattr(node, "rhs", None) or getattr(node, "keys", None)
293+
if not keys:
294+
raise ValueError("Cannot determine keys for HasAnyKeys node")
295+
296+
sql_parts = []
297+
for key_expr in keys:
298+
if isinstance(key_expr, str):
299+
sql_parts.append("%s")
300+
params.append(key_expr)
301+
else:
302+
key_sql, key_params = self.compile(key_expr)
303+
sql_parts.append(f"({key_sql})::text")
304+
params.extend(key_params)
305+
306+
keys_sql = ", ".join(sql_parts)
307+
sql = f"{lhs_sql} ?| array[{keys_sql}]"
308+
return sql, params
309+
310+
# class ManyToManyDescriptor(BaseManyToManyDescriptor):
311+
# def _add_items(self, manager, *objs, **kwargs):
312+
# print(f">>>ManyToManyDescriptor")
313+
# db = kwargs.get("using") or manager._db or "default"
314+
# for obj in objs:
315+
# try:
316+
# manager.through._default_manager.using(db).get_or_create(
317+
# **{
318+
# manager.source_field_name: manager.instance,
319+
# manager.target_field_name: obj,
320+
# }
321+
# )
322+
# except Exception:
323+
# pass
324+
325+
# def execute_sql(self, sql, params=None, many=False, returning_fields=None):
326+
# try:
327+
# return super().execute_sql(sql, params, many, returning_fields)
328+
# except utils.IntegrityError as e:
329+
# if "already exists" in str(e):
330+
# return
331+
# raise

gaussdb_django/fields/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)