Skip to content

Commit 4b9ef3c

Browse files
committed
Added ANSI_QUOTES mode
1 parent 24fd111 commit 4b9ef3c

File tree

7 files changed

+117
-26
lines changed

7 files changed

+117
-26
lines changed

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,34 @@ func main() {
6868
}
6969
```
7070

71+
Parsing SQL mode `ANSI_QUOTES`:
72+
73+
Treat `"` as an identifier quote character (like the \` quote character) and not as a string quote character. You can still use \` to quote identifiers with this mode enabled. With `ANSI_QUOTES` enabled, you cannot use double quotation marks to quote literal strings because they are interpreted as identifiers.
74+
75+
```go
76+
package main
77+
78+
import (
79+
"github.com/SananGuliyev/sqlparser"
80+
)
81+
82+
func main() {
83+
sql := "SELECT * FROM table WHERE a = 'abc'"
84+
sqlparser.SQLMode = sqlparser.SQLModeANSIQuotes
85+
stmt, err := sqlparser.Parse(sql)
86+
if err != nil {
87+
// Do something with the err
88+
}
89+
90+
// Otherwise do something with stmt
91+
switch stmt := stmt.(type) {
92+
case *sqlparser.Select:
93+
_ = stmt
94+
case *sqlparser.Insert:
95+
}
96+
}
97+
```
98+
7199
See [parse_test.go](https://github.com/SananGuliyev/sqlparser/blob/master/parse_test.go) for more examples, or read the [godoc](https://godoc.org/github.com/SananGuliyev/sqlparser).
72100

73101

analyzer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ func ExtractSetValues(sql string) (keyValues map[SetKey]interface{}, scope strin
300300
if setStmt.Scope != "" && scope != "" {
301301
return nil, "", fmt.Errorf("unsupported in set: mixed using of variable scope")
302302
}
303-
_, out := NewStringTokenizer(key).Scan()
303+
_, out := NewStringTokenizer(key, SQLMode).Scan()
304304
key = string(out)
305305
}
306306

ast.go

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ import (
2929
"github.com/SananGuliyev/sqlparser/dependency/sqltypes"
3030
)
3131

32+
const (
33+
SQLModeStandard = iota
34+
SQLModeANSIQuotes
35+
)
36+
37+
var SQLMode = SQLModeStandard
38+
3239
// Instructions for creating new types: If a type
3340
// needs to satisfy an interface, declare that function
3441
// along with that interface. This will help users
@@ -46,7 +53,7 @@ import (
4653
// is partially parsed but still contains a syntax error, the
4754
// error is ignored and the DDL is returned anyway.
4855
func Parse(sql string) (Statement, error) {
49-
tokenizer := NewStringTokenizer(sql)
56+
tokenizer := NewStringTokenizer(sql, SQLMode)
5057
if yyParse(tokenizer) != 0 {
5158
if tokenizer.partialDDL != nil {
5259
log.Printf("ignoring error parsing DDL '%s': %v", sql, tokenizer.LastError)
@@ -61,7 +68,7 @@ func Parse(sql string) (Statement, error) {
6168
// ParseStrictDDL is the same as Parse except it errors on
6269
// partially parsed DDL statements.
6370
func ParseStrictDDL(sql string) (Statement, error) {
64-
tokenizer := NewStringTokenizer(sql)
71+
tokenizer := NewStringTokenizer(sql, SQLMode)
6572
if yyParse(tokenizer) != 0 {
6673
return nil, tokenizer.LastError
6774
}
@@ -97,7 +104,7 @@ func ParseNext(tokenizer *Tokenizer) (Statement, error) {
97104
// SplitStatement returns the first sql statement up to either a ; or EOF
98105
// and the remainder from the given buffer
99106
func SplitStatement(blob string) (string, string, error) {
100-
tokenizer := NewStringTokenizer(blob)
107+
tokenizer := NewStringTokenizer(blob, SQLMode)
101108
tkn := 0
102109
for {
103110
tkn, _ = tokenizer.Scan()
@@ -118,7 +125,7 @@ func SplitStatement(blob string) (string, string, error) {
118125
// returns the sql pieces blob contains; or error if sql cannot be parsed
119126
func SplitStatementToPieces(blob string) (pieces []string, err error) {
120127
pieces = make([]string, 0, 16)
121-
tokenizer := NewStringTokenizer(blob)
128+
tokenizer := NewStringTokenizer(blob, SQLMode)
122129

123130
tkn := 0
124131
var stmt string
@@ -3430,6 +3437,12 @@ func Backtick(in string) string {
34303437
}
34313438

34323439
func formatID(buf *TrackedBuffer, original, lowered string) {
3440+
var identChar rune
3441+
if SQLMode == SQLModeANSIQuotes {
3442+
identChar = '"'
3443+
} else {
3444+
identChar = '`'
3445+
}
34333446
isDbSystemVariable := false
34343447
if len(original) > 1 && original[:2] == "@@" {
34353448
isDbSystemVariable = true
@@ -3449,14 +3462,14 @@ func formatID(buf *TrackedBuffer, original, lowered string) {
34493462
return
34503463

34513464
mustEscape:
3452-
buf.WriteByte('`')
3465+
_, _ = buf.WriteRune(identChar)
34533466
for _, c := range original {
3454-
buf.WriteRune(c)
3455-
if c == '`' {
3456-
buf.WriteByte('`')
3467+
_, _ = buf.WriteRune(c)
3468+
if c == identChar {
3469+
_, _ = buf.WriteRune(identChar)
34573470
}
34583471
}
3459-
buf.WriteByte('`')
3472+
_, _ = buf.WriteRune(identChar)
34603473
}
34613474

34623475
func compliantName(in string) string {

parse_next_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func TestParseNextErrors(t *testing.T) {
6767
}
6868

6969
sql := tcase.input + "; select 1 from t"
70-
tokens := NewStringTokenizer(sql)
70+
tokens := NewStringTokenizer(sql, SQLMode)
7171

7272
// The first statement should be an error
7373
_, err := ParseNext(tokens)
@@ -136,13 +136,12 @@ func TestParseNextEdgeCases(t *testing.T) {
136136
}}
137137

138138
for _, test := range tests {
139-
tokens := NewStringTokenizer(test.input)
139+
tokens := NewStringTokenizer(test.input, SQLMode)
140140

141141
for i, want := range test.want {
142142
tree, err := ParseNext(tokens)
143143
if err != nil {
144144
t.Fatalf("[%d] ParseNext(%q) err = %q, want nil", i, test.input, err)
145-
continue
146145
}
147146

148147
if got := String(tree); got != want {

token.go

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type Tokenizer struct {
4444
posVarIndex int
4545
ParseTree Statement
4646
partialDDL *DDL
47+
sqlMode int
4748
nesting int
4849
multi bool
4950
specialComment *Tokenizer
@@ -55,11 +56,12 @@ type Tokenizer struct {
5556

5657
// NewStringTokenizer creates a new Tokenizer for the
5758
// sql string.
58-
func NewStringTokenizer(sql string) *Tokenizer {
59+
func NewStringTokenizer(sql string, sqlMode int) *Tokenizer {
5960
buf := []byte(sql)
6061
return &Tokenizer{
6162
buf: buf,
6263
bufSize: len(buf),
64+
sqlMode: sqlMode,
6365
}
6466
}
6567

@@ -595,7 +597,12 @@ func (tkn *Tokenizer) Scan() (int, []byte) {
595597
return NE, nil
596598
}
597599
return int(ch), nil
598-
case '\'', '"':
600+
case '\'':
601+
return tkn.scanString(ch, STRING)
602+
case '"':
603+
if tkn.sqlMode == SQLModeANSIQuotes {
604+
return tkn.scanLiteralIdentifier()
605+
}
599606
return tkn.scanString(ch, STRING)
600607
case '`':
601608
return tkn.scanLiteralIdentifier()
@@ -667,25 +674,41 @@ func (tkn *Tokenizer) scanBitLiteral() (int, []byte) {
667674
func (tkn *Tokenizer) scanLiteralIdentifier() (int, []byte) {
668675
buffer := &bytes2.Buffer{}
669676
backTickSeen := false
677+
quoteSeen := false
670678
for {
671679
if backTickSeen {
672680
if tkn.lastChar != '`' {
673681
break
674682
}
675683
backTickSeen = false
676-
buffer.WriteByte('`')
684+
_ = buffer.WriteByte('`')
685+
tkn.next()
686+
continue
687+
}
688+
if quoteSeen {
689+
if tkn.lastChar != '"' {
690+
break
691+
}
692+
quoteSeen = false
693+
_ = buffer.WriteByte('"')
677694
tkn.next()
678695
continue
679696
}
680697
// The previous char was not a backtick.
681698
switch tkn.lastChar {
682699
case '`':
683700
backTickSeen = true
701+
case '"':
702+
if tkn.sqlMode == SQLModeANSIQuotes {
703+
quoteSeen = true
704+
} else {
705+
_ = buffer.WriteByte(byte(tkn.lastChar))
706+
}
684707
case eofChar:
685708
// Premature EOF.
686709
return LEX_ERROR, buffer.Bytes()
687710
default:
688-
buffer.WriteByte(byte(tkn.lastChar))
711+
_ = buffer.WriteByte(byte(tkn.lastChar))
689712
}
690713
tkn.next()
691714
}
@@ -880,7 +903,7 @@ func (tkn *Tokenizer) scanMySQLSpecificComment() (int, []byte) {
880903
tkn.consumeNext(buffer)
881904
}
882905
_, sql := ExtractMysqlComment(buffer.String())
883-
tkn.specialComment = NewStringTokenizer(sql)
906+
tkn.specialComment = NewStringTokenizer(sql, SQLMode)
884907
return tkn.Scan()
885908
}
886909

token_test.go

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func TestLiteralID(t *testing.T) {
5757
}}
5858

5959
for _, tcase := range testcases {
60-
tkn := NewStringTokenizer(tcase.in)
60+
tkn := NewStringTokenizer(tcase.in, SQLMode)
6161
id, out := tkn.Scan()
6262
if tcase.id != id || string(out) != tcase.out {
6363
t.Errorf("Scan(%s): %d, %s, want %d, %s", tcase.in, id, out, tcase.id, tcase.out)
@@ -130,7 +130,7 @@ func TestString(t *testing.T) {
130130
}}
131131

132132
for _, tcase := range testcases {
133-
id, got := NewStringTokenizer(tcase.in).Scan()
133+
id, got := NewStringTokenizer(tcase.in, SQLMode).Scan()
134134
if tcase.id != id || string(got) != tcase.want {
135135
t.Errorf("Scan(%q) = (%s, %q), want (%s, %q)", tcase.in, tokenName(id), got, tokenName(tcase.id), tcase.want)
136136
}
@@ -189,3 +189,31 @@ func TestSplitStatement(t *testing.T) {
189189
}
190190
}
191191
}
192+
193+
func TestParseANSIQuotesMode(t *testing.T) {
194+
testcases := []struct {
195+
in string
196+
out string
197+
}{{
198+
in: `select * from "table"`,
199+
out: `select * from "table"`,
200+
}, {
201+
in: `select * from "tbl"`,
202+
out: `select * from tbl`,
203+
}}
204+
205+
SQLMode = SQLModeANSIQuotes
206+
for _, tcase := range testcases {
207+
stmt, err := Parse(tcase.in)
208+
if err != nil {
209+
t.Errorf("EndOfStatementPosition(%s): ERROR: %v", tcase.in, err)
210+
continue
211+
}
212+
213+
finalSQL := String(stmt)
214+
if tcase.out != finalSQL {
215+
t.Errorf("EndOfStatementPosition(%s) got sql \"%s\" want \"%s\"", tcase.in, finalSQL, tcase.out)
216+
}
217+
}
218+
SQLMode = SQLModeStandard
219+
}

tracked_buffer.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
6868
i++
6969
}
7070
if i > lasti {
71-
buf.WriteString(format[lasti:i])
71+
_, _ = buf.WriteString(format[lasti:i])
7272
}
7373
if i >= end {
7474
break
@@ -78,18 +78,18 @@ func (buf *TrackedBuffer) Myprintf(format string, values ...interface{}) {
7878
case 'c':
7979
switch v := values[fieldnum].(type) {
8080
case byte:
81-
buf.WriteByte(v)
81+
_ = buf.WriteByte(v)
8282
case rune:
83-
buf.WriteRune(v)
83+
_, _ = buf.WriteRune(v)
8484
default:
8585
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
8686
}
8787
case 's':
8888
switch v := values[fieldnum].(type) {
8989
case []byte:
90-
buf.Write(v)
90+
_, _ = buf.Write(v)
9191
case string:
92-
buf.WriteString(v)
92+
_, _ = buf.WriteString(v)
9393
default:
9494
panic(fmt.Sprintf("unexpected TrackedBuffer type %T", v))
9595
}
@@ -118,7 +118,7 @@ func (buf *TrackedBuffer) WriteArg(arg string) {
118118
offset: buf.Len(),
119119
length: len(arg),
120120
})
121-
buf.WriteString(arg)
121+
_, _ = buf.WriteString(arg)
122122
}
123123

124124
// ParsedQuery returns a ParsedQuery that contains bind

0 commit comments

Comments
 (0)