diff --git a/.gitignore b/.gitignore index d083ea1dd..220056c20 100644 --- a/.gitignore +++ b/.gitignore @@ -62,3 +62,6 @@ system_tests/local_test_setup # Make sure a generated file isn't accidentally committed. pylintrc pylintrc.test + +# Benchmarking results and logs +__benchmark_results__/** diff --git a/.kokoro/trampoline_v2.sh b/.kokoro/trampoline_v2.sh index 35fa52923..d03f92dfc 100755 --- a/.kokoro/trampoline_v2.sh +++ b/.kokoro/trampoline_v2.sh @@ -26,8 +26,8 @@ # To run this script, first download few files from gcs to /dev/shm. # (/dev/shm is passed into the container as KOKORO_GFILE_DIR). # -# gsutil cp gs://cloud-devrel-kokoro-resources/python-docs-samples/secrets_viewer_service_account.json /dev/shm -# gsutil cp gs://cloud-devrel-kokoro-resources/python-docs-samples/automl_secrets.txt /dev/shm +# gcloud storage cp gs://cloud-devrel-kokoro-resources/python-docs-samples/secrets_viewer_service_account.json /dev/shm +# gcloud storage cp gs://cloud-devrel-kokoro-resources/python-docs-samples/automl_secrets.txt /dev/shm # # Then run the script. # .kokoro/trampoline_v2.sh diff --git a/.librarian/state.yaml b/.librarian/state.yaml index 1502e804d..80e2355be 100644 --- a/.librarian/state.yaml +++ b/.librarian/state.yaml @@ -1,7 +1,7 @@ image: us-central1-docker.pkg.dev/cloud-sdk-librarian-prod/images-prod/python-librarian-generator@sha256:8e2c32496077054105bd06c54a59d6a6694287bc053588e24debe6da6920ad91 libraries: - id: google-cloud-storage - version: 3.6.0 + version: 3.9.0 last_generated_commit: 5400ccce473c439885bd6bf2924fd242271bfcab apis: - path: google/storage/v2 diff --git a/CHANGELOG.md b/CHANGELOG.md index da1f2149b..4c46db115 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,60 @@ [1]: https://pypi.org/project/google-cloud-storage/#history +## [3.9.0](https://github.com/googleapis/python-storage/compare/v3.8.0...v3.9.0) (2026-02-02) + + +### Features + +* add get_object method for async grpc client (#1735) ([0e5ec29bc6a31b77bcfba4254cef5bffb199095c](https://github.com/googleapis/python-storage/commit/0e5ec29bc6a31b77bcfba4254cef5bffb199095c)) +* expose `DELETE_OBJECT` in `AsyncGrpcClient` (#1718) ([c8dd7a0b124c395b7b60189ee78f47aba8d51f7d](https://github.com/googleapis/python-storage/commit/c8dd7a0b124c395b7b60189ee78f47aba8d51f7d)) +* update generation for MRD (#1730) ([08bc7082db7392f13bc8c51511b4afa9c7b157c9](https://github.com/googleapis/python-storage/commit/08bc7082db7392f13bc8c51511b4afa9c7b157c9)) +* Move Zonal Buckets features of `_experimental` (#1728) ([74c9ecc54173420bfcd48498a8956088a035af50](https://github.com/googleapis/python-storage/commit/74c9ecc54173420bfcd48498a8956088a035af50)) +* add default user agent for grpc (#1726) ([7b319469d2e495ea0bf7367f3949190e8f5d9fff](https://github.com/googleapis/python-storage/commit/7b319469d2e495ea0bf7367f3949190e8f5d9fff)) +* expose finalized_time in blob.py applicable for GET_OBJECT in ZB (#1719) ([8e21a7fe54d0a043f31937671003630a1985a5d2](https://github.com/googleapis/python-storage/commit/8e21a7fe54d0a043f31937671003630a1985a5d2)) +* add context manager to mrd (#1724) ([5ac2808a69195c688ed42c3604d4bfadbb602a66](https://github.com/googleapis/python-storage/commit/5ac2808a69195c688ed42c3604d4bfadbb602a66)) +* integrate writes strategy and appendable object writer (#1695) ([dbd162b3583e32e6f705a51f5c3fef333a9b89d0](https://github.com/googleapis/python-storage/commit/dbd162b3583e32e6f705a51f5c3fef333a9b89d0)) +* Add support for opening via `write_handle` and fix `write_handle` type (#1715) ([2bc15fa570683ba584230c51b439d189dbdcd580](https://github.com/googleapis/python-storage/commit/2bc15fa570683ba584230c51b439d189dbdcd580)) +* Add micro-benchmarks for writes comparing standard (regional) vs rapid (zonal) buckets. (#1707) ([dbe9d8b89d975dfbed8c830a5687ccfafea51d5f](https://github.com/googleapis/python-storage/commit/dbe9d8b89d975dfbed8c830a5687ccfafea51d5f)) +* Add micro-benchmarks for reads comparing standard (regional) vs rapid (zonal) buckets. (#1697) ([1917649fac41481da1adea6c2a9f4ab1298a34c4](https://github.com/googleapis/python-storage/commit/1917649fac41481da1adea6c2a9f4ab1298a34c4)) +* send `user_agent` to grpc channel (#1712) ([cdb2486bb051dcbfbffc2510aff6aacede5e54d3](https://github.com/googleapis/python-storage/commit/cdb2486bb051dcbfbffc2510aff6aacede5e54d3)) +* add samples for appendable objects writes and reads (#1705) ([2e1a1eb5cbe1c909f1f892a0cc74fe63c8ef36ff](https://github.com/googleapis/python-storage/commit/2e1a1eb5cbe1c909f1f892a0cc74fe63c8ef36ff)) +* add samples for appendable objects writes and reads ([2e1a1eb5cbe1c909f1f892a0cc74fe63c8ef36ff](https://github.com/googleapis/python-storage/commit/2e1a1eb5cbe1c909f1f892a0cc74fe63c8ef36ff)) +* add support for `generation=0` to avoid overwriting existing objects and add `is_stream_open` support (#1709) ([ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae](https://github.com/googleapis/python-storage/commit/ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae)) +* add support for `generation=0` to prevent overwriting existing objects ([ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae](https://github.com/googleapis/python-storage/commit/ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae)) +* add `is_stream_open` property to AsyncAppendableObjectWriter for stream status check ([ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae](https://github.com/googleapis/python-storage/commit/ea0f5bf8316f4bfcff2728d9d1baa68dde6ebdae)) + + +### Bug Fixes + +* receive eof while closing reads stream (#1733) ([2ef63396dca1c36f9b0f0f3cf87a61b5aa4bd465](https://github.com/googleapis/python-storage/commit/2ef63396dca1c36f9b0f0f3cf87a61b5aa4bd465)) +* Change contructors of MRD and AAOW AsyncGrpcClient.grpc_client to AsyncGrpcClient (#1727) ([e730bf50c4584f737ab86b2e409ddb27b40d2cec](https://github.com/googleapis/python-storage/commit/e730bf50c4584f737ab86b2e409ddb27b40d2cec)) +* instance grpc client once per process in benchmarks (#1725) ([721ea2dd6c6db2aa91fd3b90e56a831aaaa64061](https://github.com/googleapis/python-storage/commit/721ea2dd6c6db2aa91fd3b90e56a831aaaa64061)) +* update write handle on every recv() (#1716) ([5d9fafe1466b5ccb1db4a814967a5cc8465148a2](https://github.com/googleapis/python-storage/commit/5d9fafe1466b5ccb1db4a814967a5cc8465148a2)) +* Fix formatting in setup.py dependencies list (#1713) ([cc4831d7e253b265b0b96e08b5479f4c759be442](https://github.com/googleapis/python-storage/commit/cc4831d7e253b265b0b96e08b5479f4c759be442)) +* implement requests_done method to signal end of requests in async streams. Gracefully close streams. (#1700) ([6c160794afded5e8f4179399f1fe5248e32bf707](https://github.com/googleapis/python-storage/commit/6c160794afded5e8f4179399f1fe5248e32bf707)) +* implement requests_done method to signal end of requests in async streams. Gracefully close streams. ([6c160794afded5e8f4179399f1fe5248e32bf707](https://github.com/googleapis/python-storage/commit/6c160794afded5e8f4179399f1fe5248e32bf707)) + +## [3.8.0](https://github.com/googleapis/python-storage/compare/v3.7.0...v3.8.0) (2026-01-13) + + +### Features + +* flush the last chunk in append method (#1699) ([89bfe7a5fcd0391da35e9ceccc185279782b5420](https://github.com/googleapis/python-storage/commit/89bfe7a5fcd0391da35e9ceccc185279782b5420)) +* add write resumption strategy (#1663) ([a57ea0ec786a84c7ae9ed82c6ae5d38ecadba4af](https://github.com/googleapis/python-storage/commit/a57ea0ec786a84c7ae9ed82c6ae5d38ecadba4af)) +* add bidi stream retry manager. (#1632) ([d90f0ee09902a21b186106bcf0a8cb0b81b34340](https://github.com/googleapis/python-storage/commit/d90f0ee09902a21b186106bcf0a8cb0b81b34340)) +* implement "append_from_file" (#1686) ([1333c956da18b4db753cda98c41c3619c84caf69](https://github.com/googleapis/python-storage/commit/1333c956da18b4db753cda98c41c3619c84caf69)) +* make flush size configurable (#1677) ([f7095faf0a81239894ff9d277849788b62eb6ac5](https://github.com/googleapis/python-storage/commit/f7095faf0a81239894ff9d277849788b62eb6ac5)) +* compute chunk wise checksum for bidi_writes (#1675) ([139390cb01f93a2d61e7ec201e3637dffe0b2a34](https://github.com/googleapis/python-storage/commit/139390cb01f93a2d61e7ec201e3637dffe0b2a34)) +* expose persisted size in mrd (#1671) ([0e2961bef285fc064174a5c18e3db05c7a682521](https://github.com/googleapis/python-storage/commit/0e2961bef285fc064174a5c18e3db05c7a682521)) + + +### Bug Fixes + +* add system test for opening with read_handle (#1672) ([6dc711dacd4d38c573aa4ca9ad71fe412c0e49c1](https://github.com/googleapis/python-storage/commit/6dc711dacd4d38c573aa4ca9ad71fe412c0e49c1)) +* no state lookup while opening bidi-write stream (#1636) ([2d5a7b16846a69f3a911844971241899f60cce14](https://github.com/googleapis/python-storage/commit/2d5a7b16846a69f3a911844971241899f60cce14)) +* close write object stream always (#1661) ([4a609a4b3f4ba1396825911cb02f8a9649135cd5](https://github.com/googleapis/python-storage/commit/4a609a4b3f4ba1396825911cb02f8a9649135cd5)) + ## [3.7.0](https://github.com/googleapis/python-storage/compare/v3.6.0...v3.7.0) (2025-12-09) diff --git a/cloudbuild/run_zonal_tests.sh b/cloudbuild/run_zonal_tests.sh index ef94e629b..22ca8fe4b 100644 --- a/cloudbuild/run_zonal_tests.sh +++ b/cloudbuild/run_zonal_tests.sh @@ -6,6 +6,7 @@ sudo apt-get update && sudo apt-get install -y git python3-pip python3-venv # Clone the repository and checkout the specific commit from the build trigger. git clone https://github.com/googleapis/python-storage.git cd python-storage +git fetch origin "refs/pull/${_PR_NUMBER}/head" git checkout ${COMMIT_SHA} @@ -22,5 +23,7 @@ pip install -e . echo '--- Setting up environment variables on VM ---' export ZONAL_BUCKET=${_ZONAL_BUCKET} export RUN_ZONAL_SYSTEM_TESTS=True -echo '--- Running Zonal tests on VM ---' +CURRENT_ULIMIT=$(ulimit -n) +echo '--- Running Zonal tests on VM with ulimit set to ---' $CURRENT_ULIMIT pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' tests/system/test_zonal.py +pytest -vv -s --log-format='%(asctime)s %(levelname)s %(message)s' --log-date-format='%H:%M:%S' samples/snippets/zonal_buckets/zonal_snippets_test.py diff --git a/cloudbuild/zb-system-tests-cloudbuild.yaml b/cloudbuild/zb-system-tests-cloudbuild.yaml index 383c4fa96..562eae175 100644 --- a/cloudbuild/zb-system-tests-cloudbuild.yaml +++ b/cloudbuild/zb-system-tests-cloudbuild.yaml @@ -3,6 +3,7 @@ substitutions: _ZONE: "us-central1-a" _SHORT_BUILD_ID: ${BUILD_ID:0:8} _VM_NAME: "py-sdk-sys-test-${_SHORT_BUILD_ID}" + _ULIMIT: "10000" # 10k, for gRPC bidi streams @@ -67,7 +68,7 @@ steps: # Execute the script on the VM via SSH. # Capture the exit code to ensure cleanup happens before the build fails. set +e - gcloud compute ssh ${_VM_NAME} --zone=${_ZONE} --internal-ip --ssh-key-file=/workspace/.ssh/google_compute_engine --command="COMMIT_SHA=${COMMIT_SHA} _ZONAL_BUCKET=${_ZONAL_BUCKET} bash run_zonal_tests.sh" + gcloud compute ssh ${_VM_NAME} --zone=${_ZONE} --internal-ip --ssh-key-file=/workspace/.ssh/google_compute_engine --command="ulimit -n {_ULIMIT}; COMMIT_SHA=${COMMIT_SHA} _ZONAL_BUCKET=${_ZONAL_BUCKET} _PR_NUMBER=${_PR_NUMBER} bash run_zonal_tests.sh" EXIT_CODE=$? set -e diff --git a/google/cloud/_storage_v2/gapic_version.py b/google/cloud/_storage_v2/gapic_version.py index d69b0530e..0d5599e8b 100644 --- a/google/cloud/_storage_v2/gapic_version.py +++ b/google/cloud/_storage_v2/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "3.6.0" # {x-release-please-version} +__version__ = "3.9.0" # {x-release-please-version} diff --git a/google/cloud/storage/_experimental/asyncio/_utils.py b/google/cloud/storage/_experimental/asyncio/_utils.py new file mode 100644 index 000000000..7e81a4bc7 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/_utils.py @@ -0,0 +1,11 @@ +import warnings + +# Import everything from the new stable module +from google.cloud.storage.asyncio._utils import * # noqa + +warnings.warn( + "google.cloud.storage._experimental.asyncio._utils has been moved to google.cloud.storage.asyncio._utils. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/async_abstract_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_abstract_object_stream.py index 49d7a293a..538241bd2 100644 --- a/google/cloud/storage/_experimental/asyncio/async_abstract_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_abstract_object_stream.py @@ -1,67 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import warnings -import abc -from typing import Any, Optional +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_abstract_object_stream import * # noqa - -class _AsyncAbstractObjectStream(abc.ABC): - """Abstract base class to represent gRPC bidi-stream for GCS ``Object``. - - Concrete implementation of this class could be ``_AsyncReadObjectStream`` - or ``_AsyncWriteObjectStream``. - - :type bucket_name: str - :param bucket_name: (Optional) The name of the bucket containing the object. - - :type object_name: str - :param object_name: (Optional) The name of the object. - - :type generation_number: int - :param generation_number: (Optional) If present, selects a specific revision of - this object. - - :type handle: bytes - :param handle: (Optional) The handle for the object, could be read_handle or - write_handle, based on how the stream is used. - """ - - def __init__( - self, - bucket_name: str, - object_name: str, - generation_number: Optional[int] = None, - handle: Optional[bytes] = None, - ) -> None: - super().__init__() - self.bucket_name: str = bucket_name - self.object_name: str = object_name - self.generation_number: Optional[int] = generation_number - self.handle: Optional[bytes] = handle - - @abc.abstractmethod - async def open(self) -> None: - pass - - @abc.abstractmethod - async def close(self) -> None: - pass - - @abc.abstractmethod - async def send(self, protobuf: Any) -> None: - pass - - @abc.abstractmethod - async def recv(self) -> Any: - pass +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_abstract_object_stream has been moved to google.cloud.storage.asyncio.async_abstract_object_stream. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py index 27c4b4f19..53b813643 100644 --- a/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py +++ b/google/cloud/storage/_experimental/asyncio/async_appendable_object_writer.py @@ -1,321 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE: -This is _experimental module for upcoming support for Rapid Storage. -(https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new) +import warnings -APIs may not work as intended and are not stable yet. Feature is not -GA(Generally Available) yet, please contact your TAM (Technical Account Manager) -if you want to use these Rapid Storage APIs. +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_appendable_object_writer import * # noqa -""" -from typing import Optional, Union -from google.cloud import _storage_v2 -from google.cloud.storage._experimental.asyncio.async_grpc_client import ( - AsyncGrpcClient, +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_appendable_object_writer has been moved to google.cloud.storage.asyncio.async_appendable_object_writer. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, ) -from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( - _AsyncWriteObjectStream, -) - - -_MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB -_MAX_BUFFER_SIZE_BYTES = 16 * 1024 * 1024 # 16 MiB - - -class AsyncAppendableObjectWriter: - """Class for appending data to a GCS Appendable Object asynchronously.""" - - def __init__( - self, - client: AsyncGrpcClient.grpc_client, - bucket_name: str, - object_name: str, - generation=None, - write_handle=None, - ): - """ - Class for appending data to a GCS Appendable Object. - - Example usage: - - ``` - - from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import AsyncAppendableObjectWriter - import asyncio - - client = AsyncGrpcClient().grpc_client - bucket_name = "my-bucket" - object_name = "my-appendable-object" - - # instantiate the writer - writer = AsyncAppendableObjectWriter(client, bucket_name, object_name) - # open the writer, (underlying gRPC bidi-stream will be opened) - await writer.open() - - # append data, it can be called multiple times. - await writer.append(b"hello world") - await writer.append(b"some more data") - - # optionally flush data to persist. - await writer.flush() - - # close the gRPC stream. - # Please note closing the program will also close the stream, - # however it's recommended to close the stream if no more data to append - # to clean up gRPC connection (which means CPU/memory/network resources) - await writer.close() - ``` - - :type client: :class:`~google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` - :param client: async grpc client to use for making API requests. - - :type bucket_name: str - :param bucket_name: The name of the GCS bucket containing the object. - - :type object_name: str - :param object_name: The name of the GCS Appendable Object to be written. - - :type generation: int - :param generation: (Optional) If present, selects a specific revision of - that object. - If None, a new object is created. - If None and Object already exists then it'll will be - overwritten. - - :type write_handle: bytes - :param write_handle: (Optional) An existing handle for writing the object. - If provided, opening the bidi-gRPC connection will be faster. - """ - self.client = client - self.bucket_name = bucket_name - self.object_name = object_name - self.write_handle = write_handle - self.generation = generation - - self.write_obj_stream = _AsyncWriteObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation, - write_handle=self.write_handle, - ) - self._is_stream_open: bool = False - # `offset` is the latest size of the object without staleless. - self.offset: Optional[int] = None - # `persisted_size` is the total_bytes persisted in the GCS server. - # Please note: `offset` and `persisted_size` are same when the stream is - # opened. - self.persisted_size: Optional[int] = None - - async def state_lookup(self) -> int: - """Returns the persisted_size - - :rtype: int - :returns: persisted size. - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before state_lookup().") - - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - state_lookup=True, - ) - ) - response = await self.write_obj_stream.recv() - self.persisted_size = response.persisted_size - return self.persisted_size - - async def open(self) -> None: - """Opens the underlying bidi-gRPC stream. - - :raises ValueError: If the stream is already open. - - """ - if self._is_stream_open: - raise ValueError("Underlying bidi-gRPC stream is already open") - - await self.write_obj_stream.open() - self._is_stream_open = True - if self.generation is None: - self.generation = self.write_obj_stream.generation_number - self.write_handle = self.write_obj_stream.write_handle - self.persisted_size = self.write_obj_stream.persisted_size - - async def append(self, data: bytes) -> None: - """Appends data to the Appendable object. - - calling `self.append` will append bytes at the end of the current size - ie. `self.offset` bytes relative to the begining of the object. - - This method sends the provided `data` to the GCS server in chunks. - and persists data in GCS at every `_MAX_BUFFER_SIZE_BYTES` bytes by - calling `self.simple_flush`. - - :type data: bytes - :param data: The bytes to append to the object. - - :rtype: None - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ - - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before append().") - total_bytes = len(data) - if total_bytes == 0: - # TODO: add warning. - return - if self.offset is None: - assert self.persisted_size is not None - self.offset = self.persisted_size - - start_idx = 0 - bytes_to_flush = 0 - while start_idx < total_bytes: - end_idx = min(start_idx + _MAX_CHUNK_SIZE_BYTES, total_bytes) - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - write_offset=self.offset, - checksummed_data=_storage_v2.ChecksummedData( - content=data[start_idx:end_idx] - ), - ) - ) - chunk_size = end_idx - start_idx - self.offset += chunk_size - bytes_to_flush += chunk_size - if bytes_to_flush >= _MAX_BUFFER_SIZE_BYTES: - await self.simple_flush() - bytes_to_flush = 0 - start_idx = end_idx - - async def simple_flush(self) -> None: - """Flushes the data to the server. - Please note: Unlike `flush` it does not do `state_lookup` - - :rtype: None - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before simple_flush().") - - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - flush=True, - ) - ) - - async def flush(self) -> int: - """Flushes the data to the server. - - :rtype: int - :returns: The persisted size after flush. - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before flush().") - - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest( - flush=True, - state_lookup=True, - ) - ) - response = await self.write_obj_stream.recv() - self.persisted_size = response.persisted_size - self.offset = self.persisted_size - return self.persisted_size - - async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object]: - """Closes the underlying bidi-gRPC stream. - - :type finalize_on_close: bool - :param finalize_on_close: Finalizes the Appendable Object. No more data - can be appended. - - rtype: Union[int, _storage_v2.Object] - returns: Updated `self.persisted_size` by default after closing the - bidi-gRPC stream. However, if `finalize_on_close=True` is passed, - returns the finalized object resource. - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - - """ - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before close().") - - if finalize_on_close: - await self.finalize() - else: - await self.flush() - - await self.write_obj_stream.close() - - self._is_stream_open = False - self.offset = None - return self.object_resource if finalize_on_close else self.persisted_size - - async def finalize(self) -> _storage_v2.Object: - """Finalizes the Appendable Object. - - Note: Once finalized no more data can be appended. - - rtype: google.cloud.storage_v2.types.Object - returns: The finalized object resource. - - :raises ValueError: If the stream is not open (i.e., `open()` has not - been called). - """ - if not self._is_stream_open: - raise ValueError("Stream is not open. Call open() before finalize().") - - await self.write_obj_stream.send( - _storage_v2.BidiWriteObjectRequest(finish_write=True) - ) - response = await self.write_obj_stream.recv() - self.object_resource = response.resource - self.persisted_size = self.object_resource.size - return self.object_resource - - # helper methods. - async def append_from_string(self, data: str): - """ - str data will be encoded to bytes using utf-8 encoding calling - - self.append(data.encode("utf-8")) - """ - raise NotImplementedError("append_from_string is not implemented yet.") - - async def append_from_stream(self, stream_obj): - """ - At a time read a chunk of data (16MiB) from `stream_obj` - and call self.append(chunk) - """ - raise NotImplementedError("append_from_stream is not implemented yet.") - - async def append_from_file(self, file_path: str): - """Create a file object from `file_path` and call append_from_stream(file_obj)""" - raise NotImplementedError("append_from_file is not implemented yet.") diff --git a/google/cloud/storage/_experimental/asyncio/async_client.py b/google/cloud/storage/_experimental/asyncio/async_client.py index bd8817a09..3c4fddbca 100644 --- a/google/cloud/storage/_experimental/asyncio/async_client.py +++ b/google/cloud/storage/_experimental/asyncio/async_client.py @@ -16,10 +16,14 @@ import functools -from google.cloud.storage._experimental.asyncio.async_helpers import ASYNC_DEFAULT_TIMEOUT +from google.cloud.storage._experimental.asyncio.async_helpers import ( + ASYNC_DEFAULT_TIMEOUT, +) from google.cloud.storage._experimental.asyncio.async_helpers import ASYNC_DEFAULT_RETRY from google.cloud.storage._experimental.asyncio.async_helpers import AsyncHTTPIterator -from google.cloud.storage._experimental.asyncio.async_helpers import _do_nothing_page_start +from google.cloud.storage._experimental.asyncio.async_helpers import ( + _do_nothing_page_start, +) from google.cloud.storage._opentelemetry_tracing import create_trace_span from google.cloud.storage._experimental.asyncio.async_creds import AsyncCredsWrapper from google.cloud.storage.abstracts.base_client import BaseClient @@ -28,6 +32,7 @@ try: from google.auth.aio.transport import sessions + AsyncSession = sessions.AsyncAuthorizedSession _AIO_AVAILABLE = True except ImportError: @@ -70,12 +75,16 @@ def __init__( client_info=client_info, client_options=client_options, extra_headers=extra_headers, - api_key=api_key + api_key=api_key, ) - self.credentials = AsyncCredsWrapper(self._credentials) # self._credential is synchronous. - self._connection = AsyncConnection(self, **self.connection_kw_args) # adapter for async communication + self.credentials = AsyncCredsWrapper( + self._credentials + ) # self._credential is synchronous. + self._connection = AsyncConnection( + self, **self.connection_kw_args + ) # adapter for async communication self._async_http_internal = _async_http - self._async_http_passed_by_user = (_async_http is not None) + self._async_http_passed_by_user = _async_http is not None @property def async_http(self): @@ -86,7 +95,10 @@ def async_http(self): async def close(self): """Close the session, if it exists""" - if self._async_http_internal is not None and not self._async_http_passed_by_user: + if ( + self._async_http_internal is not None + and not self._async_http_passed_by_user + ): await self._async_http_internal.close() async def _get_resource( diff --git a/google/cloud/storage/_experimental/asyncio/async_creds.py b/google/cloud/storage/_experimental/asyncio/async_creds.py index 2fb899b19..e2abc3316 100644 --- a/google/cloud/storage/_experimental/asyncio/async_creds.py +++ b/google/cloud/storage/_experimental/asyncio/async_creds.py @@ -5,21 +5,23 @@ try: from google.auth.aio import credentials as aio_creds_module + BaseCredentials = aio_creds_module.Credentials _AIO_AVAILABLE = True except ImportError: BaseCredentials = object _AIO_AVAILABLE = False + class AsyncCredsWrapper(BaseCredentials): """Wraps synchronous Google Auth credentials to provide an asynchronous interface. Args: - sync_creds (google.auth.credentials.Credentials): The synchronous credentials + sync_creds (google.auth.credentials.Credentials): The synchronous credentials instance to wrap. Raises: - ImportError: If instantiated in an environment where 'google.auth.aio' + ImportError: If instantiated in an environment where 'google.auth.aio' is not available. """ @@ -36,9 +38,7 @@ def __init__(self, sync_creds): async def refresh(self, request): """Refreshes the access token.""" loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, self.creds.refresh, Request() - ) + await loop.run_in_executor(None, self.creds.refresh, Request()) @property def valid(self): diff --git a/google/cloud/storage/_experimental/asyncio/async_grpc_client.py b/google/cloud/storage/_experimental/asyncio/async_grpc_client.py index a5cccca59..558ff0c5a 100644 --- a/google/cloud/storage/_experimental/asyncio/async_grpc_client.py +++ b/google/cloud/storage/_experimental/asyncio/async_grpc_client.py @@ -1,90 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import warnings -"""An async client for interacting with Google Cloud Storage using the gRPC API.""" +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_grpc_client import * # noqa -from google.cloud import _storage_v2 as storage_v2 - - -class AsyncGrpcClient: - """An asynchronous client for interacting with Google Cloud Storage using the gRPC API. - - :type credentials: :class:`~google.auth.credentials.Credentials` - :param credentials: (Optional) The OAuth2 Credentials to use for this - client. If not passed, falls back to the default - inferred from the environment. - - :type client_info: :class:`~google.api_core.client_info.ClientInfo` - :param client_info: - The client info used to send a user-agent string along with API - requests. If ``None``, then default info will be used. - - :type client_options: :class:`~google.api_core.client_options.ClientOptions` or :class:`dict` - :param client_options: (Optional) Client options used to set user options - on the client. - - :type attempt_direct_path: bool - :param attempt_direct_path: - (Optional) Whether to attempt to use DirectPath for gRPC connections. - Defaults to ``True``. - """ - - def __init__( - self, - credentials=None, - client_info=None, - client_options=None, - *, - attempt_direct_path=True, - ): - self._grpc_client = self._create_async_grpc_client( - credentials=credentials, - client_info=client_info, - client_options=client_options, - attempt_direct_path=attempt_direct_path, - ) - - def _create_async_grpc_client( - self, - credentials=None, - client_info=None, - client_options=None, - attempt_direct_path=True, - ): - transport_cls = storage_v2.StorageAsyncClient.get_transport_class( - "grpc_asyncio" - ) - channel = transport_cls.create_channel( - attempt_direct_path=attempt_direct_path, credentials=credentials - ) - transport = transport_cls(channel=channel) - - return storage_v2.StorageAsyncClient( - transport=transport, - client_info=client_info, - client_options=client_options, - ) - - @property - def grpc_client(self): - """The underlying gRPC client. - - This property gives users direct access to the `_storage_v2.StorageAsyncClient` - instance. This can be useful for accessing - newly added or experimental RPCs that are not yet exposed through - the high-level GrpcClient. - Returns: - google.cloud._storage_v2.StorageAsyncClient: The configured GAPIC client. - """ - return self._grpc_client +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_grpc_client has been moved to google.cloud.storage.asyncio.async_grpc_client. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/async_helpers.py b/google/cloud/storage/_experimental/asyncio/async_helpers.py index bfebfaafa..4a7d78732 100644 --- a/google/cloud/storage/_experimental/asyncio/async_helpers.py +++ b/google/cloud/storage/_experimental/asyncio/async_helpers.py @@ -24,6 +24,7 @@ async def _do_nothing_page_start(iterator, page, response): # pylint: disable=unused-argument pass + class AsyncHTTPIterator(AsyncIterator): """A generic class for iterating through HTTP/JSON API list responses asynchronously. @@ -32,7 +33,7 @@ class AsyncHTTPIterator(AsyncIterator): api_request (Callable): The **async** function to use to make API requests. This must be an awaitable. path (str): The method path to query for the list of items. - item_to_value (Callable[AsyncIterator, Any]): Callable to convert an item + item_to_value (Callable[AsyncIterator, Any]): Callable to convert an item from the type in the JSON response into a native object. items_key (str): The key in the API response where the list of items can be found. @@ -40,7 +41,7 @@ class AsyncHTTPIterator(AsyncIterator): page_size (int): The maximum number of results to fetch per page. max_results (int): The maximum number of results to fetch. extra_params (dict): Extra query string parameters for the API call. - page_start (Callable): Callable to provide special behavior after a new page + page_start (Callable): Callable to provide special behavior after a new page is created. next_token (str): The name of the field used in the response for page tokens. """ @@ -137,6 +138,4 @@ def _get_query_params(self): async def _get_next_page_response(self): """Requests the next page from the path provided asynchronously.""" params = self._get_query_params() - return await self.api_request( - method="GET", path=self.path, query_params=params - ) + return await self.api_request(method="GET", path=self.path, query_params=params) diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index c05cb1c08..bfc2c7c2b 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -1,348 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import warnings -from __future__ import annotations -import asyncio -import google_crc32c -from google.api_core import exceptions -from google_crc32c import Checksum +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_multi_range_downloader import * # noqa -from typing import List, Optional, Tuple - -from google.cloud.storage._experimental.asyncio.async_read_object_stream import ( - _AsyncReadObjectStream, -) -from google.cloud.storage._experimental.asyncio.async_grpc_client import ( - AsyncGrpcClient, +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader has been moved to google.cloud.storage.asyncio.async_multi_range_downloader. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, ) - -from io import BytesIO -from google.cloud import _storage_v2 -from google.cloud.storage.exceptions import DataCorruption -from google.cloud.storage._helpers import generate_random_56_bit_integer - - -_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 - - -class Result: - """An instance of this class will be populated and retured for each - `read_range` provided to ``download_ranges`` method. - - """ - - def __init__(self, bytes_requested: int): - # only while instantiation, should not be edited later. - # hence there's no setter, only getter is provided. - self._bytes_requested: int = bytes_requested - self._bytes_written: int = 0 - - @property - def bytes_requested(self) -> int: - return self._bytes_requested - - @property - def bytes_written(self) -> int: - return self._bytes_written - - @bytes_written.setter - def bytes_written(self, value: int): - self._bytes_written = value - - def __repr__(self): - return f"bytes_requested: {self._bytes_requested}, bytes_written: {self._bytes_written}" - - -class AsyncMultiRangeDownloader: - """Provides an interface for downloading multiple ranges of a GCS ``Object`` - concurrently. - - Example usage: - - .. code-block:: python - - client = AsyncGrpcClient().grpc_client - mrd = await AsyncMultiRangeDownloader.create_mrd( - client, bucket_name="chandrasiri-rs", object_name="test_open9" - ) - my_buff1 = open('my_fav_file.txt', 'wb') - my_buff2 = BytesIO() - my_buff3 = BytesIO() - my_buff4 = any_object_which_provides_BytesIO_like_interface() - await mrd.download_ranges( - [ - # (start_byte, bytes_to_read, writeable_buffer) - (0, 100, my_buff1), - (100, 20, my_buff2), - (200, 123, my_buff3), - (300, 789, my_buff4), - ] - ) - - # verify data in buffers... - assert my_buff2.getbuffer().nbytes == 20 - - - """ - - @classmethod - async def create_mrd( - cls, - client: AsyncGrpcClient.grpc_client, - bucket_name: str, - object_name: str, - generation_number: Optional[int] = None, - read_handle: Optional[bytes] = None, - ) -> AsyncMultiRangeDownloader: - """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC - object for reading. - - :type client: :class:`~google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` - :param client: The asynchronous client to use for making API requests. - - :type bucket_name: str - :param bucket_name: The name of the bucket containing the object. - - :type object_name: str - :param object_name: The name of the object to be read. - - :type generation_number: int - :param generation_number: (Optional) If present, selects a specific - revision of this object. - - :type read_handle: bytes - :param read_handle: (Optional) An existing handle for reading the object. - If provided, opening the bidi-gRPC connection will be faster. - - :rtype: :class:`~google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader` - :returns: An initialized AsyncMultiRangeDownloader instance for reading. - """ - mrd = cls(client, bucket_name, object_name, generation_number, read_handle) - await mrd.open() - return mrd - - def __init__( - self, - client: AsyncGrpcClient.grpc_client, - bucket_name: str, - object_name: str, - generation_number: Optional[int] = None, - read_handle: Optional[bytes] = None, - ) -> None: - """Constructor for AsyncMultiRangeDownloader, clients are not adviced to - use it directly. Instead it's adviced to use the classmethod `create_mrd`. - - :type client: :class:`~google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` - :param client: The asynchronous client to use for making API requests. - - :type bucket_name: str - :param bucket_name: The name of the bucket containing the object. - - :type object_name: str - :param object_name: The name of the object to be read. - - :type generation_number: int - :param generation_number: (Optional) If present, selects a specific revision of - this object. - - :type read_handle: bytes - :param read_handle: (Optional) An existing read handle. - """ - - # Verify that the fast, C-accelerated version of crc32c is available. - # If not, raise an error to prevent silent performance degradation. - if google_crc32c.implementation != "c": - raise exceptions.NotFound( - "The google-crc32c package is not installed with C support. " - "Bidi reads require the C extension for data integrity checks." - "For more information, see https://github.com/googleapis/python-crc32c." - ) - - self.client = client - self.bucket_name = bucket_name - self.object_name = object_name - self.generation_number = generation_number - self.read_handle = read_handle - self.read_obj_str: Optional[_AsyncReadObjectStream] = None - self._is_stream_open: bool = False - - self._read_id_to_writable_buffer_dict = {} - self._read_id_to_download_ranges_id = {} - self._download_ranges_id_to_pending_read_ids = {} - self.persisted_size: Optional[int] = None # updated after opening the stream - - async def open(self) -> None: - """Opens the bidi-gRPC connection to read from the object. - - This method initializes and opens an `_AsyncReadObjectStream` (bidi-gRPC stream) to - for downloading ranges of data from GCS ``Object``. - - "Opening" constitutes fetching object metadata such as generation number - and read handle and sets them as attributes if not already set. - """ - if self._is_stream_open: - raise ValueError("Underlying bidi-gRPC stream is already open") - self.read_obj_str = _AsyncReadObjectStream( - client=self.client, - bucket_name=self.bucket_name, - object_name=self.object_name, - generation_number=self.generation_number, - read_handle=self.read_handle, - ) - await self.read_obj_str.open() - self._is_stream_open = True - if self.generation_number is None: - self.generation_number = self.read_obj_str.generation_number - self.read_handle = self.read_obj_str.read_handle - if self.read_obj_str.persisted_size is not None: - self.persisted_size = self.read_obj_str.persisted_size - return - - async def download_ranges( - self, read_ranges: List[Tuple[int, int, BytesIO]], lock: asyncio.Lock = None - ) -> None: - """Downloads multiple byte ranges from the object into the buffers - provided by user. - - :type read_ranges: List[Tuple[int, int, "BytesIO"]] - :param read_ranges: A list of tuples, where each tuple represents a - combintaion of byte_range and writeable buffer in format - - (`start_byte`, `bytes_to_read`, `writeable_buffer`). Buffer has - to be provided by the user, and user has to make sure appropriate - memory is available in the application to avoid out-of-memory crash. - - :type lock: asyncio.Lock - :param lock: (Optional) An asyncio lock to synchronize sends and recvs - on the underlying bidi-GRPC stream. This is required when multiple - coroutines are calling this method concurrently. - - i.e. Example usage with multiple coroutines: - - ``` - lock = asyncio.Lock() - task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) - task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) - await asyncio.gather(task1, task2) - - ``` - - If user want to call this method serially from multiple coroutines, - then providing a lock is not necessary. - - ``` - await mrd.download_ranges(ranges1) - await mrd.download_ranges(ranges2) - - # ... some other code code... - - ``` - - - :raises ValueError: if the underlying bidi-GRPC stream is not open. - :raises ValueError: if the length of read_ranges is more than 1000. - :raises DataCorruption: if a checksum mismatch is detected while reading data. - - """ - - if len(read_ranges) > 1000: - raise ValueError( - "Invalid input - length of read_ranges cannot be more than 1000" - ) - - if not self._is_stream_open: - raise ValueError("Underlying bidi-gRPC stream is not open") - - if lock is None: - lock = asyncio.Lock() - - _func_id = generate_random_56_bit_integer() - read_ids_in_current_func = set() - for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST): - read_ranges_segment = read_ranges[ - i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST - ] - - read_ranges_for_bidi_req = [] - for j, read_range in enumerate(read_ranges_segment): - read_id = generate_random_56_bit_integer() - read_ids_in_current_func.add(read_id) - self._read_id_to_download_ranges_id[read_id] = _func_id - self._read_id_to_writable_buffer_dict[read_id] = read_range[2] - bytes_requested = read_range[1] - read_ranges_for_bidi_req.append( - _storage_v2.ReadRange( - read_offset=read_range[0], - read_length=bytes_requested, - read_id=read_id, - ) - ) - async with lock: - await self.read_obj_str.send( - _storage_v2.BidiReadObjectRequest( - read_ranges=read_ranges_for_bidi_req - ) - ) - self._download_ranges_id_to_pending_read_ids[ - _func_id - ] = read_ids_in_current_func - - while len(self._download_ranges_id_to_pending_read_ids[_func_id]) > 0: - async with lock: - response = await self.read_obj_str.recv() - - if response is None: - raise Exception("None response received, something went wrong.") - - for object_data_range in response.object_data_ranges: - if object_data_range.read_range is None: - raise Exception("Invalid response, read_range is None") - - checksummed_data = object_data_range.checksummed_data - data = checksummed_data.content - server_checksum = checksummed_data.crc32c - - client_crc32c = Checksum(data).digest() - client_checksum = int.from_bytes(client_crc32c, "big") - - if server_checksum != client_checksum: - raise DataCorruption( - response, - f"Checksum mismatch for read_id {object_data_range.read_range.read_id}. " - f"Server sent {server_checksum}, client calculated {client_checksum}.", - ) - - read_id = object_data_range.read_range.read_id - buffer = self._read_id_to_writable_buffer_dict[read_id] - buffer.write(data) - - if object_data_range.range_end: - tmp_dn_ranges_id = self._read_id_to_download_ranges_id[read_id] - self._download_ranges_id_to_pending_read_ids[ - tmp_dn_ranges_id - ].remove(read_id) - del self._read_id_to_download_ranges_id[read_id] - - async def close(self): - """ - Closes the underlying bidi-gRPC connection. - """ - if not self._is_stream_open: - raise ValueError("Underlying bidi-gRPC stream is not open") - await self.read_obj_str.close() - self.read_obj_str = None - self._is_stream_open = False - - @property - def is_stream_open(self) -> bool: - return self._is_stream_open diff --git a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py index df2430ee6..cb39386f2 100644 --- a/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_read_object_stream.py @@ -1,165 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE: -This is _experimental module for upcoming support for Rapid Storage. -(https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new) +import warnings -APIs may not work as intended and are not stable yet. Feature is not -GA(Generally Available) yet, please contact your TAM(Technical Account Manager) -if you want to use these APIs. +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_read_object_stream import * # noqa -""" - -from typing import Optional -from google.cloud import _storage_v2 -from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient -from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( - _AsyncAbstractObjectStream, +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_read_object_stream has been moved to google.cloud.storage.asyncio.async_read_object_stream. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, ) - -from google.api_core.bidi_async import AsyncBidiRpc - - -class _AsyncReadObjectStream(_AsyncAbstractObjectStream): - """Class representing a gRPC bidi-stream for reading data from a GCS ``Object``. - - This class provides a unix socket-like interface to a GCS ``Object``, with - methods like ``open``, ``close``, ``send``, and ``recv``. - - :type client: :class:`~google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` - :param client: async grpc client to use for making API requests. - - :type bucket_name: str - :param bucket_name: The name of the GCS ``bucket`` containing the object. - - :type object_name: str - :param object_name: The name of the GCS ``object`` to be read. - - :type generation_number: int - :param generation_number: (Optional) If present, selects a specific revision of - this object. - - :type read_handle: bytes - :param read_handle: (Optional) An existing handle for reading the object. - If provided, opening the bidi-gRPC connection will be faster. - """ - - def __init__( - self, - client: AsyncGrpcClient.grpc_client, - bucket_name: str, - object_name: str, - generation_number: Optional[int] = None, - read_handle: Optional[bytes] = None, - ) -> None: - if client is None: - raise ValueError("client must be provided") - if bucket_name is None: - raise ValueError("bucket_name must be provided") - if object_name is None: - raise ValueError("object_name must be provided") - - super().__init__( - bucket_name=bucket_name, - object_name=object_name, - generation_number=generation_number, - ) - self.client: AsyncGrpcClient.grpc_client = client - self.read_handle: Optional[bytes] = read_handle - - self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}" - - self.rpc = self.client._client._transport._wrapped_methods[ - self.client._client._transport.bidi_read_object - ] - self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),) - self.socket_like_rpc: Optional[AsyncBidiRpc] = None - self._is_stream_open: bool = False - self.persisted_size: Optional[int] = None - - async def open(self) -> None: - """Opens the bidi-gRPC connection to read from the object. - - This method sends an initial request to start the stream and receives - the first response containing metadata and a read handle. - """ - if self._is_stream_open: - raise ValueError("Stream is already open") - self.first_bidi_read_req = _storage_v2.BidiReadObjectRequest( - read_object_spec=_storage_v2.BidiReadObjectSpec( - bucket=self._full_bucket_name, - object=self.object_name, - read_handle=self.read_handle, - ), - ) - self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_read_req, metadata=self.metadata - ) - await self.socket_like_rpc.open() # this is actually 1 send - response = await self.socket_like_rpc.recv() - # populated only in the first response of bidi-stream and when opened - # without using `read_handle` - if hasattr(response, "metadata") and response.metadata: - if self.generation_number is None: - self.generation_number = response.metadata.generation - # update persisted size - self.persisted_size = response.metadata.size - - self.read_handle = response.read_handle - - self._is_stream_open = True - - async def close(self) -> None: - """Closes the bidi-gRPC connection.""" - if not self._is_stream_open: - raise ValueError("Stream is not open") - await self.socket_like_rpc.close() - self._is_stream_open = False - - async def send( - self, bidi_read_object_request: _storage_v2.BidiReadObjectRequest - ) -> None: - """Sends a request message on the stream. - - Args: - bidi_read_object_request (:class:`~google.cloud._storage_v2.types.BidiReadObjectRequest`): - The request message to send. This is typically used to specify - the read offset and limit. - """ - if not self._is_stream_open: - raise ValueError("Stream is not open") - await self.socket_like_rpc.send(bidi_read_object_request) - - async def recv(self) -> _storage_v2.BidiReadObjectResponse: - """Receives a response from the stream. - - This method waits for the next message from the server, which could - contain object data or metadata. - - Returns: - :class:`~google.cloud._storage_v2.types.BidiReadObjectResponse`: - The response message from the server. - """ - if not self._is_stream_open: - raise ValueError("Stream is not open") - response = await self.socket_like_rpc.recv() - # Update read_handle if present in response - if response and response.read_handle: - self.read_handle = response.read_handle - return response - - @property - def is_stream_open(self) -> bool: - return self._is_stream_open diff --git a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py index 183a8eeb1..132e2c9d0 100644 --- a/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py +++ b/google/cloud/storage/_experimental/asyncio/async_write_object_stream.py @@ -1,188 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -NOTE: -This is _experimental module for upcoming support for Rapid Storage. -(https://cloud.google.com/blog/products/storage-data-transfer/high-performance-storage-innovations-for-ai-hpc#:~:text=your%20AI%20workloads%3A-,Rapid%20Storage,-%3A%20A%20new) +import warnings -APIs may not work as intended and are not stable yet. Feature is not -GA(Generally Available) yet, please contact your TAM(Technical Account Manager) -if you want to use these Rapid Storage APIs. +# Import everything from the new stable module +from google.cloud.storage.asyncio.async_write_object_stream import * # noqa -""" -from typing import Optional -from google.cloud import _storage_v2 -from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient -from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import ( - _AsyncAbstractObjectStream, +warnings.warn( + "google.cloud.storage._experimental.asyncio.async_write_object_stream has been moved to google.cloud.storage.asyncio.async_write_object_stream. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, ) -from google.api_core.bidi_async import AsyncBidiRpc - - -class _AsyncWriteObjectStream(_AsyncAbstractObjectStream): - """Class representing a gRPC bidi-stream for writing data from a GCS - ``Appendable Object``. - - This class provides a unix socket-like interface to a GCS ``Object``, with - methods like ``open``, ``close``, ``send``, and ``recv``. - - :type client: :class:`~google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` - :param client: async grpc client to use for making API requests. - - :type bucket_name: str - :param bucket_name: The name of the GCS ``bucket`` containing the object. - - :type object_name: str - :param object_name: The name of the GCS ``Appendable Object`` to be write. - - :type generation_number: int - :param generation_number: (Optional) If present, selects a specific revision of - this object. If None, a new object is created. - - :type write_handle: bytes - :param write_handle: (Optional) An existing handle for writing the object. - If provided, opening the bidi-gRPC connection will be faster. - """ - - def __init__( - self, - client: AsyncGrpcClient.grpc_client, - bucket_name: str, - object_name: str, - generation_number: Optional[int] = None, # None means new object - write_handle: Optional[bytes] = None, - ) -> None: - if client is None: - raise ValueError("client must be provided") - if bucket_name is None: - raise ValueError("bucket_name must be provided") - if object_name is None: - raise ValueError("object_name must be provided") - - super().__init__( - bucket_name=bucket_name, - object_name=object_name, - generation_number=generation_number, - ) - self.client: AsyncGrpcClient.grpc_client = client - self.write_handle: Optional[bytes] = write_handle - - self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}" - - self.rpc = self.client._client._transport._wrapped_methods[ - self.client._client._transport.bidi_write_object - ] - - self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),) - self.socket_like_rpc: Optional[AsyncBidiRpc] = None - self._is_stream_open: bool = False - self.first_bidi_write_req = None - self.persisted_size = 0 - self.object_resource: Optional[_storage_v2.Object] = None - - async def open(self) -> None: - """Opening an object for write , should do it's state lookup - to know what's the persisted size is. - """ - if self._is_stream_open: - raise ValueError("Stream is already open") - - # Create a new object or overwrite existing one if generation_number - # is None. This makes it consistent with GCS JSON API behavior. - # Created object type would be Appendable Object. - if self.generation_number is None: - self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest( - write_object_spec=_storage_v2.WriteObjectSpec( - resource=_storage_v2.Object( - name=self.object_name, bucket=self._full_bucket_name - ), - appendable=True, - ), - ) - else: - self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest( - append_object_spec=_storage_v2.AppendObjectSpec( - bucket=self._full_bucket_name, - object=self.object_name, - generation=self.generation_number, - ), - ) - - self.socket_like_rpc = AsyncBidiRpc( - self.rpc, initial_request=self.first_bidi_write_req, metadata=self.metadata - ) - - await self.socket_like_rpc.open() # this is actually 1 send - response = await self.socket_like_rpc.recv() - self._is_stream_open = True - - if not response.resource: - raise ValueError( - "Failed to obtain object resource after opening the stream" - ) - if not response.resource.generation: - raise ValueError( - "Failed to obtain object generation after opening the stream" - ) - - if not response.write_handle: - raise ValueError("Failed to obtain write_handle after opening the stream") - - if not response.resource.size: - # Appending to a 0 byte appendable object. - self.persisted_size = 0 - else: - self.persisted_size = response.resource.size - - self.generation_number = response.resource.generation - self.write_handle = response.write_handle - - async def close(self) -> None: - """Closes the bidi-gRPC connection.""" - if not self._is_stream_open: - raise ValueError("Stream is not open") - await self.socket_like_rpc.close() - self._is_stream_open = False - - async def send( - self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest - ) -> None: - """Sends a request message on the stream. - - Args: - bidi_write_object_request (:class:`~google.cloud._storage_v2.types.BidiReadObjectRequest`): - The request message to send. This is typically used to specify - the read offset and limit. - """ - if not self._is_stream_open: - raise ValueError("Stream is not open") - await self.socket_like_rpc.send(bidi_write_object_request) - - async def recv(self) -> _storage_v2.BidiWriteObjectResponse: - """Receives a response from the stream. - - This method waits for the next message from the server, which could - contain object data or metadata. - - Returns: - :class:`~google.cloud._storage_v2.types.BidiWriteObjectResponse`: - The response message from the server. - """ - if not self._is_stream_open: - raise ValueError("Stream is not open") - return await self.socket_like_rpc.recv() - - @property - def is_stream_open(self) -> bool: - return self._is_stream_open diff --git a/google/cloud/storage/_experimental/asyncio/retry/_helpers.py b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py new file mode 100644 index 000000000..092986f58 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/_helpers.py @@ -0,0 +1,11 @@ +import warnings + +# Import everything from the new stable module +from google.cloud.storage.asyncio.retry._helpers import * # noqa + +warnings.warn( + "google.cloud.storage._experimental.asyncio.retry._helpers has been moved to google.cloud.storage.asyncio.retry._helpers. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py index e32125069..58c58136c 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/base_strategy.py @@ -1,69 +1,11 @@ -import abc -from typing import Any, Iterable +import warnings +# Import everything from the new stable module +from google.cloud.storage.asyncio.retry.base_strategy import * # noqa -class _BaseResumptionStrategy(abc.ABC): - """Abstract base class defining the interface for a bidi stream resumption strategy. - - This class defines the skeleton for a pluggable strategy that contains - all the service-specific logic for a given bidi operation (e.g., reads - or writes). This allows a generic retry manager to handle the common - retry loop while sending the state management and request generation - to a concrete implementation of this class. - """ - - @abc.abstractmethod - def generate_requests(self, state: Any) -> Iterable[Any]: - """Generates the next batch of requests based on the current state. - - This method is called at the beginning of each retry attempt. It should - inspect the provided state object and generate the appropriate list of - request protos to send to the server. For example, a read strategy - would use this to implement "Smarter Resumption" by creating smaller - `ReadRange` requests for partially downloaded ranges. For bidi-writes, - it will set the `write_offset` field to the persisted size received - from the server in the next request. - - :type state: Any - :param state: An object containing all the state needed for the - operation (e.g., requested ranges, user buffers, - bytes written). - """ - pass - - @abc.abstractmethod - def update_state_from_response(self, response: Any, state: Any) -> None: - """Updates the state based on a successful server response. - - This method is called for every message received from the server. It is - responsible for processing the response and updating the shared state - object. - - :type response: Any - :param response: The response message received from the server. - - :type state: Any - :param state: The shared state object for the operation, which will be - mutated by this method. - """ - pass - - @abc.abstractmethod - async def recover_state_on_failure(self, error: Exception, state: Any) -> None: - """Prepares the state for the next retry attempt after a failure. - - This method is called when a retriable gRPC error occurs. It is - responsible for performing any necessary actions to ensure the next - retry attempt can succeed. For bidi reads, its primary role is to - handle the `BidiReadObjectRedirectError` by extracting the - `routing_token` and updating the state. For bidi writes, it will update - the state to reflect any bytes that were successfully persisted before - the failure. - - :type error: :class:`Exception` - :param error: The exception that was caught by the retry engine. - - :type state: Any - :param state: The shared state object for the operation. - """ - pass +warnings.warn( + "google.cloud.storage._experimental.asyncio.retry.base_strategy has been moved to google.cloud.storage.asyncio.retry.base_strategy. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py new file mode 100644 index 000000000..331ee1326 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/bidi_stream_retry_manager.py @@ -0,0 +1,11 @@ +import warnings + +# Import everything from the new stable module +from google.cloud.storage.asyncio.retry.bidi_stream_retry_manager import * # noqa + +warnings.warn( + "google.cloud.storage._experimental.asyncio.retry.bidi_stream_retry_manager has been moved to google.cloud.storage.asyncio.retry.bidi_stream_retry_manager. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py index d5d080358..8f7051b6a 100644 --- a/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py +++ b/google/cloud/storage/_experimental/asyncio/retry/reads_resumption_strategy.py @@ -1,85 +1,11 @@ -from typing import Any, List, IO +import warnings -from google.cloud import _storage_v2 as storage_v2 -from google.cloud.storage.exceptions import DataCorruption -from google.cloud.storage._experimental.asyncio.retry.base_strategy import ( - _BaseResumptionStrategy, -) -from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError - - -class _DownloadState: - """A helper class to track the state of a single range download.""" - - def __init__( - self, initial_offset: int, initial_length: int, user_buffer: IO[bytes] - ): - self.initial_offset = initial_offset - self.initial_length = initial_length - self.user_buffer = user_buffer - self.bytes_written = 0 - self.next_expected_offset = initial_offset - self.is_complete = False - - -class _ReadResumptionStrategy(_BaseResumptionStrategy): - """The concrete resumption strategy for bidi reads.""" - - def generate_requests(self, state: dict) -> List[storage_v2.ReadRange]: - """Generates new ReadRange requests for all incomplete downloads. +# Import everything from the new stable module +from google.cloud.storage.asyncio.retry.reads_resumption_strategy import * # noqa - :type state: dict - :param state: A dictionary mapping a read_id to its corresponding - _DownloadState object. - """ - pending_requests = [] - for read_id, read_state in state.items(): - if not read_state.is_complete: - new_offset = read_state.initial_offset + read_state.bytes_written - new_length = read_state.initial_length - read_state.bytes_written - - new_request = storage_v2.ReadRange( - read_offset=new_offset, - read_length=new_length, - read_id=read_id, - ) - pending_requests.append(new_request) - return pending_requests - - def update_state_from_response( - self, response: storage_v2.BidiReadObjectResponse, state: dict - ) -> None: - """Processes a server response, performs integrity checks, and updates state.""" - for object_data_range in response.object_data_ranges: - read_id = object_data_range.read_range.read_id - read_state = state[read_id] - - # Offset Verification - chunk_offset = object_data_range.read_range.read_offset - if chunk_offset != read_state.next_expected_offset: - raise DataCorruption(response, f"Offset mismatch for read_id {read_id}") - - data = object_data_range.checksummed_data.content - chunk_size = len(data) - read_state.bytes_written += chunk_size - read_state.next_expected_offset += chunk_size - read_state.user_buffer.write(data) - - # Final Byte Count Verification - if object_data_range.range_end: - read_state.is_complete = True - if ( - read_state.initial_length != 0 - and read_state.bytes_written != read_state.initial_length - ): - raise DataCorruption( - response, f"Byte count mismatch for read_id {read_id}" - ) - - async def recover_state_on_failure(self, error: Exception, state: Any) -> None: - """Handles BidiReadObjectRedirectedError for reads.""" - # This would parse the gRPC error details, extract the routing_token, - # and store it on the shared state object. - cause = getattr(error, "cause", error) - if isinstance(cause, BidiReadObjectRedirectedError): - state["routing_token"] = cause.routing_token +warnings.warn( + "google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy has been moved to google.cloud.storage.asyncio.retry.reads_resumption_strategy. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py new file mode 100644 index 000000000..7d2493841 --- /dev/null +++ b/google/cloud/storage/_experimental/asyncio/retry/writes_resumption_strategy.py @@ -0,0 +1,11 @@ +import warnings + +# Import everything from the new stable module +from google.cloud.storage.asyncio.retry.writes_resumption_strategy import * # noqa + +warnings.warn( + "google.cloud.storage._experimental.asyncio.retry.writes_resumption_strategy has been moved to google.cloud.storage.asyncio.retry.writes_resumption_strategy. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_experimental/grpc_client.py b/google/cloud/storage/_experimental/grpc_client.py index 7a739b7b7..99ecbe044 100644 --- a/google/cloud/storage/_experimental/grpc_client.py +++ b/google/cloud/storage/_experimental/grpc_client.py @@ -1,122 +1,11 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import warnings -"""A client for interacting with Google Cloud Storage using the gRPC API.""" +# Import everything from the new stable module +from google.cloud.storage.grpc_client import * # noqa -from google.cloud.client import ClientWithProject -from google.cloud import _storage_v2 as storage_v2 - -_marker = object() - - -class GrpcClient(ClientWithProject): - """A client for interacting with Google Cloud Storage using the gRPC API. - - :type project: str or None - :param project: The project which the client acts on behalf of. If not - passed, falls back to the default inferred from the - environment. - - :type credentials: :class:`~google.auth.credentials.Credentials` - :param credentials: (Optional) The OAuth2 Credentials to use for this - client. If not passed, falls back to the default - inferred from the environment. - - :type client_info: :class:`~google.api_core.client_info.ClientInfo` - :param client_info: - The client info used to send a user-agent string along with API - requests. If ``None``, then default info will be used. Generally, - you only need to set this if you're developing your own library - or partner tool. - - :type client_options: :class:`~google.api_core.client_options.ClientOptions` or :class:`dict` - :param client_options: (Optional) Client options used to set user options - on the client. A non-default universe domain or API endpoint should be - set through client_options. - - :type api_key: string - :param api_key: - (Optional) An API key. Mutually exclusive with any other credentials. - This parameter is an alias for setting `client_options.api_key` and - will supersede any API key set in the `client_options` parameter. - - :type attempt_direct_path: bool - :param attempt_direct_path: - (Optional) Whether to attempt to use DirectPath for gRPC connections. - This provides a direct, unproxied connection to GCS for lower latency - and higher throughput, and is highly recommended when running on Google - Cloud infrastructure. Defaults to ``True``. - """ - - def __init__( - self, - project=_marker, - credentials=None, - client_info=None, - client_options=None, - *, - api_key=None, - attempt_direct_path=True, - ): - super(GrpcClient, self).__init__(project=project, credentials=credentials) - - if isinstance(client_options, dict): - if api_key: - client_options["api_key"] = api_key - elif client_options is None: - client_options = {} if not api_key else {"api_key": api_key} - elif api_key: - client_options.api_key = api_key - - self._grpc_client = self._create_gapic_client( - credentials=credentials, - client_info=client_info, - client_options=client_options, - attempt_direct_path=attempt_direct_path, - ) - - def _create_gapic_client( - self, - credentials=None, - client_info=None, - client_options=None, - attempt_direct_path=True, - ): - """Creates and configures the low-level GAPIC `storage_v2` client.""" - transport_cls = storage_v2.StorageClient.get_transport_class("grpc") - - channel = transport_cls.create_channel(attempt_direct_path=attempt_direct_path) - - transport = transport_cls(credentials=credentials, channel=channel) - - return storage_v2.StorageClient( - credentials=credentials, - transport=transport, - client_info=client_info, - client_options=client_options, - ) - - @property - def grpc_client(self): - """The underlying gRPC client. - - This property gives users direct access to the `storage_v2.StorageClient` - instance. This can be useful for accessing - newly added or experimental RPCs that are not yet exposed through - the high-level GrpcClient. - - Returns: - google.cloud.storage_v2.StorageClient: The configured GAPIC client. - """ - return self._grpc_client +warnings.warn( + "google.cloud.storage._experimental.grpc_client has been moved to google.cloud.storage.grpc_client. " + "Please update your imports.", + DeprecationWarning, + stacklevel=2, +) diff --git a/google/cloud/storage/_media/requests/download.py b/google/cloud/storage/_media/requests/download.py index 13e049bd3..c5686fcb7 100644 --- a/google/cloud/storage/_media/requests/download.py +++ b/google/cloud/storage/_media/requests/download.py @@ -774,6 +774,5 @@ def flush(self): def has_unconsumed_tail(self) -> bool: return self._decoder.has_unconsumed_tail - else: # pragma: NO COVER _BrotliDecoder = None # type: ignore # pragma: NO COVER diff --git a/google/cloud/storage/abstracts/base_client.py b/google/cloud/storage/abstracts/base_client.py index c2030cb89..ce89a8bec 100644 --- a/google/cloud/storage/abstracts/base_client.py +++ b/google/cloud/storage/abstracts/base_client.py @@ -30,6 +30,7 @@ marker = object() + class BaseClient(ClientWithProject, ABC): """Abstract class for python-storage Client""" @@ -248,7 +249,7 @@ def _connection(self, value): """ if self._base_connection is not None: raise ValueError("Connection already set on client") - self._base_connection = value + self._base_connection = value @property def _use_client_cert(self): @@ -260,9 +261,7 @@ def _use_client_cert(self): if hasattr(mtls, "should_use_client_cert"): use_client_cert = mtls.should_use_client_cert() else: - use_client_cert = ( - os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" - ) + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE") == "true" return use_client_cert def _push_batch(self, batch): diff --git a/google/cloud/storage/asyncio/_utils.py b/google/cloud/storage/asyncio/_utils.py new file mode 100644 index 000000000..170a0cfae --- /dev/null +++ b/google/cloud/storage/asyncio/_utils.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import google_crc32c + +from google.api_core import exceptions + + +def raise_if_no_fast_crc32c(): + """Check if the C-accelerated version of google-crc32c is available. + + If not, raise an error to prevent silent performance degradation. + + raises google.api_core.exceptions.FailedPrecondition: If the C extension is not available. + returns: True if the C extension is available. + rtype: bool + + """ + if google_crc32c.implementation != "c": + raise exceptions.FailedPrecondition( + "The google-crc32c package is not installed with C support. " + "C extension is required for faster data integrity checks." + "For more information, see https://github.com/googleapis/python-crc32c." + ) + + +def update_write_handle_if_exists(obj, response): + """Update the write_handle attribute of an object if it exists in the response.""" + if hasattr(response, "write_handle") and response.write_handle is not None: + obj.write_handle = response.write_handle diff --git a/google/cloud/storage/asyncio/async_abstract_object_stream.py b/google/cloud/storage/asyncio/async_abstract_object_stream.py new file mode 100644 index 000000000..26cbab7a0 --- /dev/null +++ b/google/cloud/storage/asyncio/async_abstract_object_stream.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Optional + + +class _AsyncAbstractObjectStream(abc.ABC): + """Abstract base class to represent gRPC bidi-stream for GCS ``Object``. + + Concrete implementation of this class could be ``_AsyncReadObjectStream`` + or ``_AsyncWriteObjectStream``. + + :type bucket_name: str + :param bucket_name: (Optional) The name of the bucket containing the object. + + :type object_name: str + :param object_name: (Optional) The name of the object. + + :type generation_number: int + :param generation_number: (Optional) If present, selects a specific revision of + this object. + + :type handle: Any + :param handle: (Optional) The handle for the object, could be read_handle or + write_handle, based on how the stream is used. + """ + + def __init__( + self, + bucket_name: str, + object_name: str, + generation_number: Optional[int] = None, + handle: Optional[Any] = None, + ) -> None: + super().__init__() + self.bucket_name: str = bucket_name + self.object_name: str = object_name + self.generation_number: Optional[int] = generation_number + self.handle: Optional[Any] = handle + + @abc.abstractmethod + async def open(self) -> None: + pass + + @abc.abstractmethod + async def close(self) -> None: + pass + + @abc.abstractmethod + async def send(self, protobuf: Any) -> None: + pass + + @abc.abstractmethod + async def recv(self) -> Any: + pass diff --git a/google/cloud/storage/asyncio/async_appendable_object_writer.py b/google/cloud/storage/asyncio/async_appendable_object_writer.py new file mode 100644 index 000000000..3ab06f8ba --- /dev/null +++ b/google/cloud/storage/asyncio/async_appendable_object_writer.py @@ -0,0 +1,586 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from io import BufferedReader +import io +import logging +from typing import List, Optional, Tuple, Union + +from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry +from google.rpc import status_pb2 +from google.cloud._storage_v2.types import BidiWriteObjectRedirectedError +from google.cloud._storage_v2.types.storage import BidiWriteObjectRequest + + +from . import _utils +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio.async_grpc_client import ( + AsyncGrpcClient, +) +from google.cloud.storage.asyncio.async_write_object_stream import ( + _AsyncWriteObjectStream, +) +from google.cloud.storage.asyncio.retry.bidi_stream_retry_manager import ( + _BidiStreamRetryManager, +) +from google.cloud.storage.asyncio.retry.writes_resumption_strategy import ( + _WriteResumptionStrategy, + _WriteState, +) +from google.cloud.storage.asyncio.retry._helpers import ( + _extract_bidi_writes_redirect_proto, +) + + +_MAX_CHUNK_SIZE_BYTES = 2 * 1024 * 1024 # 2 MiB +_DEFAULT_FLUSH_INTERVAL_BYTES = 16 * 1024 * 1024 # 16 MiB +_BIDI_WRITE_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" +) +logger = logging.getLogger(__name__) + + +def _is_write_retryable(exc): + """Predicate to determine if a write operation should be retried.""" + + if isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + BidiWriteObjectRedirectedError, + ), + ): + logger.warning(f"Retryable write exception encountered: {exc}") + return True + + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + return True + + trailers = grpc_error.trailing_metadata() + if not trailers: + return False + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + return True + except Exception: + logger.error( + "Error unpacking redirect details from gRPC error. Exception: ", + {exc}, + ) + return False + return False + + +class AsyncAppendableObjectWriter: + """Class for appending data to a GCS Appendable Object asynchronously.""" + + def __init__( + self, + client: AsyncGrpcClient, + bucket_name: str, + object_name: str, + generation: Optional[int] = None, + write_handle: Optional[_storage_v2.BidiWriteHandle] = None, + writer_options: Optional[dict] = None, + ): + """ + Class for appending data to a GCS Appendable Object. + + Example usage: + + ``` + + from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient + from google.cloud.storage.asyncio.async_appendable_object_writer import AsyncAppendableObjectWriter + import asyncio + + client = AsyncGrpcClient() + bucket_name = "my-bucket" + object_name = "my-appendable-object" + + # instantiate the writer + writer = AsyncAppendableObjectWriter(client, bucket_name, object_name) + # open the writer, (underlying gRPC bidi-stream will be opened) + await writer.open() + + # append data, it can be called multiple times. + await writer.append(b"hello world") + await writer.append(b"some more data") + + # optionally flush data to persist. + await writer.flush() + + # close the gRPC stream. + # Please note closing the program will also close the stream, + # however it's recommended to close the stream if no more data to append + # to clean up gRPC connection (which means CPU/memory/network resources) + await writer.close() + ``` + + :type client: :class:`~google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient` + :param client: async grpc client to use for making API requests. + + :type bucket_name: str + :param bucket_name: The name of the GCS bucket containing the object. + + :type object_name: str + :param object_name: The name of the GCS Appendable Object to be written. + + :type generation: Optional[int] + :param generation: (Optional) If present, creates writer for that + specific revision of that object. Use this to append data to an + existing Appendable Object. + + Setting to ``0`` makes the `writer.open()` succeed only if + object doesn't exist in the bucket (useful for not accidentally + overwriting existing objects). + + Warning: If `None`, a new object is created. If an object with the + same name already exists, it will be overwritten the moment + `writer.open()` is called. + + :type write_handle: _storage_v2.BidiWriteHandle + :param write_handle: (Optional) An handle for writing the object. + If provided, opening the bidi-gRPC connection will be faster. + + :type writer_options: dict + :param writer_options: (Optional) A dictionary of writer options. + Supported options: + - "FLUSH_INTERVAL_BYTES": int + The number of bytes to append before "persisting" data in GCS + servers. Default is `_DEFAULT_FLUSH_INTERVAL_BYTES`. + Must be a multiple of `_MAX_CHUNK_SIZE_BYTES`. + """ + _utils.raise_if_no_fast_crc32c() + self.client = client + self.bucket_name = bucket_name + self.object_name = object_name + self.write_handle = write_handle + self.generation = generation + + self.write_obj_stream: Optional[_AsyncWriteObjectStream] = None + self._is_stream_open: bool = False + # `offset` is the latest size of the object without staleless. + self.offset: Optional[int] = None + # `persisted_size` is the total_bytes persisted in the GCS server. + # Please note: `offset` and `persisted_size` are same when the stream is + # opened. + self.persisted_size: Optional[int] = None + if writer_options is None: + writer_options = {} + self.flush_interval = writer_options.get( + "FLUSH_INTERVAL_BYTES", _DEFAULT_FLUSH_INTERVAL_BYTES + ) + if self.flush_interval < _MAX_CHUNK_SIZE_BYTES: + raise exceptions.OutOfRange( + f"flush_interval must be >= {_MAX_CHUNK_SIZE_BYTES} , but provided {self.flush_interval}" + ) + if self.flush_interval % _MAX_CHUNK_SIZE_BYTES != 0: + raise exceptions.OutOfRange( + f"flush_interval must be a multiple of {_MAX_CHUNK_SIZE_BYTES}, but provided {self.flush_interval}" + ) + self.bytes_appended_since_last_flush = 0 + self._routing_token: Optional[str] = None + self.object_resource: Optional[_storage_v2.Object] = None + + async def state_lookup(self) -> int: + """Returns the persisted_size + + :rtype: int + :returns: persisted size. + + :raises ValueError: If the stream is not open (i.e., `open()` has not + been called). + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before state_lookup().") + + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + state_lookup=True, + ) + ) + response = await self.write_obj_stream.recv() + self.persisted_size = response.persisted_size + return self.persisted_size + + def _on_open_error(self, exc): + """Extracts routing token and write handle on redirect error during open.""" + redirect_proto = _extract_bidi_writes_redirect_proto(exc) + if redirect_proto: + if redirect_proto.routing_token: + self._routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + self.write_handle = redirect_proto.write_handle + if redirect_proto.generation: + self.generation = redirect_proto.generation + + async def open( + self, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Opens the underlying bidi-gRPC stream. + + :raises ValueError: If the stream is already open. + + """ + if self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is already open") + + if retry_policy is None: + retry_policy = AsyncRetry( + predicate=_is_write_retryable, on_error=self._on_open_error + ) + else: + original_on_error = retry_policy._on_error + + def combined_on_error(exc): + self._on_open_error(exc) + if original_on_error: + original_on_error(exc) + + retry_policy = AsyncRetry( + predicate=_is_write_retryable, + initial=retry_policy._initial, + maximum=retry_policy._maximum, + multiplier=retry_policy._multiplier, + deadline=retry_policy._deadline, + on_error=combined_on_error, + ) + + async def _do_open(): + current_metadata = list(metadata) if metadata else [] + + # Cleanup stream from previous failed attempt, if any. + if self.write_obj_stream: + if self.write_obj_stream.is_stream_open: + try: + await self.write_obj_stream.close() + except Exception as e: + logger.warning( + "Error closing previous write stream during open retry. Got exception: ", + {e}, + ) + self.write_obj_stream = None + self._is_stream_open = False + + self.write_obj_stream = _AsyncWriteObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + write_handle=self.write_handle, + routing_token=self._routing_token, + ) + + if self._routing_token: + current_metadata.append( + ("x-goog-request-params", f"routing_token={self._routing_token}") + ) + + await self.write_obj_stream.open( + metadata=current_metadata if metadata else None + ) + + if self.write_obj_stream.generation_number: + self.generation = self.write_obj_stream.generation_number + if self.write_obj_stream.write_handle: + self.write_handle = self.write_obj_stream.write_handle + if self.write_obj_stream.persisted_size is not None: + self.persisted_size = self.write_obj_stream.persisted_size + + self._is_stream_open = True + self._routing_token = None + + await retry_policy(_do_open)() + + async def append( + self, + data: bytes, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Appends data to the Appendable object with automatic retries. + + calling `self.append` will append bytes at the end of the current size + ie. `self.offset` bytes relative to the begining of the object. + + This method sends the provided `data` to the GCS server in chunks. + and persists data in GCS at every `_DEFAULT_FLUSH_INTERVAL_BYTES` bytes + or at the last chunk whichever is earlier. Persisting is done by setting + `flush=True` on request. + + :type data: bytes + :param data: The bytes to append to the object. + + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the operation. + + :type metadata: List[Tuple[str, str]] + :param metadata: (Optional) The metadata to be sent with the request. + + :raises ValueError: If the stream is not open. + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before append().") + if not data: + logger.debug("No data provided to append; returning without action.") + return + + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_write_retryable) + + strategy = _WriteResumptionStrategy() + buffer = io.BytesIO(data) + attempt_count = 0 + + def send_and_recv_generator( + requests: List[BidiWriteObjectRequest], + state: dict[str, _WriteState], + metadata: Optional[List[Tuple[str, str]]] = None, + ): + async def generator(): + nonlocal attempt_count + nonlocal requests + attempt_count += 1 + resp = None + write_state = state["write_state"] + # If this is a retry or redirect, we must re-open the stream + if attempt_count > 1 or write_state.routing_token: + logger.info( + f"Re-opening the stream with attempt_count: {attempt_count}" + ) + if self.write_obj_stream and self.write_obj_stream.is_stream_open: + await self.write_obj_stream.close() + + current_metadata = list(metadata) if metadata else [] + if write_state.routing_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={write_state.routing_token}", + ) + ) + self._routing_token = write_state.routing_token + + self._is_stream_open = False + await self.open(metadata=current_metadata) + + write_state.persisted_size = self.persisted_size + write_state.write_handle = self.write_handle + write_state.routing_token = None + + write_state.user_buffer.seek(write_state.persisted_size) + write_state.bytes_sent = write_state.persisted_size + write_state.bytes_since_last_flush = 0 + + requests = strategy.generate_requests(state) + + num_requests = len(requests) + for i, chunk_req in enumerate(requests): + if i == num_requests - 1: + chunk_req.state_lookup = True + chunk_req.flush = True + await self.write_obj_stream.send(chunk_req) + + resp = await self.write_obj_stream.recv() + if resp: + if resp.persisted_size is not None: + self.persisted_size = resp.persisted_size + state["write_state"].persisted_size = resp.persisted_size + self.offset = self.persisted_size + if resp.write_handle: + self.write_handle = resp.write_handle + state["write_state"].write_handle = resp.write_handle + self.bytes_appended_since_last_flush = 0 + + yield resp + + return generator() + + # State initialization + write_state = _WriteState(_MAX_CHUNK_SIZE_BYTES, buffer, self.flush_interval) + write_state.write_handle = self.write_handle + write_state.persisted_size = self.persisted_size + write_state.bytes_sent = self.persisted_size + write_state.bytes_since_last_flush = self.bytes_appended_since_last_flush + + retry_manager = _BidiStreamRetryManager( + _WriteResumptionStrategy(), + lambda r, s: send_and_recv_generator(r, s, metadata), + ) + await retry_manager.execute({"write_state": write_state}, retry_policy) + + # Sync local markers + self.write_obj_stream.persisted_size = write_state.persisted_size + self.write_obj_stream.write_handle = write_state.write_handle + self.bytes_appended_since_last_flush = write_state.bytes_since_last_flush + self.persisted_size = write_state.persisted_size + self.offset = write_state.persisted_size + + async def simple_flush(self) -> None: + """Flushes the data to the server. + Please note: Unlike `flush` it does not do `state_lookup` + + :rtype: None + + :raises ValueError: If the stream is not open (i.e., `open()` has not + been called). + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before simple_flush().") + + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + flush=True, + ) + ) + self.bytes_appended_since_last_flush = 0 + + async def flush(self) -> int: + """Flushes the data to the server. + + :rtype: int + :returns: The persisted size after flush. + + :raises ValueError: If the stream is not open (i.e., `open()` has not + been called). + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before flush().") + + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest( + flush=True, + state_lookup=True, + ) + ) + response = await self.write_obj_stream.recv() + self.persisted_size = response.persisted_size + self.offset = self.persisted_size + self.bytes_appended_since_last_flush = 0 + return self.persisted_size + + async def close(self, finalize_on_close=False) -> Union[int, _storage_v2.Object]: + """Closes the underlying bidi-gRPC stream. + + :type finalize_on_close: bool + :param finalize_on_close: Finalizes the Appendable Object. No more data + can be appended. + + rtype: Union[int, _storage_v2.Object] + returns: Updated `self.persisted_size` by default after closing the + bidi-gRPC stream. However, if `finalize_on_close=True` is passed, + returns the finalized object resource. + + :raises ValueError: If the stream is not open (i.e., `open()` has not + been called). + + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before close().") + + if finalize_on_close: + return await self.finalize() + + await self.write_obj_stream.close() + + self._is_stream_open = False + self.offset = None + return self.persisted_size + + async def finalize(self) -> _storage_v2.Object: + """Finalizes the Appendable Object. + + Note: Once finalized no more data can be appended. + This method is different from `close`. if `.close()` is called data may + still be appended to object at a later point in time by opening with + generation number. + (i.e. `open(..., generation=)`. + However if `.finalize()` is called no more data can be appended to the + object. + + rtype: google.cloud.storage_v2.types.Object + returns: The finalized object resource. + + :raises ValueError: If the stream is not open (i.e., `open()` has not + been called). + """ + if not self._is_stream_open: + raise ValueError("Stream is not open. Call open() before finalize().") + + await self.write_obj_stream.send( + _storage_v2.BidiWriteObjectRequest(finish_write=True) + ) + response = await self.write_obj_stream.recv() + self.object_resource = response.resource + self.persisted_size = self.object_resource.size + await self.write_obj_stream.close() + + self._is_stream_open = False + self.offset = None + return self.object_resource + + @property + def is_stream_open(self) -> bool: + return self._is_stream_open + + # helper methods. + async def append_from_string(self, data: str): + """ + str data will be encoded to bytes using utf-8 encoding calling + + self.append(data.encode("utf-8")) + """ + raise NotImplementedError("append_from_string is not implemented yet.") + + async def append_from_stream(self, stream_obj): + """ + At a time read a chunk of data (16MiB) from `stream_obj` + and call self.append(chunk) + """ + raise NotImplementedError("append_from_stream is not implemented yet.") + + async def append_from_file( + self, file_obj: BufferedReader, block_size: int = _DEFAULT_FLUSH_INTERVAL_BYTES + ): + """ + Appends data to an Appendable Object using file_handle which is opened + for reading in binary mode. + + :type file_obj: file + :param file_obj: A file handle opened in binary mode for reading. + + """ + while block := file_obj.read(block_size): + await self.append(block) diff --git a/google/cloud/storage/asyncio/async_grpc_client.py b/google/cloud/storage/asyncio/async_grpc_client.py new file mode 100644 index 000000000..640e7fe38 --- /dev/null +++ b/google/cloud/storage/asyncio/async_grpc_client.py @@ -0,0 +1,223 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An async client for interacting with Google Cloud Storage using the gRPC API.""" + +from google.cloud import _storage_v2 as storage_v2 +from google.cloud._storage_v2.services.storage.transports.base import ( + DEFAULT_CLIENT_INFO, +) +from google.cloud.storage import __version__ + + +class AsyncGrpcClient: + """An asynchronous client for interacting with Google Cloud Storage using the gRPC API. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The OAuth2 Credentials to use for this + client. If not passed, falls back to the default + inferred from the environment. + + :type client_info: :class:`~google.api_core.client_info.ClientInfo` + :param client_info: + The client info used to send a user-agent string along with API + requests. If ``None``, then default info will be used. + + :type client_options: :class:`~google.api_core.client_options.ClientOptions` + :param client_options: (Optional) Client options used to set user options + on the client. + + :type attempt_direct_path: bool + :param attempt_direct_path: + (Optional) Whether to attempt to use DirectPath for gRPC connections. + Defaults to ``True``. + """ + + def __init__( + self, + credentials=None, + client_info=None, + client_options=None, + *, + attempt_direct_path=True, + ): + if client_info is None: + client_info = DEFAULT_CLIENT_INFO + client_info.client_library_version = __version__ + if client_info.user_agent is None: + client_info.user_agent = "" + agent_version = f"gcloud-python/{__version__}" + if agent_version not in client_info.user_agent: + client_info.user_agent += f" {agent_version} " + + self._grpc_client = self._create_async_grpc_client( + credentials=credentials, + client_info=client_info, + client_options=client_options, + attempt_direct_path=attempt_direct_path, + ) + + def _create_async_grpc_client( + self, + credentials=None, + client_info=None, + client_options=None, + attempt_direct_path=True, + ): + transport_cls = storage_v2.StorageAsyncClient.get_transport_class( + "grpc_asyncio" + ) + + primary_user_agent = client_info.to_user_agent() + + channel = transport_cls.create_channel( + attempt_direct_path=attempt_direct_path, + credentials=credentials, + options=(("grpc.primary_user_agent", primary_user_agent),), + ) + transport = transport_cls(channel=channel) + + return storage_v2.StorageAsyncClient( + transport=transport, + client_info=client_info, + client_options=client_options, + ) + + @property + def grpc_client(self): + """The underlying gRPC client. + + This property gives users direct access to the `_storage_v2.StorageAsyncClient` + instance. This can be useful for accessing + newly added or experimental RPCs that are not yet exposed through + the high-level GrpcClient. + Returns: + google.cloud._storage_v2.StorageAsyncClient: The configured GAPIC client. + """ + return self._grpc_client + + async def delete_object( + self, + bucket_name, + object_name, + generation=None, + if_generation_match=None, + if_generation_not_match=None, + if_metageneration_match=None, + if_metageneration_not_match=None, + **kwargs, + ): + """Deletes an object and its metadata. + + :type bucket_name: str + :param bucket_name: The name of the bucket in which the object resides. + + :type object_name: str + :param object_name: The name of the object to delete. + + :type generation: int + :param generation: + (Optional) If present, permanently deletes a specific generation + of an object. + + :type if_generation_match: int + :param if_generation_match: (Optional) + + :type if_generation_not_match: int + :param if_generation_not_match: (Optional) + + :type if_metageneration_match: int + :param if_metageneration_match: (Optional) + + :type if_metageneration_not_match: int + :param if_metageneration_not_match: (Optional) + + + """ + # The gRPC API requires the bucket name to be in the format "projects/_/buckets/bucket_name" + bucket_path = f"projects/_/buckets/{bucket_name}" + request = storage_v2.DeleteObjectRequest( + bucket=bucket_path, + object=object_name, + generation=generation, + if_generation_match=if_generation_match, + if_generation_not_match=if_generation_not_match, + if_metageneration_match=if_metageneration_match, + if_metageneration_not_match=if_metageneration_not_match, + **kwargs, + ) + await self._grpc_client.delete_object(request=request) + + async def get_object( + self, + bucket_name, + object_name, + generation=None, + if_generation_match=None, + if_generation_not_match=None, + if_metageneration_match=None, + if_metageneration_not_match=None, + soft_deleted=None, + **kwargs, + ): + """Retrieves an object's metadata. + + In the gRPC API, this is performed by the GetObject RPC, which + returns the object resource (metadata) without the object's data. + + :type bucket_name: str + :param bucket_name: The name of the bucket in which the object resides. + + :type object_name: str + :param object_name: The name of the object. + + :type generation: int + :param generation: + (Optional) If present, selects a specific generation of an object. + + :type if_generation_match: int + :param if_generation_match: (Optional) Precondition for object generation match. + + :type if_generation_not_match: int + :param if_generation_not_match: (Optional) Precondition for object generation mismatch. + + :type if_metageneration_match: int + :param if_metageneration_match: (Optional) Precondition for metageneration match. + + :type if_metageneration_not_match: int + :param if_metageneration_not_match: (Optional) Precondition for metageneration mismatch. + + :type soft_deleted: bool + :param soft_deleted: + (Optional) If True, return the soft-deleted version of this object. + + :rtype: :class:`google.cloud._storage_v2.types.Object` + :returns: The object metadata resource. + """ + bucket_path = f"projects/_/buckets/{bucket_name}" + + request = storage_v2.GetObjectRequest( + bucket=bucket_path, + object=object_name, + generation=generation, + if_generation_match=if_generation_match, + if_generation_not_match=if_generation_not_match, + if_metageneration_match=if_metageneration_match, + if_metageneration_not_match=if_metageneration_not_match, + soft_deleted=soft_deleted or False, + **kwargs, + ) + + # Calls the underlying GAPIC StorageAsyncClient.get_object method + return await self._grpc_client.get_object(request=request) diff --git a/google/cloud/storage/asyncio/async_multi_range_downloader.py b/google/cloud/storage/asyncio/async_multi_range_downloader.py new file mode 100644 index 000000000..3ee773a04 --- /dev/null +++ b/google/cloud/storage/asyncio/async_multi_range_downloader.py @@ -0,0 +1,526 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations +import asyncio +import logging +from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry +from google.cloud.storage.asyncio.retry._helpers import _handle_redirect +from google.rpc import status_pb2 + +from typing import List, Optional, Tuple, Any, Dict + +from ._utils import raise_if_no_fast_crc32c +from google.cloud.storage.asyncio.async_read_object_stream import ( + _AsyncReadObjectStream, +) +from google.cloud.storage.asyncio.async_grpc_client import ( + AsyncGrpcClient, +) +from google.cloud.storage.asyncio.retry.bidi_stream_retry_manager import ( + _BidiStreamRetryManager, +) +from google.cloud.storage.asyncio.retry.reads_resumption_strategy import ( + _ReadResumptionStrategy, + _DownloadState, +) + +from io import BytesIO +from google.cloud import _storage_v2 +from google.cloud.storage._helpers import generate_random_56_bit_integer + + +_MAX_READ_RANGES_PER_BIDI_READ_REQUEST = 100 +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) + +logger = logging.getLogger(__name__) + + +def _is_read_retryable(exc): + """Predicate to determine if a read operation should be retried.""" + if isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + ), + ): + return True + + if not isinstance(exc, exceptions.Aborted) or not exc.errors: + return False + + try: + grpc_error = exc.errors[0] + trailers = grpc_error.trailing_metadata() + if not trailers: + return False + + status_details_bin = next( + (v for k, v in trailers if k == "grpc-status-details-bin"), None + ) + + if not status_details_bin: + return False + + status_proto = status_pb2.Status() + status_proto.ParseFromString(status_details_bin) + return any( + detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL + for detail in status_proto.details + ) + except Exception as e: + logger.error(f"Error parsing status_details_bin: {e}") + return False + + +class AsyncMultiRangeDownloader: + """Provides an interface for downloading multiple ranges of a GCS ``Object`` + concurrently. + + Example usage: + + .. code-block:: python + + client = AsyncGrpcClient() + mrd = await AsyncMultiRangeDownloader.create_mrd( + client, bucket_name="chandrasiri-rs", object_name="test_open9" + ) + my_buff1 = open('my_fav_file.txt', 'wb') + my_buff2 = BytesIO() + my_buff3 = BytesIO() + my_buff4 = any_object_which_provides_BytesIO_like_interface() + await mrd.download_ranges( + [ + # (start_byte, bytes_to_read, writeable_buffer) + (0, 100, my_buff1), + (100, 20, my_buff2), + (200, 123, my_buff3), + (300, 789, my_buff4), + ] + ) + + # verify data in buffers... + assert my_buff2.getbuffer().nbytes == 20 + + + """ + + @classmethod + async def create_mrd( + cls, + client: AsyncGrpcClient, + bucket_name: str, + object_name: str, + generation: Optional[int] = None, + read_handle: Optional[_storage_v2.BidiReadHandle] = None, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + **kwargs, + ) -> AsyncMultiRangeDownloader: + """Initializes a MultiRangeDownloader and opens the underlying bidi-gRPC + object for reading. + + :type client: :class:`~google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient` + :param client: The asynchronous client to use for making API requests. + + :type bucket_name: str + :param bucket_name: The name of the bucket containing the object. + + :type object_name: str + :param object_name: The name of the object to be read. + + :type generation: int + :param generation: (Optional) If present, selects a specific + revision of this object. + + :type read_handle: _storage_v2.BidiReadHandle + :param read_handle: (Optional) An existing handle for reading the object. + If provided, opening the bidi-gRPC connection will be faster. + + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the ``open`` operation. + + :type metadata: List[Tuple[str, str]] + :param metadata: (Optional) The metadata to be sent with the ``open`` request. + + :rtype: :class:`~google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader` + :returns: An initialized AsyncMultiRangeDownloader instance for reading. + """ + mrd = cls( + client, + bucket_name, + object_name, + generation=generation, + read_handle=read_handle, + **kwargs, + ) + await mrd.open(retry_policy=retry_policy, metadata=metadata) + return mrd + + def __init__( + self, + client: AsyncGrpcClient, + bucket_name: str, + object_name: str, + generation: Optional[int] = None, + read_handle: Optional[_storage_v2.BidiReadHandle] = None, + **kwargs, + ) -> None: + """Constructor for AsyncMultiRangeDownloader, clients are not adviced to + use it directly. Instead it's adviced to use the classmethod `create_mrd`. + + :type client: :class:`~google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient` + :param client: The asynchronous client to use for making API requests. + + :type bucket_name: str + :param bucket_name: The name of the bucket containing the object. + + :type object_name: str + :param object_name: The name of the object to be read. + + :type generation: int + :param generation: (Optional) If present, selects a specific revision of + this object. + + :type read_handle: _storage_v2.BidiReadHandle + :param read_handle: (Optional) An existing read handle. + """ + if "generation_number" in kwargs: + if generation is not None: + raise TypeError( + "Cannot set both 'generation' and 'generation_number'. " + "Use 'generation' for new code." + ) + logger.warning( + "'generation_number' is deprecated and will be removed in a future " + "major release. Please use 'generation' instead." + ) + generation = kwargs.pop("generation_number") + + raise_if_no_fast_crc32c() + + self.client = client + self.bucket_name = bucket_name + self.object_name = object_name + self.generation = generation + self.read_handle: Optional[_storage_v2.BidiReadHandle] = read_handle + self.read_obj_str: Optional[_AsyncReadObjectStream] = None + self._is_stream_open: bool = False + self._routing_token: Optional[str] = None + self._read_id_to_writable_buffer_dict = {} + self._read_id_to_download_ranges_id = {} + self._download_ranges_id_to_pending_read_ids = {} + self.persisted_size: Optional[int] = None # updated after opening the stream + + async def __aenter__(self): + """Opens the underlying bidi-gRPC connection to read from the object.""" + await self.open() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Closes the underlying bidi-gRPC connection.""" + if self.is_stream_open: + await self.close() + + def _on_open_error(self, exc): + """Extracts routing token and read handle on redirect error during open.""" + routing_token, read_handle = _handle_redirect(exc) + if routing_token: + self._routing_token = routing_token + if read_handle: + self.read_handle = read_handle + + async def open( + self, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Opens the bidi-gRPC connection to read from the object.""" + if self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is already open") + + if retry_policy is None: + retry_policy = AsyncRetry( + predicate=_is_read_retryable, on_error=self._on_open_error + ) + else: + original_on_error = retry_policy._on_error + + def combined_on_error(exc): + self._on_open_error(exc) + if original_on_error: + original_on_error(exc) + + retry_policy = AsyncRetry( + predicate=_is_read_retryable, + initial=retry_policy._initial, + maximum=retry_policy._maximum, + multiplier=retry_policy._multiplier, + deadline=retry_policy._deadline, + on_error=combined_on_error, + ) + + async def _do_open(): + current_metadata = list(metadata) if metadata else [] + + # Cleanup stream from previous failed attempt, if any. + if self.read_obj_str: + if self.read_obj_str.is_stream_open: + try: + await self.read_obj_str.close() + except exceptions.GoogleAPICallError as e: + logger.warning( + f"Failed to close existing stream during resumption: {e}" + ) + self.read_obj_str = None + self._is_stream_open = False + + self.read_obj_str = _AsyncReadObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + read_handle=self.read_handle, + ) + + if self._routing_token: + current_metadata.append( + ("x-goog-request-params", f"routing_token={self._routing_token}") + ) + self._routing_token = None + + await self.read_obj_str.open( + metadata=current_metadata if current_metadata else None + ) + + if self.read_obj_str.generation_number: + self.generation = self.read_obj_str.generation_number + if self.read_obj_str.read_handle: + self.read_handle = self.read_obj_str.read_handle + if self.read_obj_str.persisted_size is not None: + self.persisted_size = self.read_obj_str.persisted_size + + self._is_stream_open = True + + await retry_policy(_do_open)() + + async def download_ranges( + self, + read_ranges: List[Tuple[int, int, BytesIO]], + lock: asyncio.Lock = None, + retry_policy: Optional[AsyncRetry] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + ) -> None: + """Downloads multiple byte ranges from the object into the buffers + provided by user with automatic retries. + + :type read_ranges: List[Tuple[int, int, "BytesIO"]] + :param read_ranges: A list of tuples, where each tuple represents a + combination of byte_range and writeable buffer in format - + (`start_byte`, `bytes_to_read`, `writeable_buffer`). Buffer has + to be provided by the user, and user has to make sure appropriate + memory is available in the application to avoid out-of-memory crash. + + Special cases: + if the value of `bytes_to_read` is 0, it'll be interpreted as + download all contents until the end of the file from `start_byte`. + Examples: + * (0, 0, buffer) : downloads 0 to end , i.e. entire object. + * (100, 0, buffer) : downloads from 100 to end. + + + :type lock: asyncio.Lock + :param lock: (Optional) An asyncio lock to synchronize sends and recvs + on the underlying bidi-GRPC stream. This is required when multiple + coroutines are calling this method concurrently. + + i.e. Example usage with multiple coroutines: + + ``` + lock = asyncio.Lock() + task1 = asyncio.create_task(mrd.download_ranges(ranges1, lock)) + task2 = asyncio.create_task(mrd.download_ranges(ranges2, lock)) + await asyncio.gather(task1, task2) + + ``` + + If user want to call this method serially from multiple coroutines, + then providing a lock is not necessary. + + ``` + await mrd.download_ranges(ranges1) + await mrd.download_ranges(ranges2) + + # ... some other code code... + + ``` + + :type retry_policy: :class:`~google.api_core.retry_async.AsyncRetry` + :param retry_policy: (Optional) The retry policy to use for the operation. + + :raises ValueError: if the underlying bidi-GRPC stream is not open. + :raises ValueError: if the length of read_ranges is more than 1000. + :raises DataCorruption: if a checksum mismatch is detected while reading data. + + """ + + if len(read_ranges) > 1000: + raise ValueError( + "Invalid input - length of read_ranges cannot be more than 1000" + ) + + if not self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is not open") + + if lock is None: + lock = asyncio.Lock() + + if retry_policy is None: + retry_policy = AsyncRetry(predicate=_is_read_retryable) + + # Initialize Global State for Retry Strategy + download_states = {} + for read_range in read_ranges: + read_id = generate_random_56_bit_integer() + download_states[read_id] = _DownloadState( + initial_offset=read_range[0], + initial_length=read_range[1], + user_buffer=read_range[2], + ) + + initial_state = { + "download_states": download_states, + "read_handle": self.read_handle, + "routing_token": None, + } + + # Track attempts to manage stream reuse + attempt_count = 0 + + def send_ranges_and_get_bytes( + requests: List[_storage_v2.ReadRange], + state: Dict[str, Any], + metadata: Optional[List[Tuple[str, str]]] = None, + ): + async def generator(): + nonlocal attempt_count + attempt_count += 1 + + if attempt_count > 1: + logger.info( + f"Resuming download (attempt {attempt_count - 1}) for {len(requests)} ranges." + ) + + async with lock: + current_handle = state.get("read_handle") + current_token = state.get("routing_token") + + # We reopen if it's a redirect (token exists) OR if this is a retry + # (not first attempt). This prevents trying to send data on a dead + # stream from a previous failed attempt. + should_reopen = ( + (attempt_count > 1) + or (current_token is not None) + or (metadata is not None) + ) + + if should_reopen: + if current_token: + logger.info( + f"Re-opening stream with routing token: {current_token}" + ) + # Close existing stream if any + if self.read_obj_str and self.read_obj_str.is_stream_open: + await self.read_obj_str.close() + + # Re-initialize stream + self.read_obj_str = _AsyncReadObjectStream( + client=self.client.grpc_client, + bucket_name=self.bucket_name, + object_name=self.object_name, + generation_number=self.generation, + read_handle=current_handle, + ) + + # Inject routing_token into metadata if present + current_metadata = list(metadata) if metadata else [] + if current_token: + current_metadata.append( + ( + "x-goog-request-params", + f"routing_token={current_token}", + ) + ) + + await self.read_obj_str.open( + metadata=current_metadata if current_metadata else None + ) + self._is_stream_open = True + + pending_read_ids = {r.read_id for r in requests} + + # Send Requests + for i in range( + 0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST + ): + batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST] + await self.read_obj_str.send( + _storage_v2.BidiReadObjectRequest(read_ranges=batch) + ) + + while pending_read_ids: + response = await self.read_obj_str.recv() + if response is None: + break + if response.object_data_ranges: + for data_range in response.object_data_ranges: + if data_range.range_end: + pending_read_ids.discard( + data_range.read_range.read_id + ) + yield response + + return generator() + + strategy = _ReadResumptionStrategy() + retry_manager = _BidiStreamRetryManager( + strategy, lambda r, s: send_ranges_and_get_bytes(r, s, metadata=metadata) + ) + + await retry_manager.execute(initial_state, retry_policy) + + if initial_state.get("read_handle"): + self.read_handle = initial_state["read_handle"] + + async def close(self): + """ + Closes the underlying bidi-gRPC connection. + """ + if not self._is_stream_open: + raise ValueError("Underlying bidi-gRPC stream is not open") + + if self.read_obj_str: + await self.read_obj_str.close() + self.read_obj_str = None + self._is_stream_open = False + + @property + def is_stream_open(self) -> bool: + return self._is_stream_open diff --git a/google/cloud/storage/asyncio/async_read_object_stream.py b/google/cloud/storage/asyncio/async_read_object_stream.py new file mode 100644 index 000000000..d456f16cc --- /dev/null +++ b/google/cloud/storage/asyncio/async_read_object_stream.py @@ -0,0 +1,188 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_abstract_object_stream import ( + _AsyncAbstractObjectStream, +) + +from google.api_core.bidi_async import AsyncBidiRpc + + +class _AsyncReadObjectStream(_AsyncAbstractObjectStream): + """Class representing a gRPC bidi-stream for reading data from a GCS ``Object``. + + This class provides a unix socket-like interface to a GCS ``Object``, with + methods like ``open``, ``close``, ``send``, and ``recv``. + + :type client: :class:`~google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` + :param client: async grpc client to use for making API requests. + + :type bucket_name: str + :param bucket_name: The name of the GCS ``bucket`` containing the object. + + :type object_name: str + :param object_name: The name of the GCS ``object`` to be read. + + :type generation_number: int + :param generation_number: (Optional) If present, selects a specific revision of + this object. + + :type read_handle: _storage_v2.BidiReadHandle + :param read_handle: (Optional) An existing handle for reading the object. + If provided, opening the bidi-gRPC connection will be faster. + """ + + def __init__( + self, + client: AsyncGrpcClient.grpc_client, + bucket_name: str, + object_name: str, + generation_number: Optional[int] = None, + read_handle: Optional[_storage_v2.BidiReadHandle] = None, + ) -> None: + if client is None: + raise ValueError("client must be provided") + if bucket_name is None: + raise ValueError("bucket_name must be provided") + if object_name is None: + raise ValueError("object_name must be provided") + + super().__init__( + bucket_name=bucket_name, + object_name=object_name, + generation_number=generation_number, + ) + self.client: AsyncGrpcClient.grpc_client = client + self.read_handle: Optional[_storage_v2.BidiReadHandle] = read_handle + + self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}" + + self.rpc = self.client._client._transport._wrapped_methods[ + self.client._client._transport.bidi_read_object + ] + self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),) + self.socket_like_rpc: Optional[AsyncBidiRpc] = None + self._is_stream_open: bool = False + self.persisted_size: Optional[int] = None + + async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: + """Opens the bidi-gRPC connection to read from the object. + + This method sends an initial request to start the stream and receives + the first response containing metadata and a read handle. + + Args: + metadata (Optional[List[Tuple[str, str]]]): Additional metadata + to send with the initial stream request, e.g., for routing tokens. + """ + if self._is_stream_open: + raise ValueError("Stream is already open") + + read_handle = self.read_handle if self.read_handle else None + + read_object_spec = _storage_v2.BidiReadObjectSpec( + bucket=self._full_bucket_name, + object=self.object_name, + generation=self.generation_number if self.generation_number else None, + read_handle=read_handle, + ) + self.first_bidi_read_req = _storage_v2.BidiReadObjectRequest( + read_object_spec=read_object_spec + ) + + # Build the x-goog-request-params header + request_params = [f"bucket={self._full_bucket_name}"] + other_metadata = [] + if metadata: + for key, value in metadata: + if key == "x-goog-request-params": + request_params.append(value) + else: + other_metadata.append((key, value)) + + current_metadata = other_metadata + current_metadata.append(("x-goog-request-params", ",".join(request_params))) + + self.socket_like_rpc = AsyncBidiRpc( + self.rpc, + initial_request=self.first_bidi_read_req, + metadata=current_metadata, + ) + await self.socket_like_rpc.open() # this is actually 1 send + response = await self.socket_like_rpc.recv() + # populated only in the first response of bidi-stream and when opened + # without using `read_handle` + if hasattr(response, "metadata") and response.metadata: + if self.generation_number is None: + self.generation_number = response.metadata.generation + # update persisted size + self.persisted_size = response.metadata.size + + if response and response.read_handle: + self.read_handle = response.read_handle + + self._is_stream_open = True + + async def close(self) -> None: + """Closes the bidi-gRPC connection.""" + if not self._is_stream_open: + raise ValueError("Stream is not open") + await self.requests_done() + await self.socket_like_rpc.close() + self._is_stream_open = False + + async def requests_done(self): + """Signals that all requests have been sent.""" + + await self.socket_like_rpc.send(None) + await self.socket_like_rpc.recv() + + async def send( + self, bidi_read_object_request: _storage_v2.BidiReadObjectRequest + ) -> None: + """Sends a request message on the stream. + + Args: + bidi_read_object_request (:class:`~google.cloud._storage_v2.types.BidiReadObjectRequest`): + The request message to send. This is typically used to specify + the read offset and limit. + """ + if not self._is_stream_open: + raise ValueError("Stream is not open") + await self.socket_like_rpc.send(bidi_read_object_request) + + async def recv(self) -> _storage_v2.BidiReadObjectResponse: + """Receives a response from the stream. + + This method waits for the next message from the server, which could + contain object data or metadata. + + Returns: + :class:`~google.cloud._storage_v2.types.BidiReadObjectResponse`: + The response message from the server. + """ + if not self._is_stream_open: + raise ValueError("Stream is not open") + response = await self.socket_like_rpc.recv() + # Update read_handle if present in response + if response and response.read_handle: + self.read_handle = response.read_handle + return response + + @property + def is_stream_open(self) -> bool: + return self._is_stream_open diff --git a/google/cloud/storage/asyncio/async_write_object_stream.py b/google/cloud/storage/asyncio/async_write_object_stream.py new file mode 100644 index 000000000..319f394dd --- /dev/null +++ b/google/cloud/storage/asyncio/async_write_object_stream.py @@ -0,0 +1,236 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple +import grpc +from google.cloud import _storage_v2 +from google.cloud.storage.asyncio import _utils +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_abstract_object_stream import ( + _AsyncAbstractObjectStream, +) +from google.api_core.bidi_async import AsyncBidiRpc + + +class _AsyncWriteObjectStream(_AsyncAbstractObjectStream): + """Class representing a gRPC bidi-stream for writing data from a GCS + ``Appendable Object``. + + This class provides a unix socket-like interface to a GCS ``Object``, with + methods like ``open``, ``close``, ``send``, and ``recv``. + + :type client: :class:`~google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client` + :param client: async grpc client to use for making API requests. + + :type bucket_name: str + :param bucket_name: The name of the GCS ``bucket`` containing the object. + + :type object_name: str + :param object_name: The name of the GCS ``Appendable Object`` to be write. + + :type generation_number: int + :param generation_number: (Optional) If present, creates writer for that + specific revision of that object. Use this to append data to an + existing Appendable Object. + + Setting to ``0`` makes the `writer.open()` succeed only if + object doesn't exist in the bucket (useful for not accidentally + overwriting existing objects). + + Warning: If `None`, a new object is created. If an object with the + same name already exists, it will be overwritten the moment + `writer.open()` is called. + + :type write_handle: _storage_v2.BidiWriteHandle + :param write_handle: (Optional) An existing handle for writing the object. + If provided, opening the bidi-gRPC connection will be faster. + """ + + def __init__( + self, + client: AsyncGrpcClient.grpc_client, + bucket_name: str, + object_name: str, + generation_number: Optional[int] = None, # None means new object + write_handle: Optional[_storage_v2.BidiWriteHandle] = None, + routing_token: Optional[str] = None, + ) -> None: + if client is None: + raise ValueError("client must be provided") + if bucket_name is None: + raise ValueError("bucket_name must be provided") + if object_name is None: + raise ValueError("object_name must be provided") + + super().__init__( + bucket_name=bucket_name, + object_name=object_name, + generation_number=generation_number, + ) + self.client: AsyncGrpcClient.grpc_client = client + self.write_handle: Optional[_storage_v2.BidiWriteHandle] = write_handle + self.routing_token: Optional[str] = routing_token + + self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}" + + self.rpc = self.client._client._transport._wrapped_methods[ + self.client._client._transport.bidi_write_object + ] + + self.metadata = (("x-goog-request-params", f"bucket={self._full_bucket_name}"),) + self.socket_like_rpc: Optional[AsyncBidiRpc] = None + self._is_stream_open: bool = False + self.first_bidi_write_req = None + self.persisted_size = 0 + self.object_resource: Optional[_storage_v2.Object] = None + + async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None: + """ + Opens the bidi-gRPC connection to write to the object. + + This method sends an initial request to start the stream and receives + the first response containing metadata and a write handle. + + :rtype: None + :raises ValueError: If the stream is already open. + :raises google.api_core.exceptions.FailedPrecondition: + if `generation_number` is 0 and object already exists. + """ + if self._is_stream_open: + raise ValueError("Stream is already open") + + # Create a new object or overwrite existing one if generation_number + # is None. This makes it consistent with GCS JSON API behavior. + # Created object type would be Appendable Object. + # if `generation_number` == 0 new object will be created only if there + # isn't any existing object. + if self.generation_number is None or self.generation_number == 0: + self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest( + write_object_spec=_storage_v2.WriteObjectSpec( + resource=_storage_v2.Object( + name=self.object_name, bucket=self._full_bucket_name + ), + appendable=True, + if_generation_match=self.generation_number, + ), + ) + else: + self.first_bidi_write_req = _storage_v2.BidiWriteObjectRequest( + append_object_spec=_storage_v2.AppendObjectSpec( + bucket=self._full_bucket_name, + object=self.object_name, + generation=self.generation_number, + write_handle=self.write_handle if self.write_handle else None, + routing_token=self.routing_token if self.routing_token else None, + ), + ) + + request_param_values = [f"bucket={self._full_bucket_name}"] + final_metadata = [] + if metadata: + for key, value in metadata: + if key == "x-goog-request-params": + request_param_values.append(value) + else: + final_metadata.append((key, value)) + + final_metadata.append(("x-goog-request-params", ",".join(request_param_values))) + + self.socket_like_rpc = AsyncBidiRpc( + self.rpc, + initial_request=self.first_bidi_write_req, + metadata=final_metadata, + ) + + await self.socket_like_rpc.open() # this is actually 1 send + response = await self.socket_like_rpc.recv() + self._is_stream_open = True + + if response.persisted_size: + self.persisted_size = response.persisted_size + + if response.resource: + if not response.resource.size: + # Appending to a 0 byte appendable object. + self.persisted_size = 0 + else: + self.persisted_size = response.resource.size + + self.generation_number = response.resource.generation + + if response.write_handle: + self.write_handle = response.write_handle + + async def close(self) -> None: + """Closes the bidi-gRPC connection.""" + if not self._is_stream_open: + raise ValueError("Stream is not open") + await self.requests_done() + await self.socket_like_rpc.close() + self._is_stream_open = False + + async def requests_done(self): + """Signals that all requests have been sent.""" + await self.socket_like_rpc.send(None) + + # The server may send a final "EOF" response immediately, or it may + # first send an intermediate response followed by the EOF response depending on whether the object was finalized or not. + first_resp = await self.socket_like_rpc.recv() + _utils.update_write_handle_if_exists(self, first_resp) + + if first_resp != grpc.aio.EOF: + self.persisted_size = first_resp.persisted_size + second_resp = await self.socket_like_rpc.recv() + assert second_resp == grpc.aio.EOF + + async def send( + self, bidi_write_object_request: _storage_v2.BidiWriteObjectRequest + ) -> None: + """Sends a request message on the stream. + + Args: + bidi_write_object_request (:class:`~google.cloud._storage_v2.types.BidiReadObjectRequest`): + The request message to send. This is typically used to specify + the read offset and limit. + """ + if not self._is_stream_open: + raise ValueError("Stream is not open") + await self.socket_like_rpc.send(bidi_write_object_request) + + async def recv(self) -> _storage_v2.BidiWriteObjectResponse: + """Receives a response from the stream. + + This method waits for the next message from the server, which could + contain object data or metadata. + + Returns: + :class:`~google.cloud._storage_v2.types.BidiWriteObjectResponse`: + The response message from the server. + """ + if not self._is_stream_open: + raise ValueError("Stream is not open") + response = await self.socket_like_rpc.recv() + # Update write_handle if present in response + if response: + if response.write_handle: + self.write_handle = response.write_handle + if response.persisted_size is not None: + self.persisted_size = response.persisted_size + if response.resource and response.resource.size: + self.persisted_size = response.resource.size + return response + + @property + def is_stream_open(self) -> bool: + return self._is_stream_open diff --git a/google/cloud/storage/asyncio/retry/_helpers.py b/google/cloud/storage/asyncio/retry/_helpers.py new file mode 100644 index 000000000..d9ad2462e --- /dev/null +++ b/google/cloud/storage/asyncio/retry/_helpers.py @@ -0,0 +1,125 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from typing import Tuple, Optional + +from google.api_core import exceptions +from google.cloud._storage_v2.types import ( + BidiReadObjectRedirectedError, + BidiWriteObjectRedirectedError, +) +from google.rpc import status_pb2 + +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) +_BIDI_WRITE_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" +) +logger = logging.getLogger(__name__) + + +def _handle_redirect( + exc: Exception, +) -> Tuple[Optional[str], Optional[bytes]]: + """ + Extracts routing token and read handle from a gRPC error. + + :type exc: Exception + :param exc: The exception to parse. + + :rtype: Tuple[Optional[str], Optional[bytes]] + :returns: A tuple of (routing_token, read_handle). + """ + routing_token = None + read_handle = None + + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + + if grpc_error: + if isinstance(grpc_error, BidiReadObjectRedirectedError): + routing_token = grpc_error.routing_token + if grpc_error.read_handle: + read_handle = grpc_error.read_handle + return routing_token, read_handle + + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return None, None + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_READ_REDIRECTED_TYPE_URL: + redirect_proto = BidiReadObjectRedirectedError.deserialize( + detail.value + ) + if redirect_proto.routing_token: + routing_token = redirect_proto.routing_token + if redirect_proto.read_handle: + read_handle = redirect_proto.read_handle + break + except Exception as e: + logger.error(f"Error unpacking redirect: {e}") + + return routing_token, read_handle + + +def _extract_bidi_writes_redirect_proto(exc: Exception): + grpc_error = None + if isinstance(exc, exceptions.Aborted) and exc.errors: + grpc_error = exc.errors[0] + + if grpc_error: + if isinstance(grpc_error, BidiWriteObjectRedirectedError): + return grpc_error + + if hasattr(grpc_error, "trailing_metadata"): + trailers = grpc_error.trailing_metadata() + if not trailers: + return + + status_details_bin = None + for key, value in trailers: + if key == "grpc-status-details-bin": + status_details_bin = value + break + + if status_details_bin: + status_proto = status_pb2.Status() + try: + status_proto.ParseFromString(status_details_bin) + for detail in status_proto.details: + if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL: + redirect_proto = BidiWriteObjectRedirectedError.deserialize( + detail.value + ) + return redirect_proto + except Exception: + logger.error("Error unpacking redirect details from gRPC error.") + pass diff --git a/google/cloud/storage/asyncio/retry/base_strategy.py b/google/cloud/storage/asyncio/retry/base_strategy.py new file mode 100644 index 000000000..ff193f109 --- /dev/null +++ b/google/cloud/storage/asyncio/retry/base_strategy.py @@ -0,0 +1,83 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Iterable + + +class _BaseResumptionStrategy(abc.ABC): + """Abstract base class defining the interface for a bidi stream resumption strategy. + + This class defines the skeleton for a pluggable strategy that contains + all the service-specific logic for a given bidi operation (e.g., reads + or writes). This allows a generic retry manager to handle the common + retry loop while sending the state management and request generation + to a concrete implementation of this class. + """ + + @abc.abstractmethod + def generate_requests(self, state: Any) -> Iterable[Any]: + """Generates the next batch of requests based on the current state. + + This method is called at the beginning of each retry attempt. It should + inspect the provided state object and generate the appropriate list of + request protos to send to the server. For example, a read strategy + would use this to implement "Smarter Resumption" by creating smaller + `ReadRange` requests for partially downloaded ranges. For bidi-writes, + it will set the `write_offset` field to the persisted size received + from the server in the next request. + + :type state: Any + :param state: An object containing all the state needed for the + operation (e.g., requested ranges, user buffers, + bytes written). + """ + pass + + @abc.abstractmethod + def update_state_from_response(self, response: Any, state: Any) -> None: + """Updates the state based on a successful server response. + + This method is called for every message received from the server. It is + responsible for processing the response and updating the shared state + object. + + :type response: Any + :param response: The response message received from the server. + + :type state: Any + :param state: The shared state object for the operation, which will be + mutated by this method. + """ + pass + + @abc.abstractmethod + async def recover_state_on_failure(self, error: Exception, state: Any) -> None: + """Prepares the state for the next retry attempt after a failure. + + This method is called when a retriable gRPC error occurs. It is + responsible for performing any necessary actions to ensure the next + retry attempt can succeed. For bidi reads, its primary role is to + handle the `BidiReadObjectRedirectError` by extracting the + `routing_token` and updating the state. For bidi writes, it will update + the state to reflect any bytes that were successfully persisted before + the failure. + + :type error: :class:`Exception` + :param error: The exception that was caught by the retry engine. + + :type state: Any + :param state: The shared state object for the operation. + """ + pass diff --git a/google/cloud/storage/asyncio/retry/bidi_stream_retry_manager.py b/google/cloud/storage/asyncio/retry/bidi_stream_retry_manager.py new file mode 100644 index 000000000..23bffb63d --- /dev/null +++ b/google/cloud/storage/asyncio/retry/bidi_stream_retry_manager.py @@ -0,0 +1,69 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, AsyncIterator, Callable + +from google.cloud.storage.asyncio.retry.base_strategy import ( + _BaseResumptionStrategy, +) + +logger = logging.getLogger(__name__) + + +class _BidiStreamRetryManager: + """Manages the generic retry loop for a bidi streaming operation.""" + + def __init__( + self, + strategy: _BaseResumptionStrategy, + send_and_recv: Callable[..., AsyncIterator[Any]], + ): + """Initializes the retry manager. + Args: + strategy: The strategy for managing the state of a specific + bidi operation (e.g., reads or writes). + send_and_recv: An async callable that opens a new gRPC stream. + """ + self._strategy = strategy + self._send_and_recv = send_and_recv + + async def execute(self, initial_state: Any, retry_policy): + """ + Executes the bidi operation with the configured retry policy. + Args: + initial_state: An object containing all state for the operation. + retry_policy: The `google.api_core.retry.AsyncRetry` object to + govern the retry behavior for this specific operation. + """ + state = initial_state + + async def attempt(): + requests = self._strategy.generate_requests(state) + stream = self._send_and_recv(requests, state) + try: + async for response in stream: + self._strategy.update_state_from_response(response, state) + return + except Exception as e: + if retry_policy._predicate(e): + logger.info( + f"Bidi stream operation failed: {e}. Attempting state recovery and retry." + ) + await self._strategy.recover_state_on_failure(e, state) + raise e + + wrapped_attempt = retry_policy(attempt) + + await wrapped_attempt() diff --git a/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py b/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py new file mode 100644 index 000000000..468954332 --- /dev/null +++ b/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py @@ -0,0 +1,157 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, IO +import logging + +from google_crc32c import Checksum +from google.cloud import _storage_v2 as storage_v2 +from google.cloud.storage.exceptions import DataCorruption +from google.cloud.storage.asyncio.retry._helpers import ( + _handle_redirect, +) +from google.cloud.storage.asyncio.retry.base_strategy import ( + _BaseResumptionStrategy, +) + + +_BIDI_READ_REDIRECTED_TYPE_URL = ( + "type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError" +) +logger = logging.getLogger(__name__) + + +class _DownloadState: + """A helper class to track the state of a single range download.""" + + def __init__( + self, initial_offset: int, initial_length: int, user_buffer: IO[bytes] + ): + self.initial_offset = initial_offset + self.initial_length = initial_length + self.user_buffer = user_buffer + self.bytes_written = 0 + self.next_expected_offset = initial_offset + self.is_complete = False + + +class _ReadResumptionStrategy(_BaseResumptionStrategy): + """The concrete resumption strategy for bidi reads.""" + + def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]: + """Generates new ReadRange requests for all incomplete downloads. + + :type state: dict + :param state: A dictionary mapping a read_id to its corresponding + _DownloadState object. + """ + pending_requests = [] + download_states: Dict[int, _DownloadState] = state["download_states"] + + for read_id, read_state in download_states.items(): + if not read_state.is_complete: + new_offset = read_state.initial_offset + read_state.bytes_written + + # Calculate remaining length. If initial_length is 0 (read to end), + # it stays 0. Otherwise, subtract bytes_written. + new_length = 0 + if read_state.initial_length > 0: + new_length = read_state.initial_length - read_state.bytes_written + + new_request = storage_v2.ReadRange( + read_offset=new_offset, + read_length=new_length, + read_id=read_id, + ) + pending_requests.append(new_request) + return pending_requests + + def update_state_from_response( + self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any] + ) -> None: + """Processes a server response, performs integrity checks, and updates state.""" + + # Capture read_handle if provided. + if response.read_handle: + state["read_handle"] = response.read_handle + + download_states = state["download_states"] + + for object_data_range in response.object_data_ranges: + # Ignore empty ranges or ranges for IDs not in our state + # (e.g., from a previously cancelled request on the same stream). + if not object_data_range.read_range: + logger.warning( + "Received response with missing read_range field; ignoring." + ) + continue + + read_id = object_data_range.read_range.read_id + if read_id not in download_states: + logger.warning( + f"Received data for unknown or stale read_id {read_id}; ignoring." + ) + continue + + read_state = download_states[read_id] + + # Offset Verification + chunk_offset = object_data_range.read_range.read_offset + if chunk_offset != read_state.next_expected_offset: + raise DataCorruption( + response, + f"Offset mismatch for read_id {read_id}. " + f"Expected {read_state.next_expected_offset}, got {chunk_offset}", + ) + + # Checksum Verification + # We must validate data before updating state or writing to buffer. + data = object_data_range.checksummed_data.content + server_checksum = object_data_range.checksummed_data.crc32c + + if server_checksum is not None: + client_checksum = int.from_bytes(Checksum(data).digest(), "big") + if server_checksum != client_checksum: + raise DataCorruption( + response, + f"Checksum mismatch for read_id {read_id}. " + f"Server sent {server_checksum}, client calculated {client_checksum}.", + ) + + # Update State & Write Data + chunk_size = len(data) + read_state.user_buffer.write(data) + read_state.bytes_written += chunk_size + read_state.next_expected_offset += chunk_size + + # Final Byte Count Verification + if object_data_range.range_end: + read_state.is_complete = True + if ( + read_state.initial_length != 0 + and read_state.bytes_written > read_state.initial_length + ): + raise DataCorruption( + response, + f"Byte count mismatch for read_id {read_id}. " + f"Expected {read_state.initial_length}, got {read_state.bytes_written}", + ) + + async def recover_state_on_failure(self, error: Exception, state: Any) -> None: + """Handles BidiReadObjectRedirectedError for reads.""" + routing_token, read_handle = _handle_redirect(error) + if routing_token: + state["routing_token"] = routing_token + if read_handle: + state["read_handle"] = read_handle diff --git a/google/cloud/storage/asyncio/retry/writes_resumption_strategy.py b/google/cloud/storage/asyncio/retry/writes_resumption_strategy.py new file mode 100644 index 000000000..b98b9b2e7 --- /dev/null +++ b/google/cloud/storage/asyncio/retry/writes_resumption_strategy.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, IO, List, Optional, Union + +import google_crc32c +from google.cloud._storage_v2.types import storage as storage_type +from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError +from google.cloud.storage.asyncio.retry.base_strategy import ( + _BaseResumptionStrategy, +) +from google.cloud.storage.asyncio.retry._helpers import ( + _extract_bidi_writes_redirect_proto, +) + + +class _WriteState: + """A helper class to track the state of a single upload operation. + + :type chunk_size: int + :param chunk_size: The size of chunks to write to the server. + + :type user_buffer: IO[bytes] + :param user_buffer: The data source. + + :type flush_interval: int + :param flush_interval: The flush interval at which the data is flushed. + """ + + def __init__( + self, + chunk_size: int, + user_buffer: IO[bytes], + flush_interval: int, + ): + self.chunk_size = chunk_size + self.user_buffer = user_buffer + self.persisted_size: int = 0 + self.bytes_sent: int = 0 + self.bytes_since_last_flush: int = 0 + self.flush_interval: int = flush_interval + self.write_handle: Union[bytes, storage_type.BidiWriteHandle, None] = None + self.routing_token: Optional[str] = None + self.is_finalized: bool = False + + +class _WriteResumptionStrategy(_BaseResumptionStrategy): + """The concrete resumption strategy for bidi writes.""" + + def generate_requests( + self, state: Dict[str, Any] + ) -> List[storage_type.BidiWriteObjectRequest]: + """Generates BidiWriteObjectRequests to resume or continue the upload. + + This method is not applicable for `open` methods. + """ + write_state: _WriteState = state["write_state"] + + requests = [] + # The buffer should already be seeked to the correct position (persisted_size) + # by the `recover_state_on_failure` method before this is called. + while not write_state.is_finalized: + chunk = write_state.user_buffer.read(write_state.chunk_size) + + # End of File detection + if not chunk: + break + + checksummed_data = storage_type.ChecksummedData(content=chunk) + checksum = google_crc32c.Checksum(chunk) + checksummed_data.crc32c = int.from_bytes(checksum.digest(), "big") + + request = storage_type.BidiWriteObjectRequest( + write_offset=write_state.bytes_sent, + checksummed_data=checksummed_data, + ) + chunk_len = len(chunk) + write_state.bytes_sent += chunk_len + write_state.bytes_since_last_flush += chunk_len + + if write_state.bytes_since_last_flush >= write_state.flush_interval: + request.flush = True + # reset counter after marking flush + write_state.bytes_since_last_flush = 0 + + requests.append(request) + return requests + + def update_state_from_response( + self, response: storage_type.BidiWriteObjectResponse, state: Dict[str, Any] + ) -> None: + """Processes a server response and updates the write state.""" + write_state: _WriteState = state["write_state"] + if response is None: + return + if response.persisted_size: + write_state.persisted_size = response.persisted_size + + if response.write_handle: + write_state.write_handle = response.write_handle + + if response.resource: + write_state.persisted_size = response.resource.size + if response.resource.finalize_time: + write_state.is_finalized = True + + async def recover_state_on_failure( + self, error: Exception, state: Dict[str, Any] + ) -> None: + """ + Handles errors, specifically BidiWriteObjectRedirectedError, and rewinds state. + + This method rewinds the user buffer and internal byte tracking to the + last confirmed 'persisted_size' from the server. + """ + write_state: _WriteState = state["write_state"] + + redirect_proto = None + + if isinstance(error, BidiWriteObjectRedirectedError): + redirect_proto = error + else: + redirect_proto = _extract_bidi_writes_redirect_proto(error) + + # Extract routing token and potentially a new write handle for redirection. + if redirect_proto: + if redirect_proto.routing_token: + write_state.routing_token = redirect_proto.routing_token + if redirect_proto.write_handle: + write_state.write_handle = redirect_proto.write_handle + + # We must assume any data sent beyond 'persisted_size' was lost. + # Reset the user buffer to the last known good byte confirmed by the server. + write_state.user_buffer.seek(write_state.persisted_size) + write_state.bytes_sent = write_state.persisted_size + write_state.bytes_since_last_flush = 0 diff --git a/google/cloud/storage/blob.py b/google/cloud/storage/blob.py index 746334d1c..0b022985f 100644 --- a/google/cloud/storage/blob.py +++ b/google/cloud/storage/blob.py @@ -5034,6 +5034,19 @@ def hard_delete_time(self): if hard_delete_time is not None: return _rfc3339_nanos_to_datetime(hard_delete_time) + @property + def finalized_time(self): + """If this object has been soft-deleted, returns the time at which it will be permanently deleted. + + :rtype: :class:`datetime.datetime` or ``NoneType`` + :returns: + (readonly) The time that the object will be permanently deleted. + Note this property is only set for soft-deleted objects. + """ + finalize_time = self._properties.get("finalizedTime", None) + if finalize_time is not None: + return _rfc3339_nanos_to_datetime(finalize_time) + def _get_host_name(connection): """Returns the host name from the given connection. diff --git a/google/cloud/storage/client.py b/google/cloud/storage/client.py index afa0b3a4a..4a2c623e9 100644 --- a/google/cloud/storage/client.py +++ b/google/cloud/storage/client.py @@ -50,6 +50,7 @@ _marker = base_client.marker + def _buckets_page_start(iterator, page, response): """Grab unreachable buckets after a :class:`~google.cloud.iterator.Page` started.""" unreachable = response.get("unreachable", []) @@ -139,15 +140,16 @@ def __init__( client_options=client_options, use_auth_w_custom_endpoint=use_auth_w_custom_endpoint, extra_headers=extra_headers, - api_key=api_key + api_key=api_key, ) # Pass extra_headers to Connection - connection = Connection(self, **self.connection_kw_args) # connection_kw_args would always be set in base class + connection = Connection( + self, **self.connection_kw_args + ) # connection_kw_args would always be set in base class connection.extra_headers = extra_headers self._connection = connection - def get_service_account_email( self, project=None, timeout=_DEFAULT_TIMEOUT, retry=DEFAULT_RETRY ): diff --git a/google/cloud/storage/grpc_client.py b/google/cloud/storage/grpc_client.py new file mode 100644 index 000000000..7a739b7b7 --- /dev/null +++ b/google/cloud/storage/grpc_client.py @@ -0,0 +1,122 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A client for interacting with Google Cloud Storage using the gRPC API.""" + +from google.cloud.client import ClientWithProject +from google.cloud import _storage_v2 as storage_v2 + +_marker = object() + + +class GrpcClient(ClientWithProject): + """A client for interacting with Google Cloud Storage using the gRPC API. + + :type project: str or None + :param project: The project which the client acts on behalf of. If not + passed, falls back to the default inferred from the + environment. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The OAuth2 Credentials to use for this + client. If not passed, falls back to the default + inferred from the environment. + + :type client_info: :class:`~google.api_core.client_info.ClientInfo` + :param client_info: + The client info used to send a user-agent string along with API + requests. If ``None``, then default info will be used. Generally, + you only need to set this if you're developing your own library + or partner tool. + + :type client_options: :class:`~google.api_core.client_options.ClientOptions` or :class:`dict` + :param client_options: (Optional) Client options used to set user options + on the client. A non-default universe domain or API endpoint should be + set through client_options. + + :type api_key: string + :param api_key: + (Optional) An API key. Mutually exclusive with any other credentials. + This parameter is an alias for setting `client_options.api_key` and + will supersede any API key set in the `client_options` parameter. + + :type attempt_direct_path: bool + :param attempt_direct_path: + (Optional) Whether to attempt to use DirectPath for gRPC connections. + This provides a direct, unproxied connection to GCS for lower latency + and higher throughput, and is highly recommended when running on Google + Cloud infrastructure. Defaults to ``True``. + """ + + def __init__( + self, + project=_marker, + credentials=None, + client_info=None, + client_options=None, + *, + api_key=None, + attempt_direct_path=True, + ): + super(GrpcClient, self).__init__(project=project, credentials=credentials) + + if isinstance(client_options, dict): + if api_key: + client_options["api_key"] = api_key + elif client_options is None: + client_options = {} if not api_key else {"api_key": api_key} + elif api_key: + client_options.api_key = api_key + + self._grpc_client = self._create_gapic_client( + credentials=credentials, + client_info=client_info, + client_options=client_options, + attempt_direct_path=attempt_direct_path, + ) + + def _create_gapic_client( + self, + credentials=None, + client_info=None, + client_options=None, + attempt_direct_path=True, + ): + """Creates and configures the low-level GAPIC `storage_v2` client.""" + transport_cls = storage_v2.StorageClient.get_transport_class("grpc") + + channel = transport_cls.create_channel(attempt_direct_path=attempt_direct_path) + + transport = transport_cls(credentials=credentials, channel=channel) + + return storage_v2.StorageClient( + credentials=credentials, + transport=transport, + client_info=client_info, + client_options=client_options, + ) + + @property + def grpc_client(self): + """The underlying gRPC client. + + This property gives users direct access to the `storage_v2.StorageClient` + instance. This can be useful for accessing + newly added or experimental RPCs that are not yet exposed through + the high-level GrpcClient. + + Returns: + google.cloud.storage_v2.StorageClient: The configured GAPIC client. + """ + return self._grpc_client diff --git a/google/cloud/storage/version.py b/google/cloud/storage/version.py index dc87b3c5b..0bc275357 100644 --- a/google/cloud/storage/version.py +++ b/google/cloud/storage/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "3.7.0" +__version__ = "3.9.0" diff --git a/noxfile.py b/noxfile.py index 14dfb29d0..1cef2a75f 100644 --- a/noxfile.py +++ b/noxfile.py @@ -192,7 +192,14 @@ def system(session): # 2021-05-06: defer installing 'google-cloud-*' to after this package, # in order to work around Python 2.7 googolapis-common-protos # issue. - session.install("mock", "pytest", "pytest-rerunfailures", "-c", constraints_path) + session.install( + "mock", + "pytest", + "pytest-rerunfailures", + "pytest-asyncio", + "-c", + constraints_path, + ) session.install("-e", ".", "-c", constraints_path) session.install( "google-cloud-testutils", @@ -225,10 +232,22 @@ def conftest_retry(session): if not conformance_test_folder_exists: session.skip("Conformance tests were not found") + constraints_path = str( + CURRENT_DIRECTORY / "testing" / f"constraints-{session.python}.txt" + ) + # Install all test dependencies and pytest plugin to run tests in parallel. # Then install this package in-place. - session.install("pytest", "pytest-xdist") - session.install("-e", ".") + session.install( + "pytest", + "pytest-xdist", + "grpcio", + "grpcio-status", + "grpc-google-iam-v1", + "-c", + constraints_path, + ) + session.install("-e", ".", "-c", constraints_path) # Run #CPU processes in parallel if no test session arguments are passed in. if session.posargs: diff --git a/samples/README.md b/samples/README.md index 490af710a..118a778cb 100644 --- a/samples/README.md +++ b/samples/README.md @@ -34,8 +34,15 @@ for more detailed instructions. ``` source /bin/activate ``` +3. To run samples for [Zonal Buckets](https://github.com/googleapis/python-storage/tree/main/samples/snippets/zonal_buckets) -3. Install the dependencies needed to run the samples. + ``` + pip install "google-cloud-storage[grpc]" + python samples/snippets/zonal_buckets/storage_create_and_write_appendable_object.py --bucket_name --object_name + + ``` + +4. Install the dependencies needed to run the samples. ``` cd samples/snippets pip install -r requirements.txt diff --git a/samples/generated_samples/snippet_metadata_google.storage.v2.json b/samples/generated_samples/snippet_metadata_google.storage.v2.json index 4af7ef641..1889f0c5d 100644 --- a/samples/generated_samples/snippet_metadata_google.storage.v2.json +++ b/samples/generated_samples/snippet_metadata_google.storage.v2.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-storage", - "version": "3.6.0" + "version": "3.9.0" }, "snippets": [ { diff --git a/samples/snippets/notification_polling.py b/samples/snippets/notification_polling.py index 2ee6789c3..1359c9cfa 100644 --- a/samples/snippets/notification_polling.py +++ b/samples/snippets/notification_polling.py @@ -32,10 +32,10 @@ https://console.cloud.google.com/flows/enableapi?apiid=pubsub 3. Create a Google Cloud Storage bucket: - $ gsutil mb gs://testbucket + $ gcloud storage buckets create gs://testbucket 4. Create a Cloud Pub/Sub topic and publish bucket notifications there: - $ gsutil notification create -f json -t testtopic gs://testbucket + $ gcloud storage buckets notifications create gs://testbucket --topic=testtopic --payload-format=json 5. Create a subscription for your new topic: $ gcloud pubsub subscriptions create testsubscription --topic=testtopic diff --git a/samples/snippets/snippets_test.py b/samples/snippets/snippets_test.py index 0edba46ca..1d3c8c1c4 100644 --- a/samples/snippets/snippets_test.py +++ b/samples/snippets/snippets_test.py @@ -18,6 +18,7 @@ import tempfile import time import uuid +import sys from google.cloud import storage import google.cloud.exceptions @@ -99,8 +100,10 @@ import storage_upload_with_kms_key KMS_KEY = os.environ.get("CLOUD_KMS_KEY") +IS_PYTHON_3_14 = sys.version_info[:2] == (3, 14) +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_enable_default_kms_key(test_bucket): storage_set_bucket_default_kms_key.enable_default_kms_key( bucket_name=test_bucket.name, kms_key_name=KMS_KEY @@ -305,6 +308,7 @@ def test_upload_blob_from_stream(test_bucket, capsys): assert "Stream data uploaded to test_upload_blob" in out +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_upload_blob_with_kms(test_bucket): blob_name = f"test_upload_with_kms_{uuid.uuid4().hex}" with tempfile.NamedTemporaryFile() as source_file: @@ -598,6 +602,7 @@ def test_create_bucket_dual_region(test_bucket_create, capsys): assert "dual-region" in out +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_bucket_delete_default_kms_key(test_bucket, capsys): test_bucket.default_kms_key_name = KMS_KEY test_bucket.patch() @@ -646,6 +651,7 @@ def test_define_bucket_website_configuration(test_bucket): assert bucket._properties["website"] == website_val +@pytest.mark.skipif(IS_PYTHON_3_14, reason="b/470276398") def test_object_get_kms_key(test_bucket): with tempfile.NamedTemporaryFile() as source_file: storage_upload_with_kms_key.upload_blob_with_kms( diff --git a/samples/snippets/storage_get_metadata.py b/samples/snippets/storage_get_metadata.py index 7216efdb4..1e332b445 100644 --- a/samples/snippets/storage_get_metadata.py +++ b/samples/snippets/storage_get_metadata.py @@ -34,6 +34,7 @@ def blob_metadata(bucket_name, blob_name): blob = bucket.get_blob(blob_name) print(f"Blob: {blob.name}") + print(f"Blob finalization: {blob.finalized_time}") print(f"Bucket: {blob.bucket.name}") print(f"Storage class: {blob.storage_class}") print(f"ID: {blob.id}") diff --git a/samples/snippets/zonal_buckets/README.md b/samples/snippets/zonal_buckets/README.md new file mode 100644 index 000000000..71c17e5c3 --- /dev/null +++ b/samples/snippets/zonal_buckets/README.md @@ -0,0 +1,78 @@ +# Google Cloud Storage - Zonal Buckets Snippets + +This directory contains snippets for interacting with Google Cloud Storage zonal buckets. + +## Prerequisites + +- A Google Cloud Platform project with the Cloud Storage API enabled. +- A zonal Google Cloud Storage bucket. + +## Running the snippets + +### Create and write to an appendable object + +This snippet uploads an appendable object to a zonal bucket. + +```bash +python samples/snippets/zonal_buckets/storage_create_and_write_appendable_object.py --bucket_name --object_name +``` + +### Finalize an appendable object upload + +This snippet creates, writes to, and finalizes an appendable object. + +```bash +python samples/snippets/zonal_buckets/storage_finalize_appendable_object_upload.py --bucket_name --object_name +``` + +### Pause and resume an appendable object upload + +This snippet demonstrates pausing and resuming an appendable object upload. + +```bash +python samples/snippets/zonal_buckets/storage_pause_and_resume_appendable_upload.py --bucket_name --object_name +``` + +### Tail an appendable object + +This snippet demonstrates tailing an appendable GCS object, similar to `tail -f`. + +```bash +python samples/snippets/zonal_buckets/storage_read_appendable_object_tail.py --bucket_name --object_name --duration +``` + + +### Download a range of bytes from an object + +This snippet downloads a range of bytes from an object. + +```bash +python samples/snippets/zonal_buckets/storage_open_object_single_ranged_read.py --bucket_name --object_name --start_byte --size +``` + + +### Download multiple ranges of bytes from a single object + +This snippet downloads multiple ranges of bytes from a single object into different buffers. + +```bash +python samples/snippets/zonal_buckets/storage_open_object_multiple_ranged_read.py --bucket_name --object_name +``` + +### Download the entire content of an object + +This snippet downloads the entire content of an object using a multi-range downloader. + +```bash +python samples/snippets/zonal_buckets/storage_open_object_read_full_object.py --bucket_name --object_name +``` + + + +### Download a range of bytes from multiple objects concurrently + +This snippet downloads a range of bytes from multiple objects concurrently. + +```bash +python samples/snippets/zonal_buckets/storage_open_multiple_objects_ranged_read.py --bucket_name --object_names +``` \ No newline at end of file diff --git a/samples/snippets/zonal_buckets/storage_create_and_write_appendable_object.py b/samples/snippets/zonal_buckets/storage_create_and_write_appendable_object.py new file mode 100644 index 000000000..725eeb2bd --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_create_and_write_appendable_object.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio + +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient + + +# [START storage_create_and_write_appendable_object] + + +async def storage_create_and_write_appendable_object( + bucket_name, object_name, grpc_client=None +): + """Uploads an appendable object to zonal bucket. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + + if grpc_client is None: + grpc_client = AsyncGrpcClient() + writer = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=bucket_name, + object_name=object_name, + generation=0, # throws `FailedPrecondition` if object already exists. + ) + # This creates a new appendable object of size 0 and opens it for appending. + await writer.open() + + # appends data to the object + # you can perform `.append` multiple times as needed. Data will be appended + # to the end of the object. + await writer.append(b"Some data") + + # Once all appends are done, close the gRPC bidirectional stream. + await writer.close() + + print( + f"Appended object {object_name} created of size {writer.persisted_size} bytes." + ) + + +# [END storage_create_and_write_appendable_object] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + + args = parser.parse_args() + + asyncio.run( + storage_create_and_write_appendable_object( + bucket_name=args.bucket_name, + object_name=args.object_name, + ) + ) diff --git a/samples/snippets/zonal_buckets/storage_finalize_appendable_object_upload.py b/samples/snippets/zonal_buckets/storage_finalize_appendable_object_upload.py new file mode 100644 index 000000000..807fe40a5 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_finalize_appendable_object_upload.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio + +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient + + +# [START storage_finalize_appendable_object_upload] +async def storage_finalize_appendable_object_upload( + bucket_name, object_name, grpc_client=None +): + """Creates, writes to, and finalizes an appendable object. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + + if grpc_client is None: + grpc_client = AsyncGrpcClient() + writer = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=bucket_name, + object_name=object_name, + generation=0, # throws `FailedPrecondition` if object already exists. + ) + # This creates a new appendable object of size 0 and opens it for appending. + await writer.open() + + # Appends data to the object. + await writer.append(b"Some data") + + # finalize the appendable object, + # NOTE: + # 1. once finalized no more appends can be done to the object. + # 2. If you don't want to finalize, you can simply call `writer.close` + # 3. calling `.finalize()` also closes the grpc-bidi stream, calling + # `.close` after `.finalize` may lead to undefined behavior. + object_resource = await writer.finalize() + + print(f"Appendable object {object_name} created and finalized.") + print("Object Metadata:") + print(object_resource) + + +# [END storage_finalize_appendable_object_upload] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + + args = parser.parse_args() + + asyncio.run( + storage_finalize_appendable_object_upload( + bucket_name=args.bucket_name, + object_name=args.object_name, + ) + ) diff --git a/samples/snippets/zonal_buckets/storage_open_multiple_objects_ranged_read.py b/samples/snippets/zonal_buckets/storage_open_multiple_objects_ranged_read.py new file mode 100644 index 000000000..bed580d36 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_open_multiple_objects_ranged_read.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Downloads a range of bytes from multiple objects concurrently. +Example usage: + ```python samples/snippets/zonal_buckets/storage_open_multiple_objects_ranged_read.py \ + --bucket_name \ + --object_names ``` +""" +import argparse +import asyncio +from io import BytesIO + +from google.cloud.storage.asyncio.async_grpc_client import ( + AsyncGrpcClient, +) +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + + +# [START storage_open_multiple_objects_ranged_read] +async def storage_open_multiple_objects_ranged_read( + bucket_name, object_names, grpc_client=None +): + """Downloads a range of bytes from multiple objects concurrently. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + + async def _download_range(object_name): + """Helper coroutine to download a range from a single object.""" + mrd = AsyncMultiRangeDownloader(grpc_client, bucket_name, object_name) + try: + # Open the object, mrd always opens in read mode. + await mrd.open() + + # Each object downloads the first 100 bytes. + start_byte = 0 + size = 100 + + # requested range will be downloaded into this buffer, user may provide + # their own buffer or file-like object. + output_buffer = BytesIO() + await mrd.download_ranges([(start_byte, size, output_buffer)]) + finally: + if mrd.is_stream_open: + await mrd.close() + + # Downloaded size can differ from requested size if object is smaller. + # mrd will download at most up to the end of the object. + downloaded_size = output_buffer.getbuffer().nbytes + print(f"Downloaded {downloaded_size} bytes from {object_name}") + + download_tasks = [_download_range(name) for name in object_names] + await asyncio.gather(*download_tasks) + + +# [END storage_open_multiple_objects_ranged_read] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument( + "--object_names", nargs="+", help="Your Cloud Storage object name(s)." + ) + + args = parser.parse_args() + + asyncio.run( + storage_open_multiple_objects_ranged_read(args.bucket_name, args.object_names) + ) diff --git a/samples/snippets/zonal_buckets/storage_open_object_multiple_ranged_read.py b/samples/snippets/zonal_buckets/storage_open_object_multiple_ranged_read.py new file mode 100644 index 000000000..b0f64c486 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_open_object_multiple_ranged_read.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +from io import BytesIO + +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + + +# [START storage_open_object_multiple_ranged_read] +async def storage_open_object_multiple_ranged_read( + bucket_name, object_name, grpc_client=None +): + """Downloads multiple ranges of bytes from a single object into different buffers. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + + mrd = AsyncMultiRangeDownloader(grpc_client, bucket_name, object_name) + + try: + # Open the object, mrd always opens in read mode. + await mrd.open() + + # Specify four different buffers to download ranges into. + buffers = [BytesIO(), BytesIO(), BytesIO(), BytesIO()] + + # Define the ranges to download. Each range is a tuple of (start_byte, size, buffer). + # All ranges will download 10 bytes from different starting positions. + # We choose arbitrary start bytes for this example. An object should be large enough. + # A user can choose any start byte between 0 and `object_size`. + # If `start_bytes` is greater than `object_size`, mrd will throw an error. + ranges = [ + (0, 10, buffers[0]), + (20, 10, buffers[1]), + (40, 10, buffers[2]), + (60, 10, buffers[3]), + ] + + await mrd.download_ranges(ranges) + + finally: + await mrd.close() + + # Print the downloaded content from each buffer. + for i, output_buffer in enumerate(buffers): + downloaded_size = output_buffer.getbuffer().nbytes + print( + f"Downloaded {downloaded_size} bytes into buffer {i + 1} from start byte {ranges[i][0]}: {output_buffer.getvalue()}" + ) + + +# [END storage_open_object_multiple_ranged_read] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + + args = parser.parse_args() + + asyncio.run( + storage_open_object_multiple_ranged_read(args.bucket_name, args.object_name) + ) diff --git a/samples/snippets/zonal_buckets/storage_open_object_read_full_object.py b/samples/snippets/zonal_buckets/storage_open_object_read_full_object.py new file mode 100644 index 000000000..2e18caabe --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_open_object_read_full_object.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +from io import BytesIO + +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + + +# [START storage_open_object_read_full_object] +async def storage_open_object_read_full_object( + bucket_name, object_name, grpc_client=None +): + """Downloads the entire content of an object using a multi-range downloader. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + + # mrd = Multi-Range-Downloader + mrd = AsyncMultiRangeDownloader(grpc_client, bucket_name, object_name) + + try: + # Open the object, mrd always opens in read mode. + await mrd.open() + + # This could be any buffer or file-like object. + output_buffer = BytesIO() + # A download range of (0, 0) means to read from the beginning to the end. + await mrd.download_ranges([(0, 0, output_buffer)]) + finally: + if mrd.is_stream_open: + await mrd.close() + + downloaded_bytes = output_buffer.getvalue() + print( + f"Downloaded all {len(downloaded_bytes)} bytes from object {object_name} in bucket {bucket_name}." + ) + + +# [END storage_open_object_read_full_object] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + + args = parser.parse_args() + + asyncio.run( + storage_open_object_read_full_object(args.bucket_name, args.object_name) + ) diff --git a/samples/snippets/zonal_buckets/storage_open_object_single_ranged_read.py b/samples/snippets/zonal_buckets/storage_open_object_single_ranged_read.py new file mode 100644 index 000000000..74bec43f6 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_open_object_single_ranged_read.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +from io import BytesIO + +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + + +# [START storage_open_object_single_ranged_read] +async def storage_open_object_single_ranged_read( + bucket_name, object_name, start_byte, size, grpc_client=None +): + """Downloads a range of bytes from an object. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + + mrd = AsyncMultiRangeDownloader(grpc_client, bucket_name, object_name) + + try: + # Open the object, mrd always opens in read mode. + await mrd.open() + + # requested range will be downloaded into this buffer, user may provide + # their own buffer or file-like object. + output_buffer = BytesIO() + await mrd.download_ranges([(start_byte, size, output_buffer)]) + finally: + if mrd.is_stream_open: + await mrd.close() + + # Downloaded size can differ from requested size if object is smaller. + # mrd will download at most up to the end of the object. + downloaded_size = output_buffer.getbuffer().nbytes + print(f"Downloaded {downloaded_size} bytes from {object_name}") + + +# [END storage_open_object_single_ranged_read] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + parser.add_argument( + "--start_byte", type=int, help="The starting byte of the range." + ) + parser.add_argument("--size", type=int, help="The number of bytes to download.") + + args = parser.parse_args() + + asyncio.run( + storage_open_object_single_ranged_read( + args.bucket_name, args.object_name, args.start_byte, args.size + ) + ) diff --git a/samples/snippets/zonal_buckets/storage_pause_and_resume_appendable_upload.py b/samples/snippets/zonal_buckets/storage_pause_and_resume_appendable_upload.py new file mode 100644 index 000000000..c758dc641 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_pause_and_resume_appendable_upload.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio + +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient + + +# [START storage_pause_and_resume_appendable_upload] +async def storage_pause_and_resume_appendable_upload( + bucket_name, object_name, grpc_client=None +): + """Demonstrates pausing and resuming an appendable object upload. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + + writer1 = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=bucket_name, + object_name=object_name, + ) + await writer1.open() + await writer1.append(b"First part of the data. ") + print(f"Appended {writer1.persisted_size} bytes with the first writer.") + + # 2. After appending some data, close the writer to "pause" the upload. + # NOTE: you can pause indefinitely and still read the conetent uploaded so far using MRD. + await writer1.close() + + print("First writer closed. Upload is 'paused'.") + + # 3. Create a new writer, passing the generation number from the previous + # writer. This is a precondition to ensure that the object hasn't been + # modified since we last accessed it. + generation_to_resume = writer1.generation + print(f"Generation to resume from is: {generation_to_resume}") + + writer2 = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=bucket_name, + object_name=object_name, + generation=generation_to_resume, + ) + # 4. Open the new writer. + try: + await writer2.open() + + # 5. Append some more data using the new writer. + await writer2.append(b"Second part of the data.") + print(f"Appended more data. Total size is now {writer2.persisted_size} bytes.") + finally: + # 6. Finally, close the new writer. + if writer2._is_stream_open: + await writer2.close() + print("Second writer closed. Full object uploaded.") + + +# [END storage_pause_and_resume_appendable_upload] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument("--object_name", help="Your Cloud Storage object name.") + + args = parser.parse_args() + + asyncio.run( + storage_pause_and_resume_appendable_upload( + bucket_name=args.bucket_name, + object_name=args.object_name, + ) + ) diff --git a/samples/snippets/zonal_buckets/storage_read_appendable_object_tail.py b/samples/snippets/zonal_buckets/storage_read_appendable_object_tail.py new file mode 100644 index 000000000..9e4dcd738 --- /dev/null +++ b/samples/snippets/zonal_buckets/storage_read_appendable_object_tail.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2026 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import time +from datetime import datetime +from io import BytesIO + +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + +BYTES_TO_APPEND = b"fav_bytes." +NUM_BYTES_TO_APPEND_EVERY_SECOND = len(BYTES_TO_APPEND) + + +# [START storage_read_appendable_object_tail] +async def appender(writer: AsyncAppendableObjectWriter, duration: int): + """Appends 10 bytes to the object every second for a given duration.""" + print("Appender started.") + bytes_appended = 0 + for i in range(duration): + await writer.append(BYTES_TO_APPEND) + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + bytes_appended += NUM_BYTES_TO_APPEND_EVERY_SECOND + print( + f"[{now}] Appended {NUM_BYTES_TO_APPEND_EVERY_SECOND} new bytes. Total appended: {bytes_appended} bytes." + ) + await asyncio.sleep(1) + print("Appender finished.") + + +async def tailer( + bucket_name: str, object_name: str, duration: int, client: AsyncGrpcClient +): + """Tails the object by reading new data as it is appended.""" + print("Tailer started.") + start_byte = 0 + start_time = time.monotonic() + mrd = AsyncMultiRangeDownloader(client, bucket_name, object_name) + try: + await mrd.open() + # Run the tailer for the specified duration. + while time.monotonic() - start_time < duration: + output_buffer = BytesIO() + # A download range of (start, 0) means to read from 'start' to the end. + await mrd.download_ranges([(start_byte, 0, output_buffer)]) + + bytes_downloaded = output_buffer.getbuffer().nbytes + if bytes_downloaded > 0: + now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + print( + f"[{now}] Tailer read {bytes_downloaded} new bytes: {output_buffer.getvalue()}" + ) + start_byte += bytes_downloaded + + await asyncio.sleep(0.1) # Poll for new data every 0.1 seconds. + finally: + if mrd.is_stream_open: + await mrd.close() + print("Tailer finished.") + + +# read_appendable_object_tail simulates a "tail -f" command on a GCS object. It +# repeatedly polls an appendable object for new content. In a real +# application, the object would be written to by a separate process. +async def read_appendable_object_tail( + bucket_name: str, object_name: str, duration: int, grpc_client=None +): + """Main function to create an appendable object and run tasks. + + grpc_client: an existing grpc_client to use, this is only for testing. + """ + if grpc_client is None: + grpc_client = AsyncGrpcClient() + writer = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=bucket_name, + object_name=object_name, + ) + # 1. Create an empty appendable object. + try: + # 1. Create an empty appendable object. + await writer.open() + print(f"Created empty appendable object: {object_name}") + + # 2. Create the appender and tailer coroutines. + appender_task = asyncio.create_task(appender(writer, duration)) + tailer_task = asyncio.create_task( + tailer(bucket_name, object_name, duration, grpc_client) + ) + + # 3. Execute the coroutines concurrently. + await asyncio.gather(appender_task, tailer_task) + finally: + if writer._is_stream_open: + await writer.close() + print("Writer closed.") + + +# [END storage_read_appendable_object_tail] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Demonstrates tailing an appendable GCS object.", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--bucket_name", help="Your Cloud Storage bucket name.") + parser.add_argument( + "--object_name", help="Your Cloud Storage object name to be created." + ) + parser.add_argument( + "--duration", + type=int, + default=60, + help="Duration in seconds to run the demo.", + ) + + args = parser.parse_args() + + asyncio.run( + read_appendable_object_tail(args.bucket_name, args.object_name, args.duration) + ) diff --git a/samples/snippets/zonal_buckets/zonal_snippets_test.py b/samples/snippets/zonal_buckets/zonal_snippets_test.py new file mode 100644 index 000000000..6852efe22 --- /dev/null +++ b/samples/snippets/zonal_buckets/zonal_snippets_test.py @@ -0,0 +1,260 @@ +# Copyright 2025 Google, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import uuid +import os + +import pytest +from google.cloud.storage import Client +import contextlib + +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) + +# Import all the snippets +import storage_create_and_write_appendable_object +import storage_finalize_appendable_object_upload +import storage_open_multiple_objects_ranged_read +import storage_open_object_multiple_ranged_read +import storage_open_object_read_full_object +import storage_open_object_single_ranged_read +import storage_pause_and_resume_appendable_upload +import storage_read_appendable_object_tail + +pytestmark = pytest.mark.skipif( + os.getenv("RUN_ZONAL_SYSTEM_TESTS") != "True", + reason="Zonal system tests need to be explicitly enabled. This helps scheduling tests in Kokoro and Cloud Build.", +) + + +# TODO: replace this with a fixture once zonal bucket creation / deletion +# is supported in grpc client or json client client. +_ZONAL_BUCKET = os.getenv("ZONAL_BUCKET") + + +async def create_async_grpc_client(): + """Initializes async client and gets the current event loop.""" + return AsyncGrpcClient() + + +# Forcing a single event loop for the whole test session +@pytest.fixture(scope="session") +def event_loop(): + """Redefine pytest-asyncio's event_loop fixture to be session-scoped.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def async_grpc_client(event_loop): + """Yields a StorageAsyncClient that is closed after the test session.""" + grpc_client = event_loop.run_until_complete(create_async_grpc_client()) + yield grpc_client + + +@pytest.fixture(scope="session") +def json_client(): + client = Client() + with contextlib.closing(client): + yield client + + +async def create_appendable_object(grpc_client, object_name, data): + writer = AsyncAppendableObjectWriter( + client=grpc_client, + bucket_name=_ZONAL_BUCKET, + object_name=object_name, + generation=0, # throws `FailedPrecondition` if object already exists. + ) + await writer.open() + await writer.append(data) + await writer.close() + return writer.generation + + +# TODO: replace this with a fixture once zonal bucket creation / deletion +# is supported in grpc client or json client client. +_ZONAL_BUCKET = os.getenv("ZONAL_BUCKET") + + +def test_storage_create_and_write_appendable_object( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"zonal-snippets-test-{uuid.uuid4()}" + + event_loop.run_until_complete( + storage_create_and_write_appendable_object.storage_create_and_write_appendable_object( + _ZONAL_BUCKET, object_name, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert f"Appended object {object_name} created of size" in out + + blob = json_client.bucket(_ZONAL_BUCKET).blob(object_name) + blob.delete() + + +def test_storage_finalize_appendable_object_upload( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-finalize-appendable-{uuid.uuid4()}" + event_loop.run_until_complete( + storage_finalize_appendable_object_upload.storage_finalize_appendable_object_upload( + _ZONAL_BUCKET, object_name, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert f"Appendable object {object_name} created and finalized." in out + blob = json_client.bucket(_ZONAL_BUCKET).get_blob(object_name) + blob.delete() + + +def test_storage_pause_and_resume_appendable_upload( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-pause-resume-{uuid.uuid4()}" + event_loop.run_until_complete( + storage_pause_and_resume_appendable_upload.storage_pause_and_resume_appendable_upload( + _ZONAL_BUCKET, object_name, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert "First writer closed. Upload is 'paused'." in out + assert "Second writer closed. Full object uploaded." in out + + blob = json_client.bucket(_ZONAL_BUCKET).get_blob(object_name) + blob.delete() + + +def test_storage_read_appendable_object_tail( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-read-tail-{uuid.uuid4()}" + event_loop.run_until_complete( + storage_read_appendable_object_tail.read_appendable_object_tail( + _ZONAL_BUCKET, object_name, duration=3, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert f"Created empty appendable object: {object_name}" in out + assert "Appender started." in out + assert "Tailer started." in out + assert "Tailer read" in out + assert "Tailer finished." in out + assert "Writer closed." in out + + bucket = json_client.bucket(_ZONAL_BUCKET) + blob = bucket.blob(object_name) + blob.delete() + + +def test_storage_open_object_read_full_object( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-open-read-full-{uuid.uuid4()}" + data = b"Hello, is it me you're looking for?" + event_loop.run_until_complete( + create_appendable_object(async_grpc_client, object_name, data) + ) + event_loop.run_until_complete( + storage_open_object_read_full_object.storage_open_object_read_full_object( + _ZONAL_BUCKET, object_name, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert ( + f"Downloaded all {len(data)} bytes from object {object_name} in bucket {_ZONAL_BUCKET}." + in out + ) + blob = json_client.bucket(_ZONAL_BUCKET).blob(object_name) + blob.delete() + + +def test_storage_open_object_single_ranged_read( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-open-single-range-{uuid.uuid4()}" + event_loop.run_until_complete( + create_appendable_object( + async_grpc_client, object_name, b"Hello, is it me you're looking for?" + ) + ) + download_size = 5 + event_loop.run_until_complete( + storage_open_object_single_ranged_read.storage_open_object_single_ranged_read( + _ZONAL_BUCKET, + object_name, + start_byte=0, + size=download_size, + grpc_client=async_grpc_client, + ) + ) + out, _ = capsys.readouterr() + assert f"Downloaded {download_size} bytes from {object_name}" in out + blob = json_client.bucket(_ZONAL_BUCKET).blob(object_name) + blob.delete() + + +def test_storage_open_object_multiple_ranged_read( + async_grpc_client, json_client, event_loop, capsys +): + object_name = f"test-open-multi-range-{uuid.uuid4()}" + data = b"a" * 100 + event_loop.run_until_complete( + create_appendable_object(async_grpc_client, object_name, data) + ) + event_loop.run_until_complete( + storage_open_object_multiple_ranged_read.storage_open_object_multiple_ranged_read( + _ZONAL_BUCKET, object_name, grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert "Downloaded 10 bytes into buffer 1 from start byte 0: b'aaaaaaaaaa'" in out + assert "Downloaded 10 bytes into buffer 2 from start byte 20: b'aaaaaaaaaa'" in out + assert "Downloaded 10 bytes into buffer 3 from start byte 40: b'aaaaaaaaaa'" in out + assert "Downloaded 10 bytes into buffer 4 from start byte 60: b'aaaaaaaaaa'" in out + blob = json_client.bucket(_ZONAL_BUCKET).blob(object_name) + blob.delete() + + +def test_storage_open_multiple_objects_ranged_read( + async_grpc_client, json_client, event_loop, capsys +): + blob1_name = f"multi-obj-1-{uuid.uuid4()}" + blob2_name = f"multi-obj-2-{uuid.uuid4()}" + data1 = b"Content of object 1" + data2 = b"Content of object 2" + event_loop.run_until_complete( + create_appendable_object(async_grpc_client, blob1_name, data1) + ) + event_loop.run_until_complete( + create_appendable_object(async_grpc_client, blob2_name, data2) + ) + + event_loop.run_until_complete( + storage_open_multiple_objects_ranged_read.storage_open_multiple_objects_ranged_read( + _ZONAL_BUCKET, [blob1_name, blob2_name], grpc_client=async_grpc_client + ) + ) + out, _ = capsys.readouterr() + assert f"Downloaded {len(data1)} bytes from {blob1_name}" in out + assert f"Downloaded {len(data2)} bytes from {blob2_name}" in out + blob1 = json_client.bucket(_ZONAL_BUCKET).blob(blob1_name) + blob2 = json_client.bucket(_ZONAL_BUCKET).blob(blob2_name) + blob1.delete() + blob2.delete() diff --git a/setup.py b/setup.py index b45053856..02cd11140 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,30 @@ "tracing": [ "opentelemetry-api >= 1.1.0, < 2.0.0", ], + "testing": [ + "google-cloud-testutils", + "numpy", + "psutil", + "py-cpuinfo", + "pytest-benchmark", + "PyYAML", + "mock", + "pytest", + "pytest-cov", + "pytest-asyncio", + "pytest-rerunfailures", + "pytest-xdist", + "google-cloud-testutils", + "google-cloud-iam", + "google-cloud-pubsub", + "google-cloud-kms", + "brotli", + "coverage", + "pyopenssl", + "opentelemetry-sdk", + "flake8", + "black", + ], } diff --git a/tests/conformance/test_bidi_reads.py b/tests/conformance/test_bidi_reads.py new file mode 100644 index 000000000..4157182cb --- /dev/null +++ b/tests/conformance/test_bidi_reads.py @@ -0,0 +1,266 @@ +import asyncio +import io +import uuid +import grpc +import requests + +from google.api_core import exceptions +from google.auth import credentials as auth_credentials +from google.cloud import _storage_v2 as storage_v2 + +from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) + +# --- Configuration --- +PROJECT_NUMBER = "12345" # A dummy project number is fine for the testbench. +GRPC_ENDPOINT = "localhost:8888" +HTTP_ENDPOINT = "http://localhost:9000" +CONTENT_LENGTH = 1024 * 10 # 10 KB + + +def _is_retriable(exc): + """Predicate for identifying retriable errors.""" + return isinstance( + exc, + ( + exceptions.ServiceUnavailable, + exceptions.Aborted, # Required to retry on redirect + exceptions.InternalServerError, + exceptions.ResourceExhausted, + ), + ) + + +async def run_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario +): + """Runs a single fault-injection test scenario.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + + # 2. Set up downloader and metadata for fault injection. + downloader = await AsyncMultiRangeDownloader.create_mrd( + gapic_client, bucket_name, object_name + ) + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + buffer = io.BytesIO() + + # 3. Execute the download and assert the outcome. + try: + await downloader.download_ranges( + [(0, 5 * 1024, buffer), (6 * 1024, 4 * 1024, buffer)], + metadata=fault_injection_metadata, + ) + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + assert len(buffer.getvalue()) == 9 * 1024 + + except scenario["expected_error"] as e: + print(f"Caught expected exception for {scenario['name']}: {e}") + + await downloader.close() + + finally: + # 4. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +async def main(): + """Main function to set up resources and run all test scenarios.""" + channel = grpc.aio.insecure_channel(GRPC_ENDPOINT) + creds = auth_credentials.AnonymousCredentials() + transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport( + channel=channel, credentials=creds + ) + gapic_client = storage_v2.StorageAsyncClient(transport=transport) + http_client = requests.Session() + + bucket_name = f"grpc-test-bucket-{uuid.uuid4().hex[:8]}" + object_name = "retry-test-object" + + # Define all test scenarios + test_scenarios = [ + { + "name": "Retry on Service Unavailable (503)", + "method": "storage.objects.get", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Retry on 500", + "method": "storage.objects.get", + "instruction": "return-500", + "expected_error": None, + }, + { + "name": "Retry on 504", + "method": "storage.objects.get", + "instruction": "return-504", + "expected_error": None, + }, + { + "name": "Retry on 429", + "method": "storage.objects.get", + "instruction": "return-429", + "expected_error": None, + }, + { + "name": "Smarter Resumption: Retry 503 after partial data", + "method": "storage.objects.get", + "instruction": "return-broken-stream-after-2K", + "expected_error": None, + }, + { + "name": "Retry on BidiReadObjectRedirectedError", + "method": "storage.objects.get", + "instruction": "redirect-send-handle-and-token-tokenval", # Testbench instruction for redirect + "expected_error": None, + }, + ] + + try: + # Create a single bucket and object for all tests to use. + content = b"A" * CONTENT_LENGTH + bucket_resource = storage_v2.Bucket(project=f"projects/{PROJECT_NUMBER}") + create_bucket_request = storage_v2.CreateBucketRequest( + parent="projects/_", bucket_id=bucket_name, bucket=bucket_resource + ) + await gapic_client.create_bucket(request=create_bucket_request) + + write_spec = storage_v2.WriteObjectSpec( + resource=storage_v2.Object( + bucket=f"projects/_/buckets/{bucket_name}", name=object_name + ) + ) + + async def write_req_gen(): + yield storage_v2.WriteObjectRequest( + write_object_spec=write_spec, + checksummed_data={"content": content}, + finish_write=True, + ) + + await gapic_client.write_object(requests=write_req_gen()) + + # Run all defined test scenarios. + for scenario in test_scenarios: + await run_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario + ) + + # Define and run test scenarios specifically for the open() method + open_test_scenarios = [ + { + "name": "Open: Retry on 503", + "method": "storage.objects.get", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Open: Retry on BidiReadObjectRedirectedError", + "method": "storage.objects.get", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + }, + { + "name": "Open: Fail Fast on 401", + "method": "storage.objects.get", + "instruction": "return-401", + "expected_error": exceptions.Unauthorized, + }, + ] + for scenario in open_test_scenarios: + await run_open_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario + ) + + except Exception: + import traceback + + traceback.print_exc() + finally: + # Clean up the test bucket. + try: + delete_object_req = storage_v2.DeleteObjectRequest( + bucket="projects/_/buckets/" + bucket_name, object=object_name + ) + await gapic_client.delete_object(request=delete_object_req) + + delete_bucket_req = storage_v2.DeleteBucketRequest( + name=f"projects/_/buckets/{bucket_name}" + ) + await gapic_client.delete_bucket(request=delete_bucket_req) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + +async def run_open_test_scenario( + gapic_client, http_client, bucket_name, object_name, scenario +): + """Runs a fault-injection test scenario specifically for the open() method.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + print(f"Retry Test created with ID: {retry_test_id}") + + # 2. Set up metadata for fault injection. + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + # 3. Execute the open (via create_mrd) and assert the outcome. + try: + downloader = await AsyncMultiRangeDownloader.create_mrd( + gapic_client, + bucket_name, + object_name, + metadata=fault_injection_metadata, + ) + + # If open was successful, perform a simple download to ensure the stream is usable. + buffer = io.BytesIO() + await downloader.download_ranges([(0, 1024, buffer)]) + await downloader.close() + assert len(buffer.getvalue()) == 1024 + + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + except scenario["expected_error"] as e: + print(f"Caught expected exception for {scenario['name']}: {e}") + + finally: + # 4. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/conformance/test_bidi_writes.py b/tests/conformance/test_bidi_writes.py new file mode 100644 index 000000000..90dfaf5f8 --- /dev/null +++ b/tests/conformance/test_bidi_writes.py @@ -0,0 +1,267 @@ +import asyncio +import uuid +import grpc +import requests + +from google.api_core import exceptions +from google.auth import credentials as auth_credentials +from google.cloud import _storage_v2 as storage_v2 + +from google.api_core.retry_async import AsyncRetry +from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) + +# --- Configuration --- +PROJECT_NUMBER = "12345" # A dummy project number is fine for the testbench. +GRPC_ENDPOINT = "localhost:8888" +HTTP_ENDPOINT = "http://localhost:9000" +CONTENT = b"A" * 1024 * 10 # 10 KB + + +def _is_retryable(exc): + return isinstance( + exc, + ( + exceptions.InternalServerError, + exceptions.ServiceUnavailable, + exceptions.DeadlineExceeded, + exceptions.TooManyRequests, + exceptions.Aborted, # For Redirects + ), + ) + + +async def run_test_scenario( + gapic_client, + http_client, + bucket_name, + object_name, + scenario, +): + """Runs a single fault-injection test scenario.""" + print(f"\n--- RUNNING SCENARIO: {scenario['name']} ---") + retry_count = 0 + + def on_retry_error(exc): + nonlocal retry_count + retry_count += 1 + print(f"Retry attempt {retry_count} triggered by: {type(exc).__name__}") + + custom_retry = AsyncRetry( + predicate=_is_retryable, + on_error=on_retry_error, + initial=0.1, # Short backoff for fast tests + multiplier=1.0, + ) + + use_default = scenario.get("use_default_policy", False) + policy_to_pass = None if use_default else custom_retry + + retry_test_id = None + try: + # 1. Create a Retry Test resource on the testbench. + retry_test_config = { + "instructions": {scenario["method"]: [scenario["instruction"]]}, + "transport": "GRPC", + } + resp = http_client.post(f"{HTTP_ENDPOINT}/retry_test", json=retry_test_config) + resp.raise_for_status() + retry_test_id = resp.json()["id"] + + # 2. Set up writer and metadata for fault injection. + writer = AsyncAppendableObjectWriter( + gapic_client, + bucket_name, + object_name, + ) + fault_injection_metadata = (("x-retry-test-id", retry_test_id),) + + # 3. Execute the write and assert the outcome. + try: + await writer.open( + metadata=fault_injection_metadata, retry_policy=policy_to_pass + ) + await writer.append( + CONTENT, metadata=fault_injection_metadata, retry_policy=policy_to_pass + ) + # await writer.finalize() + await writer.close(finalize_on_close=True) + + # If an exception was expected, this line should not be reached. + if scenario["expected_error"] is not None: + raise AssertionError( + f"Expected exception {scenario['expected_error']} was not raised." + ) + + # 4. Verify the object content. + read_request = storage_v2.ReadObjectRequest( + bucket=f"projects/_/buckets/{bucket_name}", + object=object_name, + ) + read_stream = await gapic_client.read_object(request=read_request) + data = b"" + async for chunk in read_stream: + data += chunk.checksummed_data.content + assert data == CONTENT + if scenario["expected_error"] is None: + # Scenarios like 503, 500, smarter resumption, and redirects + # SHOULD trigger at least one retry attempt. + if not use_default: + assert ( + retry_count > 0 + ), f"Test passed but no retry was actually triggered for {scenario['name']}!" + else: + print("Successfully recovered using library's default policy.") + print(f"Success: {scenario['name']}") + + except Exception as e: + if scenario["expected_error"] is None or not isinstance( + e, scenario["expected_error"] + ): + raise + + if not use_default: + assert ( + retry_count == 0 + ), f"Retry was incorrectly triggered for non-retriable error in {scenario['name']}!" + print(f"Success: caught expected exception for {scenario['name']}: {e}") + + finally: + # 5. Clean up the Retry Test resource. + if retry_test_id: + http_client.delete(f"{HTTP_ENDPOINT}/retry_test/{retry_test_id}") + + +async def main(): + """Main function to set up resources and run all test scenarios.""" + channel = grpc.aio.insecure_channel(GRPC_ENDPOINT) + creds = auth_credentials.AnonymousCredentials() + transport = storage_v2.services.storage.transports.StorageGrpcAsyncIOTransport( + channel=channel, + credentials=creds, + ) + gapic_client = storage_v2.StorageAsyncClient(transport=transport) + http_client = requests.Session() + + bucket_name = f"grpc-test-bucket-{uuid.uuid4().hex[:8]}" + object_name_prefix = "retry-test-object-" + + # Define all test scenarios + test_scenarios = [ + { + "name": "Retry on Service Unavailable (503)", + "method": "storage.objects.insert", + "instruction": "return-503", + "expected_error": None, + }, + { + "name": "Retry on 500", + "method": "storage.objects.insert", + "instruction": "return-500", + "expected_error": None, + }, + { + "name": "Retry on 504", + "method": "storage.objects.insert", + "instruction": "return-504", + "expected_error": None, + }, + { + "name": "Retry on 429", + "method": "storage.objects.insert", + "instruction": "return-429", + "expected_error": None, + }, + { + "name": "Smarter Resumption: Retry 503 after partial data", + "method": "storage.objects.insert", + "instruction": "return-503-after-2K", + "expected_error": None, + }, + { + "name": "Retry on BidiWriteObjectRedirectedError", + "method": "storage.objects.insert", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + }, + { + "name": "Fail on 401", + "method": "storage.objects.insert", + "instruction": "return-401", + "expected_error": exceptions.Unauthorized, + }, + { + "name": "Default Policy: Retry on 503", + "method": "storage.objects.insert", + "instruction": "return-503", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Retry on 503", + "method": "storage.objects.insert", + "instruction": "return-500", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Retry on BidiWriteObjectRedirectedError", + "method": "storage.objects.insert", + "instruction": "redirect-send-handle-and-token-tokenval", + "expected_error": None, + "use_default_policy": True, + }, + { + "name": "Default Policy: Smarter Ressumption", + "method": "storage.objects.insert", + "instruction": "return-503-after-2K", + "expected_error": None, + "use_default_policy": True, + }, + ] + + try: + bucket_resource = storage_v2.Bucket(project=f"projects/{PROJECT_NUMBER}") + create_bucket_request = storage_v2.CreateBucketRequest( + parent="projects/_", bucket_id=bucket_name, bucket=bucket_resource + ) + await gapic_client.create_bucket(request=create_bucket_request) + + for i, scenario in enumerate(test_scenarios): + object_name = f"{object_name_prefix}{i}" + await run_test_scenario( + gapic_client, + http_client, + bucket_name, + object_name, + scenario, + ) + + except Exception: + import traceback + + traceback.print_exc() + finally: + # Clean up the test bucket. + try: + list_objects_req = storage_v2.ListObjectsRequest( + parent=f"projects/_/buckets/{bucket_name}", + ) + list_objects_res = await gapic_client.list_objects(request=list_objects_req) + async for obj in list_objects_res: + delete_object_req = storage_v2.DeleteObjectRequest( + bucket=f"projects/_/buckets/{bucket_name}", object=obj.name + ) + await gapic_client.delete_object(request=delete_object_req) + + delete_bucket_req = storage_v2.DeleteBucketRequest( + name=f"projects/_/buckets/{bucket_name}" + ) + await gapic_client.delete_bucket(request=delete_bucket_req) + except Exception as e: + print(f"Warning: Cleanup failed: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/perf/__init__.py b/tests/perf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/perf/microbenchmarks/README.md b/tests/perf/microbenchmarks/README.md new file mode 100644 index 000000000..a3e045682 --- /dev/null +++ b/tests/perf/microbenchmarks/README.md @@ -0,0 +1,41 @@ +# Performance Microbenchmarks + +This directory contains performance microbenchmarks for the Python Storage client library. + +## Usage + +To run the benchmarks, use `pytest` with the `--benchmark-json` flag to specify an output file for the results. + +Example: +```bash +pytest --benchmark-json=output.json -vv -s tests/perf/microbenchmarks/reads/test_reads.py +``` + +### Running a Specific Test + +To run a single test, append `::` followed by the test name to the file path. + +Examples: +```bash +pytest --benchmark-json=output.json -vv -s tests/perf/microbenchmarks/reads/test_reads.py::test_downloads_single_proc_single_coro +``` +```bash +pytest --benchmark-json=output.json -vv -s tests/perf/microbenchmarks/writes/test_writes.py::test_uploads_single_proc_single_coro +``` + +## Configuration + +The benchmarks are configured using `config.yaml` files located in the respective subdirectories (e.g., `reads/config.yaml`). + +## Overriding Buckets + +You can override the buckets used in the benchmarks by setting environment variables. Please refer to the specific benchmark implementation for the environment variable names. + +## Output + +The benchmarks produce a JSON file with the results. This file can be converted to a CSV file for easier analysis in spreadsheets using the provided `json_to_csv.py` script. + +Example: +```bash +python3 tests/perf/microbenchmarks/json_to_csv.py output.json +``` diff --git a/tests/perf/microbenchmarks/__init__.py b/tests/perf/microbenchmarks/__init__.py new file mode 100644 index 000000000..58d482ea3 --- /dev/null +++ b/tests/perf/microbenchmarks/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/perf/microbenchmarks/_utils.py b/tests/perf/microbenchmarks/_utils.py new file mode 100644 index 000000000..ff29b8783 --- /dev/null +++ b/tests/perf/microbenchmarks/_utils.py @@ -0,0 +1,167 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List +import statistics +import io +import os + + +def publish_benchmark_extra_info( + benchmark: Any, + params: Any, + benchmark_group: str = "read", + true_times: List[float] = [], +) -> None: + """ + Helper function to publish benchmark parameters to the extra_info property. + """ + + benchmark.extra_info["num_files"] = params.num_files + benchmark.extra_info["file_size"] = params.file_size_bytes + benchmark.extra_info["chunk_size"] = params.chunk_size_bytes + if benchmark_group == "write": + benchmark.extra_info["pattern"] = "seq" + else: + benchmark.extra_info["pattern"] = params.pattern + benchmark.extra_info["coros"] = params.num_coros + benchmark.extra_info["rounds"] = params.rounds + benchmark.extra_info["bucket_name"] = params.bucket_name + benchmark.extra_info["bucket_type"] = params.bucket_type + benchmark.extra_info["processes"] = params.num_processes + benchmark.group = benchmark_group + + object_size = params.file_size_bytes + num_files = params.num_files + total_uploaded_mib = object_size / (1024 * 1024) * num_files + min_throughput = total_uploaded_mib / benchmark.stats["max"] + max_throughput = total_uploaded_mib / benchmark.stats["min"] + mean_throughput = total_uploaded_mib / benchmark.stats["mean"] + median_throughput = total_uploaded_mib / benchmark.stats["median"] + + benchmark.extra_info["throughput_MiB_s_min"] = min_throughput + benchmark.extra_info["throughput_MiB_s_max"] = max_throughput + benchmark.extra_info["throughput_MiB_s_mean"] = mean_throughput + benchmark.extra_info["throughput_MiB_s_median"] = median_throughput + + print("\nThroughput Statistics (MiB/s):") + print(f" Min: {min_throughput:.2f} (from max time)") + print(f" Max: {max_throughput:.2f} (from min time)") + print(f" Mean: {mean_throughput:.2f} (approx, from mean time)") + print(f" Median: {median_throughput:.2f} (approx, from median time)") + + if true_times: + throughputs = [total_uploaded_mib / t for t in true_times] + true_min_throughput = min(throughputs) + true_max_throughput = max(throughputs) + true_mean_throughput = statistics.mean(throughputs) + true_median_throughput = statistics.median(throughputs) + + benchmark.extra_info["true_throughput_MiB_s_min"] = true_min_throughput + benchmark.extra_info["true_throughput_MiB_s_max"] = true_max_throughput + benchmark.extra_info["true_throughput_MiB_s_mean"] = true_mean_throughput + benchmark.extra_info["true_throughput_MiB_s_median"] = true_median_throughput + + print("\nThroughput Statistics from true_times (MiB/s):") + print(f" Min: {true_min_throughput:.2f}") + print(f" Max: {true_max_throughput:.2f}") + print(f" Mean: {true_mean_throughput:.2f}") + print(f" Median: {true_median_throughput:.2f}") + + # Get benchmark name, rounds, and iterations + name = benchmark.name + rounds = benchmark.stats["rounds"] + iterations = benchmark.stats["iterations"] + + # Header for throughput table + header = "\n\n" + "-" * 125 + "\n" + header += "Throughput Benchmark (MiB/s)\n" + header += "-" * 125 + "\n" + header += f"{'Name':<50} {'Min':>10} {'Max':>10} {'Mean':>10} {'StdDev':>10} {'Median':>10} {'Rounds':>8} {'Iterations':>12}\n" + header += "-" * 125 + + # Data row for throughput table + # The table headers (Min, Max) refer to the throughput values. + row = f"{name:<50} {min_throughput:>10.4f} {max_throughput:>10.4f} {mean_throughput:>10.4f} {'N/A':>10} {median_throughput:>10.4f} {rounds:>8} {iterations:>12}" + + print(header) + print(row) + print("-" * 125) + + +class RandomBytesIO(io.RawIOBase): + """ + A file-like object that generates random bytes using os.urandom. + It enforces a fixed size and an upper safety cap. + """ + + # 10 GiB default safety cap + DEFAULT_CAP = 10 * 1024 * 1024 * 1024 + + def __init__(self, size, max_size=DEFAULT_CAP): + """ + Args: + size (int): The exact size of the virtual file in bytes. + max_size (int): The maximum allowed size to prevent safety issues. + """ + if size is None: + raise ValueError("Size must be defined (cannot be infinite).") + + if size > max_size: + raise ValueError( + f"Requested size {size} exceeds the maximum limit of {max_size} bytes (10 GiB)." + ) + + self._size = size + self._pos = 0 + + def read(self, n=-1): + # 1. Handle "read all" (n=-1) + if n is None or n < 0: + n = self._size - self._pos + + # 2. Handle EOF (End of File) + if self._pos >= self._size: + return b"" + + # 3. Clamp read amount to remaining size + # This ensures we stop exactly at `size` bytes. + n = min(n, self._size - self._pos) + + # 4. Generate data + data = os.urandom(n) + self._pos += len(data) + return data + + def readable(self): + return True + + def seekable(self): + return True + + def tell(self): + return self._pos + + def seek(self, offset, whence=io.SEEK_SET): + if whence == io.SEEK_SET: + new_pos = offset + elif whence == io.SEEK_CUR: + new_pos = self._pos + offset + elif whence == io.SEEK_END: + new_pos = self._size + offset + else: + raise ValueError(f"Invalid whence: {whence}") + + # Clamp position to valid range [0, size] + self._pos = max(0, min(new_pos, self._size)) + return self._pos diff --git a/tests/perf/microbenchmarks/conftest.py b/tests/perf/microbenchmarks/conftest.py new file mode 100644 index 000000000..e748c6e43 --- /dev/null +++ b/tests/perf/microbenchmarks/conftest.py @@ -0,0 +1,160 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Any +from tests.perf.microbenchmarks.resource_monitor import ResourceMonitor +import pytest +from tests.system._helpers import delete_blob + +import asyncio +import multiprocessing +import os +import uuid +from google.cloud import storage +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from tests.perf.microbenchmarks.writes.parameters import WriteParameters + +_OBJECT_NAME_PREFIX = "micro-benchmark" + + +@pytest.fixture(scope="function") +def blobs_to_delete(): + blobs_to_delete = [] + + yield blobs_to_delete + + for blob in blobs_to_delete: + delete_blob(blob) + + +@pytest.fixture(scope="session") +def storage_client(): + from google.cloud.storage import Client + + client = Client() + with contextlib.closing(client): + yield client + + +@pytest.fixture +def monitor(): + """ + Provides the ResourceMonitor class. + Usage: with monitor() as m: ... + """ + return ResourceMonitor + + +def publish_resource_metrics(benchmark: Any, monitor: ResourceMonitor) -> None: + """ + Helper function to publish resource monitor results to the extra_info property. + """ + benchmark.extra_info.update( + { + "cpu_max_global": f"{monitor.max_cpu:.2f}", + "mem_max": f"{monitor.max_mem:.2f}", + "net_throughput_mb_s": f"{monitor.throughput_mb_s:.2f}", + "vcpus": monitor.vcpus, + } + ) + + +async def upload_appendable_object(bucket_name, object_name, object_size, chunk_size): + # flush interval set to little over 1GiB to minimize number of flushes. + # this method is to write "appendable" objects which will be used for + # benchmarking reads, hence not concerned performance of writes here. + writer = AsyncAppendableObjectWriter( + AsyncGrpcClient(), + bucket_name, + object_name, + writer_options={"FLUSH_INTERVAL_BYTES": 1026 * 1024**2}, + ) + await writer.open() + uploaded_bytes = 0 + while uploaded_bytes < object_size: + bytes_to_upload = min(chunk_size, object_size - uploaded_bytes) + await writer.append(os.urandom(bytes_to_upload)) + uploaded_bytes += bytes_to_upload + object_metdata = await writer.close(finalize_on_close=True) + assert object_metdata.size == uploaded_bytes + return uploaded_bytes + + +def upload_simple_object(bucket_name, object_name, object_size, chunk_size): + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(object_name) + blob.chunk_size = chunk_size + data = os.urandom(object_size) + blob.upload_from_string(data) + return object_size + + +def _upload_worker(args): + bucket_name, object_name, object_size, chunk_size, bucket_type = args + if bucket_type == "zonal": + uploaded_bytes = asyncio.run( + upload_appendable_object(bucket_name, object_name, object_size, chunk_size) + ) + else: + uploaded_bytes = upload_simple_object( + bucket_name, object_name, object_size, chunk_size + ) + return object_name, uploaded_bytes + + +def _create_files( + num_files, bucket_name, bucket_type, object_size, chunk_size=1024 * 1024 * 1024 +): + """ + Create/Upload objects for benchmarking and return a list of their names. + """ + object_names = [ + f"{_OBJECT_NAME_PREFIX}-{uuid.uuid4().hex[:5]}" for _ in range(num_files) + ] + + args_list = [ + (bucket_name, object_names[i], object_size, chunk_size, bucket_type) + for i in range(num_files) + ] + + ctx = multiprocessing.get_context("spawn") + with ctx.Pool() as pool: + results = pool.map(_upload_worker, args_list) + + total_uploaded_bytes = sum(r[1] for r in results) + assert total_uploaded_bytes == object_size * num_files + + return [r[0] for r in results] + + +@pytest.fixture +def workload_params(request): + params = request.param + if isinstance(params, WriteParameters): + files_names = [ + f"{_OBJECT_NAME_PREFIX}-{uuid.uuid4().hex[:5]}" + for _ in range(params.num_files) + ] + else: + files_names = _create_files( + params.num_files, + params.bucket_name, + params.bucket_type, + params.file_size_bytes, + ) + return params, files_names diff --git a/tests/perf/microbenchmarks/json_to_csv.py b/tests/perf/microbenchmarks/json_to_csv.py new file mode 100644 index 000000000..1ef58f907 --- /dev/null +++ b/tests/perf/microbenchmarks/json_to_csv.py @@ -0,0 +1,190 @@ +import json +import csv +import argparse +import logging +import numpy as np + +MB = 1024 * 1024 + + +def _process_benchmark_result(bench, headers, extra_info_headers, stats_headers): + """ + Process a single benchmark result and prepare it for CSV reporting. + + This function extracts relevant statistics and metadata from a benchmark + run, calculates derived metrics like percentiles and throughput, and + formats it as a dictionary. + + Args: + bench (dict): The dictionary for a single benchmark from the JSON output. + headers (list): The list of all header names for the CSV. + extra_info_headers (list): Headers from the 'extra_info' section. + stats_headers (list): Headers from the 'stats' section. + + """ + row = {h: "" for h in headers} + row["name"] = bench.get("name", "") + row["group"] = bench.get("group", "") + + extra_info = bench.get("extra_info", {}) + + # Populate extra_info and stats + for key in extra_info_headers: + row[key] = extra_info.get(key) + for key in stats_headers: + row[key] = bench.get("stats", {}).get(key) + + # Handle threads/coros mapping + if "threads" in row: + row["threads"] = extra_info.get("num_coros", extra_info.get("coros")) + + # Calculate percentiles + timings = bench.get("stats", {}).get("data") + if timings: + row["p90"] = np.percentile(timings, 90) + row["p95"] = np.percentile(timings, 95) + row["p99"] = np.percentile(timings, 99) + + # Calculate max throughput + file_size = extra_info.get("file_size_bytes", extra_info.get("file_size", 0)) + num_files = extra_info.get("num_files", 1) + total_bytes = file_size * num_files + + min_time = bench.get("stats", {}).get("min") + if min_time and min_time > 0: + row["max_throughput_mb_s"] = (total_bytes / min_time) / MB + else: + row["max_throughput_mb_s"] = 0.0 + + return row + + +def _generate_report(json_path, csv_path): + """Generate a CSV summary report from the pytest-benchmark JSON output. + + Args: + json_path (str): The path to the JSON file containing benchmark results. + csv_path (str): The path where the CSV report will be saved. + + Returns: + str: The path to the generated CSV report file. + + """ + logging.info(f"Generating CSV report from {json_path}") + + with open(json_path, "r") as f: + data = json.load(f) + + benchmarks = data.get("benchmarks", []) + if not benchmarks: + logging.warning("No benchmarks found in the JSON file.") + return + + # headers order - name group block_size bucket_name bucket_type chunk_size cpu_max_global file_size mem_max net_throughput_mb_s num_files pattern processes rounds threads vcpus min max mean median stddev p90 p95 p99 max_throughput_mb_s + # if there are any other column keep it at the afterwards. + ordered_headers = [ + "name", + "group", + "block_size", + "bucket_name", + "bucket_type", + "chunk_size", + "cpu_max_global", + "file_size", + "mem_max", + "net_throughput_mb_s", + "num_files", + "pattern", + "processes", + "rounds", + "threads", + "vcpus", + "min", + "max", + "mean", + "median", + "stddev", + "p90", + "p95", + "p99", + "max_throughput_mb_s", + ] + + # Gather all available headers from the data + all_available_headers = set(["name", "group"]) + stats_headers = ["min", "max", "mean", "median", "stddev"] + custom_headers = ["p90", "p95", "p99", "max_throughput_mb_s"] + + all_available_headers.update(stats_headers) + all_available_headers.update(custom_headers) + + extra_info_keys = set() + for bench in benchmarks: + if "extra_info" in bench and isinstance(bench["extra_info"], dict): + extra_info_keys.update(bench["extra_info"].keys()) + all_available_headers.update(extra_info_keys) + + # Construct the final header list + final_headers = list(ordered_headers) + + # Add any headers from the data that are not in the ordered list + for header in sorted(list(all_available_headers)): + if header not in final_headers: + final_headers.append(header) + + # We still need the full list of extra_info headers for _process_benchmark_result + extra_info_headers = sorted(list(extra_info_keys)) + + with open(csv_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(final_headers) + + for bench in benchmarks: + row = _process_benchmark_result( + bench, final_headers, extra_info_headers, stats_headers + ) + writer.writerow([row.get(h, "") for h in final_headers]) + + logging.info(f"CSV report generated at {csv_path}") + return csv_path + + +def main(): + """ + Converts a JSON benchmark file to a CSV file. + + The CSV file will contain the 'name' of each benchmark and all fields + from the 'extra_info' section. + """ + parser = argparse.ArgumentParser(description="Convert benchmark JSON to CSV.") + parser.add_argument( + "--input_file", + nargs="?", + default="output.json", + help="Path to the input JSON file (default: output.json)", + ) + parser.add_argument( + "--output_file", + nargs="?", + default="output.csv", + help="Path to the output CSV file (default: output.csv)", + ) + args = parser.parse_args() + + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) + + try: + _generate_report(args.input_file, args.output_file) + print(f"Successfully converted {args.input_file} to {args.output_file}") + except FileNotFoundError: + logging.error(f"Error: Input file not found at {args.input_file}") + except json.JSONDecodeError: + logging.error(f"Error: Could not decode JSON from {args.input_file}") + except Exception as e: + logging.error(f"An unexpected error occurred: {e}") + + +if __name__ == "__main__": + main() diff --git a/tests/perf/microbenchmarks/parameters.py b/tests/perf/microbenchmarks/parameters.py new file mode 100644 index 000000000..72b8476b6 --- /dev/null +++ b/tests/perf/microbenchmarks/parameters.py @@ -0,0 +1,28 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + + +@dataclass +class IOBenchmarkParameters: + name: str + workload_name: str + bucket_name: str + bucket_type: str + num_coros: int + num_processes: int + num_files: int + rounds: int + chunk_size_bytes: int + file_size_bytes: int diff --git a/tests/perf/microbenchmarks/reads/__init__.py b/tests/perf/microbenchmarks/reads/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/perf/microbenchmarks/reads/config.py b/tests/perf/microbenchmarks/reads/config.py new file mode 100644 index 000000000..7d83e3f8e --- /dev/null +++ b/tests/perf/microbenchmarks/reads/config.py @@ -0,0 +1,113 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import os +from typing import Dict, List + +import yaml + +try: + from tests.perf.microbenchmarks.reads.parameters import ReadParameters +except ModuleNotFoundError: + from reads.parameters import ReadParameters + + +def _get_params() -> Dict[str, List[ReadParameters]]: + """Generates a dictionary of benchmark parameters for read operations. + + This function reads configuration from a `config.yaml` file, which defines + common parameters (like bucket types, file sizes) and different workloads. + It then generates all possible combinations of these parameters for each + workload using `itertools.product`. + + The resulting parameter sets are encapsulated in `ReadParameters` objects + and organized by workload name in the returned dictionary. + + Bucket names can be overridden by setting the `DEFAULT_RAPID_ZONAL_BUCKET` + and `DEFAULT_STANDARD_BUCKET` environment variables. + + Returns: + Dict[str, List[ReadParameters]]: A dictionary where keys are workload + names (e.g., 'read_seq', 'read_rand_multi_coros') and values are lists + of `ReadParameters` objects, each representing a unique benchmark scenario. + """ + params: Dict[str, List[ReadParameters]] = {} + config_path = os.path.join(os.path.dirname(__file__), "config.yaml") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + common_params = config["common"] + bucket_types = common_params["bucket_types"] + file_sizes_mib = common_params["file_sizes_mib"] + chunk_sizes_mib = common_params["chunk_sizes_mib"] + rounds = common_params["rounds"] + + bucket_map = { + "zonal": os.environ.get( + "DEFAULT_RAPID_ZONAL_BUCKET", + config["defaults"]["DEFAULT_RAPID_ZONAL_BUCKET"], + ), + "regional": os.environ.get( + "DEFAULT_STANDARD_BUCKET", config["defaults"]["DEFAULT_STANDARD_BUCKET"] + ), + } + + for workload in config["workload"]: + workload_name = workload["name"] + params[workload_name] = [] + pattern = workload["pattern"] + processes = workload["processes"] + coros = workload["coros"] + + # Create a product of all parameter combinations + product = itertools.product( + bucket_types, + file_sizes_mib, + chunk_sizes_mib, + processes, + coros, + ) + + for ( + bucket_type, + file_size_mib, + chunk_size_mib, + num_processes, + num_coros, + ) in product: + file_size_bytes = file_size_mib * 1024 * 1024 + chunk_size_bytes = chunk_size_mib * 1024 * 1024 + bucket_name = bucket_map[bucket_type] + + num_files = num_processes * num_coros + + # Create a descriptive name for the parameter set + name = f"{pattern}_{bucket_type}_{num_processes}p_{num_coros}c" + + params[workload_name].append( + ReadParameters( + name=name, + workload_name=workload_name, + pattern=pattern, + bucket_name=bucket_name, + bucket_type=bucket_type, + num_coros=num_coros, + num_processes=num_processes, + num_files=num_files, + rounds=rounds, + chunk_size_bytes=chunk_size_bytes, + file_size_bytes=file_size_bytes, + ) + ) + return params diff --git a/tests/perf/microbenchmarks/reads/config.yaml b/tests/perf/microbenchmarks/reads/config.yaml new file mode 100644 index 000000000..25bfd92c8 --- /dev/null +++ b/tests/perf/microbenchmarks/reads/config.yaml @@ -0,0 +1,49 @@ +common: + bucket_types: + - "regional" + - "zonal" + file_sizes_mib: + - 1024 # 1GiB + chunk_sizes_mib: [100] + rounds: 10 + +workload: + + ############# single process single coroutine ######### + - name: "read_seq" + pattern: "seq" + coros: [1] + processes: [1] + + - name: "read_rand" + pattern: "rand" + coros: [1] + processes: [1] + + ############# single process multi coroutine ######### + + - name: "read_seq_multi_coros" + pattern: "seq" + coros: [2, 4, 8, 16] + processes: [1] + + - name: "read_rand_multi_coros" + pattern: "rand" + coros: [2, 4, 8, 16] + processes: [1] + + ############# multi process multi coroutine ######### + - name: "read_seq_multi_process" + pattern: "seq" + coros: [1, 2, 4] + processes: [2, 16, 48, 96] + + - name: "read_rand_multi_process" + pattern: "rand" + coros: [1, 2, 4] + processes: [2, 16, 48, 96] + + +defaults: + DEFAULT_RAPID_ZONAL_BUCKET: "chandrasiri-benchmarks-zb" + DEFAULT_STANDARD_BUCKET: "chandrasiri-benchmarks-rb" diff --git a/tests/perf/microbenchmarks/reads/parameters.py b/tests/perf/microbenchmarks/reads/parameters.py new file mode 100644 index 000000000..0785a4147 --- /dev/null +++ b/tests/perf/microbenchmarks/reads/parameters.py @@ -0,0 +1,20 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from ..parameters import IOBenchmarkParameters + + +@dataclass +class ReadParameters(IOBenchmarkParameters): + pattern: str diff --git a/tests/perf/microbenchmarks/reads/test_reads.py b/tests/perf/microbenchmarks/reads/test_reads.py new file mode 100644 index 000000000..d51102cea --- /dev/null +++ b/tests/perf/microbenchmarks/reads/test_reads.py @@ -0,0 +1,422 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for Google Cloud Storage read operations. + +This module contains performance benchmarks for various read patterns from Google Cloud Storage. +It includes three main test functions: +- `test_downloads_single_proc_single_coro`: Benchmarks reads using a single process and a single coroutine. +- `test_downloads_single_proc_multi_coro`: Benchmarks reads using a single process and multiple coroutines. +- `test_downloads_multi_proc_multi_coro`: Benchmarks reads using multiple processes and multiple coroutines. + +All other functions in this module are helper methods for these three tests. +""" + +import time +import asyncio +import random +from io import BytesIO +import logging + +import pytest + +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, +) +from tests.perf.microbenchmarks._utils import publish_benchmark_extra_info +from tests.perf.microbenchmarks.conftest import ( + publish_resource_metrics, +) +import tests.perf.microbenchmarks.reads.config as config +from concurrent.futures import ThreadPoolExecutor +import multiprocessing + +all_params = config._get_params() + + +async def create_client(): + """Initializes async client and gets the current event loop.""" + return AsyncGrpcClient() + + +async def download_chunks_using_mrd_async(client, filename, other_params, chunks): + # start timer. + start_time = time.monotonic_ns() + + total_bytes_downloaded = 0 + mrd = AsyncMultiRangeDownloader(client, other_params.bucket_name, filename) + await mrd.open() + for offset, size in chunks: + buffer = BytesIO() + await mrd.download_ranges([(offset, size, buffer)]) + total_bytes_downloaded += buffer.tell() + await mrd.close() + + assert total_bytes_downloaded == other_params.file_size_bytes + + # end timer. + end_time = time.monotonic_ns() + elapsed_time = end_time - start_time + return elapsed_time / 1_000_000_000 + + +def download_chunks_using_mrd(loop, client, filename, other_params, chunks): + return loop.run_until_complete( + download_chunks_using_mrd_async(client, filename, other_params, chunks) + ) + + +def download_chunks_using_json(_, json_client, filename, other_params, chunks): + bucket = json_client.bucket(other_params.bucket_name) + blob = bucket.blob(filename) + start_time = time.monotonic_ns() + for offset, size in chunks: + _ = blob.download_as_bytes(start=offset, end=offset + size - 1) + return (time.monotonic_ns() - start_time) / 1_000_000_000 + + +@pytest.mark.parametrize( + "workload_params", + all_params["read_rand"] + all_params["read_seq"], + indirect=True, + ids=lambda p: p.name, +) +def test_downloads_single_proc_single_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + """ + Benchmarks reads using a single process and a single coroutine. + It creates chunks based on object size and chunk_size, then passes them to either + `download_chunks_using_mrd` (for zonal buckets) or `download_chunks_using_json` (for regional buckets) + for benchmarking using `benchmark.pedantic`. + """ + params, files_names = workload_params + + object_size = params.file_size_bytes + chunk_size = params.chunk_size_bytes + chunks = [] + for offset in range(0, object_size, chunk_size): + size = min(chunk_size, object_size - offset) + chunks.append((offset, size)) + + if params.pattern == "rand": + logging.info("randomizing chunks") + random.shuffle(chunks) + + if params.bucket_type == "zonal": + logging.info("bucket type zonal") + target_func = download_chunks_using_mrd + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = loop.run_until_complete(create_client()) + else: + logging.info("bucket type regional") + target_func = download_chunks_using_json + loop = None + client = storage_client + + output_times = [] + + def target_wrapper(*args, **kwargs): + result = target_func(*args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + loop, + client, + files_names[0], + params, + chunks, + ), + ) + finally: + if loop is not None: + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.close() + publish_benchmark_extra_info(benchmark, params, true_times=output_times) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) + + +def download_files_using_mrd_multi_coro(loop, client, files, other_params, chunks): + """ + Downloads multiple files concurrently using AsyncMultiRangeDownloader (MRD) with asyncio. + + For each file, it creates a coroutine to download its chunks using `download_chunks_using_mrd_async`. + All coroutines are then executed concurrently using `asyncio.gather`. + The function returns the maximum latency observed among all coroutines. + + Args: + loop: The asyncio event loop. + client: The AsyncGrpcClient instance. + files (list): A list of filenames to download. + other_params: An object containing benchmark parameters (e.g., bucket_name, file_size_bytes). + chunks (list): A list of (offset, size) tuples representing the parts of each file to download. + + Returns: + float: The maximum latency (in seconds) among all coroutines. + """ + + async def main(): + if len(files) == 1: + result = await download_chunks_using_mrd_async( + client, files[0], other_params, chunks + ) + return [result] + else: + tasks = [] + for f in files: + tasks.append( + download_chunks_using_mrd_async(client, f, other_params, chunks) + ) + return await asyncio.gather(*tasks) + + results = loop.run_until_complete(main()) + return max(results) + + +def download_files_using_json_multi_threaded( + _, json_client, files, other_params, chunks +): + """ + Downloads multiple files concurrently using the JSON API with a ThreadPoolExecutor. + + For each file, it submits a task to a `ThreadPoolExecutor` to download its chunks + using `download_chunks_using_json`. The number of concurrent downloads is + determined by `other_params.num_coros` (which acts as `max_workers`). + The function returns the maximum latency among all concurrent downloads. + + The `chunks` parameter is a list of (offset, size) tuples representing + the parts of each file to download. + """ + results = [] + # In the context of multi-coro, num_coros is the number of files to download concurrently. + # So we can use it as max_workers for the thread pool. + with ThreadPoolExecutor(max_workers=other_params.num_coros) as executor: + futures = [] + for f in files: + future = executor.submit( + download_chunks_using_json, None, json_client, f, other_params, chunks + ) + futures.append(future) + + for future in futures: + results.append(future.result()) + + return max(results) + + +@pytest.mark.parametrize( + "workload_params", + all_params["read_seq_multi_coros"] + all_params["read_rand_multi_coros"], + indirect=True, + ids=lambda p: p.name, +) +def test_downloads_single_proc_multi_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + params, files_names = workload_params + + object_size = params.file_size_bytes + chunk_size = params.chunk_size_bytes + chunks = [] + for offset in range(0, object_size, chunk_size): + size = min(chunk_size, object_size - offset) + chunks.append((offset, size)) + + if params.pattern == "rand": + logging.info("randomizing chunks") + random.shuffle(chunks) + + if params.bucket_type == "zonal": + logging.info("bucket type zonal") + target_func = download_files_using_mrd_multi_coro + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = loop.run_until_complete(create_client()) + else: + logging.info("bucket type regional") + target_func = download_files_using_json_multi_threaded + loop = None + client = storage_client + + output_times = [] + + def target_wrapper(*args, **kwargs): + result = target_func(*args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + loop, + client, + files_names, + params, + chunks, + ), + ) + finally: + if loop is not None: + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.close() + publish_benchmark_extra_info(benchmark, params, true_times=output_times) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) + + +# --- Global Variables for Worker Process --- +worker_loop = None +worker_client = None +worker_json_client = None + + +def _worker_init(bucket_type): + """Initializes a persistent event loop and client for each worker process.""" + global worker_loop, worker_client, worker_json_client + if bucket_type == "zonal": + worker_loop = asyncio.new_event_loop() + asyncio.set_event_loop(worker_loop) + worker_client = worker_loop.run_until_complete(create_client()) + else: # regional + from google.cloud import storage + + worker_json_client = storage.Client() + + +def _download_files_worker(files_to_download, other_params, chunks, bucket_type): + if bucket_type == "zonal": + # The loop and client are already initialized in _worker_init. + # download_files_using_mrd_multi_coro returns max latency of coros + return download_files_using_mrd_multi_coro( + worker_loop, worker_client, files_to_download, other_params, chunks + ) + else: # regional + # download_files_using_json_multi_threaded returns max latency of threads + return download_files_using_json_multi_threaded( + None, worker_json_client, files_to_download, other_params, chunks + ) + + +def download_files_mp_mc_wrapper(pool, files_names, params, chunks, bucket_type): + num_coros = params.num_coros # This is n, number of files per process + + # Distribute filenames to processes + filenames_per_process = [ + files_names[i : i + num_coros] for i in range(0, len(files_names), num_coros) + ] + args = [ + ( + filenames, + params, + chunks, + bucket_type, + ) + for filenames in filenames_per_process + ] + + results = pool.starmap(_download_files_worker, args) + return max(results) + + +@pytest.mark.parametrize( + "workload_params", + all_params["read_seq_multi_process"] + all_params["read_rand_multi_process"], + indirect=True, + ids=lambda p: p.name, +) +def test_downloads_multi_proc_multi_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + """ + Benchmarks reads using multiple processes and multiple coroutines. + + This test distributes `m*n` files among `m` processes, where each process + downloads `n` files concurrently using `n` coroutines. The processes are spawned + in "spawn" mode. The reported latency for each round is the maximum latency + observed across all processes. + """ + params, files_names = workload_params + logging.info(f"num files: {len(files_names)}") + + object_size = params.file_size_bytes + chunk_size = params.chunk_size_bytes + chunks = [] + for offset in range(0, object_size, chunk_size): + size = min(chunk_size, object_size - offset) + chunks.append((offset, size)) + + if params.pattern == "rand": + logging.info("randomizing chunks") + random.shuffle(chunks) + + ctx = multiprocessing.get_context("spawn") + pool = ctx.Pool( + processes=params.num_processes, + initializer=_worker_init, + initargs=(params.bucket_type,), + ) + output_times = [] + + def target_wrapper(*args, **kwargs): + result = download_files_mp_mc_wrapper(pool, *args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + files_names, + params, + chunks, + params.bucket_type, + ), + ) + finally: + pool.close() + pool.join() + publish_benchmark_extra_info(benchmark, params, true_times=output_times) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) diff --git a/tests/perf/microbenchmarks/resource_monitor.py b/tests/perf/microbenchmarks/resource_monitor.py new file mode 100644 index 000000000..8ad2a27b7 --- /dev/null +++ b/tests/perf/microbenchmarks/resource_monitor.py @@ -0,0 +1,99 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time + +import psutil + + +class ResourceMonitor: + def __init__(self): + self.interval = 1.0 + + self.vcpus = psutil.cpu_count() or 1 + self.max_cpu = 0.0 + self.max_mem = 0.0 + + # Network and Time tracking + self.start_time = 0.0 + self.duration = 0.0 + self.start_net = None + self.net_sent_mb = 0.0 + self.net_recv_mb = 0.0 + + self._stop_event = threading.Event() + self._thread = None + + def __enter__(self): + self.start_net = psutil.net_io_counters() + self.start_time = time.perf_counter() + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + self.duration = time.perf_counter() - self.start_time + end_net = psutil.net_io_counters() + + self.net_sent_mb = (end_net.bytes_sent - self.start_net.bytes_sent) / ( + 1024 * 1024 + ) + self.net_recv_mb = (end_net.bytes_recv - self.start_net.bytes_recv) / ( + 1024 * 1024 + ) + + def _monitor(self): + psutil.cpu_percent(interval=None) + current_process = psutil.Process() + while not self._stop_event.is_set(): + try: + # CPU and Memory tracking for current process tree + total_cpu = current_process.cpu_percent(interval=None) + current_mem = current_process.memory_info().rss + for child in current_process.children(recursive=True): + try: + total_cpu += child.cpu_percent(interval=None) + current_mem += child.memory_info().rss + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + # Normalize CPU by number of vcpus + global_cpu = total_cpu / self.vcpus + + mem = current_mem + + if global_cpu > self.max_cpu: + self.max_cpu = global_cpu + if mem > self.max_mem: + self.max_mem = mem + except psutil.NoSuchProcess: + pass + + time.sleep(self.interval) + + def start(self): + self._thread = threading.Thread(target=self._monitor, daemon=True) + self._thread.start() + + def stop(self): + self._stop_event.set() + if self._thread: + self._thread.join() + + @property + def throughput_mb_s(self): + """Calculates combined network throughput.""" + if self.duration <= 0: + return 0.0 + return (self.net_sent_mb + self.net_recv_mb) / self.duration diff --git a/tests/perf/microbenchmarks/writes/__init__.py b/tests/perf/microbenchmarks/writes/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/perf/microbenchmarks/writes/config.py b/tests/perf/microbenchmarks/writes/config.py new file mode 100644 index 000000000..d823260f9 --- /dev/null +++ b/tests/perf/microbenchmarks/writes/config.py @@ -0,0 +1,105 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import os +from typing import Dict, List + +import yaml + +try: + from tests.perf.microbenchmarks.writes.parameters import WriteParameters +except ModuleNotFoundError: + from parameters import WriteParameters + + +def get_write_params() -> Dict[str, List[WriteParameters]]: + """Generates benchmark parameters from a YAML configuration file. + + This function reads the configuration from `config.yaml`, located in the + same directory, and generates all possible combinations of write parameters + based on the defined workloads. It uses `itertools.product` to create + a Cartesian product of parameters like bucket types, file sizes, etc. + + Returns: + Dict[str, List[WriteParameters]]: A dictionary where keys are workload + names and values are lists of `WriteParameters` instances for that + workload. + """ + params: Dict[str, List[WriteParameters]] = {} + config_path = os.path.join(os.path.dirname(__file__), "config.yaml") + with open(config_path, "r") as f: + config = yaml.safe_load(f) + + common_params = config["common"] + bucket_types = common_params["bucket_types"] + file_sizes_mib = common_params["file_sizes_mib"] + chunk_sizes_mib = common_params["chunk_sizes_mib"] + rounds = common_params["rounds"] + + bucket_map = { + "zonal": os.environ.get( + "DEFAULT_RAPID_ZONAL_BUCKET", + config["defaults"]["DEFAULT_RAPID_ZONAL_BUCKET"], + ), + "regional": os.environ.get( + "DEFAULT_STANDARD_BUCKET", config["defaults"]["DEFAULT_STANDARD_BUCKET"] + ), + } + + for workload in config["workload"]: + workload_name = workload["name"] + params[workload_name] = [] + processes = workload["processes"] + coros = workload["coros"] + + # Create a product of all parameter combinations + product = itertools.product( + bucket_types, + file_sizes_mib, + chunk_sizes_mib, + processes, + coros, + ) + + for ( + bucket_type, + file_size_mib, + chunk_size_mib, + num_processes, + num_coros, + ) in product: + file_size_bytes = file_size_mib * 1024 * 1024 + chunk_size_bytes = chunk_size_mib * 1024 * 1024 + bucket_name = bucket_map[bucket_type] + + num_files = num_processes * num_coros + + # Create a descriptive name for the parameter set + name = f"{workload_name}_{bucket_type}_{num_processes}p_{num_coros}c" + + params[workload_name].append( + WriteParameters( + name=name, + workload_name=workload_name, + bucket_name=bucket_name, + bucket_type=bucket_type, + num_coros=num_coros, + num_processes=num_processes, + num_files=num_files, + rounds=rounds, + chunk_size_bytes=chunk_size_bytes, + file_size_bytes=file_size_bytes, + ) + ) + return params diff --git a/tests/perf/microbenchmarks/writes/config.yaml b/tests/perf/microbenchmarks/writes/config.yaml new file mode 100644 index 000000000..b4d93ba52 --- /dev/null +++ b/tests/perf/microbenchmarks/writes/config.yaml @@ -0,0 +1,34 @@ +common: + bucket_types: + - "regional" + - "zonal" + file_sizes_mib: + - 1024 # 1GiB + chunk_sizes_mib: [100] + rounds: 10 + +workload: + + ############# single proc single coroutines ######### + - name: "write_seq" + pattern: "seq" + coros: [1] + processes: [1] + + ############# single proc multiple coroutines ######### + + - name: "write_seq_multi_coros" + pattern: "seq" + coros: [2, 4, 8, 16] + processes: [1] + + ############# multiple proc multiple coroutines ######### + - name: "write_seq_multi_process" + pattern: "seq" + coros: [1, 2] + processes: [8, 16, 32, 64] + + +defaults: + DEFAULT_RAPID_ZONAL_BUCKET: "chandrasiri-benchmarks-zb" + DEFAULT_STANDARD_BUCKET: "chandrasiri-benchmarks-rb" diff --git a/tests/perf/microbenchmarks/writes/parameters.py b/tests/perf/microbenchmarks/writes/parameters.py new file mode 100644 index 000000000..8d44b93dc --- /dev/null +++ b/tests/perf/microbenchmarks/writes/parameters.py @@ -0,0 +1,20 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from ..parameters import IOBenchmarkParameters + + +@dataclass +class WriteParameters(IOBenchmarkParameters): + pass diff --git a/tests/perf/microbenchmarks/writes/test_writes.py b/tests/perf/microbenchmarks/writes/test_writes.py new file mode 100644 index 000000000..02a0f5e4f --- /dev/null +++ b/tests/perf/microbenchmarks/writes/test_writes.py @@ -0,0 +1,434 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Microbenchmarks for Google Cloud Storage write operations. + +This module contains performance benchmarks for various write patterns to Google Cloud Storage. +It includes three main test functions: +- `test_uploads_single_proc_single_coro`: Benchmarks uploads using a single process and a single coroutine. +- `test_uploads_single_proc_multi_coro`: Benchmarks uploads using a single process and multiple coroutines. +- `test_uploads_multi_proc_multi_coro`: Benchmarks uploads using multiple processes and multiple coroutines. + +All other functions in this module are helper methods for these three tests. +""" + +import os +import time +import asyncio +from concurrent.futures import ThreadPoolExecutor +import multiprocessing +import logging + +import pytest +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_appendable_object_writer import ( + AsyncAppendableObjectWriter, +) + +from tests.perf.microbenchmarks._utils import ( + publish_benchmark_extra_info, + RandomBytesIO, +) +from tests.perf.microbenchmarks.conftest import publish_resource_metrics +import tests.perf.microbenchmarks.writes.config as config +from google.cloud import storage + +# Get write parameters +all_params = config.get_write_params() + + +async def create_client(): + """Initializes async client and gets the current event loop.""" + return AsyncGrpcClient() + + +async def upload_chunks_using_grpc_async(client, filename, other_params): + """Uploads a file in chunks using the gRPC API asynchronously. + + Args: + client: The async gRPC client. + filename (str): The name of the object to create. + other_params: An object containing benchmark parameters like bucket_name, + file_size_bytes, and chunk_size_bytes. + + Returns: + float: The total time taken for the upload in seconds. + """ + start_time = time.monotonic_ns() + + writer = AsyncAppendableObjectWriter( + client=client, bucket_name=other_params.bucket_name, object_name=filename + ) + await writer.open() + + uploaded_bytes = 0 + upload_size = other_params.file_size_bytes + chunk_size = other_params.chunk_size_bytes + + while uploaded_bytes < upload_size: + bytes_to_upload = min(chunk_size, upload_size - uploaded_bytes) + data = os.urandom(bytes_to_upload) + await writer.append(data) + uploaded_bytes += bytes_to_upload + await writer.close() + + assert uploaded_bytes == upload_size + + end_time = time.monotonic_ns() + elapsed_time = end_time - start_time + return elapsed_time / 1_000_000_000 + + +def upload_chunks_using_grpc(loop, client, filename, other_params): + """Wrapper to run the async gRPC upload in a synchronous context. + + Args: + loop: The asyncio event loop. + client: The async gRPC client. + filename (str): The name of the object to create. + other_params: An object containing benchmark parameters. + + Returns: + float: The total time taken for the upload in seconds. + """ + return loop.run_until_complete( + upload_chunks_using_grpc_async(client, filename, other_params) + ) + + +def upload_using_json(_, json_client, filename, other_params): + """Uploads a file using the JSON API. + + Args: + _ (any): Unused. + json_client: The standard Python Storage client. + filename (str): The name of the object to create. + other_params: An object containing benchmark parameters like bucket_name + and file_size_bytes. + + Returns: + float: The total time taken for the upload in seconds. + """ + start_time = time.monotonic_ns() + + bucket = json_client.bucket(other_params.bucket_name) + blob = bucket.blob(filename) + upload_size = other_params.file_size_bytes + # Don't use BytesIO because it'll report high memory usage for large files. + # `RandomBytesIO` generates random bytes on the fly. + in_mem_file = RandomBytesIO(upload_size) + blob.upload_from_file(in_mem_file) + + end_time = time.monotonic_ns() + elapsed_time = end_time - start_time + return elapsed_time / 1_000_000_000 + + +@pytest.mark.parametrize( + "workload_params", + all_params["write_seq"], + indirect=True, + ids=lambda p: p.name, +) +def test_uploads_single_proc_single_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + """ + Benchmarks uploads using a single process and a single coroutine. + It passes the workload to either `upload_chunks_using_grpc` (for zonal buckets) + or `upload_using_json` (for regional buckets) for benchmarking using `benchmark.pedantic`. + """ + params, files_names = workload_params + + if params.bucket_type == "zonal": + logging.info("bucket type zonal") + target_func = upload_chunks_using_grpc + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = loop.run_until_complete(create_client()) + else: + logging.info("bucket type regional") + target_func = upload_using_json + loop = None + client = storage_client + + output_times = [] + + def target_wrapper(*args, **kwargs): + result = target_func(*args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + loop, + client, + files_names[0], + params, + ), + ) + finally: + if loop is not None: + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.close() + publish_benchmark_extra_info( + benchmark, params, benchmark_group="write", true_times=output_times + ) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) + + +def upload_files_using_grpc_multi_coro(loop, client, files, other_params): + """Uploads multiple files concurrently using gRPC with asyncio. + + Args: + loop: The asyncio event loop. + client: The async gRPC client. + files (list): A list of filenames to upload. + other_params: An object containing benchmark parameters. + + Returns: + float: The maximum latency observed among all coroutines. + """ + + async def main(): + tasks = [] + for f in files: + tasks.append(upload_chunks_using_grpc_async(client, f, other_params)) + return await asyncio.gather(*tasks) + + results = loop.run_until_complete(main()) + return max(results) + + +def upload_files_using_json_multi_threaded(_, json_client, files, other_params): + """Uploads multiple files concurrently using the JSON API with a ThreadPoolExecutor. + + Args: + _ (any): Unused. + json_client: The standard Python Storage client. + files (list): A list of filenames to upload. + other_params: An object containing benchmark parameters. + + Returns: + float: The maximum latency observed among all concurrent uploads. + """ + results = [] + with ThreadPoolExecutor(max_workers=other_params.num_coros) as executor: + futures = [] + for f in files: + future = executor.submit( + upload_using_json, None, json_client, f, other_params + ) + futures.append(future) + + for future in futures: + results.append(future.result()) + + return max(results) + + +@pytest.mark.parametrize( + "workload_params", + all_params["write_seq_multi_coros"], + indirect=True, + ids=lambda p: p.name, +) +def test_uploads_single_proc_multi_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + """ + Benchmarks uploads using a single process and multiple coroutines. + + For zonal buckets, it uses `upload_files_using_grpc_multi_coro` to upload + multiple files concurrently with asyncio. For regional buckets, it uses + `upload_files_using_json_multi_threaded` with a ThreadPoolExecutor. + """ + params, files_names = workload_params + + if params.bucket_type == "zonal": + logging.info("bucket type zonal") + target_func = upload_files_using_grpc_multi_coro + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = loop.run_until_complete(create_client()) + else: + logging.info("bucket type regional") + target_func = upload_files_using_json_multi_threaded + loop = None + client = storage_client + + output_times = [] + + def target_wrapper(*args, **kwargs): + result = target_func(*args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + loop, + client, + files_names, + params, + ), + ) + finally: + if loop is not None: + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.close() + publish_benchmark_extra_info( + benchmark, params, benchmark_group="write", true_times=output_times + ) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) + + +def _upload_files_worker(files_to_upload, other_params, bucket_type): + """A worker function for multi-processing uploads. + + Initializes a client and calls the appropriate multi-coroutine upload function. + This function is intended to be called in a separate process. + + Args: + files_to_upload (list): List of filenames for this worker to upload. + other_params: An object containing benchmark parameters. + bucket_type (str): The type of bucket ('zonal' or 'regional'). + + Returns: + float: The maximum latency from the uploads performed by this worker. + """ + if bucket_type == "zonal": + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + client = loop.run_until_complete(create_client()) + try: + result = upload_files_using_grpc_multi_coro( + loop, client, files_to_upload, other_params + ) + finally: + # cleanup loop + tasks = asyncio.all_tasks(loop=loop) + for task in tasks: + task.cancel() + loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) + loop.close() + return result + else: # regional + json_client = storage.Client() + return upload_files_using_json_multi_threaded( + None, json_client, files_to_upload, other_params + ) + + +def upload_files_mp_mc_wrapper(files_names, params): + """Wrapper for multi-process, multi-coroutine uploads. + + Distributes files among a pool of processes and calls the worker function. + + Args: + files_names (list): The full list of filenames to upload. + params: An object containing benchmark parameters (num_processes, num_coros). + + Returns: + float: The maximum latency observed across all processes. + """ + num_processes = params.num_processes + num_coros = params.num_coros + + filenames_per_process = [ + files_names[i : i + num_coros] for i in range(0, len(files_names), num_coros) + ] + + args = [ + ( + filenames, + params, + params.bucket_type, + ) + for filenames in filenames_per_process + ] + + ctx = multiprocessing.get_context("spawn") + with ctx.Pool(processes=num_processes) as pool: + results = pool.starmap(_upload_files_worker, args) + + return max(results) + + +@pytest.mark.parametrize( + "workload_params", + all_params["write_seq_multi_process"], + indirect=True, + ids=lambda p: p.name, +) +def test_uploads_multi_proc_multi_coro( + benchmark, storage_client, blobs_to_delete, monitor, workload_params +): + """ + Benchmarks uploads using multiple processes and multiple coroutines. + + This test distributes files among a pool of processes using `upload_files_mp_mc_wrapper`. + The reported latency for each round is the maximum latency observed across all processes. + """ + params, files_names = workload_params + + output_times = [] + + def target_wrapper(*args, **kwargs): + result = upload_files_mp_mc_wrapper(*args, **kwargs) + output_times.append(result) + return output_times + + try: + with monitor() as m: + output_times = benchmark.pedantic( + target=target_wrapper, + iterations=1, + rounds=params.rounds, + args=( + files_names, + params, + ), + ) + finally: + publish_benchmark_extra_info( + benchmark, params, benchmark_group="write", true_times=output_times + ) + publish_resource_metrics(benchmark, m) + + blobs_to_delete.extend( + storage_client.bucket(params.bucket_name).blob(f) for f in files_names + ) diff --git a/tests/system/test_notification.py b/tests/system/test_notification.py index c21d836a3..28c07aeb0 100644 --- a/tests/system/test_notification.py +++ b/tests/system/test_notification.py @@ -60,13 +60,19 @@ def topic_path(storage_client, topic_name): @pytest.fixture(scope="session") def notification_topic(storage_client, publisher_client, topic_path, no_mtls): _helpers.retry_429(publisher_client.create_topic)(request={"name": topic_path}) - policy = publisher_client.get_iam_policy(request={"resource": topic_path}) - binding = policy.bindings.add() - binding.role = "roles/pubsub.publisher" - binding.members.append( - f"serviceAccount:{storage_client.get_service_account_email()}" - ) - publisher_client.set_iam_policy(request={"resource": topic_path, "policy": policy}) + try: + policy = publisher_client.get_iam_policy(request={"resource": topic_path}) + binding = policy.bindings.add() + binding.role = "roles/pubsub.publisher" + binding.members.append( + f"serviceAccount:{storage_client.get_service_account_email()}" + ) + publisher_client.set_iam_policy( + request={"resource": topic_path, "policy": policy} + ) + yield topic_path + finally: + publisher_client.delete_topic(request={"topic": topic_path}) @pytest.mark.skip(reason="until b/470069573 is fixed") diff --git a/tests/system/test_zonal.py b/tests/system/test_zonal.py index 05bb317d7..e62f46e1b 100644 --- a/tests/system/test_zonal.py +++ b/tests/system/test_zonal.py @@ -5,16 +5,22 @@ from io import BytesIO # python additional imports +import google_crc32c + import pytest +import gc # current library imports -from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient -from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( +from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient +from google.cloud.storage.asyncio.async_appendable_object_writer import ( AsyncAppendableObjectWriter, + _DEFAULT_FLUSH_INTERVAL_BYTES, ) -from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( +from google.cloud.storage.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, ) +from google.api_core.exceptions import FailedPrecondition, NotFound + pytestmark = pytest.mark.skipif( os.getenv("RUN_ZONAL_SYSTEM_TESTS") != "True", @@ -28,113 +34,570 @@ _BYTES_TO_UPLOAD = b"dummy_bytes_to_write_read_and_delete_appendable_object" -async def write_one_appendable_object( - bucket_name: str, - object_name: str, - data: bytes, -) -> None: - """Helper to write an appendable object.""" - grpc_client = AsyncGrpcClient(attempt_direct_path=True).grpc_client - writer = AsyncAppendableObjectWriter(grpc_client, bucket_name, object_name) - await writer.open() - await writer.append(data) - await writer.close() - - -@pytest.fixture(scope="function") -def appendable_object(storage_client, blobs_to_delete): - """Fixture to create and cleanup an appendable object.""" - object_name = f"appendable_obj_for_mrd-{str(uuid.uuid4())[:4]}" - asyncio.run( - write_one_appendable_object( - _ZONAL_BUCKET, - object_name, - _BYTES_TO_UPLOAD, - ) - ) - yield object_name +async def create_async_grpc_client(attempt_direct_path=True): + """Initializes async client and gets the current event loop.""" + return AsyncGrpcClient(attempt_direct_path=attempt_direct_path) + + +@pytest.fixture(scope="session") +def event_loop(): + """Redefine pytest-asyncio's event_loop fixture to be session-scoped.""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def grpc_clients(event_loop): + # grpc clients has to be instantiated in the event loop, + # otherwise grpc creates it's own event loop and attaches to the client. + # Which will lead to deadlock because client running in one event loop and + # MRD or Appendable-Writer in another. + # https://github.com/grpc/grpc/blob/61fe9b40a986792ab7d4eb8924027b671faf26ba/src/python/grpcio/grpc/aio/_channel.py#L369 + # https://github.com/grpc/grpc/blob/61fe9b40a986792ab7d4eb8924027b671faf26ba/src/python/grpcio/grpc/_cython/_cygrpc/aio/common.pyx.pxi#L249 + clients = { + True: event_loop.run_until_complete( + create_async_grpc_client(attempt_direct_path=True) + ), + False: event_loop.run_until_complete( + create_async_grpc_client(attempt_direct_path=False) + ), + } + return clients - # Clean up; use json client (i.e. `storage_client` fixture) to delete. - blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) +# This fixture is for tests that are NOT parametrized by attempt_direct_path +@pytest.fixture +def grpc_client(grpc_clients): + return grpc_clients[False] -@pytest.mark.asyncio + +@pytest.fixture +def grpc_client_direct(grpc_clients): + return grpc_clients[True] + + +def _get_equal_dist(a: int, b: int) -> tuple[int, int]: + step = (b - a) // 3 + return a + step, a + 2 * step + + +@pytest.mark.parametrize( + "object_size", + [ + 256, # less than _chunk size + 10 * 1024 * 1024, # less than _MAX_BUFFER_SIZE_BYTES + 20 * 1024 * 1024, # greater than _MAX_BUFFER_SIZE + ], +) @pytest.mark.parametrize( "attempt_direct_path", [True, False], ) -async def test_basic_wrd(storage_client, blobs_to_delete, attempt_direct_path): +def test_basic_wrd( + storage_client, + blobs_to_delete, + attempt_direct_path, + object_size, + event_loop, + grpc_clients, +): + object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + + async def _run(): + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + grpc_client = grpc_clients[attempt_direct_path] + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(object_data) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + buffer = BytesIO() + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + assert mrd.persisted_size == object_size + + assert buffer.getvalue() == object_data + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + gc.collect() + + event_loop.run_until_complete(_run()) + + +@pytest.mark.parametrize( + "object_size", + [ + 10, # less than _chunk size, + 10 * 1024 * 1024, # less than _MAX_BUFFER_SIZE_BYTES + 20 * 1024 * 1024, # greater than _MAX_BUFFER_SIZE_BYTES + ], +) +def test_basic_wrd_in_slices( + storage_client, blobs_to_delete, object_size, event_loop, grpc_client +): + object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + + async def _run(): + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + mark1, mark2 = _get_equal_dist(0, object_size) + await writer.append(object_data[0:mark1]) + await writer.append(object_data[mark1:mark2]) + await writer.append(object_data[mark2:]) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == object_data + assert mrd.persisted_size == object_size + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + event_loop.run_until_complete(_run()) + + +@pytest.mark.parametrize( + "flush_interval", + [ + 2 * 1024 * 1024, + 4 * 1024 * 1024, + 8 * 1024 * 1024, + _DEFAULT_FLUSH_INTERVAL_BYTES, + ], +) +def test_wrd_with_non_default_flush_interval( + storage_client, + blobs_to_delete, + flush_interval, + event_loop, + grpc_client, +): object_name = f"test_basic_wrd-{str(uuid.uuid4())}" + object_size = 9 * 1024 * 1024 + + async def _run(): + object_data = os.urandom(object_size) + object_checksum = google_crc32c.value(object_data) + + writer = AsyncAppendableObjectWriter( + grpc_client, + _ZONAL_BUCKET, + object_name, + writer_options={"FLUSH_INTERVAL_BYTES": flush_interval}, + ) + await writer.open() + mark1, mark2 = _get_equal_dist(0, object_size) + await writer.append(object_data[0:mark1]) + await writer.append(object_data[mark1:mark2]) + await writer.append(object_data[mark2:]) + object_metadata = await writer.close(finalize_on_close=True) + assert object_metadata.size == object_size + assert int(object_metadata.checksums.crc32c) == object_checksum + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == object_data + assert mrd.persisted_size == object_size + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + event_loop.run_until_complete(_run()) + + +def test_read_unfinalized_appendable_object( + storage_client, blobs_to_delete, event_loop, grpc_client_direct +): + object_name = f"read_unfinalized_appendable_object-{str(uuid.uuid4())[:4]}" + + async def _run(): + grpc_client = grpc_client_direct + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.flush() + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + assert mrd.persisted_size == len(_BYTES_TO_UPLOAD) + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == _BYTES_TO_UPLOAD + + # Clean up; use json client (i.e. `storage_client` fixture) to delete. + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del mrd + gc.collect() + + event_loop.run_until_complete(_run()) + + +@pytest.mark.skip(reason="Flaky test b/478129078") +def test_mrd_open_with_read_handle(event_loop, grpc_client_direct): + object_name = f"test_read_handl-{str(uuid.uuid4())[:4]}" + + async def _run(): + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.close() + + mrd = AsyncMultiRangeDownloader(grpc_client_direct, _ZONAL_BUCKET, object_name) + await mrd.open() + read_handle = mrd.read_handle + await mrd.close() + + # Open a new MRD using the `read_handle` obtained above + new_mrd = AsyncMultiRangeDownloader( + grpc_client_direct, _ZONAL_BUCKET, object_name, read_handle=read_handle + ) + await new_mrd.open() + # persisted_size not set when opened with read_handle + assert new_mrd.persisted_size is None + buffer = BytesIO() + await new_mrd.download_ranges([(0, 0, buffer)]) + await new_mrd.close() + assert buffer.getvalue() == _BYTES_TO_UPLOAD + del mrd + del new_mrd + gc.collect() + + event_loop.run_until_complete(_run()) + + +def test_mrd_open_with_read_handle_over_cloud_path(event_loop, grpc_client): + object_name = f"test_read_handl-{str(uuid.uuid4())[:4]}" + + async def _run(): + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.close() + + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + await mrd.open() + read_handle = mrd.read_handle + await mrd.close() + + # Open a new MRD using the `read_handle` obtained above + new_mrd = AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name, read_handle=read_handle + ) + await new_mrd.open() + # persisted_size is set regardless of whether we use read_handle or not + # because read_handle won't work in CLOUD_PATH. + assert new_mrd.persisted_size == len(_BYTES_TO_UPLOAD) + buffer = BytesIO() + await new_mrd.download_ranges([(0, 0, buffer)]) + await new_mrd.close() + assert buffer.getvalue() == _BYTES_TO_UPLOAD + del mrd + del new_mrd + gc.collect() - # Client instantiation; it cannot be part of fixture because. - # grpc_client's event loop and event loop of coroutine running it - # (i.e. this test) must be same. - # Note: - # 1. @pytest.mark.asyncio ensures new event loop for each test. - # 2. we can keep the same event loop for entire module but that may - # create issues if tests are run in parallel and one test hogs the event - # loop slowing down other tests. - grpc_client = AsyncGrpcClient(attempt_direct_path=attempt_direct_path).grpc_client - - writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) - await writer.open() - await writer.append(_BYTES_TO_UPLOAD) - object_metadata = await writer.close(finalize_on_close=True) - assert object_metadata.size == len(_BYTES_TO_UPLOAD) - - mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) - buffer = BytesIO() - await mrd.open() - # (0, 0) means read the whole object - await mrd.download_ranges([(0, 0, buffer)]) - await mrd.close() - assert buffer.getvalue() == _BYTES_TO_UPLOAD - assert mrd.persisted_size == len(_BYTES_TO_UPLOAD) - - # Clean up; use json client (i.e. `storage_client` fixture) to delete. - blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) - - -@pytest.mark.asyncio -async def test_read_unfinalized_appendable_object(storage_client, blobs_to_delete): + event_loop.run_until_complete(_run()) + + +def test_wrd_open_with_write_handle( + event_loop, grpc_client_direct, storage_client, blobs_to_delete +): + object_name = f"test_write_handl-{str(uuid.uuid4())[:4]}" + + async def _run(): + # 1. Create an object and get its write_handle + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name + ) + await writer.open() + write_handle = writer.write_handle + await writer.close() + + # 2. Open a new writer using the obtained `write_handle` and generation + new_writer = AsyncAppendableObjectWriter( + grpc_client_direct, + _ZONAL_BUCKET, + object_name, + write_handle=write_handle, + generation=writer.generation, + ) + await new_writer.open() + # Verify that the new writer is open and has the same write_handle + assert new_writer.is_stream_open + assert new_writer.generation == writer.generation + + # 3. Append some data using the new writer + test_data = b"data_from_new_writer" + await new_writer.append(test_data) + await new_writer.close() + + # 4. Verify the data was written correctly by reading it back + mrd = AsyncMultiRangeDownloader(grpc_client_direct, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + assert buffer.getvalue() == test_data + + # Clean up + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del new_writer + del mrd + gc.collect() + + event_loop.run_until_complete(_run()) + + +def test_read_unfinalized_appendable_object_with_generation( + storage_client, blobs_to_delete, event_loop, grpc_client_direct +): object_name = f"read_unfinalized_appendable_object-{str(uuid.uuid4())[:4]}" - grpc_client = AsyncGrpcClient(attempt_direct_path=True).grpc_client - - writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) - await writer.open() - await writer.append(_BYTES_TO_UPLOAD) - await writer.flush() - - mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) - buffer = BytesIO() - await mrd.open() - assert mrd.persisted_size == len(_BYTES_TO_UPLOAD) - # (0, 0) means read the whole object - await mrd.download_ranges([(0, 0, buffer)]) - await mrd.close() - assert buffer.getvalue() == _BYTES_TO_UPLOAD - - # Clean up; use json client (i.e. `storage_client` fixture) to delete. - blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) - - -@pytest.mark.asyncio -async def test_mrd_open_with_read_handle(appendable_object): - grpc_client = AsyncGrpcClient(attempt_direct_path=True).grpc_client - - mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, appendable_object) - await mrd.open() - read_handle = mrd.read_handle - await mrd.close() - - # Open a new MRD using the `read_handle` obtained above - new_mrd = AsyncMultiRangeDownloader( - grpc_client, _ZONAL_BUCKET, appendable_object, read_handle=read_handle - ) - await new_mrd.open() - # persisted_size not set when opened with read_handle - assert new_mrd.persisted_size is None - buffer = BytesIO() - await new_mrd.download_ranges([(0, 0, buffer)]) - await new_mrd.close() - assert buffer.getvalue() == _BYTES_TO_UPLOAD + grpc_client = grpc_client_direct + + async def _run(): + async def _read_and_verify(expected_content, generation=None): + """Helper to read object content and verify against expected.""" + mrd = AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name, generation + ) + buffer = BytesIO() + await mrd.open() + try: + assert mrd.persisted_size == len(expected_content) + await mrd.download_ranges([(0, 0, buffer)]) + assert buffer.getvalue() == expected_content + finally: + await mrd.close() + return mrd + + # First write + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + await writer.open() + await writer.append(_BYTES_TO_UPLOAD) + await writer.flush() + generation = writer.generation + + # First read + mrd = await _read_and_verify(_BYTES_TO_UPLOAD) + + # Second write, using generation from the first write. + writer_2 = AsyncAppendableObjectWriter( + grpc_client, _ZONAL_BUCKET, object_name, generation=generation + ) + await writer_2.open() + await writer_2.append(_BYTES_TO_UPLOAD) + await writer_2.flush() + + # Second read + mrd_2 = await _read_and_verify(_BYTES_TO_UPLOAD + _BYTES_TO_UPLOAD, generation) + + # Clean up + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + del writer + del writer_2 + del mrd + del mrd_2 + gc.collect() + + event_loop.run_until_complete(_run()) + + +def test_append_flushes_and_state_lookup( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + System test for AsyncAppendableObjectWriter, verifying flushing behavior + for both small and large appends. + """ + object_name = f"test-append-flush-varied-size-{uuid.uuid4()}" + + async def _run(): + writer = AsyncAppendableObjectWriter(grpc_client, _ZONAL_BUCKET, object_name) + + # Schedule for cleanup + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + # --- Part 1: Test with small data --- + small_data = b"small data" + + await writer.open() + assert writer._is_stream_open + + await writer.append(small_data) + persisted_size = await writer.state_lookup() + assert persisted_size == len(small_data) + + # --- Part 2: Test with large data --- + large_data = os.urandom(38 * 1024 * 1024) + + # Append data larger than the default flush interval (16 MiB). + # This should trigger the interval-based flushing logic. + await writer.append(large_data) + + # Verify the total data has been persisted. + total_size = len(small_data) + len(large_data) + persisted_size = await writer.state_lookup() + assert persisted_size == total_size + + # --- Part 3: Finalize and verify --- + final_object = await writer.close(finalize_on_close=True) + + assert not writer._is_stream_open + assert final_object.size == total_size + + # Verify the full content of the object. + full_data = small_data + large_data + mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) + buffer = BytesIO() + await mrd.open() + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + await mrd.close() + content = buffer.getvalue() + assert content == full_data + + event_loop.run_until_complete(_run()) + + +def test_open_with_generation_zero( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """Tests that using `generation=0` fails if the object already exists. + + This test verifies that: + 1. An object can be created using `AsyncAppendableObjectWriter` with `generation=0`. + 2. Attempting to create the same object again with `generation=0` raises a + `FailedPrecondition` error with a 400 status code, because the + precondition (object must not exist) is not met. + """ + object_name = f"test_append_with_generation-{uuid.uuid4()}" + + async def _run(): + writer = AsyncAppendableObjectWriter( + grpc_client, _ZONAL_BUCKET, object_name, generation=0 + ) + + # Empty object is created. + await writer.open() + assert writer.is_stream_open + + await writer.close() + assert not writer.is_stream_open + + with pytest.raises(FailedPrecondition) as exc_info: + writer_fail = AsyncAppendableObjectWriter( + grpc_client, _ZONAL_BUCKET, object_name, generation=0 + ) + await writer_fail.open() + assert exc_info.value.code == 400 + + # cleanup + del writer + gc.collect() + + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_open_existing_object_with_gen_None_overrides_existing( + storage_client, blobs_to_delete, event_loop, grpc_client +): + """ + Test that a new writer when specifies `None` overrides the existing object. + """ + object_name = f"test_append_with_generation-{uuid.uuid4()}" + + async def _run(): + writer = AsyncAppendableObjectWriter( + grpc_client, _ZONAL_BUCKET, object_name, generation=0 + ) + + # Empty object is created. + await writer.open() + assert writer.is_stream_open + old_gen = writer.generation + + await writer.close() + assert not writer.is_stream_open + + new_writer = AsyncAppendableObjectWriter( + grpc_client, _ZONAL_BUCKET, object_name, generation=None + ) + await new_writer.open() + assert new_writer.generation != old_gen + + # cleanup + del writer + del new_writer + gc.collect() + + blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) + + event_loop.run_until_complete(_run()) + + +def test_delete_object_using_grpc_client(event_loop, grpc_client_direct): + """ + Test that a new writer when specifies `None` overrides the existing object. + """ + object_name = f"test_append_with_generation-{uuid.uuid4()}" + + async def _run(): + writer = AsyncAppendableObjectWriter( + grpc_client_direct, _ZONAL_BUCKET, object_name, generation=0 + ) + + # Empty object is created. + await writer.open() + await writer.append(b"some_bytes") + await writer.close() + + await grpc_client_direct.delete_object(_ZONAL_BUCKET, object_name) + + # trying to get raises raises 404. + with pytest.raises(NotFound): + # TODO: Remove this once GET_OBJECT is exposed in `AsyncGrpcClient` + await grpc_client_direct._grpc_client.get_object( + bucket=f"projects/_/buckets/{_ZONAL_BUCKET}", object_=object_name + ) + # cleanup + del writer + gc.collect() + + event_loop.run_until_complete(_run()) diff --git a/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py new file mode 100644 index 000000000..e0eba9030 --- /dev/null +++ b/tests/unit/asyncio/retry/test_bidi_stream_retry_manager.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +import pytest +from google.api_core import exceptions +from google.api_core.retry_async import AsyncRetry + +from google.cloud.storage.asyncio.retry import ( + bidi_stream_retry_manager as manager, +) +from google.cloud.storage.asyncio.retry import base_strategy + + +def _is_retriable(exc): + return isinstance(exc, exceptions.ServiceUnavailable) + + +DEFAULT_TEST_RETRY = AsyncRetry(predicate=_is_retriable, deadline=1) + + +class TestBidiStreamRetryManager: + @pytest.mark.asyncio + async def test_execute_success_on_first_try(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + yield "response_1" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_1", {} + ) + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_success_on_empty_stream(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + if False: + yield + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + await retry_manager.execute(initial_state={}, retry_policy=DEFAULT_TEST_RETRY) + + mock_strategy.generate_requests.assert_called_once() + mock_strategy.update_state_from_response.assert_not_called() + mock_strategy.recover_state_on_failure.assert_not_called() + + @pytest.mark.asyncio + async def test_execute_retries_on_initial_failure_and_succeeds(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + + async def mock_send_and_recv(*args, **kwargs): + nonlocal attempt_count + attempt_count += 1 + if attempt_count == 1: + raise exceptions.ServiceUnavailable("Service is down") + else: + yield "response_2" + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock): + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + mock_strategy.update_state_from_response.assert_called_once_with( + "response_2", {} + ) + + @pytest.mark.asyncio + async def test_execute_retries_and_succeeds_mid_stream(self): + """Test retry logic for a stream that fails after yielding some data.""" + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + attempt_count = 0 + # Use a list to simulate stream content for each attempt + stream_content = [ + ["response_1", exceptions.ServiceUnavailable("Service is down")], + ["response_2"], + ] + + async def mock_send_and_recv(*args, **kwargs): + nonlocal attempt_count + content = stream_content[attempt_count] + attempt_count += 1 + for item in content: + if isinstance(item, Exception): + raise item + else: + yield item + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + retry_policy = AsyncRetry(predicate=_is_retriable, initial=0.01) + + with mock.patch("asyncio.sleep", new_callable=mock.AsyncMock) as mock_sleep: + await retry_manager.execute(initial_state={}, retry_policy=retry_policy) + + assert attempt_count == 2 + mock_sleep.assert_called_once() + + assert mock_strategy.generate_requests.call_count == 2 + mock_strategy.recover_state_on_failure.assert_called_once() + assert mock_strategy.update_state_from_response.call_count == 2 + mock_strategy.update_state_from_response.assert_has_calls( + [ + mock.call("response_1", {}), + mock.call("response_2", {}), + ] + ) + + @pytest.mark.asyncio + async def test_execute_fails_immediately_on_non_retriable_error(self): + mock_strategy = mock.AsyncMock(spec=base_strategy._BaseResumptionStrategy) + + async def mock_send_and_recv(*args, **kwargs): + if False: + yield + raise exceptions.PermissionDenied("Auth error") + + retry_manager = manager._BidiStreamRetryManager( + strategy=mock_strategy, send_and_recv=mock_send_and_recv + ) + with pytest.raises(exceptions.PermissionDenied): + await retry_manager.execute( + initial_state={}, retry_policy=DEFAULT_TEST_RETRY + ) + + mock_strategy.recover_state_on_failure.assert_not_called() diff --git a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py index e6b343f86..1055127eb 100644 --- a/tests/unit/asyncio/retry/test_reads_resumption_strategy.py +++ b/tests/unit/asyncio/retry/test_reads_resumption_strategy.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import io import unittest -import pytest +from google_crc32c import Checksum from google.cloud.storage.exceptions import DataCorruption from google.api_core import exceptions from google.cloud import _storage_v2 as storage_v2 -from google.cloud.storage._experimental.asyncio.retry.reads_resumption_strategy import ( +from google.cloud.storage.asyncio.retry.reads_resumption_strategy import ( _DownloadState, _ReadResumptionStrategy, ) from google.cloud._storage_v2.types.storage import BidiReadObjectRedirectedError _READ_ID = 1 +LOGGER_NAME = "google.cloud.storage.asyncio.retry.reads_resumption_strategy" class TestDownloadState(unittest.TestCase): @@ -45,14 +47,67 @@ def test_initialization(self): class TestReadResumptionStrategy(unittest.TestCase): + def setUp(self): + self.strategy = _ReadResumptionStrategy() + + self.state = {"download_states": {}, "read_handle": None, "routing_token": None} + + def _add_download(self, read_id, offset=0, length=100, buffer=None): + """Helper to inject a download state into the correct nested location.""" + if buffer is None: + buffer = io.BytesIO() + state = _DownloadState( + initial_offset=offset, initial_length=length, user_buffer=buffer + ) + self.state["download_states"][read_id] = state + return state + + def _create_response( + self, + content, + read_id, + offset, + crc=None, + range_end=False, + handle=None, + has_read_range=True, + ): + """Helper to create a response object.""" + checksummed_data = None + if content is not None: + if crc is None: + c = Checksum(content) + crc = int.from_bytes(c.digest(), "big") + checksummed_data = storage_v2.ChecksummedData(content=content, crc32c=crc) + + read_range = None + if has_read_range: + read_range = storage_v2.ReadRange(read_id=read_id, read_offset=offset) + + read_handle_message = None + if handle: + read_handle_message = storage_v2.BidiReadHandle(handle=handle) + self.state["read_handle"] = handle + + return storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + storage_v2.ObjectRangeData( + checksummed_data=checksummed_data, + read_range=read_range, + range_end=range_end, + ) + ], + read_handle=read_handle_message, + ) + + # --- Request Generation Tests --- + def test_generate_requests_single_incomplete(self): """Test generating a request for a single incomplete download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0, length=100) read_state.bytes_written = 20 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 1) self.assertEqual(requests[0].read_offset, 20) @@ -62,173 +117,244 @@ def test_generate_requests_single_incomplete(self): def test_generate_requests_multiple_incomplete(self): """Test generating requests for multiple incomplete downloads.""" read_id2 = 2 - read_state1 = _DownloadState(0, 100, io.BytesIO()) - read_state1.bytes_written = 50 - read_state2 = _DownloadState(200, 100, io.BytesIO()) - state = {_READ_ID: read_state1, read_id2: read_state2} + rs1 = self._add_download(_READ_ID, offset=0, length=100) + rs1.bytes_written = 50 + + self._add_download(read_id2, offset=200, length=100) - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 2) - req1 = next(request for request in requests if request.read_id == _READ_ID) - req2 = next(request for request in requests if request.read_id == read_id2) + requests.sort(key=lambda r: r.read_id) + req1 = requests[0] + req2 = requests[1] + + self.assertEqual(req1.read_id, _READ_ID) self.assertEqual(req1.read_offset, 50) self.assertEqual(req1.read_length, 50) + + self.assertEqual(req2.read_id, read_id2) self.assertEqual(req2.read_offset, 200) self.assertEqual(req2.read_length, 100) + def test_generate_requests_read_to_end_resumption(self): + """Test resumption for 'read to end' (length=0) requests.""" + read_state = self._add_download(_READ_ID, offset=0, length=0) + read_state.bytes_written = 500 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 1) + self.assertEqual(requests[0].read_offset, 500) + self.assertEqual(requests[0].read_length, 0) + def test_generate_requests_with_complete(self): """Test that no request is generated for a completed download.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID) read_state.is_complete = True - state = {_READ_ID: read_state} - - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests(state) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + def test_generate_requests_multiple_mixed_states(self): + """Test generating requests with mixed complete, partial, and fresh states.""" + s1 = self._add_download(1, length=100) + s1.is_complete = True + + s2 = self._add_download(2, offset=0, length=100) + s2.bytes_written = 50 + + s3 = self._add_download(3, offset=200, length=100) + s3.bytes_written = 0 + + requests = self.strategy.generate_requests(self.state) + + self.assertEqual(len(requests), 2) + requests.sort(key=lambda r: r.read_id) + + self.assertEqual(requests[0].read_id, 2) + self.assertEqual(requests[1].read_id, 3) + def test_generate_requests_empty_state(self): """Test generating requests with an empty state.""" - read_strategy = _ReadResumptionStrategy() - requests = read_strategy.generate_requests({}) + requests = self.strategy.generate_requests(self.state) self.assertEqual(len(requests), 0) + # --- Update State and response processing Tests --- + def test_update_state_processes_single_chunk_successfully(self): """Test updating state from a successful response.""" - buffer = io.BytesIO() - read_state = _DownloadState(0, 100, buffer) - state = {_READ_ID: read_state} + read_state = self._add_download(_READ_ID, offset=0, length=100) data = b"test_data" - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertEqual(read_state.bytes_written, len(data)) self.assertEqual(read_state.next_expected_offset, len(data)) self.assertFalse(read_state.is_complete) - self.assertEqual(buffer.getvalue(), data) + self.assertEqual(read_state.user_buffer.getvalue(), data) + + def test_update_state_accumulates_chunks(self): + """Verify that state updates correctly over multiple chunks.""" + read_state = self._add_download(_READ_ID, offset=0, length=8) + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + self.assertEqual(read_state.bytes_written, 4) + self.assertEqual(read_state.user_buffer.getvalue(), b"test") + + resp2 = self._create_response(b"data", _READ_ID, offset=4, range_end=True) + self.strategy.update_state_from_response(resp2, self.state) + + self.assertEqual(read_state.bytes_written, 8) + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.user_buffer.getvalue(), b"testdata") + + def test_update_state_captures_read_handle(self): + """Verify read_handle is extracted from the response.""" + self._add_download(_READ_ID) + + new_handle = b"optimized_handle" + response = self._create_response(b"data", _READ_ID, 0, handle=new_handle) - def test_update_state_from_response_offset_mismatch(self): + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["read_handle"].handle, new_handle) + + def test_update_state_unknown_id(self): + """Verify we ignore data for IDs not in our tracking state.""" + self._add_download(_READ_ID) + response = self._create_response(b"ghost", read_id=999, offset=0) + + self.strategy.update_state_from_response(response, self.state) + self.assertEqual(self.state["download_states"][_READ_ID].bytes_written, 0) + + def test_update_state_missing_read_range(self): + """Verify we ignore ranges without read_range metadata.""" + response = self._create_response(b"data", _READ_ID, 0, has_read_range=False) + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_offset_mismatch(self): """Test that an offset mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) + read_state = self._add_download(_READ_ID, offset=0) read_state.next_expected_offset = 10 - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - ) - ] - ) + response = self._create_response(b"data", _READ_ID, offset=0) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Offset mismatch" in str(exc_info.value) + with self.assertRaisesRegex(DataCorruption, "Offset mismatch"): + self.strategy.update_state_from_response(response, self.state) - def test_update_state_from_response_final_byte_count_mismatch(self): - """Test that a final byte count mismatch raises DataCorruption.""" - read_state = _DownloadState(0, 100, io.BytesIO()) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + def test_update_state_checksum_mismatch(self): + """Test that a CRC32C mismatch raises DataCorruption.""" + self._add_download(_READ_ID) + response = self._create_response(b"data", _READ_ID, offset=0, crc=999999) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=4 - ), - checksummed_data=storage_v2.ChecksummedData(content=b"data"), - range_end=True, - ) - ] - ) + with self.assertRaisesRegex(DataCorruption, "Checksum mismatch"): + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_final_byte_count_mismatch(self): + """Test mismatch between expected length and actual bytes written on completion.""" + self._add_download(_READ_ID, length=100) + + data = b"data" * 30 + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - with pytest.raises(DataCorruption) as exc_info: - read_strategy.update_state_from_response(response, state) - assert "Byte count mismatch" in str(exc_info.value) + with self.assertRaisesRegex(DataCorruption, "Byte count mismatch"): + self.strategy.update_state_from_response(response, self.state) - def test_update_state_from_response_completes_download(self): + def test_update_state_completes_download(self): """Test that the download is marked complete on range_end.""" - buffer = io.BytesIO() data = b"test_data" - read_state = _DownloadState(0, len(data), buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() + read_state = self._add_download(_READ_ID, length=len(data)) - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - self.assertEqual(buffer.getvalue(), data) - def test_update_state_from_response_completes_download_zero_length(self): + def test_update_state_completes_download_zero_length(self): """Test completion for a download with initial_length of 0.""" - buffer = io.BytesIO() + read_state = self._add_download(_READ_ID, length=0) data = b"test_data" - read_state = _DownloadState(0, 0, buffer) - state = {_READ_ID: read_state} - read_strategy = _ReadResumptionStrategy() - response = storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - storage_v2.types.ObjectRangeData( - read_range=storage_v2.ReadRange( - read_id=_READ_ID, read_offset=0, read_length=len(data) - ), - checksummed_data=storage_v2.ChecksummedData(content=data), - range_end=True, - ) - ] - ) + response = self._create_response(data, _READ_ID, offset=0, range_end=True) - read_strategy.update_state_from_response(response, state) + self.strategy.update_state_from_response(response, self.state) self.assertTrue(read_state.is_complete) self.assertEqual(read_state.bytes_written, len(data)) - async def test_recover_state_on_failure_handles_redirect(self): + def test_update_state_zero_byte_file(self): + """Test downloading a completely empty file.""" + read_state = self._add_download(_READ_ID, length=0) + + response = self._create_response(b"", _READ_ID, offset=0, range_end=True) + + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.bytes_written, 0) + self.assertEqual(read_state.user_buffer.getvalue(), b"") + + def test_update_state_missing_read_range_logs_warning(self): + """Verify we log a warning and continue when read_range is missing.""" + response = self._create_response(b"data", _READ_ID, 0, has_read_range=False) + + # assertLogs captures logs for the given logger name and minimum level + with self.assertLogs(LOGGER_NAME, level="WARNING") as cm: + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue( + any("missing read_range field" in output for output in cm.output) + ) + + def test_update_state_unknown_id_logs_warning(self): + """Verify we log a warning and continue when read_id is unknown.""" + unknown_id = 999 + self._add_download(_READ_ID) + response = self._create_response(b"ghost", read_id=unknown_id, offset=0) + + with self.assertLogs(LOGGER_NAME, level="WARNING") as cm: + self.strategy.update_state_from_response(response, self.state) + + self.assertTrue( + any( + f"unknown or stale read_id {unknown_id}" in output + for output in cm.output + ) + ) + + # --- Recovery Tests --- + + def test_recover_state_on_failure_handles_redirect(self): """Verify recover_state_on_failure correctly extracts routing_token.""" - strategy = _ReadResumptionStrategy() + token = "dummy-routing-token" + redirect_error = BidiReadObjectRedirectedError(routing_token=token) + final_error = exceptions.Aborted("Retry failed", errors=[redirect_error]) + + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) + + asyncio.new_event_loop().run_until_complete(run()) + + self.assertEqual(self.state["routing_token"], token) - state = {} - self.assertIsNone(state.get("routing_token")) + def test_recover_state_ignores_standard_errors(self): + """Verify that non-redirect errors do not corrupt the routing token.""" + self.state["routing_token"] = "existing-token" - dummy_token = "dummy-routing-token" - redirect_error = BidiReadObjectRedirectedError(routing_token=dummy_token) + std_error = exceptions.ServiceUnavailable("Maintenance") + final_error = exceptions.RetryError("Retry failed", cause=std_error) - final_error = exceptions.RetryError("Retry failed", cause=redirect_error) + async def run(): + await self.strategy.recover_state_on_failure(final_error, self.state) - await strategy.recover_state_on_failure(final_error, state) + asyncio.new_event_loop().run_until_complete(run()) - self.assertEqual(state.get("routing_token"), dummy_token) + # Token should remain unchanged + self.assertEqual(self.state["routing_token"], "existing-token") diff --git a/tests/unit/asyncio/retry/test_writes_resumption_strategy.py b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py new file mode 100644 index 000000000..ca354e84a --- /dev/null +++ b/tests/unit/asyncio/retry/test_writes_resumption_strategy.py @@ -0,0 +1,373 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import unittest.mock as mock +from datetime import datetime + +import pytest +import google_crc32c +from google.rpc import status_pb2 +from google.api_core import exceptions + +from google.cloud._storage_v2.types import storage as storage_type +from google.cloud.storage.asyncio.retry.writes_resumption_strategy import ( + _WriteState, + _WriteResumptionStrategy, +) +from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError + + +@pytest.fixture +def strategy(): + """Fixture to provide a WriteResumptionStrategy instance.""" + return _WriteResumptionStrategy() + + +class TestWriteResumptionStrategy: + """Test suite for WriteResumptionStrategy.""" + + # ------------------------------------------------------------------------- + # Tests for generate_requests + # ------------------------------------------------------------------------- + + def test_generate_requests_initial_chunking(self, strategy): + """Verify initial data generation starts at offset 0 and chunks correctly.""" + mock_buffer = io.BytesIO(b"abcdefghij") + write_state = _WriteState( + chunk_size=3, user_buffer=mock_buffer, flush_interval=10 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Expected: 4 requests (3, 3, 3, 1) + assert len(requests) == 4 + + # Verify Request 1 + assert requests[0].write_offset == 0 + assert requests[0].checksummed_data.content == b"abc" + + # Verify Request 2 + assert requests[1].write_offset == 3 + assert requests[1].checksummed_data.content == b"def" + + # Verify Request 3 + assert requests[2].write_offset == 6 + assert requests[2].checksummed_data.content == b"ghi" + + # Verify Request 4 + assert requests[3].write_offset == 9 + assert requests[3].checksummed_data.content == b"j" + + def test_generate_requests_resumption(self, strategy): + """ + Verify request generation when resuming. + The strategy should generate chunks starting from the current 'bytes_sent'. + """ + mock_buffer = io.BytesIO(b"0123456789") + write_state = _WriteState( + chunk_size=4, user_buffer=mock_buffer, flush_interval=10 + ) + + # Simulate resumption state: 4 bytes already sent/persisted + write_state.persisted_size = 4 + write_state.bytes_sent = 4 + # Buffer must be seeked to 4 before calling generate + mock_buffer.seek(4) + + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Since 4 bytes are done, we expect remaining 6 bytes: [4 bytes, 2 bytes] + assert len(requests) == 2 + + # Check first generated request starts at offset 4 + assert requests[0].write_offset == 4 + assert requests[0].checksummed_data.content == b"4567" + + # Check second generated request starts at offset 8 + assert requests[1].write_offset == 8 + assert requests[1].checksummed_data.content == b"89" + + def test_generate_requests_empty_file(self, strategy): + """Verify request sequence for an empty file.""" + mock_buffer = io.BytesIO(b"") + write_state = _WriteState( + chunk_size=4, user_buffer=mock_buffer, flush_interval=10 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + assert len(requests) == 0 + + def test_generate_requests_checksum_verification(self, strategy): + """Verify CRC32C is calculated correctly for each chunk.""" + chunk_data = b"test_data" + mock_buffer = io.BytesIO(chunk_data) + write_state = _WriteState( + chunk_size=10, user_buffer=mock_buffer, flush_interval=10 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + expected_crc = google_crc32c.Checksum(chunk_data).digest() + expected_int = int.from_bytes(expected_crc, "big") + assert requests[0].checksummed_data.crc32c == expected_int + + def test_generate_requests_flush_logic_exact_interval(self, strategy): + """Verify the flush bit is set exactly when the interval is reached.""" + mock_buffer = io.BytesIO(b"A" * 12) + # 2 byte chunks, flush every 4 bytes + write_state = _WriteState( + chunk_size=2, user_buffer=mock_buffer, flush_interval=4 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Request index 1 (4 bytes total) should have flush=True + assert requests[0].flush is False + assert requests[1].flush is True + + # Request index 2 (8 bytes total) should have flush=True + assert requests[2].flush is False + assert requests[3].flush is True + + # Request index 3 (12 bytes total) should have flush=True + assert requests[4].flush is False + assert requests[5].flush is True + + # Verify counter reset in state + assert write_state.bytes_since_last_flush == 0 + + def test_generate_requests_flush_logic_data_less_than_interval(self, strategy): + """Verify flush is not set if data sent is less than interval.""" + mock_buffer = io.BytesIO(b"A" * 5) + # Flush every 10 bytes + write_state = _WriteState( + chunk_size=2, user_buffer=mock_buffer, flush_interval=10 + ) + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + + # Total 5 bytes < 10 bytes interval + for req in requests: + assert req.flush is False + + assert write_state.bytes_since_last_flush == 5 + + def test_generate_requests_honors_finalized_state(self, strategy): + """If state is already finalized, no requests should be generated.""" + mock_buffer = io.BytesIO(b"data") + write_state = _WriteState( + chunk_size=4, user_buffer=mock_buffer, flush_interval=10 + ) + write_state.is_finalized = True + state = {"write_state": write_state} + + requests = strategy.generate_requests(state) + assert len(requests) == 0 + + @pytest.mark.asyncio + async def test_generate_requests_after_failure_and_recovery(self, strategy): + """ + Verify recovery and resumption flow (Integration of recover + generate). + """ + mock_buffer = io.BytesIO(b"0123456789abcdef") # 16 bytes + write_state = _WriteState( + chunk_size=4, user_buffer=mock_buffer, flush_interval=10 + ) + state = {"write_state": write_state} + + # Simulate initial progress: sent 8 bytes + write_state.bytes_sent = 8 + mock_buffer.seek(8) + + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse( + persisted_size=4, + write_handle=storage_type.BidiWriteHandle(handle=b"handle-1"), + ), + state, + ) + + # Simulate Failure Triggering Recovery + await strategy.recover_state_on_failure(Exception("network error"), state) + + # Assertions after recovery + # 1. Buffer should rewind to persisted_size (4) + assert mock_buffer.tell() == 4 + # 2. bytes_sent should track persisted_size (4) + assert write_state.bytes_sent == 4 + + requests = strategy.generate_requests(state) + + # Remaining data from offset 4 to 16 (12 bytes total) + # Chunks: [4-8], [8-12], [12-16] + assert len(requests) == 3 + + # Verify resumption offset + assert requests[0].write_offset == 4 + assert requests[0].checksummed_data.content == b"4567" + + # ------------------------------------------------------------------------- + # Tests for update_state_from_response + # ------------------------------------------------------------------------- + + def test_update_state_from_response_all_fields(self, strategy): + """Verify all fields from a BidiWriteObjectResponse update the state.""" + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO(), flush_interval=10 + ) + state = {"write_state": write_state} + + # 1. Update persisted_size + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(persisted_size=123), state + ) + assert write_state.persisted_size == 123 + + # 2. Update write_handle + handle = storage_type.BidiWriteHandle(handle=b"new-handle") + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(write_handle=handle), state + ) + assert write_state.write_handle == handle + + # 3. Update from Resource (finalization) + resource = storage_type.Object(size=1000, finalize_time=datetime.now()) + strategy.update_state_from_response( + storage_type.BidiWriteObjectResponse(resource=resource), state + ) + assert write_state.persisted_size == 1000 + assert write_state.is_finalized + + def test_update_state_from_response_none(self, strategy): + """Verify None response doesn't crash.""" + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO(), flush_interval=10 + ) + state = {"write_state": write_state} + strategy.update_state_from_response(None, state) + assert write_state.persisted_size == 0 + + # ------------------------------------------------------------------------- + # Tests for recover_state_on_failure + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_recover_state_on_failure_rewind_logic(self, strategy): + """Verify buffer seek and counter resets on generic failure (Non-redirect).""" + mock_buffer = io.BytesIO(b"0123456789") + write_state = _WriteState( + chunk_size=2, user_buffer=mock_buffer, flush_interval=100 + ) + + # Simulate progress: sent 8 bytes, but server only persisted 4 + write_state.bytes_sent = 8 + write_state.persisted_size = 4 + write_state.bytes_since_last_flush = 2 + mock_buffer.seek(8) + + # Simulate generic 503 error without trailers + await strategy.recover_state_on_failure( + exceptions.ServiceUnavailable("busy"), {"write_state": write_state} + ) + + # Buffer must be seeked back to 4 + assert mock_buffer.tell() == 4 + assert write_state.bytes_sent == 4 + # Flush counter must be reset to avoid incorrect firing after resume + assert write_state.bytes_since_last_flush == 0 + + @pytest.mark.asyncio + async def test_recover_state_on_failure_direct_redirect(self, strategy): + """Verify handling when the error is a BidiWriteObjectRedirectedError.""" + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO(), flush_interval=100 + ) + state = {"write_state": write_state} + + redirect = BidiWriteObjectRedirectedError( + routing_token="tok-1", + write_handle=storage_type.BidiWriteHandle(handle=b"h-1"), + ) + + await strategy.recover_state_on_failure(redirect, state) + + assert write_state.routing_token == "tok-1" + assert write_state.write_handle.handle == b"h-1" + + @pytest.mark.asyncio + async def test_recover_state_on_failure_wrapped_redirect(self, strategy): + """Verify handling when RedirectedError is inside Aborted.errors.""" + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO(), flush_interval=10 + ) + + redirect = BidiWriteObjectRedirectedError(routing_token="tok-wrapped") + # google-api-core Aborted often wraps multiple errors + error = exceptions.Aborted("conflict", errors=[redirect]) + + await strategy.recover_state_on_failure(error, {"write_state": write_state}) + + assert write_state.routing_token == "tok-wrapped" + + @pytest.mark.asyncio + async def test_recover_state_on_failure_trailer_metadata_redirect(self, strategy): + """Verify complex parsing from 'grpc-status-details-bin' in trailers.""" + write_state = _WriteState( + chunk_size=4, user_buffer=io.BytesIO(), flush_interval=10 + ) + + redirect_proto = BidiWriteObjectRedirectedError(routing_token="metadata-token") + status = status_pb2.Status() + detail = status.details.add() + detail.type_url = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + ) + detail.value = BidiWriteObjectRedirectedError.serialize(redirect_proto) + + # FIX: No spec= here, because Aborted doesn't have trailing_metadata in its base definition + mock_error = mock.MagicMock() + mock_error.errors = [] + mock_error.trailing_metadata.return_value = [ + ("grpc-status-details-bin", status.SerializeToString()) + ] + + with mock.patch( + "google.cloud.storage.asyncio.retry.writes_resumption_strategy._extract_bidi_writes_redirect_proto", + return_value=redirect_proto, + ): + await strategy.recover_state_on_failure( + mock_error, {"write_state": write_state} + ) + + assert write_state.routing_token == "metadata-token" + + def test_write_state_initialization(self): + """Verify WriteState starts with clean counters.""" + buffer = io.BytesIO(b"test") + ws = _WriteState(chunk_size=10, user_buffer=buffer, flush_interval=100) + + assert ws.persisted_size == 0 + assert ws.bytes_sent == 0 + assert ws.bytes_since_last_flush == 0 + assert ws.flush_interval == 100 + assert not ws.is_finalized diff --git a/tests/unit/asyncio/test_async_appendable_object_writer.py b/tests/unit/asyncio/test_async_appendable_object_writer.py index 089c3d88f..1bbeb5330 100644 --- a/tests/unit/asyncio/test_async_appendable_object_writer.py +++ b/tests/unit/asyncio/test_async_appendable_object_writer.py @@ -12,508 +12,405 @@ # See the License for the specific language governing permissions and # limitations under the License. +import io +import unittest.mock as mock +from unittest.mock import AsyncMock, MagicMock import pytest -from unittest import mock -from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( +from google.api_core import exceptions +from google.rpc import status_pb2 +from google.cloud._storage_v2.types import storage as storage_type +from google.cloud._storage_v2.types.storage import BidiWriteObjectRedirectedError +from google.cloud.storage.asyncio.async_appendable_object_writer import ( AsyncAppendableObjectWriter, + _is_write_retryable, + _MAX_CHUNK_SIZE_BYTES, + _DEFAULT_FLUSH_INTERVAL_BYTES, ) -from google.cloud import _storage_v2 - +# Constants BUCKET = "test-bucket" OBJECT = "test-object" GENERATION = 123 WRITE_HANDLE = b"test-write-handle" PERSISTED_SIZE = 456 +EIGHT_MIB = 8 * 1024 * 1024 + + +class TestIsWriteRetryable: + """Exhaustive tests for retry predicate logic.""" + + def test_standard_transient_errors(self, mock_appendable_writer): + for exc in [ + exceptions.InternalServerError("500"), + exceptions.ServiceUnavailable("503"), + exceptions.DeadlineExceeded("timeout"), + exceptions.TooManyRequests("429"), + ]: + assert _is_write_retryable(exc) + + def test_aborted_with_redirect_proto(self, mock_appendable_writer): + # Direct redirect error wrapped in Aborted + redirect = BidiWriteObjectRedirectedError(routing_token="token") + exc = exceptions.Aborted("aborted", errors=[redirect]) + assert _is_write_retryable(exc) + + def test_aborted_with_trailers(self, mock_appendable_writer): + # Setup Status with Redirect Detail + status = status_pb2.Status() + detail = status.details.add() + detail.type_url = ( + "type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError" + ) + + # Mock error with trailing_metadata method + mock_grpc_error = MagicMock() + mock_grpc_error.trailing_metadata.return_value = [ + ("grpc-status-details-bin", status.SerializeToString()) + ] + + # Aborted wraps the grpc error + exc = exceptions.Aborted("aborted", errors=[mock_grpc_error]) + assert _is_write_retryable(exc) + + def test_aborted_without_metadata(self, mock_appendable_writer): + mock_grpc_error = MagicMock() + mock_grpc_error.trailing_metadata.return_value = [] + exc = exceptions.Aborted("bare aborted", errors=[mock_grpc_error]) + assert not _is_write_retryable(exc) + + def test_non_retryable_errors(self, mock_appendable_writer): + assert not _is_write_retryable(exceptions.BadRequest("400")) + assert not _is_write_retryable(exceptions.NotFound("404")) @pytest.fixture -def mock_client(): - """Mock the async gRPC client.""" - return mock.AsyncMock() - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init(mock_write_object_stream, mock_client): - """Test the constructor of AsyncAppendableObjectWriter.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - - assert writer.client == mock_client - assert writer.bucket_name == BUCKET - assert writer.object_name == OBJECT - assert writer.generation is None - assert writer.write_handle is None - assert not writer._is_stream_open - assert writer.offset is None - assert writer.persisted_size is None - - mock_write_object_stream.assert_called_once_with( - client=mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=None, - write_handle=None, - ) - assert writer.write_obj_stream == mock_write_object_stream.return_value - - -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -def test_init_with_optional_args(mock_write_object_stream, mock_client): - """Test the constructor with optional arguments.""" - writer = AsyncAppendableObjectWriter( - mock_client, - BUCKET, - OBJECT, - generation=GENERATION, - write_handle=WRITE_HANDLE, - ) - - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - - mock_write_object_stream.assert_called_once_with( - client=mock_client, - bucket_name=BUCKET, - object_name=OBJECT, - generation_number=GENERATION, - write_handle=WRITE_HANDLE, +def mock_appendable_writer(): + """Fixture to provide a mock AsyncAppendableObjectWriter setup.""" + mock_client = mock.MagicMock() + mock_client.grpc_client = mock.AsyncMock() + # Internal stream class patch + stream_patcher = mock.patch( + "google.cloud.storage.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" ) + mock_stream_cls = stream_patcher.start() + mock_stream = mock_stream_cls.return_value + # Configure all async methods explicitly + mock_stream.open = AsyncMock() + mock_stream.close = AsyncMock() + mock_stream.send = AsyncMock() + mock_stream.recv = AsyncMock() -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_state_lookup(mock_write_object_stream, mock_client): - """Test state_lookup method.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=PERSISTED_SIZE) - ) - - expected_request = _storage_v2.BidiWriteObjectRequest(state_lookup=True) - - # Act - response = await writer.state_lookup() - - # Assert - mock_stream.send.assert_awaited_once_with(expected_request) - mock_stream.recv.assert_awaited_once() - assert writer.persisted_size == PERSISTED_SIZE - assert response == PERSISTED_SIZE - - -@pytest.mark.asyncio -async def test_state_lookup_without_open_raises_value_error(mock_client): - """Test that state_lookup raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, - match="Stream is not open. Call open\\(\\) before state_lookup\\(\\).", - ): - await writer.state_lookup() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_appendable_object_writer(mock_write_object_stream, mock_client): - """Test the open method.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - mock_stream = mock_write_object_stream.return_value - mock_stream.open = mock.AsyncMock() - - mock_stream.generation_number = GENERATION - mock_stream.write_handle = WRITE_HANDLE + # Default mock properties + mock_stream.is_stream_open = False mock_stream.persisted_size = 0 - - # Act - await writer.open() - - # Assert - mock_stream.open.assert_awaited_once() - assert writer._is_stream_open - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - assert writer.persisted_size == 0 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_appendable_object_writer_existing_object( - mock_write_object_stream, mock_client -): - """Test the open method.""" - # Arrange - writer = AsyncAppendableObjectWriter( - mock_client, BUCKET, OBJECT, generation=GENERATION - ) - mock_stream = mock_write_object_stream.return_value - mock_stream.open = mock.AsyncMock() - mock_stream.generation_number = GENERATION mock_stream.write_handle = WRITE_HANDLE mock_stream.persisted_size = PERSISTED_SIZE - # Act - await writer.open() + yield { + "mock_client": mock_client, + "mock_stream_cls": mock_stream_cls, + "mock_stream": mock_stream, + } + + stream_patcher.stop() + + +class TestAsyncAppendableObjectWriter: + def _make_one(self, mock_client, **kwargs): + return AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT, **kwargs) + + # ------------------------------------------------------------------------- + # Initialization & Configuration Tests + # ------------------------------------------------------------------------- + + def test_init_defaults(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + assert writer.bucket_name == BUCKET + assert writer.object_name == OBJECT + assert writer.persisted_size is None + assert writer.bytes_appended_since_last_flush == 0 + assert writer.flush_interval == _DEFAULT_FLUSH_INTERVAL_BYTES + + def test_init_with_writer_options(self, mock_appendable_writer): + writer = self._make_one( + mock_appendable_writer["mock_client"], + writer_options={"FLUSH_INTERVAL_BYTES": EIGHT_MIB}, + ) + assert writer.flush_interval == EIGHT_MIB + + def test_init_validation_chunk_size_raises(self, mock_appendable_writer): + with pytest.raises(exceptions.OutOfRange): + self._make_one( + mock_appendable_writer["mock_client"], + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES - 1}, + ) + + def test_init_validation_multiple_raises(self, mock_appendable_writer): + with pytest.raises(exceptions.OutOfRange): + self._make_one( + mock_appendable_writer["mock_client"], + writer_options={"FLUSH_INTERVAL_BYTES": _MAX_CHUNK_SIZE_BYTES + 1}, + ) + + def test_init_raises_if_crc32c_missing(self, mock_appendable_writer): + with mock.patch( + "google.cloud.storage.asyncio._utils.google_crc32c" + ) as mock_crc: + mock_crc.implementation = "python" + with pytest.raises(exceptions.FailedPrecondition): + self._make_one(mock_appendable_writer["mock_client"]) + + # ------------------------------------------------------------------------- + # Stream Lifecycle Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_state_lookup(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=100) + + size = await writer.state_lookup() + + mock_appendable_writer["mock_stream"].send.assert_awaited_once() + assert size == 100 + assert writer.persisted_size == 100 + + @pytest.mark.asyncio + async def test_open_success(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + mock_appendable_writer["mock_stream"].generation_number = 456 + mock_appendable_writer["mock_stream"].write_handle = b"new-h" + mock_appendable_writer["mock_stream"].persisted_size = 0 - # Assert - mock_stream.open.assert_awaited_once() - assert writer._is_stream_open - assert writer.generation == GENERATION - assert writer.write_handle == WRITE_HANDLE - assert writer.persisted_size == PERSISTED_SIZE - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_open_when_already_open_raises_error( - mock_write_object_stream, mock_client -): - """Test that opening an already open writer raises a ValueError.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True # Manually set to open - - # Act & Assert - with pytest.raises(ValueError, match="Underlying bidi-gRPC stream is already open"): await writer.open() + assert writer._is_stream_open + assert writer.generation == 456 + assert writer.write_handle == b"new-h" + mock_appendable_writer["mock_stream"].open.assert_awaited_once() + + def test_on_open_error_redirection(self, mock_appendable_writer): + """Verify redirect info is extracted from helper.""" + writer = self._make_one(mock_appendable_writer["mock_client"]) + redirect = BidiWriteObjectRedirectedError( + routing_token="rt1", + write_handle=storage_type.BidiWriteHandle(handle=b"h1"), + generation=777, + ) + + with mock.patch( + "google.cloud.storage.asyncio.async_appendable_object_writer._extract_bidi_writes_redirect_proto", + return_value=redirect, + ): + writer._on_open_error(exceptions.Aborted("redirect")) + + assert writer._routing_token == "rt1" + assert writer.write_handle.handle == b"h1" + assert writer.generation == 777 + + # ------------------------------------------------------------------------- + # Append Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_append_basic_success(self, mock_appendable_writer): + """Verify append orchestrates manager and drives the internal generator.""" + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + writer.persisted_size = 0 + + data = b"test-data" + + with mock.patch( + "google.cloud.storage.asyncio.async_appendable_object_writer._BidiStreamRetryManager" + ) as MockManager: + + async def mock_execute(state, policy): + factory = MockManager.call_args[0][1] + dummy_reqs = [storage_type.BidiWriteObjectRequest()] + gen = factory(dummy_reqs, state) + + mock_appendable_writer["mock_stream"].recv.side_effect = [ + storage_type.BidiWriteObjectResponse( + persisted_size=len(data), + write_handle=storage_type.BidiWriteHandle(handle=b"h2"), + ), + None, + ] + async for _ in gen: + pass + + MockManager.return_value.execute.side_effect = mock_execute + await writer.append(data) + + assert writer.persisted_size == len(data) + sent_req = mock_appendable_writer["mock_stream"].send.call_args[0][0] + assert sent_req.state_lookup + assert sent_req.flush + + @pytest.mark.asyncio + async def test_append_recovery_reopens_stream(self, mock_appendable_writer): + """Verifies re-opening logic on retry.""" + writer = self._make_one( + mock_appendable_writer["mock_client"], write_handle=b"h1" + ) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + # Setup mock to allow close() call + mock_appendable_writer["mock_stream"].is_stream_open = True + + async def mock_open(metadata=None): + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + writer._is_stream_open = True + writer.persisted_size = 5 + writer.write_handle = b"h_recovered" + + with mock.patch.object( + writer, "open", side_effect=mock_open + ) as mock_writer_open: + with mock.patch( + "google.cloud.storage.asyncio.async_appendable_object_writer._BidiStreamRetryManager" + ) as MockManager: + + async def mock_execute(state, policy): + factory = MockManager.call_args[0][1] + # Simulate Attempt 1 fail + gen1 = factory([], state) + try: + await gen1.__anext__() + except Exception: + pass + # Simulate Attempt 2 + gen2 = factory([], state) + mock_appendable_writer["mock_stream"].recv.return_value = None + async for _ in gen2: + pass + + MockManager.return_value.execute.side_effect = mock_execute + await writer.append(b"0123456789") + + mock_appendable_writer["mock_stream"].close.assert_awaited() + mock_writer_open.assert_awaited() + assert writer.persisted_size == 5 + + @pytest.mark.asyncio + async def test_append_unimplemented_string_raises(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + with pytest.raises(NotImplementedError): + await writer.append_from_string("test") + + # ------------------------------------------------------------------------- + # Flush, Close, Finalize + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_flush_resets_counters(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + writer.bytes_appended_since_last_flush = 100 + + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(persisted_size=200) -@pytest.mark.asyncio -async def test_unimplemented_methods_raise_error(mock_client): - """Test that all currently unimplemented methods raise NotImplementedError.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - - with pytest.raises(NotImplementedError): - await writer.append_from_string("data") - - with pytest.raises(NotImplementedError): - await writer.append_from_stream(mock.Mock()) - - with pytest.raises(NotImplementedError): - await writer.append_from_file("file.txt") - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_flush(mock_write_object_stream, mock_client): - """Test that flush sends the correct request and updates state.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=1024) - ) - - persisted_size = await writer.flush() - - expected_request = _storage_v2.BidiWriteObjectRequest(flush=True, state_lookup=True) - mock_stream.send.assert_awaited_once_with(expected_request) - mock_stream.recv.assert_awaited_once() - assert writer.persisted_size == 1024 - assert writer.offset == 1024 - assert persisted_size == 1024 - - -@pytest.mark.asyncio -async def test_flush_without_open_raises_value_error(mock_client): - """Test that flush raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before flush\\(\\)." - ): await writer.flush() + assert writer.bytes_appended_since_last_flush == 0 + assert writer.persisted_size == 200 -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_simple_flush(mock_write_object_stream, mock_client): - """Test that flush sends the correct request and updates state.""" - # Arrange - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - - # Act - await writer.simple_flush() - - # Assert - mock_stream.send.assert_awaited_once_with( - _storage_v2.BidiWriteObjectRequest(flush=True) - ) - + @pytest.mark.asyncio + async def test_simple_flush(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + writer.bytes_appended_since_last_flush = 50 -@pytest.mark.asyncio -async def test_simple_flush_without_open_raises_value_error(mock_client): - """Test that flush raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, - match="Stream is not open. Call open\\(\\) before simple_flush\\(\\).", - ): await writer.simple_flush() - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_close(mock_write_object_stream, mock_client): - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.offset = 1024 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(persisted_size=1024) - ) - mock_stream.close = mock.AsyncMock() - writer.finalize = mock.AsyncMock() - - persisted_size = await writer.close() - - writer.finalize.assert_not_awaited() - mock_stream.close.assert_awaited_once() - assert writer.offset is None - assert persisted_size == 1024 - assert not writer._is_stream_open - - -@pytest.mark.asyncio -async def test_close_without_open_raises_value_error(mock_client): - """Test that close raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before close\\(\\)." - ): - await writer.close() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_finalize_on_close(mock_write_object_stream, mock_client): - """Test close with finalizing.""" - # Arrange - mock_resource = _storage_v2.Object(name=OBJECT, bucket=BUCKET, size=2048) - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.offset = 1024 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(resource=mock_resource) - ) - mock_stream.close = mock.AsyncMock() - - # Act - result = await writer.close(finalize_on_close=True) - - # Assert - mock_stream.close.assert_awaited_once() - assert not writer._is_stream_open - assert writer.offset is None - assert writer.object_resource == mock_resource - assert writer.persisted_size == 2048 - assert result == mock_resource - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_finalize(mock_write_object_stream, mock_client): - """Test that finalize sends the correct request and updates state.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_resource = _storage_v2.Object(name=OBJECT, bucket=BUCKET, size=123) - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - mock_stream.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse(resource=mock_resource) - ) - - gcs_object = await writer.finalize() - - mock_stream.send.assert_awaited_once_with( - _storage_v2.BidiWriteObjectRequest(finish_write=True) - ) - mock_stream.recv.assert_awaited_once() - assert writer.object_resource == mock_resource - assert writer.persisted_size == 123 - assert gcs_object == mock_resource - - -@pytest.mark.asyncio -async def test_finalize_without_open_raises_value_error(mock_client): - """Test that finalize raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before finalize\\(\\)." - ): - await writer.finalize() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_raises_error_if_not_open(mock_write_object_stream, mock_client): - """Test that append raises an error if the stream is not open.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Stream is not open. Call open\\(\\) before append\\(\\)." - ): - await writer.append(b"some data") - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_with_empty_data(mock_write_object_stream, mock_client): - """Test that append does nothing if data is empty.""" - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - - await writer.append(b"") - - mock_stream.send.assert_not_awaited() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_sends_data_in_chunks(mock_write_object_stream, mock_client): - """Test that append sends data in chunks and updates offset.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_CHUNK_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 100 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - writer.simple_flush = mock.AsyncMock() - - data = b"a" * (_MAX_CHUNK_SIZE_BYTES + 1) - await writer.append(data) - - assert mock_stream.send.await_count == 2 - first_call = mock_stream.send.await_args_list[0] - second_call = mock_stream.send.await_args_list[1] - - # First chunk - assert first_call[0][0].write_offset == 100 - assert len(first_call[0][0].checksummed_data.content) == _MAX_CHUNK_SIZE_BYTES - - # Second chunk - assert second_call[0][0].write_offset == 100 + _MAX_CHUNK_SIZE_BYTES - assert len(second_call[0][0].checksummed_data.content) == 1 - - assert writer.offset == 100 + len(data) - writer.simple_flush.assert_not_awaited() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_flushes_when_buffer_is_full( - mock_write_object_stream, mock_client -): - """Test that append flushes the stream when the buffer size is reached.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_BUFFER_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - writer.simple_flush = mock.AsyncMock() - - data = b"a" * _MAX_BUFFER_SIZE_BYTES - await writer.append(data) - - writer.simple_flush.assert_awaited_once() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_handles_large_data(mock_write_object_stream, mock_client): - """Test that append handles data larger than the buffer size.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_BUFFER_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - writer.simple_flush = mock.AsyncMock() - - data = b"a" * (_MAX_BUFFER_SIZE_BYTES * 2 + 1) - await writer.append(data) - - assert writer.simple_flush.await_count == 2 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_appendable_object_writer._AsyncWriteObjectStream" -) -async def test_append_data_two_times(mock_write_object_stream, mock_client): - """Test that append sends data correctly when called multiple times.""" - from google.cloud.storage._experimental.asyncio.async_appendable_object_writer import ( - _MAX_CHUNK_SIZE_BYTES, - ) - - writer = AsyncAppendableObjectWriter(mock_client, BUCKET, OBJECT) - writer._is_stream_open = True - writer.persisted_size = 0 - mock_stream = mock_write_object_stream.return_value - mock_stream.send = mock.AsyncMock() - writer.simple_flush = mock.AsyncMock() - - data1 = b"a" * (_MAX_CHUNK_SIZE_BYTES + 10) - await writer.append(data1) - - data2 = b"b" * (_MAX_CHUNK_SIZE_BYTES + 20) - await writer.append(data2) - - total_data_length = len(data1) + len(data2) - assert writer.offset == total_data_length - assert writer.simple_flush.await_count == 0 + mock_appendable_writer["mock_stream"].send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(flush=True) + ) + assert writer.bytes_appended_since_last_flush == 0 + + @pytest.mark.asyncio + async def test_close_without_finalize(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + writer.persisted_size = 50 + + size = await writer.close() + + mock_appendable_writer["mock_stream"].close.assert_awaited() + assert not writer._is_stream_open + assert size == 50 + + @pytest.mark.asyncio + async def test_finalize_lifecycle(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.write_obj_stream = mock_appendable_writer["mock_stream"] + + resource = storage_type.Object(size=999) + mock_appendable_writer[ + "mock_stream" + ].recv.return_value = storage_type.BidiWriteObjectResponse(resource=resource) + + res = await writer.finalize() + + assert res == resource + assert writer.persisted_size == 999 + mock_appendable_writer["mock_stream"].send.assert_awaited_with( + storage_type.BidiWriteObjectRequest(finish_write=True) + ) + mock_appendable_writer["mock_stream"].close.assert_awaited() + assert not writer._is_stream_open + + @pytest.mark.asyncio + async def test_close_with_finalize_on_close(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.finalize = AsyncMock() + + await writer.close(finalize_on_close=True) + writer.finalize.assert_awaited_once() + + # ------------------------------------------------------------------------- + # Helper Tests + # ------------------------------------------------------------------------- + + @pytest.mark.asyncio + async def test_append_from_file(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + writer._is_stream_open = True + writer.append = AsyncMock() + + fp = io.BytesIO(b"a" * 12) + await writer.append_from_file(fp, block_size=4) + + assert writer.append.await_count == 3 + + @pytest.mark.asyncio + async def test_methods_require_open_stream_raises(self, mock_appendable_writer): + writer = self._make_one(mock_appendable_writer["mock_client"]) + methods = [ + writer.append(b"data"), + writer.flush(), + writer.simple_flush(), + writer.close(), + writer.finalize(), + writer.state_lookup(), + ] + for coro in methods: + with pytest.raises(ValueError, match="Stream is not open"): + await coro diff --git a/tests/unit/asyncio/test_async_client.py b/tests/unit/asyncio/test_async_client.py index 64481a0d4..e7e232425 100644 --- a/tests/unit/asyncio/test_async_client.py +++ b/tests/unit/asyncio/test_async_client.py @@ -31,7 +31,7 @@ def _make_credentials(): @pytest.mark.skipif( sys.version_info < (3, 10), - reason="Async Client requires Python 3.10+ due to google-auth-library asyncio support" + reason="Async Client requires Python 3.10+ due to google-auth-library asyncio support", ) class TestAsyncClient: @staticmethod @@ -46,7 +46,9 @@ def test_ctor_defaults(self): credentials = _make_credentials() # We mock AsyncConnection to prevent network logic during init - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection") as MockConn: + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ) as MockConn: client = self._make_one(project=PROJECT, credentials=credentials) assert client.project == PROJECT @@ -66,21 +68,26 @@ def test_ctor_mtls_raises_error(self): credentials = _make_credentials() # Simulate environment where mTLS is enabled - with mock.patch("google.cloud.storage.abstracts.base_client.BaseClient._use_client_cert", new_callable=mock.PropertyMock) as mock_mtls: + with mock.patch( + "google.cloud.storage.abstracts.base_client.BaseClient._use_client_cert", + new_callable=mock.PropertyMock, + ) as mock_mtls: mock_mtls.return_value = True - with pytest.raises(ValueError, match="Async Client currently do not support mTLS"): + with pytest.raises( + ValueError, match="Async Client currently do not support mTLS" + ): self._make_one(credentials=credentials) def test_ctor_w_async_http_passed(self): credentials = _make_credentials() async_http = mock.Mock() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one( - project="PROJECT", - credentials=credentials, - _async_http=async_http + project="PROJECT", credentials=credentials, _async_http=async_http ) assert client._async_http_internal is async_http @@ -88,13 +95,17 @@ def test_ctor_w_async_http_passed(self): def test_async_http_property_creates_session(self): credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) assert client._async_http_internal is None # Mock the auth session class - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncSession") as MockSession: + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncSession" + ) as MockSession: session = client.async_http assert session is MockSession.return_value @@ -102,12 +113,14 @@ def test_async_http_property_creates_session(self): # Should be initialized with the AsyncCredsWrapper, not the raw credentials MockSession.assert_called_once() call_kwargs = MockSession.call_args[1] - assert call_kwargs['credentials'] == client.credentials + assert call_kwargs["credentials"] == client.credentials @pytest.mark.asyncio async def test_close_manages_session_lifecycle(self): credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) # 1. Internal session created by client -> Client closes it @@ -123,11 +136,11 @@ async def test_close_ignores_user_session(self): credentials = _make_credentials() user_session = mock.AsyncMock() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one( - project="PROJECT", - credentials=credentials, - _async_http=user_session + project="PROJECT", credentials=credentials, _async_http=user_session ) # 2. External session passed by user -> Client DOES NOT close it @@ -140,12 +153,13 @@ async def test_get_resource(self): query_params = {"foo": "bar"} credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) # Mock the connection's api_request - client._connection.api_request = mock.AsyncMock( - return_value="response") + client._connection.api_request = mock.AsyncMock(return_value="response") result = await client._get_resource(path, query_params=query_params) @@ -157,7 +171,7 @@ async def test_get_resource(self): headers=None, timeout=mock.ANY, retry=mock.ANY, - _target_object=None + _target_object=None, ) @pytest.mark.asyncio @@ -166,14 +180,13 @@ async def test_list_resource(self): item_to_value = mock.Mock() credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) iterator = client._list_resource( - path=path, - item_to_value=item_to_value, - max_results=10, - page_token="token" + path=path, item_to_value=item_to_value, max_results=10, page_token="token" ) assert isinstance(iterator, AsyncHTTPIterator) @@ -186,7 +199,9 @@ async def test_patch_resource(self): data = {"key": "val"} credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) client._connection.api_request = mock.AsyncMock() @@ -201,7 +216,7 @@ async def test_patch_resource(self): headers=None, timeout=mock.ANY, retry=None, - _target_object=None + _target_object=None, ) @pytest.mark.asyncio @@ -210,7 +225,9 @@ async def test_put_resource(self): data = b"bytes" credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) client._connection.api_request = mock.AsyncMock() @@ -225,7 +242,7 @@ async def test_put_resource(self): headers=None, timeout=mock.ANY, retry=None, - _target_object=None + _target_object=None, ) @pytest.mark.asyncio @@ -234,7 +251,9 @@ async def test_post_resource(self): data = {"source": []} credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) client._connection.api_request = mock.AsyncMock() @@ -249,7 +268,7 @@ async def test_post_resource(self): headers=None, timeout=mock.ANY, retry=None, - _target_object=None + _target_object=None, ) @pytest.mark.asyncio @@ -257,7 +276,9 @@ async def test_delete_resource(self): path = "/b/bucket" credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) client._connection.api_request = mock.AsyncMock() @@ -271,12 +292,14 @@ async def test_delete_resource(self): headers=None, timeout=mock.ANY, retry=mock.ANY, - _target_object=None + _target_object=None, ) def test_bucket_not_implemented(self): credentials = _make_credentials() - with mock.patch("google.cloud.storage._experimental.asyncio.async_client.AsyncConnection"): + with mock.patch( + "google.cloud.storage._experimental.asyncio.async_client.AsyncConnection" + ): client = self._make_one(project="PROJECT", credentials=credentials) with pytest.raises(NotImplementedError): diff --git a/tests/unit/asyncio/test_async_creds.py b/tests/unit/asyncio/test_async_creds.py index 0a45bca5d..3dad11fd0 100644 --- a/tests/unit/asyncio/test_async_creds.py +++ b/tests/unit/asyncio/test_async_creds.py @@ -4,28 +4,30 @@ from google.auth import credentials as google_creds from google.cloud.storage._experimental.asyncio import async_creds + @pytest.fixture def mock_aio_modules(): """Patches sys.modules to simulate google.auth.aio existence.""" mock_creds_module = unittest.mock.MagicMock() # We must set the base class to object so our wrapper can inherit safely in tests - mock_creds_module.Credentials = object - + mock_creds_module.Credentials = object + modules = { - 'google.auth.aio': unittest.mock.MagicMock(), - 'google.auth.aio.credentials': mock_creds_module, + "google.auth.aio": unittest.mock.MagicMock(), + "google.auth.aio.credentials": mock_creds_module, } - + with unittest.mock.patch.dict(sys.modules, modules): # We also need to manually flip the flag in the module to True for the test context # because the module was likely already imported with the flag set to False/True # depending on the real environment. - with unittest.mock.patch.object(async_creds, '_AIO_AVAILABLE', True): + with unittest.mock.patch.object(async_creds, "_AIO_AVAILABLE", True): # We also need to ensure BaseCredentials in the module points to our mock # if we want strictly correct inheritance, though duck typing usually suffices. - with unittest.mock.patch.object(async_creds, 'BaseCredentials', object): + with unittest.mock.patch.object(async_creds, "BaseCredentials", object): yield + @pytest.fixture def mock_sync_creds(): """Creates a mock of the synchronous Google Credentials object.""" @@ -33,14 +35,15 @@ def mock_sync_creds(): type(creds).valid = unittest.mock.PropertyMock(return_value=True) return creds + @pytest.fixture def async_wrapper(mock_aio_modules, mock_sync_creds): """Instantiates the wrapper with the mock credentials.""" # This instantiation would raise ImportError if mock_aio_modules didn't set _AIO_AVAILABLE=True return async_creds.AsyncCredsWrapper(mock_sync_creds) + class TestAsyncCredsWrapper: - @pytest.mark.asyncio async def test_init_sets_attributes(self, async_wrapper, mock_sync_creds): """Test that the wrapper initializes correctly.""" @@ -51,19 +54,19 @@ async def test_valid_property_delegates(self, async_wrapper, mock_sync_creds): """Test that the .valid property maps to the sync creds .valid property.""" type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True) assert async_wrapper.valid is True - + type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False) assert async_wrapper.valid is False @pytest.mark.asyncio async def test_refresh_offloads_to_executor(self, async_wrapper, mock_sync_creds): - """Test that refresh() gets the running loop and calls sync refresh in executor.""" - with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop: + """Test that refresh() gets the running loop and calls sync refresh in executor.""" + with unittest.mock.patch("asyncio.get_running_loop") as mock_get_loop: mock_loop = unittest.mock.AsyncMock() mock_get_loop.return_value = mock_loop - + await async_wrapper.refresh(None) - + mock_loop.run_in_executor.assert_called_once() args, _ = mock_loop.run_in_executor.call_args assert args[1] == mock_sync_creds.refresh @@ -72,10 +75,10 @@ async def test_refresh_offloads_to_executor(self, async_wrapper, mock_sync_creds async def test_before_request_valid_creds(self, async_wrapper, mock_sync_creds): """Test before_request when credentials are ALREADY valid.""" type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=True) - + headers = {} await async_wrapper.before_request(None, "GET", "http://example.com", headers) - + mock_sync_creds.apply.assert_called_once_with(headers) mock_sync_creds.before_request.assert_not_called() @@ -83,12 +86,12 @@ async def test_before_request_valid_creds(self, async_wrapper, mock_sync_creds): async def test_before_request_invalid_creds(self, async_wrapper, mock_sync_creds): """Test before_request when credentials are INVALID (refresh path).""" type(mock_sync_creds).valid = unittest.mock.PropertyMock(return_value=False) - + headers = {} method = "GET" url = "http://example.com" - with unittest.mock.patch('asyncio.get_running_loop') as mock_get_loop: + with unittest.mock.patch("asyncio.get_running_loop") as mock_get_loop: mock_loop = unittest.mock.AsyncMock() mock_get_loop.return_value = mock_loop @@ -101,8 +104,8 @@ async def test_before_request_invalid_creds(self, async_wrapper, mock_sync_creds def test_missing_aio_raises_error(self, mock_sync_creds): """Ensure ImportError is raised if _AIO_AVAILABLE is False.""" # We manually simulate the environment where AIO is missing - with unittest.mock.patch.object(async_creds, '_AIO_AVAILABLE', False): + with unittest.mock.patch.object(async_creds, "_AIO_AVAILABLE", False): with pytest.raises(ImportError) as excinfo: async_creds.AsyncCredsWrapper(mock_sync_creds) - + assert "Failed to import 'google.auth.aio'" in str(excinfo.value) diff --git a/tests/unit/asyncio/test_async_grpc_client.py b/tests/unit/asyncio/test_async_grpc_client.py index eb06ab938..f193acb60 100644 --- a/tests/unit/asyncio/test_async_grpc_client.py +++ b/tests/unit/asyncio/test_async_grpc_client.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest import mock +import pytest from google.auth import credentials as auth_credentials from google.auth.credentials import AnonymousCredentials +from google.api_core import client_info as client_info_lib +from google.cloud.storage.asyncio import async_grpc_client +from google.cloud.storage import __version__ def _make_credentials(spec=None): @@ -24,36 +27,65 @@ def _make_credentials(spec=None): return mock.Mock(spec=spec) -class TestAsyncGrpcClient(unittest.TestCase): +class TestAsyncGrpcClient: @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_constructor_default_options(self, mock_async_storage_client): - from google.cloud.storage._experimental.asyncio import async_grpc_client - + # Arrange mock_transport_cls = mock.MagicMock() mock_async_storage_client.get_transport_class.return_value = mock_transport_cls mock_creds = _make_credentials() + # Act async_grpc_client.AsyncGrpcClient(credentials=mock_creds) + # Assert mock_async_storage_client.get_transport_class.assert_called_once_with( "grpc_asyncio" ) + kwargs = mock_async_storage_client.call_args.kwargs + client_info = kwargs["client_info"] + agent_version = f"gcloud-python/{__version__}" + assert agent_version in client_info.user_agent + primary_user_agent = client_info.to_user_agent() + expected_options = (("grpc.primary_user_agent", primary_user_agent),) + mock_transport_cls.create_channel.assert_called_once_with( - attempt_direct_path=True, credentials=mock_creds + attempt_direct_path=True, + credentials=mock_creds, + options=expected_options, ) mock_channel = mock_transport_cls.create_channel.return_value mock_transport_cls.assert_called_once_with(channel=mock_channel) mock_transport = mock_transport_cls.return_value - mock_async_storage_client.assert_called_once_with( - transport=mock_transport, - client_options=None, - client_info=None, + assert kwargs["transport"] is mock_transport + assert kwargs["client_options"] is None + + @mock.patch("google.cloud._storage_v2.StorageAsyncClient") + def test_constructor_with_client_info(self, mock_async_storage_client): + mock_transport_cls = mock.MagicMock() + mock_async_storage_client.get_transport_class.return_value = mock_transport_cls + mock_creds = _make_credentials() + client_info = client_info_lib.ClientInfo( + client_library_version="1.2.3", + ) + + async_grpc_client.AsyncGrpcClient( + credentials=mock_creds, client_info=client_info + ) + + agent_version = f"gcloud-python/{__version__}" + assert agent_version in client_info.user_agent + primary_user_agent = client_info.to_user_agent() + expected_options = (("grpc.primary_user_agent", primary_user_agent),) + + mock_transport_cls.create_channel.assert_called_once_with( + attempt_direct_path=True, + credentials=mock_creds, + options=expected_options, ) @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_constructor_disables_directpath(self, mock_async_storage_client): - from google.cloud.storage._experimental.asyncio import async_grpc_client - mock_transport_cls = mock.MagicMock() mock_async_storage_client.get_transport_class.return_value = mock_transport_cls mock_creds = _make_credentials() @@ -62,61 +94,70 @@ def test_constructor_disables_directpath(self, mock_async_storage_client): credentials=mock_creds, attempt_direct_path=False ) + kwargs = mock_async_storage_client.call_args.kwargs + client_info = kwargs["client_info"] + agent_version = f"gcloud-python/{__version__}" + assert agent_version in client_info.user_agent + primary_user_agent = client_info.to_user_agent() + expected_options = (("grpc.primary_user_agent", primary_user_agent),) + mock_transport_cls.create_channel.assert_called_once_with( - attempt_direct_path=False, credentials=mock_creds + attempt_direct_path=False, + credentials=mock_creds, + options=expected_options, ) mock_channel = mock_transport_cls.create_channel.return_value mock_transport_cls.assert_called_once_with(channel=mock_channel) @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_grpc_client_property(self, mock_grpc_gapic_client): - from google.cloud.storage._experimental.asyncio import async_grpc_client - # Arrange mock_transport_cls = mock.MagicMock() mock_grpc_gapic_client.get_transport_class.return_value = mock_transport_cls channel_sentinel = mock.sentinel.channel - mock_transport_cls.create_channel.return_value = channel_sentinel - mock_transport_cls.return_value = mock.sentinel.transport + mock_transport_instance = mock.sentinel.transport + mock_transport_cls.return_value = mock_transport_instance mock_creds = _make_credentials() - mock_client_info = mock.sentinel.client_info + # Use a real ClientInfo instance instead of a mock to properly test user agent logic + client_info = client_info_lib.ClientInfo(user_agent="test-user-agent") mock_client_options = mock.sentinel.client_options mock_attempt_direct_path = mock.sentinel.attempt_direct_path # Act client = async_grpc_client.AsyncGrpcClient( credentials=mock_creds, - client_info=mock_client_info, + client_info=client_info, client_options=mock_client_options, attempt_direct_path=mock_attempt_direct_path, ) + retrieved_client = client.grpc_client # This is what is being tested - mock_grpc_gapic_client.get_transport_class.return_value = mock_transport_cls - - mock_transport_cls.create_channel.return_value = channel_sentinel - mock_transport_instance = mock.sentinel.transport - mock_transport_cls.return_value = mock_transport_instance + # Assert - verify that gcloud-python agent version was added + agent_version = f"gcloud-python/{__version__}" + assert agent_version in client_info.user_agent + # Also verify original user_agent is still there + assert "test-user-agent" in client_info.user_agent - retrieved_client = client.grpc_client + primary_user_agent = client_info.to_user_agent() + expected_options = (("grpc.primary_user_agent", primary_user_agent),) - # Assert mock_transport_cls.create_channel.assert_called_once_with( - attempt_direct_path=mock_attempt_direct_path, credentials=mock_creds + attempt_direct_path=mock_attempt_direct_path, + credentials=mock_creds, + options=expected_options, ) - mock_transport_cls.assere_with(channel=channel_sentinel) + mock_transport_cls.assert_called_once_with(channel=channel_sentinel) mock_grpc_gapic_client.assert_called_once_with( transport=mock_transport_instance, - client_info=mock_client_info, + client_info=client_info, client_options=mock_client_options, ) - self.assertIs(retrieved_client, mock_grpc_gapic_client.return_value) + assert retrieved_client is mock_grpc_gapic_client.return_value @mock.patch("google.cloud._storage_v2.StorageAsyncClient") def test_grpc_client_with_anon_creds(self, mock_grpc_gapic_client): - from google.cloud.storage._experimental.asyncio import async_grpc_client - # Arrange mock_transport_cls = mock.MagicMock() mock_grpc_gapic_client.get_transport_class.return_value = mock_transport_cls @@ -131,9 +172,160 @@ def test_grpc_client_with_anon_creds(self, mock_grpc_gapic_client): retrieved_client = client.grpc_client # Assert - self.assertIs(retrieved_client, mock_grpc_gapic_client.return_value) + assert retrieved_client is mock_grpc_gapic_client.return_value + + kwargs = mock_grpc_gapic_client.call_args.kwargs + client_info = kwargs["client_info"] + agent_version = f"gcloud-python/{__version__}" + assert agent_version in client_info.user_agent + primary_user_agent = client_info.to_user_agent() + expected_options = (("grpc.primary_user_agent", primary_user_agent),) mock_transport_cls.create_channel.assert_called_once_with( - attempt_direct_path=True, credentials=anonymous_creds + attempt_direct_path=True, + credentials=anonymous_creds, + options=expected_options, ) mock_transport_cls.assert_called_once_with(channel=channel_sentinel) + + @mock.patch("google.cloud._storage_v2.StorageAsyncClient") + def test_user_agent_with_custom_client_info(self, mock_async_storage_client): + """Test that gcloud-python user agent is appended to existing user agent. + + Regression test similar to test__http.py::TestConnection::test_duplicate_user_agent + """ + mock_transport_cls = mock.MagicMock() + mock_async_storage_client.get_transport_class.return_value = mock_transport_cls + mock_creds = _make_credentials() + + # Create a client_info with an existing user_agent + client_info = client_info_lib.ClientInfo(user_agent="custom-app/1.0") + + # Act + async_grpc_client.AsyncGrpcClient( + credentials=mock_creds, + client_info=client_info, + ) + + # Assert - verify that gcloud-python version was appended + agent_version = f"gcloud-python/{__version__}" + expected_user_agent = f"custom-app/1.0 {agent_version} " + assert client_info.user_agent == expected_user_agent + + @mock.patch("google.cloud._storage_v2.StorageAsyncClient") + @pytest.mark.asyncio + async def test_delete_object(self, mock_async_storage_client): + # Arrange + mock_transport_cls = mock.MagicMock() + mock_async_storage_client.get_transport_class.return_value = mock_transport_cls + mock_gapic_client = mock.AsyncMock() + mock_async_storage_client.return_value = mock_gapic_client + + client = async_grpc_client.AsyncGrpcClient( + credentials=_make_credentials(spec=AnonymousCredentials) + ) + + bucket_name = "bucket" + object_name = "object" + generation = 123 + if_generation_match = 456 + if_generation_not_match = 789 + if_metageneration_match = 111 + if_metageneration_not_match = 222 + + # Act + await client.delete_object( + bucket_name, + object_name, + generation=generation, + if_generation_match=if_generation_match, + if_generation_not_match=if_generation_not_match, + if_metageneration_match=if_metageneration_match, + if_metageneration_not_match=if_metageneration_not_match, + ) + + # Assert + call_args, call_kwargs = mock_gapic_client.delete_object.call_args + request = call_kwargs["request"] + assert request.bucket == "projects/_/buckets/bucket" + assert request.object == "object" + assert request.generation == generation + assert request.if_generation_match == if_generation_match + assert request.if_generation_not_match == if_generation_not_match + assert request.if_metageneration_match == if_metageneration_match + assert request.if_metageneration_not_match == if_metageneration_not_match + + @mock.patch("google.cloud._storage_v2.StorageAsyncClient") + @pytest.mark.asyncio + async def test_get_object(self, mock_async_storage_client): + # Arrange + mock_transport_cls = mock.MagicMock() + mock_async_storage_client.get_transport_class.return_value = mock_transport_cls + mock_gapic_client = mock.AsyncMock() + mock_async_storage_client.return_value = mock_gapic_client + + client = async_grpc_client.AsyncGrpcClient( + credentials=_make_credentials(spec=AnonymousCredentials) + ) + + bucket_name = "bucket" + object_name = "object" + + # Act + await client.get_object( + bucket_name, + object_name, + ) + + # Assert + call_args, call_kwargs = mock_gapic_client.get_object.call_args + request = call_kwargs["request"] + assert request.bucket == "projects/_/buckets/bucket" + assert request.object == "object" + assert request.soft_deleted is False + + @mock.patch("google.cloud._storage_v2.StorageAsyncClient") + @pytest.mark.asyncio + async def test_get_object_with_all_parameters(self, mock_async_storage_client): + # Arrange + mock_transport_cls = mock.MagicMock() + mock_async_storage_client.get_transport_class.return_value = mock_transport_cls + mock_gapic_client = mock.AsyncMock() + mock_async_storage_client.return_value = mock_gapic_client + + client = async_grpc_client.AsyncGrpcClient( + credentials=_make_credentials(spec=AnonymousCredentials) + ) + + bucket_name = "bucket" + object_name = "object" + generation = 123 + if_generation_match = 456 + if_generation_not_match = 789 + if_metageneration_match = 111 + if_metageneration_not_match = 222 + soft_deleted = True + + # Act + await client.get_object( + bucket_name, + object_name, + generation=generation, + if_generation_match=if_generation_match, + if_generation_not_match=if_generation_not_match, + if_metageneration_match=if_metageneration_match, + if_metageneration_not_match=if_metageneration_not_match, + soft_deleted=soft_deleted, + ) + + # Assert + call_args, call_kwargs = mock_gapic_client.get_object.call_args + request = call_kwargs["request"] + assert request.bucket == "projects/_/buckets/bucket" + assert request.object == "object" + assert request.generation == generation + assert request.if_generation_match == if_generation_match + assert request.if_generation_not_match == if_generation_not_match + assert request.if_metageneration_match == if_metageneration_match + assert request.if_metageneration_not_match == if_metageneration_not_match + assert request.soft_deleted is True diff --git a/tests/unit/asyncio/test_async_helpers.py b/tests/unit/asyncio/test_async_helpers.py index 58ebbea31..d125f2b57 100644 --- a/tests/unit/asyncio/test_async_helpers.py +++ b/tests/unit/asyncio/test_async_helpers.py @@ -27,7 +27,6 @@ async def _safe_anext(iterator): class TestAsyncHTTPIterator: - def _make_one(self, *args, **kw): return AsyncHTTPIterator(*args, **kw) @@ -35,11 +34,9 @@ def _make_one(self, *args, **kw): async def test_iterate_items_single_page(self): """Test simple iteration over one page of results.""" client = mock.Mock() - api_request = mock.AsyncMock() - api_request.return_value = { - "items": ["a", "b"] - } - + api_request = mock.AsyncMock() + api_request.return_value = {"items": ["a", "b"]} + iterator = self._make_one( client=client, api_request=api_request, @@ -53,11 +50,9 @@ async def test_iterate_items_single_page(self): assert results == ["A", "B"] assert iterator.num_results == 2 - assert iterator.page_number == 1 + assert iterator.page_number == 1 api_request.assert_awaited_once_with( - method="GET", - path="/path", - query_params={} + method="GET", path="/path", query_params={} ) @pytest.mark.asyncio @@ -65,14 +60,14 @@ async def test_iterate_items_multiple_pages(self): """Test pagination flow passes tokens correctly.""" client = mock.Mock() api_request = mock.AsyncMock() - + # Setup Response: 2 Pages api_request.side_effect = [ - {"items": ["1", "2"], "nextPageToken": "token-A"}, # Page 1 - {"items": ["3"], "nextPageToken": "token-B"}, # Page 2 - {"items": []} # Page 3 (Empty/End) + {"items": ["1", "2"], "nextPageToken": "token-A"}, # Page 1 + {"items": ["3"], "nextPageToken": "token-B"}, # Page 2 + {"items": []}, # Page 3 (Empty/End) ] - + iterator = self._make_one( client=client, api_request=api_request, @@ -84,7 +79,7 @@ async def test_iterate_items_multiple_pages(self): assert results == [1, 2, 3] assert api_request.call_count == 3 - + calls = api_request.call_args_list assert calls[0].kwargs["query_params"] == {} assert calls[1].kwargs["query_params"] == {"pageToken": "token-A"} @@ -95,12 +90,12 @@ async def test_iterate_pages_public_property(self): """Test the .pages property which yields Page objects instead of items.""" client = mock.Mock() api_request = mock.AsyncMock() - + api_request.side_effect = [ {"items": ["a"], "nextPageToken": "next"}, - {"items": ["b"]} + {"items": ["b"]}, ] - + iterator = self._make_one( client=client, api_request=api_request, @@ -115,7 +110,7 @@ async def test_iterate_pages_public_property(self): assert len(pages) == 2 assert list(pages[0]) == ["a"] - assert list(pages[1]) == ["b"] + assert list(pages[1]) == ["b"] assert iterator.page_number == 2 @pytest.mark.asyncio @@ -123,7 +118,7 @@ async def test_max_results_limits_requests(self): """Test that max_results alters the request parameters dynamically.""" client = mock.Mock() api_request = mock.AsyncMock() - + # Setup: We want 5 items total. # Page 1 returns 3 items. # Page 2 *should* only be asked for 2 items. @@ -131,24 +126,24 @@ async def test_max_results_limits_requests(self): {"items": ["a", "b", "c"], "nextPageToken": "t1"}, {"items": ["d", "e"], "nextPageToken": "t2"}, ] - + iterator = self._make_one( client=client, api_request=api_request, path="/path", item_to_value=lambda _, x: x, - max_results=5 # <--- Limit set here + max_results=5, # <--- Limit set here ) results = [i async for i in iterator] assert len(results) == 5 assert results == ["a", "b", "c", "d", "e"] - + # Verify Request 1: Asked for max 5 call1_params = api_request.call_args_list[0].kwargs["query_params"] assert call1_params["maxResults"] == 5 - + # Verify Request 2: Asked for max 2 (5 - 3 already fetched) call2_params = api_request.call_args_list[1].kwargs["query_params"] assert call2_params["maxResults"] == 2 @@ -159,15 +154,15 @@ async def test_extra_params_passthrough(self): """Test that extra_params are merged into every request.""" client = mock.Mock() api_request = mock.AsyncMock(return_value={"items": []}) - + custom_params = {"projection": "full", "delimiter": "/"} - + iterator = self._make_one( client=client, api_request=api_request, path="/path", item_to_value=mock.Mock(), - extra_params=custom_params # <--- Input + extra_params=custom_params, # <--- Input ) # Trigger a request @@ -183,13 +178,13 @@ async def test_page_size_configuration(self): """Test that page_size is sent as maxResults if no global max_results is set.""" client = mock.Mock() api_request = mock.AsyncMock(return_value={"items": []}) - + iterator = self._make_one( client=client, api_request=api_request, path="/path", item_to_value=mock.Mock(), - page_size=50 # <--- User preference + page_size=50, # <--- User preference ) await _safe_anext(iterator) @@ -210,7 +205,7 @@ async def test_page_start_callback(self): api_request=api_request, path="/path", item_to_value=lambda _, x: x, - page_start=callback + page_start=callback, ) # Run iteration @@ -258,7 +253,7 @@ async def test_error_if_iterated_twice(self): # First Start async for _ in iterator: pass - + # Second Start (Should Fail) with pytest.raises(ValueError, match="Iterator has already started"): async for _ in iterator: diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 1460e4df8..379e6410b 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -20,10 +20,10 @@ from google.api_core import exceptions from google_crc32c import Checksum -from google.cloud.storage._experimental.asyncio.async_multi_range_downloader import ( +from google.cloud.storage.asyncio.async_multi_range_downloader import ( AsyncMultiRangeDownloader, ) -from google.cloud.storage._experimental.asyncio import async_read_object_stream +from google.cloud.storage.asyncio import async_read_object_stream from io import BytesIO from google.cloud.storage.exceptions import DataCorruption @@ -39,22 +39,22 @@ class TestAsyncMultiRangeDownloader: def create_read_ranges(self, num_ranges): ranges = [] for i in range(num_ranges): - ranges.append( - _storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i) - ) + ranges.append((i, 1, BytesIO())) return ranges # helper method @pytest.mark.asyncio async def _make_mock_mrd( self, - mock_grpc_client, mock_cls_async_read_object_stream, bucket_name=_TEST_BUCKET_NAME, object_name=_TEST_OBJECT_NAME, - generation_number=_TEST_GENERATION_NUMBER, + generation=_TEST_GENERATION_NUMBER, read_handle=_TEST_READ_HANDLE, ): + mock_client = mock.MagicMock() + mock_client.grpc_client = mock.AsyncMock() + mock_stream = mock_cls_async_read_object_stream.return_value mock_stream.open = AsyncMock() mock_stream.generation_number = _TEST_GENERATION_NUMBER @@ -62,39 +62,22 @@ async def _make_mock_mrd( mock_stream.read_handle = _TEST_READ_HANDLE mrd = await AsyncMultiRangeDownloader.create_mrd( - mock_grpc_client, bucket_name, object_name, generation_number, read_handle + mock_client, bucket_name, object_name, generation, read_handle ) - return mrd + return mrd, mock_client @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_create_mrd( - self, mock_grpc_client, mock_cls_async_read_object_stream - ): + async def test_create_mrd(self, mock_cls_async_read_object_stream): # Arrange & Act - mrd = await self._make_mock_mrd( - mock_grpc_client, mock_cls_async_read_object_stream - ) - - # Assert - mock_cls_async_read_object_stream.assert_called_once_with( - client=mock_grpc_client, - bucket_name=_TEST_BUCKET_NAME, - object_name=_TEST_OBJECT_NAME, - generation_number=_TEST_GENERATION_NUMBER, - read_handle=_TEST_READ_HANDLE, - ) + mrd, mock_client = await self._make_mock_mrd(mock_cls_async_read_object_stream) - mrd.read_obj_str.open.assert_called_once() # Assert mock_cls_async_read_object_stream.assert_called_once_with( - client=mock_grpc_client, + client=mock_client.grpc_client, bucket_name=_TEST_BUCKET_NAME, object_name=_TEST_OBJECT_NAME, generation_number=_TEST_GENERATION_NUMBER, @@ -103,26 +86,23 @@ async def test_create_mrd( mrd.read_obj_str.open.assert_called_once() - assert mrd.client == mock_grpc_client + assert mrd.client == mock_client assert mrd.bucket_name == _TEST_BUCKET_NAME assert mrd.object_name == _TEST_OBJECT_NAME - assert mrd.generation_number == _TEST_GENERATION_NUMBER + assert mrd.generation == _TEST_GENERATION_NUMBER assert mrd.read_handle == _TEST_READ_HANDLE assert mrd.persisted_size == _TEST_OBJECT_SIZE assert mrd.is_stream_open @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" ) @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio async def test_download_ranges_via_async_gather( - self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int + self, mock_cls_async_read_object_stream, mock_random_int ): # Arrange data = b"these_are_18_chars" @@ -132,10 +112,10 @@ async def test_download_ranges_via_async_gather( Checksum(data[10:16]).digest(), "big" ) - mock_mrd = await self._make_mock_mrd( - mock_grpc_client, mock_cls_async_read_object_stream - ) - mock_random_int.side_effect = [123, 456, 789, 91011] # for _func_id and read_id + mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) + + mock_random_int.side_effect = [456, 91011] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() @@ -167,12 +147,14 @@ async def test_download_ranges_via_async_gather( ) ], ), + None, ] # Act buffer = BytesIO() second_buffer = BytesIO() lock = asyncio.Lock() + task1 = asyncio.create_task(mock_mrd.download_ranges([(0, 18, buffer)], lock)) task2 = asyncio.create_task( mock_mrd.download_ranges([(10, 6, second_buffer)], lock) @@ -180,58 +162,46 @@ async def test_download_ranges_via_async_gather( await asyncio.gather(task1, task2) # Assert - mock_mrd.read_obj_str.send.side_effect = [ - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=0, read_length=18, read_id=456) - ] - ), - _storage_v2.BidiReadObjectRequest( - read_ranges=[ - _storage_v2.ReadRange(read_offset=10, read_length=6, read_id=91011) - ] - ), - ] assert buffer.getvalue() == data assert second_buffer.getvalue() == data[10:16] @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer" ) @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio async def test_download_ranges( - self, mock_grpc_client, mock_cls_async_read_object_stream, mock_random_int + self, mock_cls_async_read_object_stream, mock_random_int ): # Arrange data = b"these_are_18_chars" crc32c = Checksum(data).digest() crc32c_int = int.from_bytes(crc32c, "big") - mock_mrd = await self._make_mock_mrd( - mock_grpc_client, mock_cls_async_read_object_stream - ) - mock_random_int.side_effect = [123, 456] # for _func_id and read_id + mock_mrd, _ = await self._make_mock_mrd(mock_cls_async_read_object_stream) + + mock_random_int.side_effect = [456] + mock_mrd.read_obj_str.send = AsyncMock() mock_mrd.read_obj_str.recv = AsyncMock() - mock_mrd.read_obj_str.recv.return_value = _storage_v2.BidiReadObjectResponse( - object_data_ranges=[ - _storage_v2.ObjectRangeData( - checksummed_data=_storage_v2.ChecksummedData( - content=data, crc32c=crc32c_int - ), - range_end=True, - read_range=_storage_v2.ReadRange( - read_offset=0, read_length=18, read_id=456 - ), - ) - ], - ) + mock_mrd.read_obj_str.recv.side_effect = [ + _storage_v2.BidiReadObjectResponse( + object_data_ranges=[ + _storage_v2.ObjectRangeData( + checksummed_data=_storage_v2.ChecksummedData( + content=data, crc32c=crc32c_int + ), + range_end=True, + read_range=_storage_v2.ReadRange( + read_offset=0, read_length=18, read_id=456 + ), + ) + ], + ), + None, + ] # Act buffer = BytesIO() @@ -247,16 +217,12 @@ async def test_download_ranges( ) assert buffer.getvalue() == data - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" - ) @pytest.mark.asyncio - async def test_downloading_ranges_with_more_than_1000_should_throw_error( - self, mock_grpc_client - ): + async def test_downloading_ranges_with_more_than_1000_should_throw_error(self): # Arrange + mock_client = mock.MagicMock() mrd = AsyncMultiRangeDownloader( - mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + mock_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME ) # Act + Assert @@ -270,18 +236,15 @@ async def test_downloading_ranges_with_more_than_1000_should_throw_error( ) @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio async def test_opening_mrd_more_than_once_should_throw_error( - self, mock_grpc_client, mock_cls_async_read_object_stream + self, mock_cls_async_read_object_stream ): # Arrange - mrd = await self._make_mock_mrd( - mock_grpc_client, mock_cls_async_read_object_stream + mrd, _ = await self._make_mock_mrd( + mock_cls_async_read_object_stream ) # mock mrd is already opened # Act + Assert @@ -292,16 +255,13 @@ async def test_opening_mrd_more_than_once_should_throw_error( assert str(exc.value) == "Underlying bidi-gRPC stream is already open" @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" ) @pytest.mark.asyncio - async def test_close_mrd(self, mock_grpc_client, mock_cls_async_read_object_stream): + async def test_close_mrd(self, mock_cls_async_read_object_stream): # Arrange - mrd = await self._make_mock_mrd( - mock_grpc_client, mock_cls_async_read_object_stream + mrd, _ = await self._make_mock_mrd( + mock_cls_async_read_object_stream ) # mock mrd is already opened mrd.read_obj_str.close = AsyncMock() @@ -311,16 +271,13 @@ async def test_close_mrd(self, mock_grpc_client, mock_cls_async_read_object_stre # Assert assert not mrd.is_stream_open - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" - ) @pytest.mark.asyncio - async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client): + async def test_close_mrd_not_opened_should_throw_error(self): # Arrange + mock_client = mock.MagicMock() mrd = AsyncMultiRangeDownloader( - mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + mock_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME ) - # Act + Assert with pytest.raises(ValueError) as exc: await mrd.close() @@ -329,16 +286,12 @@ async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client): assert str(exc.value) == "Underlying bidi-gRPC stream is not open" assert not mrd.is_stream_open - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" - ) @pytest.mark.asyncio - async def test_downloading_without_opening_should_throw_error( - self, mock_grpc_client - ): + async def test_downloading_without_opening_should_throw_error(self): # Arrange + mock_client = mock.MagicMock() mrd = AsyncMultiRangeDownloader( - mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + mock_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME ) # Act + Assert @@ -349,34 +302,28 @@ async def test_downloading_without_opening_should_throw_error( assert str(exc.value) == "Underlying bidi-gRPC stream is not open" assert not mrd.is_stream_open - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.google_crc32c" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" - ) - def test_init_raises_if_crc32c_c_extension_is_missing( - self, mock_grpc_client, mock_google_crc32c - ): + @mock.patch("google.cloud.storage.asyncio._utils.google_crc32c") + def test_init_raises_if_crc32c_c_extension_is_missing(self, mock_google_crc32c): mock_google_crc32c.implementation = "python" + mock_client = mock.MagicMock() - with pytest.raises(exceptions.NotFound) as exc_info: - AsyncMultiRangeDownloader(mock_grpc_client, "bucket", "object") + with pytest.raises(exceptions.FailedPrecondition) as exc_info: + AsyncMultiRangeDownloader(mock_client, "bucket", "object") assert "The google-crc32c package is not installed with C support" in str( exc_info.value ) @pytest.mark.asyncio - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.Checksum" - ) - @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" - ) + @mock.patch("google.cloud.storage.asyncio.retry.reads_resumption_strategy.Checksum") async def test_download_ranges_raises_on_checksum_mismatch( - self, mock_client, mock_checksum_class + self, mock_checksum_class ): + from google.cloud.storage.asyncio.async_multi_range_downloader import ( + AsyncMultiRangeDownloader, + ) + + mock_client = mock.MagicMock() mock_stream = mock.AsyncMock( spec=async_read_object_stream._AsyncReadObjectStream ) @@ -392,7 +339,9 @@ async def test_download_ranges_raises_on_checksum_mismatch( checksummed_data=_storage_v2.ChecksummedData( content=test_data, crc32c=server_checksum ), - read_range=_storage_v2.ReadRange(read_id=0), + read_range=_storage_v2.ReadRange( + read_id=0, read_offset=0, read_length=len(test_data) + ), range_end=True, ) ] @@ -405,7 +354,95 @@ async def test_download_ranges_raises_on_checksum_mismatch( mrd._is_stream_open = True with pytest.raises(DataCorruption) as exc_info: - await mrd.download_ranges([(0, len(test_data), BytesIO())]) + with mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.generate_random_56_bit_integer", + return_value=0, + ): + await mrd.download_ranges([(0, len(test_data), BytesIO())]) assert "Checksum mismatch" in str(exc_info.value) mock_checksum_class.assert_called_once_with(test_data) + + @mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.open", + new_callable=AsyncMock, + ) + @mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.close", + new_callable=AsyncMock, + ) + @pytest.mark.asyncio + async def test_async_context_manager_calls_open_and_close( + self, mock_close, mock_open + ): + # Arrange + mock_client = mock.MagicMock() + mrd = AsyncMultiRangeDownloader( + mock_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + ) + + # To simulate the behavior of open and close changing the stream state + async def open_side_effect(): + mrd._is_stream_open = True + + async def close_side_effect(): + mrd._is_stream_open = False + + mock_open.side_effect = open_side_effect + mock_close.side_effect = close_side_effect + mrd._is_stream_open = False + + # Act + async with mrd as downloader: + # Assert + mock_open.assert_called_once() + assert downloader == mrd + assert mrd.is_stream_open + + mock_close.assert_called_once() + assert not mrd.is_stream_open + + @mock.patch( + "google.cloud.storage.asyncio.async_multi_range_downloader._AsyncReadObjectStream" + ) + @pytest.mark.asyncio + async def test_create_mrd_with_generation_number( + self, mock_cls_async_read_object_stream, caplog + ): + # Arrange + mock_client = mock.MagicMock() + mock_client.grpc_client = mock.AsyncMock() + + mock_stream = mock_cls_async_read_object_stream.return_value + mock_stream.open = AsyncMock() + mock_stream.generation_number = _TEST_GENERATION_NUMBER + mock_stream.persisted_size = _TEST_OBJECT_SIZE + mock_stream.read_handle = _TEST_READ_HANDLE + + # Act + mrd = await AsyncMultiRangeDownloader.create_mrd( + mock_client, + _TEST_BUCKET_NAME, + _TEST_OBJECT_NAME, + generation_number=_TEST_GENERATION_NUMBER, + read_handle=_TEST_READ_HANDLE, + ) + + # Assert + assert mrd.generation == _TEST_GENERATION_NUMBER + assert "'generation_number' is deprecated" in caplog.text + + @pytest.mark.asyncio + async def test_create_mrd_with_both_generation_and_generation_number(self): + # Arrange + mock_client = mock.MagicMock() + + # Act & Assert + with pytest.raises(TypeError): + await AsyncMultiRangeDownloader.create_mrd( + mock_client, + _TEST_BUCKET_NAME, + _TEST_OBJECT_NAME, + generation=_TEST_GENERATION_NUMBER, + generation_number=_TEST_GENERATION_NUMBER, + ) diff --git a/tests/unit/asyncio/test_async_read_object_stream.py b/tests/unit/asyncio/test_async_read_object_stream.py index 4ba8d34a1..2d2f28edd 100644 --- a/tests/unit/asyncio/test_async_read_object_stream.py +++ b/tests/unit/asyncio/test_async_read_object_stream.py @@ -17,8 +17,8 @@ from unittest.mock import AsyncMock from google.cloud import _storage_v2 -from google.cloud.storage._experimental.asyncio import async_read_object_stream -from google.cloud.storage._experimental.asyncio.async_read_object_stream import ( +from google.cloud.storage.asyncio import async_read_object_stream +from google.cloud.storage.asyncio.async_read_object_stream import ( _AsyncReadObjectStream, ) @@ -79,11 +79,9 @@ async def instantiate_read_obj_stream_with_read_handle( return read_obj_stream +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) def test_init_with_bucket_object_generation(mock_client, mock_async_bidi_rpc): # Arrange @@ -110,11 +108,9 @@ def test_init_with_bucket_object_generation(mock_client, mock_async_bidi_rpc): assert read_obj_stream.rpc == rpc_sentinel +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_open(mock_client, mock_cls_async_bidi_rpc): @@ -136,11 +132,9 @@ async def test_open(mock_client, mock_cls_async_bidi_rpc): assert read_obj_stream.is_stream_open +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_open_with_read_handle(mock_client, mock_cls_async_bidi_rpc): @@ -162,11 +156,9 @@ async def test_open_with_read_handle(mock_client, mock_cls_async_bidi_rpc): assert read_obj_stream.is_stream_open +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_open_when_already_open_should_raise_error( @@ -185,11 +177,9 @@ async def test_open_when_already_open_should_raise_error( assert str(exc.value) == "Stream is already open" +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_close(mock_client, mock_cls_async_bidi_rpc): @@ -197,20 +187,42 @@ async def test_close(mock_client, mock_cls_async_bidi_rpc): read_obj_stream = await instantiate_read_obj_stream( mock_client, mock_cls_async_bidi_rpc, open=True ) + read_obj_stream.requests_done = AsyncMock() # act await read_obj_stream.close() # assert + read_obj_stream.requests_done.assert_called_once() read_obj_stream.socket_like_rpc.close.assert_called_once() assert not read_obj_stream.is_stream_open +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) +@pytest.mark.asyncio +async def test_requests_done(mock_client, mock_cls_async_bidi_rpc): + """Test that requests_done signals the end of requests.""" + # Arrange + read_obj_stream = await instantiate_read_obj_stream( + mock_client, mock_cls_async_bidi_rpc, open=True + ) + read_obj_stream.socket_like_rpc.send = AsyncMock() + read_obj_stream.socket_like_rpc.recv = AsyncMock() + + # Act + await read_obj_stream.requests_done() + + # Assert + read_obj_stream.socket_like_rpc.send.assert_called_once_with(None) + read_obj_stream.socket_like_rpc.recv.assert_called_once() + + +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_close_without_open_should_raise_error( @@ -229,11 +241,9 @@ async def test_close_without_open_should_raise_error( assert str(exc.value) == "Stream is not open" +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_send(mock_client, mock_cls_async_bidi_rpc): @@ -252,11 +262,9 @@ async def test_send(mock_client, mock_cls_async_bidi_rpc): ) +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_send_without_open_should_raise_error( @@ -275,11 +283,9 @@ async def test_send_without_open_should_raise_error( assert str(exc.value) == "Stream is not open" +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_recv(mock_client, mock_cls_async_bidi_rpc): @@ -300,11 +306,9 @@ async def test_recv(mock_client, mock_cls_async_bidi_rpc): assert response == bidi_read_object_response +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_recv_without_open_should_raise_error( @@ -323,11 +327,9 @@ async def test_recv_without_open_should_raise_error( assert str(exc.value) == "Stream is not open" +@mock.patch("google.cloud.storage.asyncio.async_read_object_stream.AsyncBidiRpc") @mock.patch( - "google.cloud.storage._experimental.asyncio.async_read_object_stream.AsyncBidiRpc" -) -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + "google.cloud.storage.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" ) @pytest.mark.asyncio async def test_recv_updates_read_handle_on_refresh( diff --git a/tests/unit/asyncio/test_async_write_object_stream.py b/tests/unit/asyncio/test_async_write_object_stream.py index c6ea8a8ff..4e952336b 100644 --- a/tests/unit/asyncio/test_async_write_object_stream.py +++ b/tests/unit/asyncio/test_async_write_object_stream.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock as mock +from unittest.mock import AsyncMock, MagicMock import pytest -from unittest import mock +import grpc -from unittest.mock import AsyncMock -from google.cloud.storage._experimental.asyncio.async_write_object_stream import ( + +from google.cloud.storage.asyncio.async_write_object_stream import ( _AsyncWriteObjectStream, ) from google.cloud import _storage_v2 @@ -25,372 +27,233 @@ OBJECT = "my-object" GENERATION = 12345 WRITE_HANDLE = b"test-handle" +FULL_BUCKET_PATH = f"projects/_/buckets/{BUCKET}" @pytest.fixture def mock_client(): - """Mock the async gRPC client.""" - mock_transport = mock.AsyncMock() + """Fixture to provide a mock gRPC client.""" + client = MagicMock() + # Mocking transport internal structures + mock_transport = MagicMock() mock_transport.bidi_write_object = mock.sentinel.bidi_write_object mock_transport._wrapped_methods = { mock.sentinel.bidi_write_object: mock.sentinel.wrapped_bidi_write_object } - - mock_gapic_client = mock.AsyncMock() - mock_gapic_client._transport = mock_transport - - client = mock.AsyncMock() - client._client = mock_gapic_client + client._client._transport = mock_transport return client -async def instantiate_write_obj_stream(mock_client, mock_cls_async_bidi_rpc, open=True): - """Helper to create an instance of _AsyncWriteObjectStream and open it by default.""" - socket_like_rpc = AsyncMock() - mock_cls_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = AsyncMock() - socket_like_rpc.send = AsyncMock() - socket_like_rpc.close = AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = AsyncMock(return_value=mock_response) - - write_obj_stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - if open: - await write_obj_stream.open() - - return write_obj_stream - - -def test_async_write_object_stream_init(mock_client): - """Test the constructor of _AsyncWriteObjectStream.""" - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - assert stream.client == mock_client - assert stream.bucket_name == BUCKET - assert stream.object_name == OBJECT - assert stream.generation_number is None - assert stream.write_handle is None - assert stream._full_bucket_name == f"projects/_/buckets/{BUCKET}" - assert stream.rpc == mock.sentinel.wrapped_bidi_write_object - assert stream.metadata == ( - ("x-goog-request-params", f"bucket=projects/_/buckets/{BUCKET}"), - ) - assert stream.socket_like_rpc is None - assert not stream._is_stream_open - assert stream.first_bidi_write_req is None - assert stream.persisted_size == 0 - assert stream.object_resource is None - - -def test_async_write_object_stream_init_with_generation_and_handle(mock_client): - """Test the constructor with optional arguments.""" - generation = 12345 - write_handle = b"test-handle" - stream = _AsyncWriteObjectStream( - mock_client, - BUCKET, - OBJECT, - generation_number=generation, - write_handle=write_handle, - ) - - assert stream.generation_number == generation - assert stream.write_handle == write_handle - - -def test_async_write_object_stream_init_raises_value_error(): - """Test that the constructor raises ValueError for missing arguments.""" - with pytest.raises(ValueError, match="client must be provided"): - _AsyncWriteObjectStream(None, BUCKET, OBJECT) - - with pytest.raises(ValueError, match="bucket_name must be provided"): - _AsyncWriteObjectStream(mock.Mock(), None, OBJECT) - - with pytest.raises(ValueError, match="object_name must be provided"): - _AsyncWriteObjectStream(mock.Mock(), BUCKET, None) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_for_new_object(mock_async_bidi_rpc, mock_client): - """Test opening a stream for a new object.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - - # Act - await stream.open() - - # Assert - assert stream._is_stream_open - socket_like_rpc.open.assert_called_once() - socket_like_rpc.recv.assert_called_once() - assert stream.generation_number == GENERATION - assert stream.write_handle == WRITE_HANDLE - assert stream.persisted_size == 0 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_for_existing_object(mock_async_bidi_rpc, mock_client): - """Test opening a stream for an existing object.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.size = 1024 - mock_response.resource.generation = GENERATION - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream( - mock_client, BUCKET, OBJECT, generation_number=GENERATION - ) - - # Act - await stream.open() - - # Assert - assert stream._is_stream_open - socket_like_rpc.open.assert_called_once() - socket_like_rpc.recv.assert_called_once() - assert stream.generation_number == GENERATION - assert stream.write_handle == WRITE_HANDLE - assert stream.persisted_size == 1024 - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_when_already_open_raises_error(mock_async_bidi_rpc, mock_client): - """Test that opening an already open stream raises a ValueError.""" - # Arrange - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.open = mock.AsyncMock() - - mock_response = mock.MagicMock(spec=_storage_v2.BidiWriteObjectResponse) - mock_response.resource = mock.MagicMock(spec=_storage_v2.Object) - mock_response.resource.generation = GENERATION - mock_response.resource.size = 0 - mock_response.write_handle = WRITE_HANDLE - socket_like_rpc.recv = mock.AsyncMock(return_value=mock_response) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - await stream.open() - - # Act & Assert - with pytest.raises(ValueError, match="Stream is already open"): - await stream.open() +class TestAsyncWriteObjectStream: + """Test suite for AsyncWriteObjectStream.""" + # ------------------------------------------------------------------------- + # Initialization Tests + # ------------------------------------------------------------------------- -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_object_resource( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if object_resource is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - - mock_reponse = mock.AsyncMock() - type(mock_reponse).resource = mock.PropertyMock(return_value=None) - socket_like_rpc.recv.return_value = mock_reponse - - # Note: Don't use below code as unittest library automatically assigns an - # `AsyncMock` object to an attribute, if not set. - # socket_like_rpc.recv.return_value = mock.AsyncMock( - # return_value=_storage_v2.BidiWriteObjectResponse(resource=None) - # ) - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Failed to obtain object resource after opening the stream" - ): - await stream.open() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_generation( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if generation is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - - # Configure the mock response object - mock_response = mock.AsyncMock() - type(mock_response.resource).generation = mock.PropertyMock(return_value=None) - socket_like_rpc.recv.return_value = mock_response - - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises( - ValueError, match="Failed to obtain object generation after opening the stream" - ): + def test_init_basic(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + assert stream.bucket_name == BUCKET + assert stream.object_name == OBJECT + assert stream._full_bucket_name == FULL_BUCKET_PATH + assert stream.metadata == ( + ("x-goog-request-params", f"bucket={FULL_BUCKET_PATH}"), + ) + assert not stream.is_stream_open + + def test_init_raises_value_error(self, mock_client): + with pytest.raises(ValueError, match="client must be provided"): + _AsyncWriteObjectStream(None, BUCKET, OBJECT) + with pytest.raises(ValueError, match="bucket_name must be provided"): + _AsyncWriteObjectStream(mock_client, None, OBJECT) + with pytest.raises(ValueError, match="object_name must be provided"): + _AsyncWriteObjectStream(mock_client, BUCKET, None) + + # ------------------------------------------------------------------------- + # Open Stream Tests + # ------------------------------------------------------------------------- + + @mock.patch("google.cloud.storage.asyncio.async_write_object_stream.AsyncBidiRpc") + @pytest.mark.asyncio + async def test_open_new_object(self, mock_rpc_cls, mock_client): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + + # We don't use spec here to avoid descriptor issues with nested protos + mock_response = MagicMock() + mock_response.persisted_size = 0 + mock_response.resource.generation = GENERATION + mock_response.resource.size = 0 + mock_response.write_handle = WRITE_HANDLE + mock_rpc.recv = AsyncMock(return_value=mock_response) + + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) await stream.open() - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_open_raises_error_on_missing_write_handle( - mock_async_bidi_rpc, mock_client -): - """Test that open raises ValueError if write_handle is not in the response.""" - socket_like_rpc = mock.AsyncMock() - mock_async_bidi_rpc.return_value = socket_like_rpc - socket_like_rpc.recv = mock.AsyncMock( - return_value=_storage_v2.BidiWriteObjectResponse( - resource=_storage_v2.Object(generation=GENERATION), write_handle=None + # Check if BidiRpc was initialized with WriteObjectSpec + call_args = mock_rpc_cls.call_args + initial_request = call_args.kwargs["initial_request"] + assert initial_request.write_object_spec is not None + assert initial_request.write_object_spec.resource.name == OBJECT + assert initial_request.write_object_spec.appendable + + assert stream.is_stream_open + assert stream.write_handle == WRITE_HANDLE + assert stream.generation_number == GENERATION + + @mock.patch("google.cloud.storage.asyncio.async_write_object_stream.AsyncBidiRpc") + @pytest.mark.asyncio + async def test_open_existing_object_with_token(self, mock_rpc_cls, mock_client): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + + # Ensure resource is None so persisted_size logic doesn't get overwritten by child mocks + mock_response = MagicMock() + mock_response.persisted_size = 1024 + mock_response.resource = None + mock_response.write_handle = WRITE_HANDLE + mock_rpc.recv = AsyncMock(return_value=mock_response) + + stream = _AsyncWriteObjectStream( + mock_client, + BUCKET, + OBJECT, + generation_number=GENERATION, + routing_token="token-123", ) - ) - stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) - with pytest.raises(ValueError, match="Failed to obtain write_handle"): await stream.open() + # Verify AppendObjectSpec attributes + initial_request = mock_rpc_cls.call_args.kwargs["initial_request"] + assert initial_request.append_object_spec is not None + assert initial_request.append_object_spec.generation == GENERATION + assert initial_request.append_object_spec.routing_token == "token-123" + assert stream.persisted_size == 1024 + + @mock.patch("google.cloud.storage.asyncio.async_write_object_stream.AsyncBidiRpc") + @pytest.mark.asyncio + async def test_open_metadata_merging(self, mock_rpc_cls, mock_client): + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) + + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + extra_metadata = [("x-custom", "val"), ("x-goog-request-params", "extra=param")] + + await stream.open(metadata=extra_metadata) + + # Verify that metadata combined bucket and extra params + passed_metadata = mock_rpc_cls.call_args.kwargs["metadata"] + meta_dict = dict(passed_metadata) + assert meta_dict["x-custom"] == "val" + # Params should be comma separated + params = meta_dict["x-goog-request-params"] + assert f"bucket={FULL_BUCKET_PATH}" in params + assert "extra=param" in params + + @pytest.mark.asyncio + async def test_open_already_open_raises(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + with pytest.raises(ValueError, match="already open"): + await stream.open() + + # ------------------------------------------------------------------------- + # Send & Recv & Close Tests + # ------------------------------------------------------------------------- + + @mock.patch("google.cloud.storage.asyncio.async_write_object_stream.AsyncBidiRpc") + @pytest.mark.asyncio + async def test_send_and_recv_logic(self, mock_rpc_cls, mock_client): + # Setup open stream + mock_rpc = mock_rpc_cls.return_value + mock_rpc.open = AsyncMock() + mock_rpc.send = AsyncMock() # Crucial: Must be AsyncMock + mock_rpc.recv = AsyncMock(return_value=MagicMock(resource=None)) + + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + await stream.open() -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_close(mock_cls_async_bidi_rpc, mock_client): - """Test that close successfully closes the stream.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - - # Act - await write_obj_stream.close() - - # Assert - write_obj_stream.socket_like_rpc.close.assert_called_once() - assert not write_obj_stream.is_stream_open - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_close_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that closing a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.close() - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_send(mock_cls_async_bidi_rpc, mock_client): - """Test that send calls the underlying rpc's send method.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - - # Act - bidi_write_object_request = _storage_v2.BidiWriteObjectRequest() - await write_obj_stream.send(bidi_write_object_request) - - # Assert - write_obj_stream.socket_like_rpc.send.assert_called_once_with( - bidi_write_object_request - ) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_send_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that sending on a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.send(_storage_v2.BidiWriteObjectRequest()) - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_recv(mock_cls_async_bidi_rpc, mock_client): - """Test that recv calls the underlying rpc's recv method.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=True - ) - bidi_write_object_response = _storage_v2.BidiWriteObjectResponse() - write_obj_stream.socket_like_rpc.recv = AsyncMock( - return_value=bidi_write_object_response - ) - - # Act - response = await write_obj_stream.recv() - - # Assert - write_obj_stream.socket_like_rpc.recv.assert_called_once() - assert response == bidi_write_object_response - - -@pytest.mark.asyncio -@mock.patch( - "google.cloud.storage._experimental.asyncio.async_write_object_stream.AsyncBidiRpc" -) -async def test_recv_without_open_should_raise_error( - mock_cls_async_bidi_rpc, mock_client -): - """Test that receiving on a stream that is not open raises a ValueError.""" - # Arrange - write_obj_stream = await instantiate_write_obj_stream( - mock_client, mock_cls_async_bidi_rpc, open=False - ) - - # Act & Assert - with pytest.raises(ValueError, match="Stream is not open"): - await write_obj_stream.recv() + # Test Send + req = _storage_v2.BidiWriteObjectRequest(write_offset=0) + await stream.send(req) + mock_rpc.send.assert_awaited_with(req) + + # Test Recv with state update + mock_response = MagicMock() + mock_response.persisted_size = 5000 + mock_response.write_handle = b"new-handle" + mock_response.resource = None + mock_rpc.recv.return_value = mock_response + + res = await stream.recv() + assert res.persisted_size == 5000 + assert stream.persisted_size == 5000 + assert stream.write_handle == b"new-handle" + + @pytest.mark.asyncio + async def test_close_success(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + + stream.socket_like_rpc.send = AsyncMock() + first_resp = _storage_v2.BidiWriteObjectResponse(persisted_size=100) + stream.socket_like_rpc.recv = AsyncMock(side_effect=[first_resp, grpc.aio.EOF]) + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + stream.socket_like_rpc.close.assert_awaited_once() + assert not stream.is_stream_open + assert stream.persisted_size == 100 + + @pytest.mark.asyncio + async def test_close_with_persisted_size_then_eof(self, mock_client): + """Test close when first recv has persisted_size, second is EOF.""" + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + + # First response has persisted_size (NOT EOF, intermediate) + persisted_resp = _storage_v2.BidiWriteObjectResponse(persisted_size=500) + # Second response is EOF (None) + eof_resp = grpc.aio.EOF + + stream.socket_like_rpc.send = AsyncMock() + stream.socket_like_rpc.recv = AsyncMock(side_effect=[persisted_resp, eof_resp]) + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + + # Verify two recv calls: first has persisted_size (NOT EOF), so read second (EOF) + assert stream.socket_like_rpc.recv.await_count == 2 + assert stream.persisted_size == 500 + assert not stream.is_stream_open + + @pytest.mark.asyncio + async def test_close_with_grpc_aio_eof_response(self, mock_client): + """Test close when first recv is grpc.aio.EOF sentinel.""" + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + stream._is_stream_open = True + stream.socket_like_rpc = AsyncMock() + + # First recv returns grpc.aio.EOF (explicit sentinel from finalize) + stream.socket_like_rpc.send = AsyncMock() + stream.socket_like_rpc.recv = AsyncMock(return_value=grpc.aio.EOF) + stream.socket_like_rpc.close = AsyncMock() + + await stream.close() + + # Verify only one recv call (grpc.aio.EOF=EOF, so don't read second) + assert stream.socket_like_rpc.recv.await_count == 1 + assert not stream.is_stream_open + + @pytest.mark.asyncio + async def test_methods_require_open_raises(self, mock_client): + stream = _AsyncWriteObjectStream(mock_client, BUCKET, OBJECT) + with pytest.raises(ValueError, match="Stream is not open"): + await stream.send(MagicMock()) + with pytest.raises(ValueError, match="Stream is not open"): + await stream.recv() + with pytest.raises(ValueError, match="Stream is not open"): + await stream.close() diff --git a/tests/unit/test_blob.py b/tests/unit/test_blob.py index cbf53b398..2359de501 100644 --- a/tests/unit/test_blob.py +++ b/tests/unit/test_blob.py @@ -228,6 +228,29 @@ def test__set_properties_w_kms_key_name(self): ) self._set_properties_helper(kms_key_name=kms_resource) + def test_finalized_time_property_is_none(self): + BLOB_NAME = "blob-name" + bucket = _Bucket() + blob = self._make_one(BLOB_NAME, bucket=bucket) + self.assertIsNone(blob.finalized_time) + + def test_finalized_time_property_is_not_none(self): + from google.cloud.storage import blob as blob_module + + BLOB_NAME = "blob-name" + bucket = _Bucket() + blob = self._make_one(BLOB_NAME, bucket=bucket) + + timestamp = "2024-07-29T12:34:56.123456Z" + blob._properties["finalizedTime"] = timestamp + + mock_datetime = mock.Mock() + with mock.patch.object( + blob_module, "_rfc3339_nanos_to_datetime", return_value=mock_datetime + ) as mocked: + self.assertEqual(blob.finalized_time, mock_datetime) + mocked.assert_called_once_with(timestamp) + def test_chunk_size_ctor(self): from google.cloud.storage.blob import Blob @@ -3064,7 +3087,13 @@ def _make_resumable_transport( fake_response2 = self._mock_requests_response( http.client.PERMANENT_REDIRECT, headers2 ) - json_body = json.dumps({"size": str(total_bytes), "md5Hash": md5_checksum_value, "crc32c": crc32c_checksum_value}) + json_body = json.dumps( + { + "size": str(total_bytes), + "md5Hash": md5_checksum_value, + "crc32c": crc32c_checksum_value, + } + ) if data_corruption: fake_response3 = DataCorruption(None) else: diff --git a/tests/unit/test_grpc_client.py b/tests/unit/test_grpc_client.py index 9eca1b280..6dbbfbaa6 100644 --- a/tests/unit/test_grpc_client.py +++ b/tests/unit/test_grpc_client.py @@ -16,6 +16,7 @@ from unittest import mock from google.auth import credentials as auth_credentials from google.api_core import client_options as client_options_lib +from google.cloud.storage import grpc_client def _make_credentials(spec=None): @@ -30,8 +31,6 @@ class TestGrpcClient(unittest.TestCase): def test_constructor_defaults_and_options( self, mock_storage_client, mock_base_client ): - from google.cloud.storage._experimental import grpc_client - mock_transport_cls = mock.MagicMock() mock_storage_client.get_transport_class.return_value = mock_transport_cls mock_creds = _make_credentials(spec=["_base", "_get_project_id"]) @@ -71,13 +70,11 @@ def test_constructor_defaults_and_options( # 4. Assert the client instance holds the mocked GAPIC client. self.assertIs(client.grpc_client, mock_storage_client.return_value) - @mock.patch("google.cloud.storage._experimental.grpc_client.ClientWithProject") + @mock.patch("google.cloud.storage.grpc_client.ClientWithProject") @mock.patch("google.cloud._storage_v2.StorageClient") def test_constructor_disables_direct_path( self, mock_storage_client, mock_base_client ): - from google.cloud.storage._experimental import grpc_client - mock_transport_cls = mock.MagicMock() mock_storage_client.get_transport_class.return_value = mock_transport_cls mock_creds = _make_credentials() @@ -94,13 +91,11 @@ def test_constructor_disables_direct_path( attempt_direct_path=False ) - @mock.patch("google.cloud.storage._experimental.grpc_client.ClientWithProject") + @mock.patch("google.cloud.storage.grpc_client.ClientWithProject") @mock.patch("google.cloud._storage_v2.StorageClient") def test_constructor_initialize_with_api_key( self, mock_storage_client, mock_base_client ): - from google.cloud.storage._experimental import grpc_client - mock_transport_cls = mock.MagicMock() mock_storage_client.get_transport_class.return_value = mock_transport_cls mock_creds = _make_credentials() @@ -124,11 +119,9 @@ def test_constructor_initialize_with_api_key( client_options={"api_key": "test-api-key"}, ) - @mock.patch("google.cloud.storage._experimental.grpc_client.ClientWithProject") + @mock.patch("google.cloud.storage.grpc_client.ClientWithProject") @mock.patch("google.cloud._storage_v2.StorageClient") def test_grpc_client_property(self, mock_storage_client, mock_base_client): - from google.cloud.storage._experimental import grpc_client - mock_creds = _make_credentials() mock_base_client.return_value._credentials = mock_creds @@ -138,13 +131,11 @@ def test_grpc_client_property(self, mock_storage_client, mock_base_client): self.assertIs(retrieved_client, mock_storage_client.return_value) - @mock.patch("google.cloud.storage._experimental.grpc_client.ClientWithProject") + @mock.patch("google.cloud.storage.grpc_client.ClientWithProject") @mock.patch("google.cloud._storage_v2.StorageClient") def test_constructor_with_api_key_and_client_options( self, mock_storage_client, mock_base_client ): - from google.cloud.storage._experimental import grpc_client - mock_transport_cls = mock.MagicMock() mock_storage_client.get_transport_class.return_value = mock_transport_cls mock_transport = mock_transport_cls.return_value @@ -173,13 +164,11 @@ def test_constructor_with_api_key_and_client_options( ) self.assertEqual(client_options_obj.api_key, "new-test-key") - @mock.patch("google.cloud.storage._experimental.grpc_client.ClientWithProject") + @mock.patch("google.cloud.storage.grpc_client.ClientWithProject") @mock.patch("google.cloud._storage_v2.StorageClient") def test_constructor_with_api_key_and_dict_options( self, mock_storage_client, mock_base_client ): - from google.cloud.storage._experimental import grpc_client - mock_creds = _make_credentials() mock_base_instance = mock_base_client.return_value mock_base_instance._credentials = mock_creds