Skip to content

Commit 3854b18

Browse files
authored
Merge pull request #25 from filintod/filinto/add-grpc-options
add grpc options (stacked over continue_as_new PR)
2 parents 8d711f3 + 5cddf4f commit 3854b18

File tree

9 files changed

+206
-52
lines changed

9 files changed

+206
-52
lines changed

durabletask/aio/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
log_formatter: Optional[logging.Formatter] = None,
3636
secure_channel: bool = False,
3737
interceptors: Optional[Sequence[ClientInterceptor]] = None,
38+
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
3839
):
3940
if interceptors is not None:
4041
interceptors = list(interceptors)
@@ -46,7 +47,10 @@ def __init__(
4647
interceptors = None
4748

4849
channel = get_grpc_aio_channel(
49-
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
50+
host_address=host_address,
51+
secure_channel=secure_channel,
52+
interceptors=interceptors,
53+
options=channel_options,
5054
)
5155
self._channel = channel
5256
self._stub = stubs.TaskHubSidecarServiceStub(channel)

durabletask/aio/internal/shared.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import grpc
77
from grpc import aio as grpc_aio
8+
from grpc.aio import ChannelArgumentType
89

910
from durabletask.internal.shared import (
1011
INSECURE_PROTOCOLS,
@@ -24,7 +25,16 @@ def get_grpc_aio_channel(
2425
host_address: Optional[str],
2526
secure_channel: bool = False,
2627
interceptors: Optional[Sequence[ClientInterceptor]] = None,
28+
options: Optional[ChannelArgumentType] = None,
2729
) -> grpc_aio.Channel:
30+
"""create a grpc asyncio channel
31+
32+
Args:
33+
host_address: The host address of the gRPC server. If None, uses the default address.
34+
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
35+
interceptors: Optional sequence of client interceptors to apply to the channel.
36+
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
37+
"""
2838
if host_address is None:
2939
host_address = get_default_host_address()
3040

@@ -42,9 +52,11 @@ def get_grpc_aio_channel(
4252

4353
if secure_channel:
4454
channel = grpc_aio.secure_channel(
45-
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors
55+
host_address, grpc.ssl_channel_credentials(), interceptors=interceptors, options=options
4656
)
4757
else:
48-
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
58+
channel = grpc_aio.insecure_channel(
59+
host_address, interceptors=interceptors, options=options
60+
)
4961

5062
return channel

durabletask/client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
log_formatter: Optional[logging.Formatter] = None,
109109
secure_channel: bool = False,
110110
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
111+
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
111112
):
112113
# If the caller provided metadata, we need to create a new interceptor for it and
113114
# add it to the list of interceptors.
@@ -121,7 +122,10 @@ def __init__(
121122
interceptors = None
122123

123124
channel = shared.get_grpc_channel(
124-
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
125+
host_address=host_address,
126+
secure_channel=secure_channel,
127+
interceptors=interceptors,
128+
options=channel_options,
125129
)
126130
self._stub = stubs.TaskHubSidecarServiceStub(channel)
127131
self._logger = shared.get_logger("client", log_handler, log_formatter)

durabletask/internal/shared.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def get_default_host_address() -> str:
3131
Honors environment variables if present; otherwise defaults to localhost:4001.
3232
3333
Supported environment variables (checked in order):
34-
- DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
35-
- DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT
34+
- DAPR_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443")
35+
- DAPR_GRPC_HOST/DAPR_RUNTIME_HOST and DAPR_GRPC_PORT
3636
"""
3737

3838
# Full endpoint overrides
@@ -54,7 +54,16 @@ def get_grpc_channel(
5454
host_address: Optional[str],
5555
secure_channel: bool = False,
5656
interceptors: Optional[Sequence[ClientInterceptor]] = None,
57+
options: Optional[Sequence[tuple[str, Any]]] = None,
5758
) -> grpc.Channel:
59+
"""create a grpc channel
60+
61+
Args:
62+
host_address: The host address of the gRPC server. If None, uses the default address (as defined in get_default_host_address above).
63+
secure_channel: Whether to use a secure channel (TLS/SSL). Defaults to False.
64+
interceptors: Optional sequence of client interceptors to apply to the channel.
65+
options: Optional sequence of gRPC channel options as (key, value) tuples. Keys defined in https://grpc.github.io/grpc/core/group__grpc__arg__keys.html
66+
"""
5867
if host_address is None:
5968
host_address = get_default_host_address()
6069

@@ -72,11 +81,10 @@ def get_grpc_channel(
7281
host_address = host_address[len(protocol) :]
7382
break
7483

75-
# Create the base channel
7684
if secure_channel:
77-
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials())
85+
channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials(), options=options)
7886
else:
79-
channel = grpc.insecure_channel(host_address)
87+
channel = grpc.insecure_channel(host_address, options=options)
8088

8189
# Apply interceptors ONLY if they exist
8290
if interceptors:

durabletask/worker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,15 @@ def __init__(
223223
secure_channel: bool = False,
224224
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
225225
concurrency_options: Optional[ConcurrencyOptions] = None,
226+
channel_options: Optional[Sequence[tuple[str, Any]]] = None,
226227
):
227228
self._registry = _Registry()
228229
self._host_address = host_address if host_address else shared.get_default_host_address()
229230
self._logger = shared.get_logger("worker", log_handler, log_formatter)
230231
self._shutdown = Event()
231232
self._is_running = False
232233
self._secure_channel = secure_channel
234+
self._channel_options = channel_options
233235

234236
# Use provided concurrency options or create default ones
235237
self._concurrency_options = (
@@ -306,7 +308,10 @@ def create_fresh_connection():
306308
current_stub = None
307309
try:
308310
current_channel = shared.get_grpc_channel(
309-
self._host_address, self._secure_channel, self._interceptors
311+
self._host_address,
312+
self._secure_channel,
313+
self._interceptors,
314+
options=self._channel_options,
310315
)
311316
current_stub = stubs.TaskHubSidecarServiceStub(current_channel)
312317
current_stub.Hello(empty_pb2.Empty())

tests/durabletask/test_client.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import ANY, patch
1+
from unittest.mock import patch
22

33
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
44
from durabletask.internal.shared import get_default_host_address, get_grpc_channel
@@ -11,7 +11,9 @@
1111
def test_get_grpc_channel_insecure():
1212
with patch("grpc.insecure_channel") as mock_channel:
1313
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
14-
mock_channel.assert_called_once_with(HOST_ADDRESS)
14+
args, kwargs = mock_channel.call_args
15+
assert args[0] == HOST_ADDRESS
16+
assert "options" in kwargs and kwargs["options"] is None
1517

1618

1719
def test_get_grpc_channel_secure():
@@ -20,13 +22,18 @@ def test_get_grpc_channel_secure():
2022
patch("grpc.ssl_channel_credentials") as mock_credentials,
2123
):
2224
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
23-
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
25+
args, kwargs = mock_channel.call_args
26+
assert args[0] == HOST_ADDRESS
27+
assert args[1] == mock_credentials.return_value
28+
assert "options" in kwargs and kwargs["options"] is None
2429

2530

2631
def test_get_grpc_channel_default_host_address():
2732
with patch("grpc.insecure_channel") as mock_channel:
2833
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
29-
mock_channel.assert_called_once_with(get_default_host_address())
34+
args, kwargs = mock_channel.call_args
35+
assert args[0] == get_default_host_address()
36+
assert "options" in kwargs and kwargs["options"] is None
3037

3138

3239
def test_get_grpc_channel_with_metadata():
@@ -35,7 +42,9 @@ def test_get_grpc_channel_with_metadata():
3542
patch("grpc.intercept_channel") as mock_intercept_channel,
3643
):
3744
get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS)
38-
mock_channel.assert_called_once_with(HOST_ADDRESS)
45+
args, kwargs = mock_channel.call_args
46+
assert args[0] == HOST_ADDRESS
47+
assert "options" in kwargs and kwargs["options"] is None
3948
mock_intercept_channel.assert_called_once()
4049

4150
# Capture and check the arguments passed to intercept_channel()
@@ -54,40 +63,80 @@ def test_grpc_channel_with_host_name_protocol_stripping():
5463

5564
prefix = "grpc://"
5665
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
57-
mock_insecure_channel.assert_called_with(host_name)
66+
args, kwargs = mock_insecure_channel.call_args
67+
assert args[0] == host_name
68+
assert "options" in kwargs and kwargs["options"] is None
5869

5970
prefix = "http://"
6071
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
61-
mock_insecure_channel.assert_called_with(host_name)
72+
args, kwargs = mock_insecure_channel.call_args
73+
assert args[0] == host_name
74+
assert "options" in kwargs and kwargs["options"] is None
6275

6376
prefix = "HTTP://"
6477
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
65-
mock_insecure_channel.assert_called_with(host_name)
78+
args, kwargs = mock_insecure_channel.call_args
79+
assert args[0] == host_name
80+
assert "options" in kwargs and kwargs["options"] is None
6681

6782
prefix = "GRPC://"
6883
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
69-
mock_insecure_channel.assert_called_with(host_name)
84+
args, kwargs = mock_insecure_channel.call_args
85+
assert args[0] == host_name
86+
assert "options" in kwargs and kwargs["options"] is None
7087

7188
prefix = ""
7289
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
73-
mock_insecure_channel.assert_called_with(host_name)
90+
args, kwargs = mock_insecure_channel.call_args
91+
assert args[0] == host_name
92+
assert "options" in kwargs and kwargs["options"] is None
7493

7594
prefix = "grpcs://"
7695
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
77-
mock_secure_channel.assert_called_with(host_name, ANY)
96+
args, kwargs = mock_secure_channel.call_args
97+
assert args[0] == host_name
98+
assert "options" in kwargs and kwargs["options"] is None
7899

79100
prefix = "https://"
80101
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
81-
mock_secure_channel.assert_called_with(host_name, ANY)
102+
args, kwargs = mock_secure_channel.call_args
103+
assert args[0] == host_name
104+
assert "options" in kwargs and kwargs["options"] is None
82105

83106
prefix = "HTTPS://"
84107
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
85-
mock_secure_channel.assert_called_with(host_name, ANY)
108+
args, kwargs = mock_secure_channel.call_args
109+
assert args[0] == host_name
110+
assert "options" in kwargs and kwargs["options"] is None
86111

87112
prefix = "GRPCS://"
88113
get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS)
89-
mock_secure_channel.assert_called_with(host_name, ANY)
114+
args, kwargs = mock_secure_channel.call_args
115+
assert args[0] == host_name
116+
assert "options" in kwargs and kwargs["options"] is None
90117

91118
prefix = ""
92119
get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS)
93-
mock_secure_channel.assert_called_with(host_name, ANY)
120+
args, kwargs = mock_secure_channel.call_args
121+
assert args[0] == host_name
122+
assert "options" in kwargs and kwargs["options"] is None
123+
124+
125+
def test_sync_channel_passes_base_options_and_max_lengths():
126+
base_options = [
127+
("grpc.max_send_message_length", 1234),
128+
("grpc.max_receive_message_length", 5678),
129+
("grpc.primary_user_agent", "durabletask-tests"),
130+
]
131+
with patch("grpc.insecure_channel") as mock_channel:
132+
get_grpc_channel(HOST_ADDRESS, False, options=base_options)
133+
# Ensure called with options kwarg
134+
assert mock_channel.call_count == 1
135+
args, kwargs = mock_channel.call_args
136+
assert args[0] == HOST_ADDRESS
137+
assert "options" in kwargs
138+
opts = kwargs["options"]
139+
# Check our base options made it through
140+
assert ("grpc.max_send_message_length", 1234) in opts
141+
assert ("grpc.max_receive_message_length", 5678) in opts
142+
assert ("grpc.primary_user_agent", "durabletask-tests") in opts

0 commit comments

Comments
 (0)