diff --git a/lib/xdb.go b/lib/xdb.go index 5ecd220..3a3098c 100644 --- a/lib/xdb.go +++ b/lib/xdb.go @@ -17,12 +17,20 @@ package lib import ( "context" "database/sql" + "fmt" "strings" ) type DBContexter interface { context.Context - Conn() *sql.DB + Conn() CommonConn +} + +type CommonConn interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } type DBContext struct { @@ -39,7 +47,11 @@ func NewDBContext(ctx context.Context, conn *sql.DB) *DBContext { } } -func (ctx *DBContext) Conn() *sql.DB { +func (ctx *DBContext) Conn() CommonConn { + // use tx first + if ctx.tx != nil { + return ctx.tx + } return ctx.conn } @@ -101,6 +113,15 @@ func RDBTxnExecute(dc *DBContext, handler func(context.Context) error) error { } } + defer func() { + if err1 := recover(); err1 != nil { + dc.tx.Rollback() + + err = fmt.Errorf("%v", err1) + return + } + }() + err = handler(dc) return commitOrRollback(dc.tx, err) }