Skip to content

Commit eb03c32

Browse files
committed
fix: transaction_tag should be set on BeginTransactionRequest
1 parent 24394d6 commit eb03c32

File tree

4 files changed

+30
-6
lines changed

4 files changed

+30
-6
lines changed

google/cloud/spanner_v1/snapshot.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -901,13 +901,19 @@ def attempt_tracking_method():
901901

902902
return [partition.partition_token for partition in response.partitions]
903903

904-
def _begin_transaction(self, mutation: Mutation = None) -> bytes:
904+
def _begin_transaction(
905+
self, mutation: Mutation = None, transaction_tag: str = None
906+
) -> bytes:
905907
"""Begins a transaction on the database.
906908
907909
:type mutation: :class:`~google.cloud.spanner_v1.mutation.Mutation`
908910
:param mutation: (Optional) Mutation to include in the begin transaction
909911
request. Required for mutation-only transactions with multiplexed sessions.
910912
913+
:type transaction_tag: str
914+
:param transaction_tag: (Optional) Transaction tag to include in the begin transaction
915+
request.
916+
911917
:rtype: bytes
912918
:returns: identifier for the transaction.
913919
@@ -931,6 +937,17 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
931937
(_metadata_with_leader_aware_routing(database._route_to_leader_enabled))
932938
)
933939

940+
begin_request_kwargs = {
941+
"session": session.name,
942+
"options": self._build_transaction_selector_pb().begin,
943+
"mutation_key": mutation,
944+
}
945+
946+
if transaction_tag:
947+
begin_request_kwargs["request_options"] = RequestOptions(
948+
transaction_tag=transaction_tag
949+
)
950+
934951
with trace_call(
935952
name=f"CloudSpanner.{type(self).__name__}.begin",
936953
session=session,
@@ -942,9 +959,7 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
942959

943960
def wrapped_method():
944961
begin_transaction_request = BeginTransactionRequest(
945-
session=session.name,
946-
options=self._build_transaction_selector_pb().begin,
947-
mutation_key=mutation,
962+
**begin_request_kwargs
948963
)
949964
begin_transaction_method = functools.partial(
950965
api.begin_transaction,

google/cloud/spanner_v1/transaction.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,9 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
714714
if self.rolled_back:
715715
raise ValueError("Transaction is already rolled back")
716716

717-
return super(Transaction, self)._begin_transaction(mutation=mutation)
717+
return super(Transaction, self)._begin_transaction(
718+
mutation=mutation, transaction_tag=self.transaction_tag
719+
)
718720

719721
def _begin_mutations_only_transaction(self) -> None:
720722
"""Begins a mutations-only transaction on the database."""

tests/unit/test_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2005,9 +2005,12 @@ def unit_of_work(txn, *args, **kw):
20052005
self.assertEqual(kw, {"some_arg": "def"})
20062006

20072007
expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
2008+
expected_request_options = RequestOptions(transaction_tag=transaction_tag)
20082009
gax_api.begin_transaction.assert_called_once_with(
20092010
request=BeginTransactionRequest(
2010-
session=self.SESSION_NAME, options=expected_options
2011+
session=self.SESSION_NAME,
2012+
options=expected_options,
2013+
request_options=expected_request_options,
20112014
),
20122015
metadata=[
20132016
("google-cloud-resource-prefix", database.name),

tests/unit/test_transaction.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,14 @@ def _commit_helper(
463463
if mutations is not None:
464464
self.assertEqual(transaction._transaction_id, TRANSACTION_ID)
465465

466+
request_options = RequestOptions()
467+
request_options.transaction_tag = TRANSACTION_TAG
468+
466469
expected_begin_transaction_request = BeginTransactionRequest(
467470
session=session.name,
468471
options=TransactionOptions(read_write=TransactionOptions.ReadWrite()),
469472
mutation_key=expected_begin_mutation,
473+
request_options=request_options,
470474
)
471475

472476
expected_begin_metadata = base_metadata.copy()

0 commit comments

Comments
 (0)