Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,13 +901,19 @@ def attempt_tracking_method():

return [partition.partition_token for partition in response.partitions]

def _begin_transaction(self, mutation: Mutation = None) -> bytes:
def _begin_transaction(
self, mutation: Mutation = None, transaction_tag: str = None
) -> bytes:
"""Begins a transaction on the database.

:type mutation: :class:`~google.cloud.spanner_v1.mutation.Mutation`
:param mutation: (Optional) Mutation to include in the begin transaction
request. Required for mutation-only transactions with multiplexed sessions.

:type transaction_tag: str
:param transaction_tag: (Optional) Transaction tag to include in the begin transaction
request.

:rtype: bytes
:returns: identifier for the transaction.

Expand All @@ -931,6 +937,17 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
(_metadata_with_leader_aware_routing(database._route_to_leader_enabled))
)

begin_request_kwargs = {
"session": session.name,
"options": self._build_transaction_selector_pb().begin,
"mutation_key": mutation,
}

if transaction_tag:
begin_request_kwargs["request_options"] = RequestOptions(
transaction_tag=transaction_tag
)

with trace_call(
name=f"CloudSpanner.{type(self).__name__}.begin",
session=session,
Expand All @@ -942,9 +959,7 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:

def wrapped_method():
begin_transaction_request = BeginTransactionRequest(
session=session.name,
options=self._build_transaction_selector_pb().begin,
mutation_key=mutation,
**begin_request_kwargs
)
begin_transaction_method = functools.partial(
api.begin_transaction,
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,9 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
if self.rolled_back:
raise ValueError("Transaction is already rolled back")

return super(Transaction, self)._begin_transaction(mutation=mutation)
return super(Transaction, self)._begin_transaction(
mutation=mutation, transaction_tag=self.transaction_tag
)

def _begin_mutations_only_transaction(self) -> None:
"""Begins a mutations-only transaction on the database."""
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,9 +2005,12 @@ def unit_of_work(txn, *args, **kw):
self.assertEqual(kw, {"some_arg": "def"})

expected_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
expected_request_options = RequestOptions(transaction_tag=transaction_tag)
gax_api.begin_transaction.assert_called_once_with(
request=BeginTransactionRequest(
session=self.SESSION_NAME, options=expected_options
session=self.SESSION_NAME,
options=expected_options,
request_options=expected_request_options,
),
metadata=[
("google-cloud-resource-prefix", database.name),
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def _commit_helper(
session=session.name,
options=TransactionOptions(read_write=TransactionOptions.ReadWrite()),
mutation_key=expected_begin_mutation,
request_options=RequestOptions(transaction_tag=TRANSACTION_TAG),
)

expected_begin_metadata = base_metadata.copy()
Expand Down