Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions iohblade/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@
]

import copy
import re

from .solution import Solution
from .utils import TimeoutException

import re


def simplify_subprocess_error(stderr: str, solution=None):
"""
Expand Down Expand Up @@ -68,6 +67,7 @@ def simplify_subprocess_error(stderr: str, solution=None):

def evaluate_in_subprocess(problem, conn, solution):
"""Evaluate a solution in a dedicated virtual environment."""
proc = None
try:
env_path = problem._env_path
python_bin = problem._python_bin
Expand Down Expand Up @@ -105,20 +105,35 @@ def evaluate_in_subprocess(problem, conn, solution):
repo_root = Path(__file__).resolve().parents[1]
env["PYTHONPATH"] = f"{repo_root}{os.pathsep}" + env.get("PYTHONPATH", "")

proc = subprocess.Popen(
[str(python_bin), str(script_path)],
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
try:
res = subprocess.run(
[str(python_bin), str(script_path)],
check=True,
env=env,
capture_output=True,
text=True,
stdout, stderr = proc.communicate(timeout=problem.eval_timeout)
except subprocess.TimeoutExpired:
proc.kill()
stdout, stderr = proc.communicate()
conn.send(
{
"error": f"Evaluation timed out after {problem.eval_timeout} seconds.",
"stdout": stdout,
"stderr": stderr,
}
)
with open(result_pickle, "rb") as f:
result = cloudpickle.load(f)
conn.send({"result": result, "stdout": res.stdout, "stderr": res.stderr})
except subprocess.CalledProcessError as e:
error_msg = simplify_subprocess_error(e.stderr, solution)
conn.send({"error": error_msg, "stdout": e.stdout, "stderr": e.stderr})
return

if proc.returncode != 0:
error_msg = simplify_subprocess_error(stderr, solution)
conn.send({"error": error_msg, "stdout": stdout, "stderr": stderr})
return

with open(result_pickle, "rb") as f:
result = cloudpickle.load(f)
conn.send({"result": result, "stdout": stdout, "stderr": stderr})

except Exception as e:
tb = traceback.extract_tb(e.__traceback__)[-1]
Expand All @@ -141,6 +156,9 @@ def evaluate_in_subprocess(problem, conn, solution):
}
)
finally:
if proc and proc.poll() is None:
proc.kill()
proc.communicate()
conn.close()


Expand Down Expand Up @@ -233,6 +251,9 @@ def __call__(self, solution: Solution, logger=None):
stderr = ""
self._last_stdout = ""
self._last_stderr = ""
process: multiprocessing.Process | None = None
parent_conn = None
child_conn = None
try:
self._ensure_env()
(
Expand All @@ -243,7 +264,7 @@ def __call__(self, solution: Solution, logger=None):
target=evaluate_in_subprocess, args=(self, child_conn, solution)
)
process.start()
process.join(timeout=self.eval_timeout)
process.join(timeout=self.eval_timeout + 1)

if process.is_alive():
raise TimeoutException(
Expand Down Expand Up @@ -294,11 +315,14 @@ def __call__(self, solution: Solution, logger=None):
error=f"{e}",
)
finally:
try:
process.terminate()
if process is not None:
if process.is_alive():
process.kill()
process.join()
except Exception:
pass
if parent_conn is not None:
parent_conn.close()
if child_conn is not None:
child_conn.close()

self._last_stdout = stdout
self._last_stderr = stderr
Expand Down
44 changes: 44 additions & 0 deletions tests/unit/test_problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import platform
import subprocess
import time
from unittest.mock import MagicMock

Expand Down Expand Up @@ -43,6 +45,21 @@ def to_dict(self):
return {}


class HangingProblem(Problem):
def get_prompt(self):
return "Problem prompt"

def evaluate(self, s):
time.sleep(5)
return s

def test(self, s):
return s

def to_dict(self):
return {}


def test_problem_abstract_methods():
dp = DummyProblem(name="dummy")
assert dp.name == "dummy"
Expand All @@ -59,3 +76,30 @@ def test_problem_timeout():
sol = sp(sol)
# We expect a TimeoutException or similar
assert "timed out" in str(sol.feedback)


def _active_run_eval_processes():
try:
if platform.system() == "Windows":
output = subprocess.check_output(["tasklist"], text=True)
else:
output = subprocess.check_output(["ps", "-eo", "pid,cmd"], text=True)
except (FileNotFoundError, subprocess.CalledProcessError):
return []

return [line for line in output.splitlines() if "run_eval.py" in line]


def test_timeout_cleans_child_processes():
before = _active_run_eval_processes()

problem = HangingProblem(eval_timeout=1)
sol = Solution()
sol = problem(sol)

assert "timed out" in str(sol.feedback)

time.sleep(0.5)
after = _active_run_eval_processes()

assert not after or after == before