Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,38 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui
allocs = nil
}()

var ordinalIndex map[int]int
byOrdinal := func(n int) (driver.NamedValue, bool) {
if ordinalIndex == nil {
ordinalIndex = make(map[int]int, len(args))
for i, v := range args {
ordinalIndex[v.Ordinal] = i
}
}
i, ok := ordinalIndex[n]
if !ok {
return driver.NamedValue{}, false
}
return args[i], true
}

var namedIndex map[string]int
byName := func(name string) (driver.NamedValue, bool) {
if namedIndex == nil {
namedIndex = make(map[string]int, len(args))
for i, v := range args {
if v.Name != "" {
namedIndex[v.Name] = i
}
}
}
i, ok := namedIndex[name]
if !ok {
return driver.NamedValue{}, false
}
return args[i], true
}

for i := 1; i <= n; i++ {
name, err := c.bindParameterName(pstmt, i)
if err != nil {
Expand All @@ -1069,41 +1101,34 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui

var found bool
var v driver.NamedValue
for _, v = range args {
if name != "" {
// For ?NNN and $NNN params, match if NNN == v.Ordinal.
//
// Supporting this for $NNN is a special case that makes eg
// `select $1, $2, $3 ...` work without needing to use
// sql.Named.
if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) {
found = true
break
}

// sqlite supports '$', '@' and ':' prefixes for string
// identifiers and '?' for numeric, so we cannot
// combine different prefixes with the same name
// because `database/sql` requires variable names
// to start with a letter
if name[1:] == v.Name[:] {
found = true
break
}
} else {
if v.Ordinal == i {
found = true
break
if name == "" {
v, found = byOrdinal(i)
if !found {
return allocs, fmt.Errorf("missing argument with index %d", i)
}
} else {
// For ?NNN and $NNN params, match if NNN == v.Ordinal.
//
// Supporting this for $NNN is a special case that makes eg
// `select $1, $2, $3 ...` work without needing to use
// sql.Named.
if name[0] == '?' || name[0] == '$' && len(name) > 1 {
ordinal, err := strconv.ParseInt(name[1:], 10, 32)
if err == nil {
v, found = byOrdinal(int(ordinal))
if !found {
return allocs, fmt.Errorf("missing named numeric argument %q", name[1:])
}
}
}
}

if !found {
if name != "" {
return allocs, fmt.Errorf("missing named argument %q", name[1:])
if !found {
v, found = byName(name[1:])
if !found {
return allocs, fmt.Errorf("missing named argument %q", name)
}
}

return allocs, fmt.Errorf("missing argument with index %d", i)
}

var p uintptr
Expand Down
Loading