|
18 | 18 | import grpc |
19 | 19 | from absl import app, flags, logging |
20 | 20 |
|
21 | | -from compiler_gym.service import connection |
| 21 | +from compiler_gym.service import CompilerGymServiceContext, connection |
22 | 22 | from compiler_gym.service.compilation_session import CompilationSession |
23 | 23 | from compiler_gym.service.proto import compiler_gym_service_pb2_grpc |
24 | 24 | from compiler_gym.service.runtime.compiler_gym_service import CompilerGymService |
@@ -51,6 +51,9 @@ def _shutdown_handler(signal_number, stack_frame): # pragma: no cover |
51 | 51 |
|
52 | 52 | def create_and_run_compiler_gym_service( |
53 | 53 | compilation_session_type: Type[CompilationSession], |
| 54 | + compiler_gym_service_context_type: Type[ |
| 55 | + CompilerGymServiceContext |
| 56 | + ] = CompilerGymServiceContext, |
54 | 57 | ): # pragma: no cover |
55 | 58 | """Create and run an RPC service for the given compilation session. |
56 | 59 |
|
@@ -92,52 +95,57 @@ def main(argv): |
92 | 95 | logging.get_absl_handler().use_absl_log_file() |
93 | 96 | logging.set_verbosity(dbg.get_logging_level()) |
94 | 97 |
|
95 | | - # Create the service. |
96 | | - server = grpc.server( |
97 | | - futures.ThreadPoolExecutor(max_workers=FLAGS.rpc_service_threads), |
98 | | - options=connection.GRPC_CHANNEL_OPTIONS, |
99 | | - ) |
100 | | - service = CompilerGymService( |
101 | | - working_directory=working_dir, |
102 | | - compilation_session_type=compilation_session_type, |
103 | | - ) |
104 | | - compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server( |
105 | | - service, server |
106 | | - ) |
107 | | - |
108 | | - address = f"0.0.0.0:{FLAGS.port}" if FLAGS.port else "0.0.0.0:0" |
109 | | - port = server.add_insecure_port(address) |
110 | | - |
111 | | - with atomic_file_write(working_dir / "port.txt", fileobj=True, mode="w") as f: |
112 | | - f.write(str(port)) |
113 | | - |
114 | | - with atomic_file_write(working_dir / "pid.txt", fileobj=True, mode="w") as f: |
115 | | - f.write(str(os.getpid())) |
116 | | - |
117 | | - logging.info( |
118 | | - "Service %s listening on %d, PID = %d", working_dir, port, os.getpid() |
119 | | - ) |
120 | | - |
121 | | - server.start() |
122 | | - |
123 | | - # Block on the RPC service in a separate thread. This enables the |
124 | | - # current thread to handle the shutdown routine. |
125 | | - server_thread = Thread(target=server.wait_for_termination) |
126 | | - server_thread.start() |
127 | | - |
128 | | - # Block until the shutdown signal is received. |
129 | | - shutdown_signal.wait() |
130 | | - logging.info("Shutting down the RPC service") |
131 | | - server.stop(60).wait() |
132 | | - server_thread.join() |
133 | | - logging.info("Service closed") |
134 | | - |
135 | | - if len(service.sessions): |
136 | | - print( |
137 | | - "ERROR: Killing a service with", |
138 | | - plural(len(service.session), "active session", "active sessions"), |
139 | | - file=sys.stderr, |
| 98 | + with compiler_gym_service_context_type(working_dir) as context: |
| 99 | + # Create the service. |
| 100 | + server = grpc.server( |
| 101 | + futures.ThreadPoolExecutor(max_workers=FLAGS.rpc_service_threads), |
| 102 | + options=connection.GRPC_CHANNEL_OPTIONS, |
| 103 | + ) |
| 104 | + service = CompilerGymService( |
| 105 | + compilation_session_type=compilation_session_type, |
| 106 | + context=context, |
| 107 | + ) |
| 108 | + compiler_gym_service_pb2_grpc.add_CompilerGymServiceServicer_to_server( |
| 109 | + service, server |
| 110 | + ) |
| 111 | + |
| 112 | + address = f"0.0.0.0:{FLAGS.port}" if FLAGS.port else "0.0.0.0:0" |
| 113 | + port = server.add_insecure_port(address) |
| 114 | + |
| 115 | + with atomic_file_write( |
| 116 | + working_dir / "port.txt", fileobj=True, mode="w" |
| 117 | + ) as f: |
| 118 | + f.write(str(port)) |
| 119 | + |
| 120 | + with atomic_file_write( |
| 121 | + working_dir / "pid.txt", fileobj=True, mode="w" |
| 122 | + ) as f: |
| 123 | + f.write(str(os.getpid())) |
| 124 | + |
| 125 | + logging.info( |
| 126 | + "Service %s listening on %d, PID = %d", working_dir, port, os.getpid() |
140 | 127 | ) |
141 | | - sys.exit(6) |
| 128 | + |
| 129 | + server.start() |
| 130 | + |
| 131 | + # Block on the RPC service in a separate thread. This enables the |
| 132 | + # current thread to handle the shutdown routine. |
| 133 | + server_thread = Thread(target=server.wait_for_termination) |
| 134 | + server_thread.start() |
| 135 | + |
| 136 | + # Block until the shutdown signal is received. |
| 137 | + shutdown_signal.wait() |
| 138 | + logging.info("Shutting down the RPC service") |
| 139 | + server.stop(60).wait() |
| 140 | + server_thread.join() |
| 141 | + logging.info("Service closed") |
| 142 | + |
| 143 | + if len(service.sessions): |
| 144 | + print( |
| 145 | + "ERROR: Killing a service with", |
| 146 | + plural(len(service.session), "active session", "active sessions"), |
| 147 | + file=sys.stderr, |
| 148 | + ) |
| 149 | + sys.exit(6) |
142 | 150 |
|
143 | 151 | app.run(main) |
0 commit comments