Skip to content

Commit 0c56aac

Browse files
authored
Add healthcheck support for JetStream (#90)
* Add healthcheck support for JetStream * fix indentation * fix pylint unit test * use pyink to reformat generated protos
1 parent a223df9 commit 0c56aac

File tree

5 files changed

+81
-3
lines changed

5 files changed

+81
-3
lines changed

jetstream/core/orchestrator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,17 @@ async def Decode( # pylint: disable=invalid-overridden-method
885885
)
886886
# Reset buffer after flushed.
887887
buffered_response_list = []
888+
889+
async def HealthCheck( # pylint: disable=invalid-overridden-method
890+
self,
891+
request: jetstream_pb2.HealthCheckRequest,
892+
context: Optional[grpc.aio.ServicerContext] = None,
893+
) -> jetstream_pb2.HealthCheckResponse:
894+
"""HealthCheck."""
895+
if context is None:
896+
logging.warning(
897+
"LLM orchestrator is being used in offline test mode, and will not"
898+
" respond to gRPC queries - only direct function calls."
899+
)
900+
is_live = self._driver.live
901+
return jetstream_pb2.HealthCheckResponse(is_live=is_live)

jetstream/core/proto/jetstream.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ package jetstream_proto;
2121
service Orchestrator {
2222
// Query LLM to generate text or tokens.
2323
rpc Decode(DecodeRequest) returns (stream DecodeResponse) {}
24+
// Checks if the model server is live.
25+
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse) {}
2426
}
2527

2628
message DecodeRequest {
@@ -74,4 +76,11 @@ message DecodeResponse {
7476
}
7577
reserved 1;
7678
// Next ID: 4
79+
}
80+
81+
message HealthCheckRequest {}
82+
83+
message HealthCheckResponse {
84+
// Denotes whether the model server is live
85+
bool is_live = 1;
7786
}

jetstream/core/proto/jetstream_pb2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
31-
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
31+
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"\xa7\x02\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05\x12\x42\n\x0ctext_content\x18\x05 \x01(\x0b\x32*.jetstream_proto.DecodeRequest.TextContentH\x00\x12\x44\n\rtoken_content\x18\x06 \x01(\x0b\x32+.jetstream_proto.DecodeRequest.TokenContentH\x00\x1a\x1b\n\x0bTextContent\x12\x0c\n\x04text\x18\x01 \x01(\t\x1a!\n\x0cTokenContent\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x02\x10\x03"\xcb\x02\n\x0e\x44\x65\x63odeResponse\x12I\n\x0finitial_content\x18\x02 \x01(\x0b\x32..jetstream_proto.DecodeResponse.InitialContentH\x00\x12G\n\x0estream_content\x18\x03 \x01(\x0b\x32-.jetstream_proto.DecodeResponse.StreamContentH\x00\x1a\x10\n\x0eInitialContent\x1a\x81\x01\n\rStreamContent\x12\x45\n\x07samples\x18\x01 \x03(\x0b\x32\x34.jetstream_proto.DecodeResponse.StreamContent.Sample\x1a)\n\x06Sample\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x42\t\n\x07\x63ontentJ\x04\x08\x01\x10\x02"\x14\n\x12HealthCheckRequest"&\n\x13HealthCheckResponse\x12\x0f\n\x07is_live\x18\x01 \x01(\x08\x32\xb9\x01\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x12Z\n\x0bHealthCheck\x12#.jetstream_proto.HealthCheckRequest\x1a$.jetstream_proto.HealthCheckResponse"\x00\x62\x06proto3'
3232
)
3333

3434
_globals = globals()
@@ -52,6 +52,10 @@
5252
_globals["_DECODERESPONSE_STREAMCONTENT"]._serialized_end = 670
5353
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_start = 629
5454
_globals["_DECODERESPONSE_STREAMCONTENT_SAMPLE"]._serialized_end = 670
55-
_globals["_ORCHESTRATOR"]._serialized_start = 689
56-
_globals["_ORCHESTRATOR"]._serialized_end = 782
55+
_globals["_HEALTHCHECKREQUEST"]._serialized_start = 689
56+
_globals["_HEALTHCHECKREQUEST"]._serialized_end = 709
57+
_globals["_HEALTHCHECKRESPONSE"]._serialized_start = 711
58+
_globals["_HEALTHCHECKRESPONSE"]._serialized_end = 749
59+
_globals["_ORCHESTRATOR"]._serialized_start = 752
60+
_globals["_ORCHESTRATOR"]._serialized_end = 937
5761
# @@protoc_insertion_point(module_scope)

jetstream/core/proto/jetstream_pb2_grpc.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def __init__(self, channel):
3434
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.SerializeToString,
3535
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.FromString,
3636
)
37+
self.HealthCheck = channel.unary_unary(
38+
"/jetstream_proto.Orchestrator/HealthCheck",
39+
request_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
40+
response_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
41+
)
3742

3843

3944
class OrchestratorServicer(object):
@@ -45,6 +50,12 @@ def Decode(self, request, context):
4550
context.set_details("Method not implemented!")
4651
raise NotImplementedError("Method not implemented!")
4752

53+
def HealthCheck(self, request, context):
54+
"""Checks if the model server is live."""
55+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
56+
context.set_details("Method not implemented!")
57+
raise NotImplementedError("Method not implemented!")
58+
4859

4960
def add_OrchestratorServicer_to_server(servicer, server):
5061
rpc_method_handlers = {
@@ -53,6 +64,11 @@ def add_OrchestratorServicer_to_server(servicer, server):
5364
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeRequest.FromString,
5465
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.DecodeResponse.SerializeToString,
5566
),
67+
"HealthCheck": grpc.unary_unary_rpc_method_handler(
68+
servicer.HealthCheck,
69+
request_deserializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.FromString,
70+
response_serializer=jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.SerializeToString,
71+
),
5672
}
5773
generic_handler = grpc.method_handlers_generic_handler(
5874
"jetstream_proto.Orchestrator", rpc_method_handlers
@@ -92,3 +108,32 @@ def Decode(
92108
timeout,
93109
metadata,
94110
)
111+
112+
@staticmethod
113+
def HealthCheck(
114+
request,
115+
target,
116+
options=(),
117+
channel_credentials=None,
118+
call_credentials=None,
119+
insecure=False,
120+
compression=None,
121+
wait_for_ready=None,
122+
timeout=None,
123+
metadata=None,
124+
):
125+
return grpc.experimental.unary_unary(
126+
request,
127+
target,
128+
"/jetstream_proto.Orchestrator/HealthCheck",
129+
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckRequest.SerializeToString,
130+
jetstream_dot_core_dot_proto_dot_jetstream__pb2.HealthCheckResponse.FromString,
131+
options,
132+
channel_credentials,
133+
insecure,
134+
call_credentials,
135+
compression,
136+
wait_for_ready,
137+
timeout,
138+
metadata,
139+
)

jetstream/tests/core/test_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ async def test_server(
8282
) as channel:
8383
stub = jetstream_pb2_grpc.OrchestratorStub(channel)
8484

85+
healthcheck_request = jetstream_pb2.HealthCheckRequest()
86+
healthcheck_response = stub.HealthCheck(healthcheck_request)
87+
healthcheck_response = await healthcheck_response
88+
89+
assert healthcheck_response.is_live is True
90+
8591
# The string representation of np.array([[65, 66]]), [2] will be prepended
8692
# as BOS
8793
text = "AB"

0 commit comments

Comments
 (0)