Skip to content

Commit 58fd4e3

Browse files
committed
WIP
1 parent 3d8edf6 commit 58fd4e3

File tree

3 files changed

+72
-54
lines changed

3 files changed

+72
-54
lines changed

compiler_gym/service/CompilerGymServiceContext.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ namespace compiler_gym {
1717
* sessions. An instance of this class is passed to every new
1818
* CompilationSession.
1919
*
20-
* You may subclass CompilerGymServiceContext to add additional mutable state.
21-
* The subclass .
20+
* You may subclass CompilerGymServiceContext to add additional mutable state,
21+
* or startup and shutdown routines. When overriding methods, subclasses should
22+
* call the parent class implementation first.
2223
*
2324
* \code{.cpp}
2425
*

compiler_gym/service/compiler_gym_service_context.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@ class CompilerGymServiceContext:
1616
sessions. An instance of this class is passed to every new
1717
CompilationSession.
1818
19-
You may subclass CompilerGymServiceContext to add additional mutable state.
20-
The subclass .
19+
You may subclass CompilerGymServiceContext to add additional mutable state,
20+
or startup and shutdown routines. When overriding methods, subclasses should
21+
call the parent class implementation first.
2122
2223
.. code-block:: python
2324
24-
from compiler_gym.service import CompilationSession
25-
from compiler_gym.service import CompilerGymServiceContext
26-
from compiler_gym.service import runtime
25+
from compiler_gym.service import CompilationSession from
26+
compiler_gym.service import CompilerGymServiceContext from
27+
compiler_gym.service import runtime
2728
2829
class MyServiceContext(CompilerGymServiceContext):
2930
...
@@ -58,3 +59,11 @@ def shutdown(self) -> None:
5859
service will terminate with a nonzero error code.
5960
"""
6061
logger.debug("Closing compiler service context")
62+
63+
def __enter__(self) -> "CompilerGymServiceContext":
64+
"""Support 'with' syntax."""
65+
return self
66+
67+
def __exit__(self, *args):
68+
"""Support 'with' syntax."""
69+
self.shutdown()

compiler_gym/service/runtime/create_and_run_compiler_gym_service.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import grpc
1919
from absl import app, flags, logging
2020

21-
from compiler_gym.service import connection
21+
from compiler_gym.service import CompilerGymServiceContext, connection
2222
from compiler_gym.service.compilation_session import CompilationSession
2323
from compiler_gym.service.proto import compiler_gym_service_pb2_grpc
2424
from compiler_gym.service.runtime.compiler_gym_service import CompilerGymService
@@ -51,6 +51,9 @@ def _shutdown_handler(signal_number, stack_frame): # pragma: no cover
5151

5252
def create_and_run_compiler_gym_service(
5353
compilation_session_type: Type[CompilationSession],
54+
compiler_gym_service_context_type: Type[
55+
CompilerGymServiceContext
56+
] = CompilerGymServiceContext,
5457
): # pragma: no cover
5558
"""Create and run an RPC service for the given compilation session.
5659
@@ -92,52 +95,57 @@ def main(argv):
9295
logging.get_absl_handler().use_absl_log_file()
9396
logging.set_verbosity(dbg.get_logging_level())
9497

95-
# Create the service.
96-
server = grpc.server(
97-
futures.ThreadPoolExecutor(max_workers=FLAGS.rpc_service_threads),
98-
options=connection.GRPC_CHANNEL_OPTIONS,
99-
)
100-
service = CompilerGymService(
101-
working_directory=working_dir,
102-
compilation_session_type=compilation_session_type,
103-
)
104-
compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server(
105-
service, server
106-
)
107-
108-
address = f"0.0.0.0:{FLAGS.port}" if FLAGS.port else "0.0.0.0:0"
109-
port = server.add_insecure_port(address)
110-
111-
with atomic_file_write(working_dir / "port.txt", fileobj=True, mode="w") as f:
112-
f.write(str(port))
113-
114-
with atomic_file_write(working_dir / "pid.txt", fileobj=True, mode="w") as f:
115-
f.write(str(os.getpid()))
116-
117-
logging.info(
118-
"Service %s listening on %d, PID = %d", working_dir, port, os.getpid()
119-
)
120-
121-
server.start()
122-
123-
# Block on the RPC service in a separate thread. This enables the
124-
# current thread to handle the shutdown routine.
125-
server_thread = Thread(target=server.wait_for_termination)
126-
server_thread.start()
127-
128-
# Block until the shutdown signal is received.
129-
shutdown_signal.wait()
130-
logging.info("Shutting down the RPC service")
131-
server.stop(60).wait()
132-
server_thread.join()
133-
logging.info("Service closed")
134-
135-
if len(service.sessions):
136-
print(
137-
"ERROR: Killing a service with",
138-
plural(len(service.session), "active session", "active sessions"),
139-
file=sys.stderr,
98+
with compiler_gym_service_context_type(working_dir) as context:
99+
# Create the service.
100+
server = grpc.server(
101+
futures.ThreadPoolExecutor(max_workers=FLAGS.rpc_service_threads),
102+
options=connection.GRPC_CHANNEL_OPTIONS,
103+
)
104+
service = CompilerGymService(
105+
compilation_session_type=compilation_session_type,
106+
context=context,
107+
)
108+
compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server(
109+
service, server
110+
)
111+
112+
address = f"0.0.0.0:{FLAGS.port}" if FLAGS.port else "0.0.0.0:0"
113+
port = server.add_insecure_port(address)
114+
115+
with atomic_file_write(
116+
working_dir / "port.txt", fileobj=True, mode="w"
117+
) as f:
118+
f.write(str(port))
119+
120+
with atomic_file_write(
121+
working_dir / "pid.txt", fileobj=True, mode="w"
122+
) as f:
123+
f.write(str(os.getpid()))
124+
125+
logging.info(
126+
"Service %s listening on %d, PID = %d", working_dir, port, os.getpid()
140127
)
141-
sys.exit(6)
128+
129+
server.start()
130+
131+
# Block on the RPC service in a separate thread. This enables the
132+
# current thread to handle the shutdown routine.
133+
server_thread = Thread(target=server.wait_for_termination)
134+
server_thread.start()
135+
136+
# Block until the shutdown signal is received.
137+
shutdown_signal.wait()
138+
logging.info("Shutting down the RPC service")
139+
server.stop(60).wait()
140+
server_thread.join()
141+
logging.info("Service closed")
142+
143+
if len(service.sessions):
144+
print(
145+
"ERROR: Killing a service with",
146+
plural(len(service.session), "active session", "active sessions"),
147+
file=sys.stderr,
148+
)
149+
sys.exit(6)
142150

143151
app.run(main)

0 commit comments

Comments
 (0)