Skip to content
3 changes: 0 additions & 3 deletions debug_gym/gym/terminals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@ def select_terminal(

logger = logger or DebugGymLogger("debug-gym")
terminal_type = terminal_config["type"]
docker_only = ["base_image", "setup_commands"]
match terminal_type:
case "docker":
terminal_class = DockerTerminal
case "kubernetes":
terminal_class = KubernetesTerminal
case "local":
terminal_class = LocalTerminal
if any(cfg in terminal_config for cfg in docker_only):
logger.warning("Ignoring Docker-only parameters for local terminal.")
case _:
raise ValueError(f"Unknown terminal {terminal_type}")

Expand Down
2 changes: 0 additions & 2 deletions debug_gym/gym/terminals/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(
working_dir: str | None = None,
session_commands: list[str] | None = None,
env_vars: dict[str, str] | None = None,
include_os_env_vars: bool = False,
logger: DebugGymLogger | None = None,
# Docker-specific parameters
base_image: str | None = None,
Expand All @@ -40,7 +39,6 @@ def __init__(
working_dir=working_dir,
session_commands=session_commands,
env_vars=env_vars,
include_os_env_vars=include_os_env_vars,
logger=logger,
**kwargs,
)
Expand Down
37 changes: 27 additions & 10 deletions debug_gym/gym/terminals/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import uuid
from pathlib import Path

from jinja2 import Template
from kubernetes import client, config, stream, watch
from kubernetes.client.rest import ApiException
from kubernetes.stream.ws_client import ERROR_CHANNEL
Expand All @@ -16,6 +17,7 @@
stop_after_attempt,
wait_random_exponential,
)
from yaml import dump, safe_load

from debug_gym.gym.terminals.shell_session import ShellSession
from debug_gym.gym.terminals.terminal import DISABLE_ECHO_COMMAND, Terminal
Expand Down Expand Up @@ -217,15 +219,15 @@ def __init__(
working_dir: str | None = None,
session_commands: list[str] | None = None,
env_vars: dict[str, str] | None = None,
include_os_env_vars: bool = False,
logger: DebugGymLogger | None = None,
setup_commands: list[str] | None = None,
# Kubernetes-specific parameters
setup_commands: list[str] | None = None,
pod_name: str | None = None,
base_image: str | None = None,
registry: str = "",
namespace: str = "default",
kube_config: str | None = None,
kube_context: str | None = None,
extra_labels: dict | None = None,
pod_spec_kwargs: dict = None,
**kwargs,
Expand All @@ -234,7 +236,6 @@ def __init__(
working_dir=working_dir,
session_commands=session_commands,
env_vars=env_vars,
include_os_env_vars=include_os_env_vars,
logger=logger,
**kwargs,
)
Expand All @@ -247,22 +248,32 @@ def __init__(
self._pod_name = pod_name
self.pod_spec_kwargs = pod_spec_kwargs or {}
user = _clean_for_kubernetes(os.environ.get("USER", "unknown"))
self.labels = {"app": "debug-gym", "component": "terminal", "user": user} | (
extra_labels or {}
)
self.labels = {"app": "dbg-gym", "user": user} | (extra_labels or {})
self._pod = None

# Initialize Kubernetes client
self.kube_config = kube_config
self.kube_context = kube_context
if self.kube_config == "incluster":
self.kube_config = None
config.load_incluster_config()
# For in-cluster kubectl access, pass Kubernetes service environment variables
# This enables kubectl to auto-discover the service account credentials
for key in ("KUBERNETES_SERVICE_HOST", "KUBERNETES_SERVICE_PORT"):
if key in os.environ:
self.env_vars.setdefault(key, os.environ[key])
else:
self.kube_config = self.kube_config or os.environ.get(
"KUBECONFIG", "~/.kube/config"
)
self.kube_config = os.path.expanduser(self.kube_config)
config.load_kube_config(self.kube_config)
config.load_kube_config(self.kube_config, self.kube_context)
self.env_vars.setdefault("KUBECONFIG", self.kube_config)

# Ensure helper binaries such as kubectl can be discovered even when
# host environment variables are not inherited.
if "PATH" in os.environ:
self.env_vars.setdefault("PATH", os.environ["PATH"])

self.k8s_client = client.CoreV1Api()
atexit.register(self.close)
Expand Down Expand Up @@ -315,9 +326,9 @@ def pod(self):
@property
def default_shell_command(self) -> list[str]:
"""Expects the pod to have bash installed."""
kubeconfig = f"--kubeconfig {self.kube_config}" if self.kube_config else ""
kubeconfig = f"--kubeconfig {self.kube_config} " if self.kube_config else ""
bash_cmd = "/bin/bash --noprofile --norc --noediting"
return f"kubectl {kubeconfig} exec -it {self.pod.name} -n {self.pod.namespace} -- {bash_cmd}"
return f"kubectl {kubeconfig}exec -it {self.pod.name} -c main -n {self.pod.namespace} -- {bash_cmd}"

def new_shell_session(self):
if not self.pod.is_running():
Expand Down Expand Up @@ -430,6 +441,12 @@ def setup_pod(self) -> None:
f"Setting up pod {pod_name} with image: {self.registry}{self.base_image}"
)

# Render pod_spec_kwargs as a Jinja2 template, replace variables, then load as dict.
pod_spec_yaml = dump(self.pod_spec_kwargs)
pod_spec_template = Template(pod_spec_yaml)
rendered_yaml = pod_spec_template.render(os.environ)
pod_spec_kwargs = safe_load(rendered_yaml)

# Create pod specification for Kubernetes.
pod_body = {
"apiVersion": "v1",
Expand Down Expand Up @@ -462,7 +479,7 @@ def setup_pod(self) -> None:
},
}
],
**self.pod_spec_kwargs, # e.g., nodeSelector, tolerations
**pod_spec_kwargs, # e.g., nodeSelector, tolerations
},
}

Expand Down
23 changes: 23 additions & 0 deletions debug_gym/gym/terminals/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,33 @@

from debug_gym.gym.terminals.shell_session import ShellSession
from debug_gym.gym.terminals.terminal import Terminal
from debug_gym.logger import DebugGymLogger


class LocalTerminal(Terminal):

def __init__(
self,
working_dir: str | None = None,
session_commands: list[str] | None = None,
env_vars: dict[str, str] | None = None,
logger: DebugGymLogger | None = None,
# Local-specific parameters
include_os_env_vars: bool = True,
**kwargs,
):
env_vars = env_vars or {}
if include_os_env_vars:
env_vars = env_vars | dict(os.environ)

super().__init__(
working_dir=working_dir,
session_commands=session_commands,
env_vars=env_vars,
logger=logger,
**kwargs,
)

@property
def working_dir(self):
"""Lazy initialization of the working directory."""
Expand Down
11 changes: 8 additions & 3 deletions debug_gym/gym/terminals/shell_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,15 @@ def start(self, command=None, read_until=None):

# Prepare entrypoint, combining session commands and command if provided
# For example: `bin/bash -c "session_command1 && session_command2 && pdb"`
entrypoint = self.shell_command
if command:
command = " && ".join(self.session_commands + [command])
entrypoint = f'{self.shell_command} -c "{command}"'
# Build command list: split shell_command, then add ["-c", command]
# Keep the command string intact so constructs like $(which ...) reach the target shell
cmd_list = shlex.split(self.shell_command) + ["-c", command]
entrypoint = f"{self.shell_command} -c {command!r}"
else:
cmd_list = shlex.split(self.shell_command)
entrypoint = self.shell_command

self.logger.debug(f"Starting {self} with entrypoint: {entrypoint}")

Expand All @@ -91,7 +96,7 @@ def start(self, command=None, read_until=None):
termios.tcsetattr(_client, termios.TCSANOW, attrs)

self.process = subprocess.Popen(
shlex.split(entrypoint),
cmd_list,
env=self.env_vars,
cwd=self.working_dir,
stdin=_client,
Expand Down
13 changes: 6 additions & 7 deletions debug_gym/gym/terminals/terminal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import atexit
import os
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -14,18 +13,15 @@ class Terminal(ABC):

def __init__(
self,
working_dir: str = None,
session_commands: list[str] = None,
env_vars: dict[str, str] = None,
include_os_env_vars: bool = True,
working_dir: str | None = None,
session_commands: list[str] | None = None,
env_vars: dict[str, str] | None = None,
logger: DebugGymLogger | None = None,
**kwargs,
):
self.logger = logger or DebugGymLogger("debug-gym")
self.session_commands = session_commands or []
self.env_vars = env_vars or {}
if include_os_env_vars:
self.env_vars = self.env_vars | dict(os.environ)
# Clean up output by disabling terminal prompt and colors
self.env_vars["NO_COLOR"] = "1" # disable colors
self.env_vars["PYTHONSTARTUP"] = "" # prevent Python from loading startup files
Expand All @@ -35,6 +31,9 @@ def __init__(
self._working_dir = working_dir
self.sessions = []

if kwargs:
self.logger.warning(f"Ignoring unknown parameters: {kwargs}")

@property
def working_dir(self):
"""Lazy initialization of the working directory."""
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,6 @@ dev = [
"pytest-xdist",
"pytest-timeout",
"pytest-env",
"isort",
"black",
]
88 changes: 86 additions & 2 deletions tests/gym/terminals/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,24 @@ def is_kubernetes_available():
def test_kubernetes_terminal_init():
terminal = KubernetesTerminal(base_image="ubuntu:latest")
assert terminal.session_commands == []
assert terminal.env_vars == {
expected_base_env = {
"NO_COLOR": "1",
"PS1": DEFAULT_PS1,
"PYTHONSTARTUP": "",
}
for key, value in expected_base_env.items():
assert terminal.env_vars[key] == value

assert terminal.env_vars["PATH"] == os.environ.get("PATH")
if terminal.kube_config:
assert terminal.env_vars["KUBECONFIG"] == terminal.kube_config
else:
assert "KUBECONFIG" not in terminal.env_vars

extra_env_keys = set(terminal.env_vars) - (
set(expected_base_env) | {"PATH", "KUBECONFIG"}
)
assert not extra_env_keys
assert os.path.basename(terminal.working_dir).startswith("Terminal-")
assert terminal.base_image == "ubuntu:latest"
assert terminal.namespace == "default"
Expand Down Expand Up @@ -84,7 +97,15 @@ def test_kubernetes_terminal_init_with_params(tmp_path):
)
assert terminal.working_dir == working_dir
assert terminal.session_commands == session_commands
assert terminal.env_vars == env_vars | {"NO_COLOR": "1", "PS1": DEFAULT_PS1}
assert terminal.env_vars["ENV_VAR"] == "value"
assert terminal.env_vars["NO_COLOR"] == "1"
assert terminal.env_vars["PS1"] == DEFAULT_PS1
assert terminal.env_vars["PYTHONSTARTUP"] == ""
assert terminal.env_vars["PATH"] == os.environ.get("PATH")
if terminal.kube_config:
assert terminal.env_vars["KUBECONFIG"] == terminal.kube_config
else:
assert "KUBECONFIG" not in terminal.env_vars
assert terminal.base_image == base_image

# Create pod.
Expand All @@ -98,6 +119,69 @@ def test_kubernetes_terminal_init_with_params(tmp_path):
assert terminal._pod is None


@if_kubernetes_available
def test_kubernetes_terminal_init_with_pod_specs(tmp_path):
working_dir = str(tmp_path)
# set an environment variable to use in the pod spec
os.environ["HOSTNAME"] = "minikube"
pod_spec_kwargs = {
"affinity": {
"nodeAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": {
"nodeSelectorTerms": [
{
"matchExpressions": [
{
"key": "kubernetes.io/hostname",
"operator": "In",
"values": ["{{HOSTNAME}}"],
}
]
}
]
}
}
},
"tolerations": [
{
"key": "kubernetes.azure.com/scalesetpriority",
"operator": "Equal",
"value": "spot",
"effect": "NoSchedule",
},
{
"key": "CriticalAddonsOnly",
"operator": "Equal",
"value": "true",
"effect": "NoSchedule",
},
],
}

terminal = KubernetesTerminal(
working_dir=working_dir,
pod_spec_kwargs=pod_spec_kwargs,
kube_context="minikube",
base_image="ubuntu:latest",
)

terminal.pod # Create pod.
assert (
terminal.pod.pod_body["spec"]["tolerations"] == pod_spec_kwargs["tolerations"]
)
# Make sure environment variable was replaced in the pod spec.
spec = terminal.pod.pod_body["spec"]
node_affinity = spec["affinity"]["nodeAffinity"]
required = node_affinity["requiredDuringSchedulingIgnoredDuringExecution"]
term = required["nodeSelectorTerms"][0]
match_expression = term["matchExpressions"][0]
assert match_expression["values"] == [os.environ["HOSTNAME"]]

# Close pod.
terminal.close()
assert terminal._pod is None


@if_kubernetes_available
@pytest.mark.parametrize(
"command",
Expand Down