@@ -59,6 +59,9 @@ type SpannerConn interface {
5959 // RunBatch sends all batched DDL or DML statements to Spanner. This is a
6060 // no-op if no statements have been batched or if there is no active batch.
6161 RunBatch (ctx context.Context ) error
62+ // RunDmlBatch sends all batched DML statements to Spanner. This is a
63+ // no-op if no statements have been batched or if there is no active DML batch.
64+ RunDmlBatch (ctx context.Context ) (SpannerResult , error )
6265 // AbortBatch aborts the current DDL or DML batch and discards all batched
6366 // statements.
6467 AbortBatch () error
@@ -202,6 +205,11 @@ type SpannerConn interface {
202205 withTempBatchReadOnlyTransactionOptions (options * BatchReadOnlyTransactionOptions )
203206}
204207
208+ type SpannerResult interface {
209+ driver.Result
210+ BatchRowsAffected () ([]int64 , error )
211+ }
212+
205213var _ SpannerConn = & conn {}
206214
207215type conn struct {
@@ -261,6 +269,7 @@ type conn struct {
261269 // tempBatchReadOnlyTransactionOptions are temporarily set right before a
262270 // batch read-only transaction is started on a Spanner connection.
263271 tempBatchReadOnlyTransactionOptions * BatchReadOnlyTransactionOptions
272+ tempProtoTransactionOptions * spannerpb.TransactionOptions
264273}
265274
266275func (c * conn ) UnderlyingClient () (* spanner.Client , error ) {
@@ -444,6 +453,18 @@ func (c *conn) RunBatch(ctx context.Context) error {
444453 return err
445454}
446455
456+ func (c * conn ) RunDmlBatch (ctx context.Context ) (SpannerResult , error ) {
457+ res , err := c .runBatch (ctx )
458+ if err != nil {
459+ return nil , err
460+ }
461+ spannerRes , ok := res .(SpannerResult )
462+ if ! ok {
463+ return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "not a DML batch" ))
464+ }
465+ return spannerRes , nil
466+ }
467+
447468func (c * conn ) AbortBatch () error {
448469 _ , err := c .abortBatch ()
449470 return err
@@ -522,7 +543,7 @@ func (c *conn) runDDLBatch(ctx context.Context) (driver.Result, error) {
522543 return c .execDDL (ctx , statements ... )
523544}
524545
525- func (c * conn ) runDMLBatch (ctx context.Context ) (driver. Result , error ) {
546+ func (c * conn ) runDMLBatch (ctx context.Context ) (SpannerResult , error ) {
526547 statements := c .batch .statements
527548 options := c .batch .options
528549 options .QueryOptions .LastStatement = true
@@ -541,7 +562,7 @@ func (c *conn) abortBatch() (driver.Result, error) {
541562
542563func (c * conn ) execDDL (ctx context.Context , statements ... spanner.Statement ) (driver.Result , error ) {
543564 if c .batch != nil && c .batch .tp == dml {
544- return nil , spanner .ToSpannerError (status .Error (codes .FailedPrecondition , "This connection has an active DML batch" ))
565+ return nil , spanner .ToSpannerError (status .Error (codes .FailedPrecondition , "This connection has an active DDL batch" ))
545566 }
546567 if c .batch != nil && c .batch .tp == ddl {
547568 c .batch .statements = append (c .batch .statements , statements ... )
@@ -567,7 +588,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr
567588 return driver .ResultNoRows , nil
568589}
569590
570- func (c * conn ) execBatchDML (ctx context.Context , statements []spanner.Statement , options ExecOptions ) (driver. Result , error ) {
591+ func (c * conn ) execBatchDML (ctx context.Context , statements []spanner.Statement , options ExecOptions ) (SpannerResult , error ) {
571592 if len (statements ) == 0 {
572593 return & result {}, nil
573594 }
@@ -586,7 +607,7 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement,
586607 return err
587608 }, options .TransactionOptions )
588609 }
589- return & result {rowsAffected : sum (affected )}, err
610+ return & result {rowsAffected : sum (affected ), batchUpdateCounts : affected }, err
590611}
591612
592613func sum (affected []int64 ) int64 {
0 commit comments