@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172}
173173
174174// Closes the network connection and unsets internal variables. Do not call this
175- // function after successfully authentication, call Close instead. This function
175+ // function after successful authentication, call Close instead. This function
176176// is called before auth or on auth failure because MySQL will have already
177177// closed the network connection.
178178func (mc * mysqlConn ) cleanup () {
@@ -245,9 +245,106 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
245245 return stmt , err
246246}
247247
248+ // findParamPositions returns the positions of real parameter holders ('?') in the query, ignoring those in comments, strings, or backticks.
249+ func findParamPositions (query string ) []int {
250+ const (
251+ stateNormal = iota
252+ stateString
253+ stateEscape
254+ stateEOLComment
255+ stateSlashStarComment
256+ stateBacktick
257+ )
258+
259+ var (
260+ QUOTE_BYTE = byte ('\'' )
261+ DBL_QUOTE_BYTE = byte ('"' )
262+ BACKSLASH_BYTE = byte ('\\' )
263+ QUESTION_MARK_BYTE = byte ('?' )
264+ SLASH_BYTE = byte ('/' )
265+ STAR_BYTE = byte ('*' )
266+ HASH_BYTE = byte ('#' )
267+ MINUS_BYTE = byte ('-' )
268+ LINE_FEED_BYTE = byte ('\n' )
269+ RADICAL_BYTE = byte ('`' )
270+ )
271+
272+ paramPositions := make ([]int , 0 )
273+ state := stateNormal
274+ singleQuotes := false
275+ lastChar := byte (0 )
276+ lenq := len (query )
277+ for i := 0 ; i < lenq ; i ++ {
278+ currentChar := query [i ]
279+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
280+ state = stateString
281+ lastChar = currentChar
282+ continue
283+ }
284+ switch currentChar {
285+ case STAR_BYTE :
286+ if state == stateNormal && lastChar == SLASH_BYTE {
287+ state = stateSlashStarComment
288+ }
289+ case SLASH_BYTE :
290+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
291+ state = stateNormal
292+ } else if state == stateNormal && lastChar == SLASH_BYTE {
293+ state = stateEOLComment
294+ }
295+ case HASH_BYTE :
296+ if state == stateNormal {
297+ state = stateEOLComment
298+ }
299+ case MINUS_BYTE :
300+ if state == stateNormal && lastChar == MINUS_BYTE {
301+ state = stateEOLComment
302+ }
303+ case LINE_FEED_BYTE :
304+ if state == stateEOLComment {
305+ state = stateNormal
306+ }
307+ case DBL_QUOTE_BYTE :
308+ if state == stateNormal {
309+ state = stateString
310+ singleQuotes = false
311+ } else if state == stateString && ! singleQuotes {
312+ state = stateNormal
313+ } else if state == stateEscape {
314+ state = stateString
315+ }
316+ case QUOTE_BYTE :
317+ if state == stateNormal {
318+ state = stateString
319+ singleQuotes = true
320+ } else if state == stateString && singleQuotes {
321+ state = stateNormal
322+ } else if state == stateEscape {
323+ state = stateString
324+ }
325+ case BACKSLASH_BYTE :
326+ if state == stateString {
327+ state = stateEscape
328+ }
329+ case QUESTION_MARK_BYTE :
330+ if state == stateNormal {
331+ paramPositions = append (paramPositions , i )
332+ }
333+ case RADICAL_BYTE :
334+ if state == stateBacktick {
335+ state = stateNormal
336+ } else if state == stateNormal {
337+ state = stateBacktick
338+ }
339+ }
340+ lastChar = currentChar
341+ }
342+ return paramPositions
343+ }
344+
248345func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249- // Number of ? should be same to len(args )
250- if strings . Count ( query , "?" ) != len (args ) {
346+ paramPositions := findParamPositions ( query )
347+ if len ( paramPositions ) != len (args ) {
251348 return "" , driver .ErrSkip
252349 }
253350
@@ -261,21 +358,16 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
261358 }
262359 buf = buf [:0 ]
263360 argPos := 0
361+ lastIdx := 0
264362
265- for i := 0 ; i < len (query ); i ++ {
266- q := strings .IndexByte (query [i :], '?' )
267- if q == - 1 {
268- buf = append (buf , query [i :]... )
269- break
270- }
271- buf = append (buf , query [i :i + q ]... )
272- i += q
273-
363+ for _ , qmIdx := range paramPositions {
364+ buf = append (buf , query [lastIdx :qmIdx ]... )
274365 arg := args [argPos ]
275366 argPos ++
276367
277368 if arg == nil {
278369 buf = append (buf , "NULL" ... )
370+ lastIdx = qmIdx + 1
279371 continue
280372 }
281373
@@ -339,7 +431,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
339431 if len (buf )+ 4 > mc .maxAllowedPacket {
340432 return "" , driver .ErrSkip
341433 }
434+ lastIdx = qmIdx + 1
342435 }
436+ buf = append (buf , query [lastIdx :]... )
343437 if argPos != len (args ) {
344438 return "" , driver .ErrSkip
345439 }
0 commit comments