Skip to content

Commit a1dffb5

Browse files
committed
generated sync code
1 parent b81a9be commit a1dffb5

File tree

4 files changed

+168
-38
lines changed

4 files changed

+168
-38
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License
14+
#
15+
16+
# This file is automatically generated by CrossSync. Do not edit manually.
17+
18+
from __future__ import annotations
19+
from typing import Callable
20+
from grpc import ChannelConnectivity
21+
from grpc import Call
22+
from grpc import Channel
23+
from grpc import UnaryUnaryMultiCallable
24+
from grpc import UnaryStreamMultiCallable
25+
from grpc import StreamUnaryMultiCallable
26+
from grpc import StreamStreamMultiCallable
27+
28+
29+
class _WrappedMultiCallable:
30+
"""
31+
Wrapper class that implements the grpc MultiCallable interface.
32+
Allows generic functions that return calls to pass checks for
33+
MultiCallable objects.
34+
"""
35+
36+
def __init__(self, call_factory: Callable[..., Call]):
37+
self._call_factory = call_factory
38+
39+
def __call__(self, *args, **kwargs) -> Call:
40+
return self._call_factory(*args, **kwargs)
41+
42+
43+
class WrappedUnaryUnaryMultiCallable(_WrappedMultiCallable, UnaryUnaryMultiCallable):
44+
pass
45+
46+
47+
class WrappedUnaryStreamMultiCallable(_WrappedMultiCallable, UnaryStreamMultiCallable):
48+
pass
49+
50+
51+
class WrappedStreamUnaryMultiCallable(_WrappedMultiCallable, StreamUnaryMultiCallable):
52+
pass
53+
54+
55+
class WrappedStreamStreamMultiCallable(
56+
_WrappedMultiCallable, StreamStreamMultiCallable
57+
):
58+
pass
59+
60+
61+
class _WrappedChannel(Channel):
62+
"""
63+
A wrapper around a gRPC channel. All methods are passed
64+
through to the underlying channel.
65+
"""
66+
67+
def __init__(self, channel: Channel):
68+
self._channel = channel
69+
70+
def unary_unary(self, *args, **kwargs) -> UnaryUnaryMultiCallable:
71+
return WrappedUnaryUnaryMultiCallable(
72+
lambda *call_args, **call_kwargs: self._channel.unary_unary(
73+
*args, **kwargs
74+
)(*call_args, **call_kwargs)
75+
)
76+
77+
def unary_stream(self, *args, **kwargs) -> UnaryStreamMultiCallable:
78+
return WrappedUnaryStreamMultiCallable(
79+
lambda *call_args, **call_kwargs: self._channel.unary_stream(
80+
*args, **kwargs
81+
)(*call_args, **call_kwargs)
82+
)
83+
84+
def stream_unary(self, *args, **kwargs) -> StreamUnaryMultiCallable:
85+
return WrappedStreamUnaryMultiCallable(
86+
lambda *call_args, **call_kwargs: self._channel.stream_unary(
87+
*args, **kwargs
88+
)(*call_args, **call_kwargs)
89+
)
90+
91+
def stream_stream(self, *args, **kwargs) -> StreamStreamMultiCallable:
92+
return WrappedStreamStreamMultiCallable(
93+
lambda *call_args, **call_kwargs: self._channel.stream_stream(
94+
*args, **kwargs
95+
)(*call_args, **call_kwargs)
96+
)
97+
98+
def close(self, grace=None):
99+
return self._channel.close(grace=grace)
100+
101+
def channel_ready(self):
102+
return self._channel.channel_ready()
103+
104+
def __enter__(self):
105+
self._channel.__enter__()
106+
return self
107+
108+
def __exit__(self, exc_type, exc_val, exc_tb):
109+
return self._channel.__exit__(exc_type, exc_val, exc_tb)
110+
111+
def get_state(self, try_to_connect: bool = False) -> ChannelConnectivity:
112+
return self._channel.get_state(try_to_connect=try_to_connect)
113+
114+
def wait_for_state_change(self, last_observed_state):
115+
return self._channel.wait_for_state_change(last_observed_state)
116+
117+
def __getattr__(self, name):
118+
return getattr(self._channel, name)
119+
120+
121+
class _ReplaceableChannel(_WrappedChannel):
122+
def __init__(self, channel_fn: Callable[[], Channel]):
123+
self._channel_fn = channel_fn
124+
self._channel = channel_fn()
125+
126+
def create_channel(self) -> Channel:
127+
new_channel = self._channel_fn()
128+
return new_channel
129+
130+
def replace_wrapped_channel(self, new_channel: Channel) -> Channel:
131+
old_channel = self._channel
132+
self._channel = new_channel
133+
return old_channel

google/cloud/bigtable/data/_sync_autogen/client.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@
7575
from google.cloud.bigtable.data._cross_sync import CrossSync
7676
from typing import Iterable
7777
from grpc import insecure_channel
78-
from grpc import intercept_channel
7978
from google.cloud.bigtable_v2.services.bigtable.transports import (
8079
BigtableGrpcTransport as TransportType,
8180
)
8281
from google.cloud.bigtable.data._sync_autogen.mutations_batcher import _MB_SIZE
82+
from google.cloud.bigtable.data._async._replaceable_channel import _ReplaceableChannel
8383

8484
if TYPE_CHECKING:
8585
from google.cloud.bigtable.data._helpers import RowKeySamples
@@ -131,15 +131,13 @@ def __init__(
131131
client_options = cast(
132132
Optional[client_options_lib.ClientOptions], client_options
133133
)
134-
custom_channel = None
135134
self._emulator_host = os.getenv(BIGTABLE_EMULATOR)
136135
if self._emulator_host is not None:
137136
warnings.warn(
138137
"Connecting to Bigtable emulator at {}".format(self._emulator_host),
139138
RuntimeWarning,
140139
stacklevel=2,
141140
)
142-
custom_channel = insecure_channel(self._emulator_host)
143141
if credentials is None:
144142
credentials = google.auth.credentials.AnonymousCredentials()
145143
if project is None:
@@ -155,7 +153,7 @@ def __init__(
155153
client_options=client_options,
156154
client_info=self.client_info,
157155
transport=lambda *args, **kwargs: TransportType(
158-
*args, **kwargs, channel=custom_channel
156+
*args, **kwargs, channel=self._build_grpc_channel
159157
),
160158
)
161159
self._is_closed = CrossSync._Sync_Impl.Event()
@@ -179,6 +177,13 @@ def __init__(
179177
stacklevel=2,
180178
)
181179

180+
def _build_grpc_channel(self, *args, **kwargs) -> _ReplaceableChannel:
181+
if self._emulator_host is not None:
182+
create_channel_fn = partial(insecure_channel, self._emulator_host)
183+
else:
184+
create_channel_fn = partial(TransportType.create_channel, *args, **kwargs)
185+
return _ReplaceableChannel(create_channel_fn)
186+
182187
@staticmethod
183188
def _client_version() -> str:
184189
"""Helper function to return the client version string for this client"""
@@ -277,27 +282,28 @@ def _manage_channel(
277282
between `refresh_interval_min` and `refresh_interval_max`
278283
grace_period: time to allow previous channel to serve existing
279284
requests before closing, in seconds"""
285+
if not isinstance(self.transport.grpc_channel, _AsyncReplaceableChannel):
286+
warnings.warn("Channel does not support auto-refresh.")
287+
return
288+
super_channel: _AsyncReplaceableChannel = self.transport.grpc_channel
280289
first_refresh = self._channel_init_time + random.uniform(
281290
refresh_interval_min, refresh_interval_max
282291
)
283292
next_sleep = max(first_refresh - time.monotonic(), 0)
284293
if next_sleep > 0:
285-
self._ping_and_warm_instances(channel=self.transport.grpc_channel)
294+
self._ping_and_warm_instances(channel=super_channel)
286295
while not self._is_closed.is_set():
287296
CrossSync._Sync_Impl.event_wait(
288297
self._is_closed, next_sleep, async_break_early=False
289298
)
290299
if self._is_closed.is_set():
291300
break
292301
start_timestamp = time.monotonic()
293-
old_channel = self.transport.grpc_channel
294-
new_channel = self.transport.create_channel()
295-
new_channel = intercept_channel(new_channel, self.transport._interceptor)
302+
new_channel = super_channel.create_channel()
296303
self._ping_and_warm_instances(channel=new_channel)
297-
self.transport._grpc_channel = new_channel
298-
self.transport._logged_channel = new_channel
299-
self.transport._stubs = {}
300-
self.transport._prep_wrapped_messages(self.client_info)
304+
old_channel = super_channel.replace_wrapped_channel(
305+
new_channel, grace_period
306+
)
301307
if grace_period:
302308
self._is_closed.wait(grace_period)
303309
old_channel.close()

tests/system/data/test_system_autogen.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,14 @@ def test_channel_refresh(self, table_id, instance_id, temp_rows):
201201
CrossSync._Sync_Impl.yield_to_event_loop()
202202
with client.get_table(instance_id, table_id) as table:
203203
rows = table.read_rows({})
204-
first_channel = client.transport.grpc_channel
204+
channel_wrapper = client.transport.grpc_channel
205+
first_channel = client.transport.grpc_channel._channel
205206
assert len(rows) == 2
206207
CrossSync._Sync_Impl.sleep(2)
207208
rows_after_refresh = table.read_rows({})
208209
assert len(rows_after_refresh) == 2
209-
assert client.transport.grpc_channel is not first_channel
210-
print(table)
210+
assert client.transport.grpc_channel is channel_wrapper
211+
assert client.transport.grpc_channel._channel is not first_channel
211212
finally:
212213
client.close()
213214

tests/unit/data/_sync_autogen/test_client.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test__start_background_channel_refresh_task_exists(self):
176176
client.close()
177177

178178
def test__start_background_channel_refresh(self):
179-
client = self._make_client(project="project-id")
179+
client = self._make_client(project="project-id", use_emulator=False)
180180
with mock.patch.object(
181181
client, "_ping_and_warm_instances", CrossSync._Sync_Impl.Mock()
182182
) as ping_and_warm:
@@ -282,7 +282,7 @@ def test__manage_channel_first_sleep(
282282
with mock.patch.object(CrossSync._Sync_Impl, "event_wait") as sleep:
283283
sleep.side_effect = asyncio.CancelledError
284284
try:
285-
client = self._make_client(project="project-id")
285+
client = self._make_client(project="project-id", use_emulator=False)
286286
client._channel_init_time = -wait_time
287287
client._manage_channel(refresh_interval, refresh_interval)
288288
except asyncio.CancelledError:
@@ -296,36 +296,29 @@ def test__manage_channel_first_sleep(
296296

297297
def test__manage_channel_ping_and_warm(self):
298298
"""_manage channel should call ping and warm internally"""
299-
import time
300299
import threading
301-
from google.cloud.bigtable_v2.services.bigtable.transports.grpc import (
302-
_LoggingClientInterceptor as Interceptor,
303-
)
304300

305-
client_mock = mock.Mock()
306-
client_mock.transport._interceptor = Interceptor()
307-
client_mock._is_closed.is_set.return_value = False
308-
client_mock._channel_init_time = time.monotonic()
309-
orig_channel = client_mock.transport.grpc_channel
301+
client = self._make_client(project="project-id", use_emulator=True)
302+
orig_channel = client.transport.grpc_channel
310303
sleep_tuple = (
311304
(asyncio, "sleep")
312305
if CrossSync._Sync_Impl.is_async
313306
else (threading.Event, "wait")
314307
)
315-
with mock.patch.object(*sleep_tuple):
316-
orig_channel.close.side_effect = asyncio.CancelledError
308+
with mock.patch.object(*sleep_tuple) as sleep_mock:
309+
sleep_mock.side_effect = [None, asyncio.CancelledError]
317310
ping_and_warm = (
318-
client_mock._ping_and_warm_instances
311+
client._ping_and_warm_instances
319312
) = CrossSync._Sync_Impl.Mock()
320313
try:
321-
self._get_target_class()._manage_channel(client_mock, 10)
314+
client._manage_channel(10)
322315
except asyncio.CancelledError:
323316
pass
324317
assert ping_and_warm.call_count == 2
325-
assert client_mock.transport._grpc_channel != orig_channel
318+
assert client.transport.grpc_channel._channel != orig_channel
326319
called_with = [call[1]["channel"] for call in ping_and_warm.call_args_list]
327320
assert orig_channel in called_with
328-
assert client_mock.transport.grpc_channel in called_with
321+
assert client.transport.grpc_channel._channel in called_with
329322

330323
@pytest.mark.parametrize(
331324
"refresh_interval, num_cycles, expected_sleep",
@@ -335,8 +328,6 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle
335328
import time
336329
import random
337330

338-
channel = mock.Mock()
339-
channel.close = CrossSync._Sync_Impl.Mock()
340331
with mock.patch.object(random, "uniform") as uniform:
341332
uniform.side_effect = lambda min_, max_: min_
342333
with mock.patch.object(time, "time") as time_mock:
@@ -345,8 +336,7 @@ def test__manage_channel_sleeps(self, refresh_interval, num_cycles, expected_sle
345336
sleep.side_effect = [None for i in range(num_cycles - 1)] + [
346337
asyncio.CancelledError
347338
]
348-
client = self._make_client(project="project-id")
349-
client.transport._grpc_channel = channel
339+
client = self._make_client(project="project-id", use_emulator=True)
350340
with mock.patch.object(
351341
client.transport, "create_channel", CrossSync._Sync_Impl.Mock
352342
):
@@ -374,7 +364,7 @@ def test__manage_channel_random(self):
374364
uniform.return_value = 0
375365
try:
376366
uniform.side_effect = asyncio.CancelledError
377-
client = self._make_client(project="project-id")
367+
client = self._make_client(project="project-id", use_emulator=False)
378368
except asyncio.CancelledError:
379369
uniform.side_effect = None
380370
uniform.reset_mock()
@@ -405,7 +395,7 @@ def test__manage_channel_refresh(self, num_cycles):
405395
CrossSync._Sync_Impl.grpc_helpers, "create_channel"
406396
) as create_channel:
407397
create_channel.return_value = new_channel
408-
client = self._make_client(project="project-id")
398+
client = self._make_client(project="project-id", use_emulator=False)
409399
create_channel.reset_mock()
410400
try:
411401
client._manage_channel(

0 commit comments

Comments
 (0)