Skip to content

Commit 4b4596e

Browse files
Suman DeSuman De
authored andcommitted
feat: add support for yugabyte smart driver
1 parent ce83d3f commit 4b4596e

File tree

4 files changed

+64
-46
lines changed

4 files changed

+64
-46
lines changed

internal/codegen/golang/driver.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ func parseDriver(sqlPackage string) opts.SQLDriver {
88
return opts.SQLDriverPGXV4
99
case opts.SQLPackagePGXV5:
1010
return opts.SQLDriverPGXV5
11+
case opts.SQLPackageYugaBytePGXV5:
12+
return opts.SQLDriverYugaBytePGXV5
1113
default:
1214
return opts.SQLDriverLibPQ
1315
}

internal/codegen/golang/imports.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ func (i *importer) dbImports() fileImports {
132132
case opts.SQLDriverPGXV5:
133133
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"})
134134
pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgx/v5"})
135+
136+
case opts.SQLDriverYugaBytePGXV5:
137+
pkg = append(pkg, ImportSpec{Path: "github.com/yugabyte/pgx/v5/pgconn"})
138+
pkg = append(pkg, ImportSpec{Path: "github.com/yugabyte/pgx/v5"})
135139
default:
136140
std = append(std, ImportSpec{Path: "database/sql"})
137141
if i.Options.EmitPreparedQueries {
@@ -176,6 +180,8 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
176180
pkg[ImportSpec{Path: "github.com/jackc/pgconn"}] = struct{}{}
177181
case opts.SQLDriverPGXV5:
178182
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgconn"}] = struct{}{}
183+
case opts.SQLDriverYugaBytePGXV5:
184+
pkg[ImportSpec{Path: "github.com/yugabyte/pgx/v5/pgconn"}] = struct{}{}
179185
default:
180186
std["database/sql"] = struct{}{}
181187
}
@@ -191,6 +197,8 @@ func buildImports(options *opts.Options, queries []Query, uses func(string) bool
191197
if uses("pgtype.") {
192198
if sqlpkg == opts.SQLDriverPGXV5 {
193199
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5/pgtype"}] = struct{}{}
200+
} else if sqlpkg == opts.SQLDriverYugaBytePGXV5 {
201+
pkg[ImportSpec{Path: "github.com/yugabyte/pgx/v5/pgtype"}] = struct{}{}
194202
} else {
195203
pkg[ImportSpec{Path: "github.com/jackc/pgtype"}] = struct{}{}
196204
}
@@ -489,6 +497,8 @@ func (i *importer) batchImports() fileImports {
489497
pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{}
490498
case opts.SQLDriverPGXV5:
491499
pkg[ImportSpec{Path: "github.com/jackc/pgx/v5"}] = struct{}{}
500+
case opts.SQLDriverYugaBytePGXV5:
501+
pkg[ImportSpec{Path: "github.com/yugabyte/pgx/v5"}] = struct{}{}
492502
}
493503

494504
return sortedImports(std, pkg)

internal/codegen/golang/opts/enum.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ import "fmt"
55
type SQLDriver string
66

77
const (
8-
SQLPackagePGXV4 string = "pgx/v4"
9-
SQLPackagePGXV5 string = "pgx/v5"
10-
SQLPackageStandard string = "database/sql"
8+
SQLPackagePGXV4 string = "pgx/v4"
9+
SQLPackagePGXV5 string = "pgx/v5"
10+
SQLPackageStandard string = "database/sql"
11+
SQLPackageYugaBytePGXV5 string = "yb/pgx/v5"
1112
)
1213

1314
var validPackages = map[string]struct{}{
14-
string(SQLPackagePGXV4): {},
15-
string(SQLPackagePGXV5): {},
16-
string(SQLPackageStandard): {},
15+
string(SQLPackagePGXV4): {},
16+
string(SQLPackagePGXV5): {},
17+
string(SQLPackageYugaBytePGXV5): {},
18+
string(SQLPackageStandard): {},
1719
}
1820

1921
func validatePackage(sqlPackage string) error {
@@ -28,13 +30,15 @@ const (
2830
SQLDriverPGXV5 = "github.com/jackc/pgx/v5"
2931
SQLDriverLibPQ = "github.com/lib/pq"
3032
SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql"
33+
SQLDriverYugaBytePGXV5 = "github.com/yugabyte/pgx/v5"
3134
)
3235

3336
var validDrivers = map[string]struct{}{
3437
string(SQLDriverPGXV4): {},
3538
string(SQLDriverPGXV5): {},
3639
string(SQLDriverLibPQ): {},
3740
string(SQLDriverGoSQLDriverMySQL): {},
41+
string(SQLDriverYugaBytePGXV5): {},
3842
}
3943

4044
func validateDriver(sqlDriver string) error {
@@ -45,7 +49,7 @@ func validateDriver(sqlDriver string) error {
4549
}
4650

4751
func (d SQLDriver) IsPGX() bool {
48-
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5
52+
return d == SQLDriverPGXV4 || d == SQLDriverPGXV5 || d == SQLDriverYugaBytePGXV5
4953
}
5054

5155
func (d SQLDriver) IsGoSQLDriverMySQL() bool {
@@ -58,6 +62,8 @@ func (d SQLDriver) Package() string {
5862
return SQLPackagePGXV4
5963
case SQLDriverPGXV5:
6064
return SQLPackagePGXV5
65+
case SQLDriverYugaBytePGXV5:
66+
return SQLPackageYugaBytePGXV5
6167
default:
6268
return SQLPackageStandard
6369
}

0 commit comments

Comments
 (0)