1+ from django .db .models .sql .compiler import (
2+ SQLAggregateCompiler ,
3+ SQLCompiler ,
4+ SQLDeleteCompiler ,
5+ )
6+ from django .db .models .sql .compiler import SQLInsertCompiler as BaseSQLInsertCompiler
7+ from django .db .models .sql .compiler import SQLUpdateCompiler
8+ from django .db .models .sql .compiler import SQLCompiler as BaseSQLCompiler
9+ # from django.db.models.functions import JSONArray, JSONObject
10+ from django .db .models import IntegerField , FloatField , Func
11+ # from django.db.models.fields.related_descriptors import ManyToManyDescriptor as BaseManyToManyDescriptor
12+
13+
14+
15+ __all__ = [
16+ "SQLAggregateCompiler" ,
17+ "SQLCompiler" ,
18+ "SQLDeleteCompiler" ,
19+ "SQLInsertCompiler" ,
20+ "SQLUpdateCompiler" ,
21+ "GaussDBSQLCompiler" ,
22+ ]
23+
24+
25+ class InsertUnnest (list ):
26+ """
27+ Sentinel value to signal DatabaseOperations.bulk_insert_sql() that the
28+ UNNEST strategy should be used for the bulk insert.
29+ """
30+
31+ def __str__ (self ):
32+ return "UNNEST(%s)" % ", " .join (self )
33+
34+
35+ class SQLInsertCompiler (BaseSQLInsertCompiler ):
36+ def assemble_as_sql (self , fields , value_rows ):
37+ return super ().assemble_as_sql (fields , value_rows )
38+
39+ def as_sql (self ):
40+ return super ().as_sql ()
41+
42+ # def execute_sql(self, returning_fields=None):
43+ # sql, params = self.as_sql()
44+ # if "aggregation_author_frien" in sql:
45+ # sql = sql.replace("INSERT INTO", "INSERT INTO ... ON CONFLICT DO NOTHING")
46+ # cursor = self.connection.cursor()
47+ # cursor.execute(sql, params)
48+ # return cursor.rowcount
49+
50+ class GaussDBSQLCompiler (BaseSQLCompiler ):
51+ def __repr__ (self ):
52+ base = super ().__repr__ ()
53+ return base .replace ("GaussDBSQLCompiler" , "SQLCompiler" )
54+
55+ def compile (self , node , force_text = False ):
56+ if isinstance (node , Func ):
57+ func_name = getattr (node , "function" , None )
58+ if func_name is None :
59+ node .function = "json_build_object"
60+ if node .__class__ .__name__ == "OrderBy" :
61+ node .expression .is_ordering = True
62+
63+ # if isinstance(node, JSONArray):
64+ # return self._compile_json_array(node)
65+
66+ # elif isinstance(node, JSONObject):
67+ # return self._compile_json_object(node)
68+
69+ if node .__class__ .__name__ == "KeyTransform" :
70+ if getattr (node , "function" , None ) is None :
71+ node .function = "json_extract_path_text"
72+ return self ._compile_key_transform (node , force_text = force_text )
73+ elif node .__class__ .__name__ == "Cast" :
74+ return self ._compile_cast (node )
75+ elif node .__class__ .__name__ == "HasKey" :
76+ return self ._compile_has_key (node )
77+ elif node .__class__ .__name__ == "HasKeys" :
78+ return self ._compile_has_keys (node )
79+ elif node .__class__ .__name__ == "HasAnyKeys" :
80+ return self ._compile_has_any_keys (node )
81+
82+ return super ().compile (node )
83+
84+ def _compile_json_array (self , node ):
85+ if not getattr (node , "source_expressions" , None ):
86+ return "'[]'::json" , []
87+ params = []
88+ sql_parts = []
89+ for arg in node .source_expressions :
90+ arg_sql , arg_params = self .compile (arg )
91+ if not arg_sql :
92+ raise ValueError (f"Cannot compile JSONArray element: { arg !r} " )
93+ sql_parts .append (arg_sql )
94+ params .extend (arg_params )
95+
96+ sql = f"json_build_array({ ', ' .join (sql_parts )} )"
97+ return sql , params
98+
99+ def _compile_json_object (self , node ):
100+ expressions = getattr (node , "source_expressions" , []) or []
101+ if not expressions :
102+ return "'{}'::json" , []
103+ sql_parts = []
104+ params = []
105+ if len (expressions ) % 2 != 0 :
106+ raise ValueError (
107+ "JSONObject requires even number of arguments (key-value pairs)"
108+ )
109+ for i in range (0 , len (expressions ), 2 ):
110+ key_expr = expressions [i ]
111+ val_expr = expressions [i + 1 ]
112+ key_sql , key_params = self .compile (key_expr )
113+ val_sql , val_params = self .compile (val_expr )
114+
115+ key_value = getattr (key_expr , "value" , None )
116+ if isinstance (key_value , str ):
117+ key_sql = f"""'{ key_value .replace ("'" , "''" )} '"""
118+ key_params = []
119+
120+ if not key_sql or not val_sql :
121+ raise ValueError (
122+ f"Cannot compile key/value pair: { key_expr } , { val_expr } "
123+ )
124+
125+ sql_parts .append (f"{ key_sql } , { val_sql } " )
126+ params .extend (key_params + val_params )
127+ sql = f"json_build_object({ ', ' .join (sql_parts )} )"
128+ return sql , params
129+
130+ def _compile_key_transform (self , node , force_text = False ):
131+ def collect_path (n ):
132+ path = []
133+ while n .__class__ .__name__ == "KeyTransform" :
134+ key_expr = getattr (n , "key" , None ) or getattr (n , "path" , None )
135+ lhs = getattr (n , "lhs" , None )
136+
137+ if isinstance (lhs , JSONObject ) and key_expr is None :
138+ key_node = lhs .source_expressions [0 ]
139+ key_expr = getattr (key_node , "value" , key_node )
140+
141+ if key_expr is None :
142+ if lhs .__class__ .__name__ == "KeyTransform" :
143+ lhs , sub_path = collect_path (lhs )
144+ path .extend (sub_path )
145+ n = lhs
146+ continue
147+ else :
148+ return lhs , path
149+ if hasattr (key_expr , "value" ):
150+ key_expr = key_expr .value
151+ path .append (key_expr )
152+ n = lhs
153+
154+ return n , list (reversed (path ))
155+
156+ base_lhs , path = collect_path (node )
157+
158+ # if isinstance(base_lhs, JSONObject):
159+ # lhs_sql, lhs_params = self._compile_json_object(base_lhs)
160+ # current_type = "object"
161+ # elif isinstance(base_lhs, JSONArray):
162+ # lhs_sql, lhs_params = self._compile_json_array(base_lhs)
163+ # current_type = "array"
164+ if isinstance (base_lhs , Func ):
165+ return super ().compile (node )
166+ else :
167+ lhs_sql , lhs_params = super ().compile (base_lhs )
168+ current_type = "scalar"
169+ sql = lhs_sql
170+ numeric_fields = (IntegerField , FloatField )
171+
172+ for i , k in enumerate (path ):
173+ is_last = i == len (path ) - 1
174+
175+ if current_type in ("object" , "array" ):
176+ if is_last and (
177+ force_text
178+ or getattr (node , "_function_context" , False )
179+ or getattr (node , "is_ordering" , False )
180+ or isinstance (getattr (node , "output_field" , None ), numeric_fields )
181+ ):
182+ cast = (
183+ "numeric"
184+ if isinstance (
185+ getattr (node , "output_field" , None ), numeric_fields
186+ )
187+ else "text"
188+ )
189+ if current_type == "object" :
190+ sql = f"({ sql } ->>'{ k } ')::{ cast } "
191+ else :
192+ sql = f"({ sql } ->'{ k } ')::{ cast } "
193+ else :
194+ sql = f"{ sql } ->'{ k } '"
195+ current_type = "unknown"
196+ else :
197+ break
198+
199+ if not path and (
200+ force_text
201+ or getattr (node , "_function_context" , False )
202+ or getattr (node , "is_ordering" , False )
203+ ):
204+ sql = f"({ sql } )::text"
205+ if getattr (node , "_is_boolean_context" , False ):
206+ sql = (
207+ f"({ sql } ) IS NOT NULL"
208+ if getattr (node , "_negated" , False )
209+ else f"({ sql } ) IS NULL"
210+ )
211+ return sql , lhs_params
212+
213+ def _compile_cast (self , node ):
214+ try :
215+ inner_expr = getattr (node , "expression" , None )
216+ if inner_expr is None :
217+ inner_expr = (
218+ node .source_expressions [0 ]
219+ if getattr (node , "source_expressions" , None )
220+ else node
221+ )
222+
223+ expr_sql , expr_params = super ().compile (inner_expr )
224+ except Exception :
225+ return super ().compile (node )
226+
227+ db_type = None
228+ try :
229+ db_type = node .output_field .db_type (self .connection ) or "varchar"
230+ except Exception :
231+ db_type = "varchar"
232+
233+ invalid_cast_map = {
234+ "serial" : "integer" ,
235+ "bigserial" : "bigint" ,
236+ "smallserial" : "smallint" ,
237+ }
238+ db_type = invalid_cast_map .get (db_type , db_type )
239+ sql = f"{ expr_sql } ::{ db_type } "
240+ return sql , expr_params
241+
242+ def _compile_has_key (self , node ):
243+ lhs_sql , lhs_params = self .compile (node .lhs )
244+ params = lhs_params [:]
245+
246+ key_expr = (
247+ getattr (node , "rhs" , None )
248+ or getattr (node , "key" , None )
249+ or getattr (node , "_key" , None )
250+ )
251+ if key_expr is None :
252+ raise ValueError ("Cannot determine key for HasKey node" )
253+
254+ if isinstance (key_expr , str ):
255+ sql = f"{ lhs_sql } ? %s"
256+ params .append (key_expr )
257+ else :
258+ key_sql , key_params = self .compile (key_expr )
259+ if not key_sql :
260+ raise ValueError ("Cannot compile HasKey key expression" )
261+ sql = f"{ lhs_sql } ? ({ key_sql } )::text"
262+ params .extend (key_params )
263+
264+ return sql , params
265+
266+ def _compile_has_keys (self , node ):
267+ lhs_sql , lhs_params = self .compile (node .lhs )
268+ params = lhs_params [:]
269+
270+ keys = getattr (node , "rhs" , None ) or getattr (node , "keys" , None )
271+ if not keys :
272+ raise ValueError ("Cannot determine keys for HasKeys node" )
273+
274+ sql_parts = []
275+ for key_expr in keys :
276+ if isinstance (key_expr , str ):
277+ sql_parts .append ("%s" )
278+ params .append (key_expr )
279+ else :
280+ key_sql , key_params = self .compile (key_expr )
281+ sql_parts .append (f"({ key_sql } )::text" )
282+ params .extend (key_params )
283+
284+ keys_sql = ", " .join (sql_parts )
285+ sql = f"{ lhs_sql } ?& array[{ keys_sql } ]"
286+ return sql , params
287+
288+ def _compile_has_any_keys (self , node ):
289+ lhs_sql , lhs_params = self .compile (node .lhs )
290+ params = lhs_params [:]
291+
292+ keys = getattr (node , "rhs" , None ) or getattr (node , "keys" , None )
293+ if not keys :
294+ raise ValueError ("Cannot determine keys for HasAnyKeys node" )
295+
296+ sql_parts = []
297+ for key_expr in keys :
298+ if isinstance (key_expr , str ):
299+ sql_parts .append ("%s" )
300+ params .append (key_expr )
301+ else :
302+ key_sql , key_params = self .compile (key_expr )
303+ sql_parts .append (f"({ key_sql } )::text" )
304+ params .extend (key_params )
305+
306+ keys_sql = ", " .join (sql_parts )
307+ sql = f"{ lhs_sql } ?| array[{ keys_sql } ]"
308+ return sql , params
309+
310+ # class ManyToManyDescriptor(BaseManyToManyDescriptor):
311+ # def _add_items(self, manager, *objs, **kwargs):
312+ # print(f">>>ManyToManyDescriptor")
313+ # db = kwargs.get("using") or manager._db or "default"
314+ # for obj in objs:
315+ # try:
316+ # manager.through._default_manager.using(db).get_or_create(
317+ # **{
318+ # manager.source_field_name: manager.instance,
319+ # manager.target_field_name: obj,
320+ # }
321+ # )
322+ # except Exception:
323+ # pass
324+
325+ # def execute_sql(self, sql, params=None, many=False, returning_fields=None):
326+ # try:
327+ # return super().execute_sql(sql, params, many, returning_fields)
328+ # except utils.IntegrityError as e:
329+ # if "already exists" in str(e):
330+ # return
331+ # raise
0 commit comments