From da8a72fd3a33ae4be32da66da7e407199a9b9521 Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Wed, 18 Dec 2024 04:31:44 +0000 Subject: [PATCH 1/2] conn.bind: use indexes by ordinal or name when finding args --- sqlite.go | 62 ++++++++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/sqlite.go b/sqlite.go index 3e63143..7736405 100644 --- a/sqlite.go +++ b/sqlite.go @@ -1061,6 +1061,15 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui allocs = nil }() + ordinalIndex := make(map[int]driver.NamedValue, len(args)) + namedIndex := make(map[string]driver.NamedValue, len(args)) + for _, v := range args { + ordinalIndex[v.Ordinal] = v + if v.Name != "" { + namedIndex[v.Name] = v + } + } + for i := 1; i <= n; i++ { name, err := c.bindParameterName(pstmt, i) if err != nil { @@ -1069,41 +1078,34 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui var found bool var v driver.NamedValue - for _, v = range args { - if name != "" { - // For ?NNN and $NNN params, match if NNN == v.Ordinal. - // - // Supporting this for $NNN is a special case that makes eg - // `select $1, $2, $3 ...` work without needing to use - // sql.Named. - if (name[0] == '?' || name[0] == '$') && name[1:] == strconv.Itoa(v.Ordinal) { - found = true - break - } - // sqlite supports '$', '@' and ':' prefixes for string - // identifiers and '?' for numeric, so we cannot - // combine different prefixes with the same name - // because `database/sql` requires variable names - // to start with a letter - if name[1:] == v.Name[:] { - found = true - break - } - } else { - if v.Ordinal == i { - found = true - break + if name == "" { + v, found = ordinalIndex[i] + if !found { + return allocs, fmt.Errorf("missing argument with index %d", i) + } + } else { + // For ?NNN and $NNN params, match if NNN == v.Ordinal. + // + // Supporting this for $NNN is a special case that makes eg + // `select $1, $2, $3 ...` work without needing to use + // sql.Named. + if name[0] == '?' || name[0] == '$' && len(name) > 1 { + ordinal, err := strconv.ParseInt(name[1:], 10, 32) + if err == nil { + v, found = ordinalIndex[int(ordinal)] + if !found { + return allocs, fmt.Errorf("missing named numeric argument %q", name[1:]) + } } } - } - if !found { - if name != "" { - return allocs, fmt.Errorf("missing named argument %q", name[1:]) + if !found { + v, found = namedIndex[name[1:]] + if !found { + return allocs, fmt.Errorf("missing named argument %q", name) + } } - - return allocs, fmt.Errorf("missing argument with index %d", i) } var p uintptr From e558987f60f9e57f085f4e406abe22b8afa58a3b Mon Sep 17 00:00:00 2001 From: Paul Querna Date: Wed, 18 Dec 2024 05:05:37 +0000 Subject: [PATCH 2/2] minimize allocations --- sqlite.go | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/sqlite.go b/sqlite.go index 7736405..82a85cb 100644 --- a/sqlite.go +++ b/sqlite.go @@ -1061,13 +1061,36 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui allocs = nil }() - ordinalIndex := make(map[int]driver.NamedValue, len(args)) - namedIndex := make(map[string]driver.NamedValue, len(args)) - for _, v := range args { - ordinalIndex[v.Ordinal] = v - if v.Name != "" { - namedIndex[v.Name] = v + var ordinalIndex map[int]int + byOrdinal := func(n int) (driver.NamedValue, bool) { + if ordinalIndex == nil { + ordinalIndex = make(map[int]int, len(args)) + for i, v := range args { + ordinalIndex[v.Ordinal] = i + } + } + i, ok := ordinalIndex[n] + if !ok { + return driver.NamedValue{}, false + } + return args[i], true + } + + var namedIndex map[string]int + byName := func(name string) (driver.NamedValue, bool) { + if namedIndex == nil { + namedIndex = make(map[string]int, len(args)) + for i, v := range args { + if v.Name != "" { + namedIndex[v.Name] = i + } + } + } + i, ok := namedIndex[name] + if !ok { + return driver.NamedValue{}, false } + return args[i], true } for i := 1; i <= n; i++ { @@ -1080,7 +1103,7 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui var v driver.NamedValue if name == "" { - v, found = ordinalIndex[i] + v, found = byOrdinal(i) if !found { return allocs, fmt.Errorf("missing argument with index %d", i) } @@ -1093,7 +1116,7 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui if name[0] == '?' || name[0] == '$' && len(name) > 1 { ordinal, err := strconv.ParseInt(name[1:], 10, 32) if err == nil { - v, found = ordinalIndex[int(ordinal)] + v, found = byOrdinal(int(ordinal)) if !found { return allocs, fmt.Errorf("missing named numeric argument %q", name[1:]) } @@ -1101,7 +1124,7 @@ func (c *conn) bind(pstmt uintptr, n int, args []driver.NamedValue) (allocs []ui } if !found { - v, found = namedIndex[name[1:]] + v, found = byName(name[1:]) if !found { return allocs, fmt.Errorf("missing named argument %q", name) }