diff --git a/docs/guides/execution.md b/docs/guides/execution.md index 6f2c0063..b29904e9 100644 --- a/docs/guides/execution.md +++ b/docs/guides/execution.md @@ -53,6 +53,7 @@ The packager support matrix is described below: | SkypilotExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | | DGXCloudExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | | LeptonExecutor | run.Packager, run.GitArchivePackager, run.PatternPackager, run.HybridPackager | +| PyTorchJobExecutor | run.Packager | `run.Packager` is a passthrough base packager. @@ -293,6 +294,34 @@ def your_dgx_executor(nodes: int, gpus_per_node: int, container_image: str): For a complete end-to-end example using DGX Cloud with NeMo, refer to the [NVIDIA DGX Cloud NeMo End-to-End Workflow Example](https://docs.nvidia.com/dgx-cloud/run-ai/latest/nemo-e2e-example.html). +#### PyTorchJobExecutor + +The `PyTorchJobExecutor` integrates with the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) to run distributed PyTorchJobs on any Kubernetes cluster. It submits PyTorchJob CRDs directly via the Kubernetes API — no `kubectl` or separate tooling required for job submission. + +Kubernetes configuration is loaded automatically: local kubeconfig is tried first, falling back to in-cluster config when running inside a pod. + +Here's an example configuration: + +```python +executor = run.PyTorchJobExecutor( + namespace="runai-nemo-ci", + image="nvcr.io/nvidian/nemo:nightly", + num_workers=2, # Worker replicas; a Master replica is always added + nproc_per_node=8, # Maps to spec.nprocPerNode + gpus_per_node=8, + cpu_requests="16", + memory_requests="64Gi", + volumes=[ + {"name": "model-cache", "persistentVolumeClaim": {"claimName": "nemo-ci-datasets-project-nkf5l"}} + ], + volume_mounts=[{"name": "model-cache", "mountPath": "/nemo-workspace"}], + labels={"app": "nemo-ci-training"}, + env_vars={"NCCL_DEBUG": "INFO"}, +) +``` + +`cancel(wait=True)` polls until both the PyTorchJob CR and all associated pods are fully terminated before returning. + #### LeptonExecutor The `LeptonExecutor` integrates with an NVIDIA DGX Cloud Lepton cluster's Python SDK to launch distributed jobs. It uses API calls behind the Lepton SDK to authenticate, identify the target node group and resource shapes, and submit the job specification which will be launched as a batch job on the cluster. diff --git a/local/example.py b/local/example.py new file mode 100644 index 00000000..59007ff5 --- /dev/null +++ b/local/example.py @@ -0,0 +1,62 @@ +import time + +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor + +EXPECTED_LOG_CONTENT = "NEMO_TEST_OK" + +e = PyTorchJobExecutor( + namespace="runai-nemo-ci", + image="nvcr.io/nvidian/nemo:nightly", + num_workers=2, + nproc_per_node=8, + gpus_per_node=8, + cpu_requests="16", + memory_requests="64Gi", + volumes=[ + { + "name": "model-cache", + "persistentVolumeClaim": {"claimName": "nemo-ci-datasets-project-nkf5l"}, + } + ], + volume_mounts=[{"name": "model-cache", "mountPath": "/nemo-workspace"}], + labels={"app": "nemo-ci-training"}, +) + +# Script: print the sentinel, then sleep so we can read logs and cancel cleanly +cmd = [ + "/bin/bash", + "-c", + f"echo 'print(\"{EXPECTED_LOG_CONTENT}\"); import time; time.sleep(300)' > /tmp/test.py && " + "torchrun --nnodes=$PET_NNODES --nproc_per_node=$PET_NPROC_PER_NODE " + "--node_rank=$RANK --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT /tmp/test.py", +] + +# ── Launch and wait until RUNNING ──────────────────────────────────────────── +job_name, state = e.launch("nemo-ci-training", cmd, wait=True, timeout=300) +print(f"Launched: {job_name}, state: {state}") + +# ── Fetch logs and verify sentinel ──────────────────────────────────────────── +print("Polling logs until sentinel appears (up to 2 min)...") +logs = [] +deadline = time.time() + 120 +while time.time() < deadline: + logs = list(e.fetch_logs(job_name, stream=False, lines=50)) + if any(EXPECTED_LOG_CONTENT in line for line in logs): + break + print(f" waiting for sentinel ({len(logs)} lines so far)...") + time.sleep(5) + +print(f" received {len(logs)} lines") +for line in logs[:5]: + print(f" | {line}") + +assert any(EXPECTED_LOG_CONTENT in line for line in logs), ( + f"Expected '{EXPECTED_LOG_CONTENT}' not found in logs.\nGot: {logs}" +) +print(f"✓ Log sentinel '{EXPECTED_LOG_CONTENT}' verified") + +# ── Cancel and wait for full cleanup ───────────────────────────────────────── +print("Cancelling job and waiting for cleanup...") +cleaned = e.cancel(job_name, wait=True, timeout=120) +assert cleaned, "Cleanup failed — pods or CR still present after timeout" +print("Full cycle complete without kubectl") diff --git a/nemo_run/__init__.py b/nemo_run/__init__.py index 04f56916..5f57b911 100644 --- a/nemo_run/__init__.py +++ b/nemo_run/__init__.py @@ -24,6 +24,7 @@ from nemo_run.core.execution.base import Executor, ExecutorMacros, import_executor from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.docker import DockerExecutor +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor from nemo_run.core.execution.launcher import FaultTolerance, SlurmRay, SlurmTemplate, Torchrun from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.local import LocalExecutor @@ -66,6 +67,7 @@ "Packager", "Partial", "Plugin", + "PyTorchJobExecutor", "run", "Script", "SkypilotExecutor", diff --git a/nemo_run/core/execution/__init__.py b/nemo_run/core/execution/__init__.py index 7c787a16..089b5172 100644 --- a/nemo_run/core/execution/__init__.py +++ b/nemo_run/core/execution/__init__.py @@ -16,6 +16,7 @@ from nemo_run.core.execution.dgxcloud import DGXCloudExecutor from nemo_run.core.execution.lepton import LeptonExecutor from nemo_run.core.execution.local import LocalExecutor +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor from nemo_run.core.execution.skypilot import SkypilotExecutor from nemo_run.core.execution.slurm import SlurmExecutor @@ -25,4 +26,5 @@ "SkypilotExecutor", "DGXCloudExecutor", "LeptonExecutor", + "PyTorchJobExecutor", ] diff --git a/nemo_run/core/execution/pytorchjob.py b/nemo_run/core/execution/pytorchjob.py new file mode 100644 index 00000000..7ea1fe6e --- /dev/null +++ b/nemo_run/core/execution/pytorchjob.py @@ -0,0 +1,346 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Iterable, Optional + +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from nemo_run.core.execution.base import Executor, ExecutorMacros +from nemo_run.core.packaging.base import Packager + +logger = logging.getLogger(__name__) + +GROUP = "kubeflow.org" +VERSION = "v1" +PLURAL = "pytorchjobs" +KIND = "PyTorchJob" + + +class PyTorchJobState(Enum): + CREATED = "Created" + RUNNING = "Running" + SUCCEEDED = "Succeeded" + FAILED = "Failed" + UNKNOWN = "Unknown" + + +@dataclass(kw_only=True) +class PyTorchJobExecutor(Executor): + """ + Dataclass to configure a PyTorchJob Executor for the Kubeflow Training Operator on Kubernetes. + + Submits distributed PyTorchJob CRDs to a Kubernetes cluster running the Kubeflow Training + Operator. Kubernetes configuration is loaded automatically (local kubeconfig with in-cluster + fallback). + """ + + namespace: str = "default" + image: str = "" + num_workers: int = 1 + nproc_per_node: int = 1 + gpus_per_node: Optional[int] = None + cpu_requests: Optional[str] = None + memory_requests: Optional[str] = None + cpu_limits: Optional[str] = None + memory_limits: Optional[str] = None + volume_mounts: list[dict[str, Any]] = field(default_factory=list) + volumes: list[dict[str, Any]] = field(default_factory=list) + labels: dict[str, Any] = field(default_factory=dict) + annotations: dict[str, Any] = field(default_factory=dict) + restart_policy: str = "OnFailure" + image_pull_secrets: list[str] = field(default_factory=list) + spec_kwargs: dict[str, Any] = field(default_factory=dict) + container_kwargs: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + try: + config.load_kube_config() + except Exception as original_exc: + try: + config.load_incluster_config() + except Exception: + raise original_exc + self._custom_objects_api = client.CustomObjectsApi() + self._core_v1_api = client.CoreV1Api() + + def assign( + self, + exp_id: str, + exp_dir: str, + task_id: str, + task_dir: str, + ) -> None: + self.experiment_id = exp_id + self.experiment_dir = exp_dir + self.job_name = task_id + self.job_dir = os.path.join(exp_dir, task_dir) + + def nnodes(self) -> int: + return 1 + self.num_workers + + def get_job_body(self, name: str, command: list[str]) -> dict: + """Build the PyTorchJob CRD manifest dict.""" + resources: dict[str, Any] = {} + limits: dict[str, Any] = {} + requests: dict[str, Any] = {} + + if self.gpus_per_node is not None: + limits["nvidia.com/gpu"] = str(self.gpus_per_node) + requests["nvidia.com/gpu"] = str(self.gpus_per_node) + if self.cpu_requests: + requests["cpu"] = self.cpu_requests + if self.memory_requests: + requests["memory"] = self.memory_requests + if self.cpu_limits: + limits["cpu"] = self.cpu_limits + if self.memory_limits: + limits["memory"] = self.memory_limits + if limits: + resources["limits"] = limits + if requests: + resources["requests"] = requests + + env = [{"name": k, "value": v} for k, v in self.env_vars.items()] + + container: dict[str, Any] = { + "name": "pytorch", + "image": self.image, + "command": command, + "env": env, + } + if self.volume_mounts: + container["volumeMounts"] = self.volume_mounts + if resources: + container["resources"] = resources + container.update(self.container_kwargs) + + pod_spec: dict[str, Any] = {"containers": [container]} + if self.volumes: + pod_spec["volumes"] = self.volumes + if self.image_pull_secrets: + pod_spec["imagePullSecrets"] = [{"name": s} for s in self.image_pull_secrets] + + template_metadata: dict[str, Any] = {} + if self.labels: + template_metadata["labels"] = self.labels + if self.annotations: + template_metadata["annotations"] = self.annotations + + replica_spec: dict[str, Any] = { + "restartPolicy": self.restart_policy, + "template": { + "metadata": template_metadata, + "spec": pod_spec, + }, + } + + spec: dict[str, Any] = { + "nprocPerNode": str(self.nproc_per_node), + "pytorchReplicaSpecs": { + "Master": { + "replicas": 1, + **replica_spec, + }, + "Worker": { + "replicas": self.num_workers, + **replica_spec, + }, + }, + **self.spec_kwargs, + } + + return { + "apiVersion": f"{GROUP}/{VERSION}", + "kind": KIND, + "metadata": { + "name": name, + "namespace": self.namespace, + "labels": self.labels, + "annotations": self.annotations, + }, + "spec": spec, + } + + def launch( + self, + name: str, + cmd: list[str], + wait: bool = False, + timeout: int = 300, + poll_interval: int = 10, + ) -> tuple[str, PyTorchJobState]: + name = name.replace("_", "-").replace(".", "-").lower() + job_body = self.get_job_body(name, cmd) + try: + self._custom_objects_api.create_namespaced_custom_object( + group=GROUP, + version=VERSION, + namespace=self.namespace, + plural=PLURAL, + body=job_body, + ) + except ApiException as e: + if e.status == 409: + raise RuntimeError( + f"PyTorchJob {name} already exists in namespace {self.namespace}" + ) from e + raise + + if not wait: + return name, PyTorchJobState.CREATED + + deadline = time.time() + timeout + state = PyTorchJobState.CREATED + while time.time() < deadline: + state = self.status(name) or PyTorchJobState.UNKNOWN + if state == PyTorchJobState.RUNNING: + return name, state + if state in (PyTorchJobState.SUCCEEDED, PyTorchJobState.FAILED): + return name, state + time.sleep(poll_interval) + + raise RuntimeError( + f"PyTorchJob {name} did not reach RUNNING within {timeout}s, last state: {state}" + ) + + def status(self, job_name: str) -> Optional[PyTorchJobState]: + try: + resp = self._custom_objects_api.get_namespaced_custom_object( + group=GROUP, + version=VERSION, + namespace=self.namespace, + plural=PLURAL, + name=job_name, + ) + except ApiException as e: + if e.status == 404: + return None + logger.warning("API error getting status for %s: %s", job_name, e) + return None + + conditions = resp.get("status", {}).get("conditions", []) + state_map = { + "Running": PyTorchJobState.RUNNING, + "Succeeded": PyTorchJobState.SUCCEEDED, + "Failed": PyTorchJobState.FAILED, + } + for cond in reversed(conditions): + if cond.get("status") == "True" and cond.get("type") in state_map: + return state_map[cond["type"]] + return PyTorchJobState.UNKNOWN + + def fetch_logs( + self, + job_name: str, + stream: bool = False, + lines: int = 100, + timeout: int = 60, + ) -> Iterable[str]: + label_selector = f"training.kubeflow.org/job-name={job_name}" + cmd = [ + "kubectl", + "logs", + "-l", + label_selector, + "-n", + self.namespace, + "--tail", + str(lines), + ] + if stream: + cmd.append("-f") + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1) + try: + for line in iter(proc.stdout.readline, ""): + if line: + yield line + if proc.poll() is not None: + break + except Exception as e: + logger.error("Error streaming logs: %s", e) + finally: + proc.terminate() + proc.wait(timeout=2) + else: + result = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout) + yield from result.stdout.splitlines() + + def cancel( + self, + job_name: str, + wait: bool = False, + timeout: int = 300, + poll_interval: int = 5, + ) -> Optional[bool]: + try: + self._custom_objects_api.delete_namespaced_custom_object( + group=GROUP, + version=VERSION, + namespace=self.namespace, + plural=PLURAL, + name=job_name, + ) + except ApiException as e: + if e.status == 404: + logger.info("PyTorchJob %s already deleted", job_name) + return None + raise + + if not wait: + return None + + label_selector = f"training.kubeflow.org/job-name={job_name}" + deadline = time.time() + timeout + + while time.time() < deadline: + time.sleep(poll_interval) + + # Check if CR is gone + try: + self._custom_objects_api.get_namespaced_custom_object( + group=GROUP, + version=VERSION, + namespace=self.namespace, + plural=PLURAL, + name=job_name, + ) + # CR still present + continue + except ApiException as e: + if e.status != 404: + continue + + # CR is gone; check pods + pods = self._core_v1_api.list_namespaced_pod( + namespace=self.namespace, + label_selector=label_selector, + ) + if len(pods.items) == 0: + return True + + return False + + def package(self, packager: Packager, job_name: str) -> None: + pass + + def macro_values(self) -> Optional[ExecutorMacros]: + return None diff --git a/nemo_run/run/torchx_backend/schedulers/pytorchjob.py b/nemo_run/run/torchx_backend/schedulers/pytorchjob.py new file mode 100644 index 00000000..f24a3a9f --- /dev/null +++ b/nemo_run/run/torchx_backend/schedulers/pytorchjob.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import shutil +import tempfile +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Any, Iterable, Optional + +import fiddle as fdl +import fiddle._src.experimental.dataclasses as fdl_dc +from torchx.schedulers.api import ( + AppDryRunInfo, + DescribeAppResponse, + ListAppResponse, + Scheduler, + Stream, + split_lines, +) +from torchx.specs import AppDef, AppState, ReplicaStatus, Role, RoleStatus, runopts + +from nemo_run.config import get_nemorun_home +from nemo_run.core.execution.base import Executor +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor, PyTorchJobState +from nemo_run.core.serialization.zlib_json import ZlibJSONSerializer +from nemo_run.run.torchx_backend.schedulers.api import SchedulerMixin + +PYTORCHJOB_JOB_DIRS = os.path.join(get_nemorun_home(), ".pytorchjob_jobs.json") + +PYTORCHJOB_STATES: dict[Optional[PyTorchJobState], AppState] = { + PyTorchJobState.CREATED: AppState.SUBMITTED, + PyTorchJobState.RUNNING: AppState.RUNNING, + PyTorchJobState.SUCCEEDED: AppState.SUCCEEDED, + PyTorchJobState.FAILED: AppState.FAILED, + PyTorchJobState.UNKNOWN: AppState.PENDING, + None: AppState.PENDING, +} + +log = logging.getLogger(__name__) + + +@dataclass +class PyTorchJobRequest: + """Wrapper around the TorchX AppDef and the PyTorchJobExecutor.""" + + app: AppDef + executor: PyTorchJobExecutor + cmd: list[str] + name: str + + +class PyTorchJobScheduler(SchedulerMixin, Scheduler[dict[str, str]]): # type: ignore + def __init__(self, session_name: str) -> None: + super().__init__("pytorchjob", session_name) + + def _run_opts(self) -> runopts: + opts = runopts() + opts.add( + "job_dir", + type_=str, + help="The directory to place the job code and outputs.", + ) + return opts + + def _submit_dryrun( # type: ignore + self, + app: AppDef, + cfg: Executor, + ) -> AppDryRunInfo[PyTorchJobRequest]: + assert isinstance(cfg, PyTorchJobExecutor), ( + f"{cfg.__class__} not supported for PyTorchJob scheduler." + ) + executor = cfg + assert len(app.roles) == 1, "Only single-role apps are supported." + role = app.roles[0] + values = cfg.macro_values() + if values: + role = values.apply(role) + + cmd = [role.entrypoint] + role.args + req = PyTorchJobRequest(app=app, executor=executor, cmd=cmd, name=role.name) + + return AppDryRunInfo( + req, + lambda r: f"PyTorchJob for app: {r.app.name}, cmd: {' '.join(r.cmd)}", + ) + + def schedule(self, dryrun_info: AppDryRunInfo[PyTorchJobRequest]) -> str: + req = dryrun_info.request + executor = req.executor + + executor.package(executor.packager, job_name=executor.job_name) + + job_name, status = executor.launch(name=req.name, cmd=req.cmd) + if not job_name: + raise RuntimeError("Failed scheduling run on PyTorchJob: no job_name returned") + + role_name = req.app.roles[0].name + experiment_id = getattr(executor, "experiment_id", "pytorchjob_experiment") + app_id = f"{experiment_id}___{role_name}___{job_name}" + + _save_job_dir(app_id, job_status=status.value, executor=executor, job_name=job_name) + return app_id + + def describe(self, app_id: str) -> Optional[DescribeAppResponse]: + stored_data = _get_job_dirs() + job_info = stored_data.get(app_id) + parts = app_id.split("___") + role_name = parts[1] if len(parts) > 1 else app_id + roles = [Role(name=role_name, image="", num_replicas=1)] + roles_statuses = [ + RoleStatus( + role_name, + replicas=[ + ReplicaStatus(id=0, role=role_name, state=AppState.SUBMITTED, hostname="") + ], + ) + ] + + if not job_info: + return None + + executor: PyTorchJobExecutor = job_info.get("executor", None) # type: ignore + if not executor: + return None + + # Use stored job_name to avoid re-splitting app_id (handles role names with '___') + job_name = job_info.get("job_name") or parts[-1] + pj_state = executor.status(job_name) + app_state = PYTORCHJOB_STATES.get(pj_state, AppState.PENDING) + roles_statuses[0].replicas[0].state = app_state + + return DescribeAppResponse( + app_id=app_id, + roles=roles, + roles_statuses=roles_statuses, + state=app_state, + msg="", + ) + + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + stored_data = _get_job_dirs() + job_info = stored_data.get(app_id) + if not job_info: + return [] + job_name = job_info.get("job_name") or app_id.split("___")[-1] + executor: Optional[PyTorchJobExecutor] = job_info.get("executor", None) # type: ignore + if not executor: + return [] + + logs = executor.fetch_logs(job_name=job_name, stream=should_tail) + if isinstance(logs, str): + if len(logs) == 0: + logs = [] + else: + logs = split_lines(logs) + + return logs + + def _cancel_existing(self, app_id: str) -> None: + stored_data = _get_job_dirs() + job_info = stored_data.get(app_id) + if not job_info: + return None + job_name = job_info.get("job_name") or app_id.split("___")[-1] + executor: PyTorchJobExecutor = job_info.get("executor", None) # type: ignore + if not executor: + return None + executor.cancel(job_name) + + def list(self) -> list[ListAppResponse]: ... + + def _validate(self, app: AppDef, scheduler: str) -> None: + pass + + +def create_scheduler(session_name: str, **kwargs: Any) -> PyTorchJobScheduler: + return PyTorchJobScheduler(session_name=session_name) + + +def _save_job_dir( + app_id: str, job_status: str, executor: PyTorchJobExecutor, job_name: str = "" +) -> None: + original_apps = {} + job_dirs_path = os.path.join(get_nemorun_home(), ".pytorchjob_jobs.json") + os.makedirs(os.path.dirname(job_dirs_path), exist_ok=True) + if not os.path.isfile(job_dirs_path): + Path(job_dirs_path).touch() + + serializer = ZlibJSONSerializer() + with open(job_dirs_path, "r+") as f: + try: + original_apps = json.load(f) + except Exception: + original_apps = {} + + app = { + "job_status": job_status, + "job_name": job_name, + "executor": serializer.serialize( + fdl_dc.convert_dataclasses_to_configs(executor, allow_post_init=True) + ), + } + original_apps[app_id] = app + + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as fp: + json.dump(original_apps, fp) + temp_path = fp.name + + f.close() + shutil.move(temp_path, job_dirs_path) + + +def _get_job_dirs() -> dict[str, dict[str, Any]]: + job_dirs_path = os.path.join(get_nemorun_home(), ".pytorchjob_jobs.json") + if not os.path.isfile(job_dirs_path): + return {} + with open(job_dirs_path, "r") as f: + data = json.load(f) + + serializer = ZlibJSONSerializer() + for app in data.values(): + try: + app["executor"] = fdl.build(serializer.deserialize(app["executor"])) + except Exception as e: + log.debug("Failed to deserialize executor: %s", e) + continue + + return data diff --git a/pyproject.toml b/pyproject.toml index 4e0d00f9..85ec87e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ docker_persistent = "nemo_run.run.torchx_backend.schedulers.docker:create_schedu dgx_cloud = "nemo_run.run.torchx_backend.schedulers.dgxcloud:create_scheduler" lepton = "nemo_run.run.torchx_backend.schedulers.lepton:create_scheduler" skypilot_jobs = "nemo_run.run.torchx_backend.schedulers.skypilot_jobs:create_scheduler" +pytorchjob = "nemo_run.run.torchx_backend.schedulers.pytorchjob:create_scheduler" [project.optional-dependencies] skypilot = [ diff --git a/test/core/execution/test_pytorchjob.py b/test/core/execution/test_pytorchjob.py new file mode 100644 index 00000000..266bd31d --- /dev/null +++ b/test/core/execution/test_pytorchjob.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +from kubernetes.client.rest import ApiException + +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor, PyTorchJobState + + +class TestPyTorchJobExecutor: + @pytest.fixture + def mock_k8s_clients(self): + with ( + patch("nemo_run.core.execution.pytorchjob.config.load_kube_config"), + patch("nemo_run.core.execution.pytorchjob.client.CustomObjectsApi") as mock_custom, + patch("nemo_run.core.execution.pytorchjob.client.CoreV1Api") as mock_core, + ): + yield mock_custom.return_value, mock_core.return_value + + @pytest.fixture + def executor(self, mock_k8s_clients): + return PyTorchJobExecutor( + image="nvcr.io/nvidian/nemo:nightly", + num_workers=2, + gpus_per_node=8, + ) + + # ── Initialization ────────────────────────────────────────────────────────── + + def test_executor_defaults(self, executor): + assert executor.namespace == "default" + assert executor.restart_policy == "OnFailure" + assert executor.nproc_per_node == 1 + + def test_kubeconfig_fallback_to_incluster(self): + with ( + patch("nemo_run.core.execution.pytorchjob.config.load_kube_config") as mock_load, + patch( + "nemo_run.core.execution.pytorchjob.config.load_incluster_config" + ) as mock_incluster, + patch("nemo_run.core.execution.pytorchjob.client.CustomObjectsApi"), + patch("nemo_run.core.execution.pytorchjob.client.CoreV1Api"), + ): + mock_load.side_effect = Exception("no kubeconfig") + PyTorchJobExecutor(image="test:latest") + mock_incluster.assert_called_once() + + def test_kubeconfig_both_fail_raises(self): + with ( + patch("nemo_run.core.execution.pytorchjob.config.load_kube_config") as mock_load, + patch( + "nemo_run.core.execution.pytorchjob.config.load_incluster_config" + ) as mock_incluster, + patch("nemo_run.core.execution.pytorchjob.client.CustomObjectsApi"), + patch("nemo_run.core.execution.pytorchjob.client.CoreV1Api"), + ): + mock_load.side_effect = Exception("no kubeconfig") + mock_incluster.side_effect = Exception("not in cluster") + with pytest.raises(Exception, match="no kubeconfig"): + PyTorchJobExecutor(image="test:latest") + + def test_nnodes(self, executor): + assert executor.nnodes() == 3 # 1 Master + 2 Workers + + def test_nproc_per_node(self, mock_k8s_clients): + e = PyTorchJobExecutor(image="test:latest", nproc_per_node=4) + assert e.nproc_per_node == 4 + + def test_assign(self, executor): + executor.assign("exp-1", "/tmp/exp", "task-0", "task-0") + assert executor.experiment_id == "exp-1" + assert executor.experiment_dir == "/tmp/exp" + assert executor.job_dir == "/tmp/exp/task-0" + + # ── Manifest generation ────────────────────────────────────────────────────── + + def test_get_job_body_structure(self, executor): + body = executor.get_job_body("my-job", ["/bin/bash", "-c", "echo hi"]) + assert body["apiVersion"] == "kubeflow.org/v1" + assert body["kind"] == "PyTorchJob" + assert body["metadata"]["name"] == "my-job" + spec = body["spec"] + assert spec["nprocPerNode"] == "1" + assert "Master" in spec["pytorchReplicaSpecs"] + assert "Worker" in spec["pytorchReplicaSpecs"] + assert spec["pytorchReplicaSpecs"]["Master"]["replicas"] == 1 + assert spec["pytorchReplicaSpecs"]["Worker"]["replicas"] == 2 + + def test_get_job_body_resources(self, executor): + executor.cpu_requests = "16" + executor.memory_requests = "64Gi" + body = executor.get_job_body("my-job", ["python", "train.py"]) + container = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"]["containers"][ + 0 + ] + resources = container["resources"] + assert resources["limits"]["nvidia.com/gpu"] == "8" + assert resources["requests"]["cpu"] == "16" + assert resources["requests"]["memory"] == "64Gi" + + def test_get_job_body_no_gpu(self, mock_k8s_clients): + e = PyTorchJobExecutor(image="test:latest", gpus_per_node=None) + body = e.get_job_body("cpu-job", ["python", "train.py"]) + container = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"]["containers"][ + 0 + ] + resources = container.get("resources", {}) + limits = resources.get("limits", {}) + requests = resources.get("requests", {}) + assert "nvidia.com/gpu" not in limits + assert "nvidia.com/gpu" not in requests + + def test_get_job_body_volumes(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + volumes=[{"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}}], + volume_mounts=[{"name": "data", "mountPath": "/data"}], + ) + body = e.get_job_body("vol-job", ["echo", "hi"]) + spec = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"] + assert spec["volumes"] == [ + {"name": "data", "persistentVolumeClaim": {"claimName": "my-pvc"}} + ] + container = spec["containers"][0] + assert container["volumeMounts"] == [{"name": "data", "mountPath": "/data"}] + + def test_get_job_body_env_vars(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + env_vars={"MY_VAR": "hello", "OTHER": "world"}, + ) + body = e.get_job_body("env-job", ["echo"]) + container = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"]["containers"][ + 0 + ] + env_names = {item["name"]: item["value"] for item in container["env"]} + assert env_names["MY_VAR"] == "hello" + assert env_names["OTHER"] == "world" + + def test_get_job_body_labels_annotations(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + labels={"app": "my-app"}, + annotations={"note": "test"}, + ) + body = e.get_job_body("labeled-job", ["echo"]) + assert body["metadata"]["labels"] == {"app": "my-app"} + assert body["metadata"]["annotations"] == {"note": "test"} + pod_meta = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["metadata"] + assert pod_meta["labels"] == {"app": "my-app"} + + def test_get_job_body_image_pull_secrets(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + image_pull_secrets=["my-secret", "other-secret"], + ) + body = e.get_job_body("secret-job", ["echo"]) + pod_spec = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"] + assert pod_spec["imagePullSecrets"] == [ + {"name": "my-secret"}, + {"name": "other-secret"}, + ] + + def test_get_job_body_spec_kwargs(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + spec_kwargs={"elasticPolicy": {"maxRestarts": 3}}, + ) + body = e.get_job_body("spec-job", ["echo"]) + assert body["spec"]["elasticPolicy"] == {"maxRestarts": 3} + + def test_get_job_body_container_kwargs(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="test:latest", + container_kwargs={"securityContext": {"runAsUser": 1000}}, + ) + body = e.get_job_body("ckwargs-job", ["echo"]) + container = body["spec"]["pytorchReplicaSpecs"]["Master"]["template"]["spec"]["containers"][ + 0 + ] + assert container["securityContext"] == {"runAsUser": 1000} + + def test_get_job_body_artifact(self, mock_k8s_clients): + e = PyTorchJobExecutor( + image="nvcr.io/nvidian/nemo:nightly", + namespace="runai-nemo-ci", + num_workers=2, + nproc_per_node=8, + gpus_per_node=8, + cpu_requests="16", + memory_requests="64Gi", + volumes=[{"name": "model-cache", "persistentVolumeClaim": {"claimName": "my-pvc"}}], + volume_mounts=[{"name": "model-cache", "mountPath": "/nemo-workspace"}], + labels={"app": "nemo-ci-training"}, + ) + body = e.get_job_body("nemo-ci-training", ["/bin/bash", "-c", "echo hi"]) + + assert body["apiVersion"] == "kubeflow.org/v1" + assert body["kind"] == "PyTorchJob" + assert body["metadata"]["name"] == "nemo-ci-training" + assert body["metadata"]["namespace"] == "runai-nemo-ci" + spec = body["spec"] + assert spec["nprocPerNode"] == "8" + master = spec["pytorchReplicaSpecs"]["Master"] + worker = spec["pytorchReplicaSpecs"]["Worker"] + assert master["replicas"] == 1 + assert worker["replicas"] == 2 + for replica in [master, worker]: + container = replica["template"]["spec"]["containers"][0] + assert container["image"] == "nvcr.io/nvidian/nemo:nightly" + assert container["resources"]["limits"]["nvidia.com/gpu"] == "8" + assert container["resources"]["requests"]["cpu"] == "16" + assert container["resources"]["requests"]["memory"] == "64Gi" + + # ── Launch / status / cancel ───────────────────────────────────────────────── + + def test_launch_success(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.create_namespaced_custom_object.return_value = {} + + job_name, state = executor.launch("test-job", ["/bin/bash", "-c", "echo hi"]) + assert job_name == "test-job" + assert state == PyTorchJobState.CREATED + mock_custom.create_namespaced_custom_object.assert_called_once() + + def test_launch_wait_until_running(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.create_namespaced_custom_object.return_value = {} + mock_custom.get_namespaced_custom_object.side_effect = [ + {"status": {"conditions": [{"type": "Created", "status": "True"}]}}, + {"status": {"conditions": [{"type": "Running", "status": "True"}]}}, + ] + + with patch("time.sleep"): + job_name, state = executor.launch( + "test-job", ["/bin/bash", "-c", "echo hi"], wait=True, timeout=30 + ) + assert state == PyTorchJobState.RUNNING + + def test_launch_wait_timeout(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.create_namespaced_custom_object.return_value = {} + mock_custom.get_namespaced_custom_object.return_value = { + "status": {"conditions": [{"type": "Created", "status": "True"}]} + } + + with patch("time.sleep"): + with pytest.raises(RuntimeError, match="did not reach RUNNING"): + executor.launch("test-job", ["echo"], wait=True, timeout=-1) + + def test_launch_conflict(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.create_namespaced_custom_object.side_effect = ApiException(status=409) + + with pytest.raises(RuntimeError, match="already exists"): + executor.launch("test-job", ["/bin/bash", "-c", "echo hi"]) + + def test_status_running(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.get_namespaced_custom_object.return_value = { + "status": { + "conditions": [ + {"type": "Created", "status": "True"}, + {"type": "Running", "status": "True"}, + ] + } + } + assert executor.status("test-job") == PyTorchJobState.RUNNING + + def test_status_succeeded(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.get_namespaced_custom_object.return_value = { + "status": { + "conditions": [ + {"type": "Running", "status": "False"}, + {"type": "Succeeded", "status": "True"}, + ] + } + } + assert executor.status("test-job") == PyTorchJobState.SUCCEEDED + + def test_status_failed(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.get_namespaced_custom_object.return_value = { + "status": { + "conditions": [ + {"type": "Running", "status": "False"}, + {"type": "Failed", "status": "True"}, + ] + } + } + assert executor.status("test-job") == PyTorchJobState.FAILED + + def test_status_not_found(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.get_namespaced_custom_object.side_effect = ApiException(status=404) + assert executor.status("missing-job") is None + + def test_status_api_error(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.get_namespaced_custom_object.side_effect = ApiException(status=500) + assert executor.status("bad-job") is None + + def test_cancel(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.delete_namespaced_custom_object.return_value = {} + # Should not raise + executor.cancel("test-job") + mock_custom.delete_namespaced_custom_object.assert_called_once() + + def test_cancel_already_deleted(self, executor, mock_k8s_clients): + mock_custom, _ = mock_k8s_clients + mock_custom.delete_namespaced_custom_object.side_effect = ApiException(status=404) + result = executor.cancel("gone-job") + assert result is None # handled gracefully + + def test_cancel_with_wait(self, executor, mock_k8s_clients): + mock_custom, mock_core = mock_k8s_clients + mock_custom.delete_namespaced_custom_object.return_value = {} + # CR is gone on first poll + mock_custom.get_namespaced_custom_object.side_effect = ApiException(status=404) + mock_core.list_namespaced_pod.return_value = MagicMock(items=[]) + + with patch("time.sleep"): + result = executor.cancel("test-job", wait=True, timeout=30, poll_interval=0) + assert result is True + + def test_cancel_with_wait_timeout(self, executor, mock_k8s_clients): + mock_custom, mock_core = mock_k8s_clients + mock_custom.delete_namespaced_custom_object.return_value = {} + # CR never disappears + mock_custom.get_namespaced_custom_object.return_value = {"metadata": {"name": "test-job"}} + + with patch("time.sleep"): + result = executor.cancel("test-job", wait=True, timeout=-1, poll_interval=0) + assert result is False + + # ── Logs ───────────────────────────────────────────────────────────────────── + + def test_fetch_logs_no_follow(self, executor, mock_k8s_clients): + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(stdout="line1\nline2\n") + lines = list(executor.fetch_logs("my-job", stream=False, lines=50)) + + mock_run.assert_called_once() + called_cmd = mock_run.call_args[0][0] + assert "--tail" in called_cmd + assert "50" in called_cmd + label_arg = " ".join(called_cmd) + assert "training.kubeflow.org/job-name=my-job" in label_arg + assert "-f" not in called_cmd + assert lines == ["line1", "line2"] + + def test_fetch_logs_follow(self, executor, mock_k8s_clients): + import io + + mock_proc = MagicMock() + mock_proc.stdout = io.StringIO("line1\nline2\n") + mock_proc.poll.return_value = None # still running; loop exits when readline() hits EOF + + with patch("subprocess.Popen", return_value=mock_proc) as mock_popen: + lines = list(executor.fetch_logs("my-job", stream=True, lines=100)) + + mock_popen.assert_called_once() + called_cmd = mock_popen.call_args[0][0] + assert "-f" in called_cmd + assert lines == ["line1\n", "line2\n"] diff --git a/test/run/torchx_backend/schedulers/test_pytorchjob.py b/test/run/torchx_backend/schedulers/test_pytorchjob.py new file mode 100644 index 00000000..e56c8fb9 --- /dev/null +++ b/test/run/torchx_backend/schedulers/test_pytorchjob.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from torchx.schedulers.api import AppDryRunInfo +from torchx.specs import AppDef, AppState, Role + +from nemo_run.core.execution.pytorchjob import PyTorchJobExecutor, PyTorchJobState +from nemo_run.run.torchx_backend.schedulers.pytorchjob import ( + PYTORCHJOB_STATES, + PyTorchJobScheduler, + create_scheduler, +) + + +@pytest.fixture +def mock_k8s(): + with ( + patch("nemo_run.core.execution.pytorchjob.config.load_kube_config"), + patch("nemo_run.core.execution.pytorchjob.client.CustomObjectsApi") as mock_custom, + patch("nemo_run.core.execution.pytorchjob.client.CoreV1Api") as mock_core, + ): + yield mock_custom.return_value, mock_core.return_value + + +@pytest.fixture +def executor(mock_k8s, tmp_path): + e = PyTorchJobExecutor( + image="nvcr.io/nvidian/nemo:nightly", + num_workers=2, + gpus_per_node=8, + ) + e.experiment_id = "test_exp" + e.job_dir = str(tmp_path) + e.experiment_dir = str(tmp_path) + e.job_name = "test_role" + return e + + +@pytest.fixture +def scheduler(): + return create_scheduler(session_name="test") + + +@pytest.fixture +def mock_app_def(): + return AppDef( + name="test_app", + roles=[ + Role( + name="test_role", + image="nvcr.io/nvidian/nemo:nightly", + entrypoint="python", + args=["train.py"], + ) + ], + ) + + +# ── Scheduler lifecycle ─────────────────────────────────────────────────────── + + +def test_create_scheduler(): + s = create_scheduler(session_name="test") + assert isinstance(s, PyTorchJobScheduler) + assert s.session_name == "test" + + +def test_submit_dryrun(scheduler, mock_app_def, executor): + with mock.patch.object(PyTorchJobExecutor, "package") as mock_pkg: + mock_pkg.return_value = None + dryrun_info = scheduler._submit_dryrun(mock_app_def, executor) + assert isinstance(dryrun_info, AppDryRunInfo) + assert dryrun_info.request is not None + + +def test_schedule(scheduler, mock_app_def, executor): + with ( + mock.patch.object(PyTorchJobExecutor, "package") as mock_pkg, + mock.patch.object(PyTorchJobExecutor, "launch") as mock_launch, + ): + mock_pkg.return_value = None + mock_launch.return_value = ("test-job", PyTorchJobState.CREATED) + + dryrun_info = scheduler._submit_dryrun(mock_app_def, executor) + app_id = scheduler.schedule(dryrun_info) + + assert app_id == "test_exp___test_role___test-job" + mock_pkg.assert_called_once() + mock_launch.assert_called_once() + + +# ── State mapping ───────────────────────────────────────────────────────────── + + +def test_describe_running(scheduler, executor): + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Created", + "job_name": "test-job", + "executor": executor, + } + } + with mock.patch.object(PyTorchJobExecutor, "status", return_value=PyTorchJobState.RUNNING): + resp = scheduler.describe("test_exp___test_role___test-job") + assert resp is not None + assert resp.state == AppState.RUNNING + + +def test_describe_succeeded(scheduler, executor): + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Created", + "job_name": "test-job", + "executor": executor, + } + } + with mock.patch.object( + PyTorchJobExecutor, "status", return_value=PyTorchJobState.SUCCEEDED + ): + resp = scheduler.describe("test_exp___test_role___test-job") + assert resp.state == AppState.SUCCEEDED + + +def test_describe_failed(scheduler, executor): + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Created", + "job_name": "test-job", + "executor": executor, + } + } + with mock.patch.object(PyTorchJobExecutor, "status", return_value=PyTorchJobState.FAILED): + resp = scheduler.describe("test_exp___test_role___test-job") + assert resp.state == AppState.FAILED + + +def test_describe_unknown_maps_to_pending(scheduler, executor): + # None status (transient error) must not become FAILED — avoids false failures + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Created", + "job_name": "test-job", + "executor": executor, + } + } + with mock.patch.object(PyTorchJobExecutor, "status", return_value=None): + resp = scheduler.describe("test_exp___test_role___test-job") + assert resp.state == AppState.PENDING + + +def test_describe_uses_stored_job_id_not_split(scheduler, executor): + # Regression: role names containing '___' must not corrupt app_id parsing. + real_job_name = "real-job-abc123" + app_id = f"experiment___role_name___{real_job_name}" + + with ( + mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs, + mock.patch.object( + PyTorchJobExecutor, "status", return_value=PyTorchJobState.RUNNING + ) as mock_status, + ): + mock_dirs.return_value = { + app_id: { + "job_status": "Created", + "job_name": real_job_name, + "executor": executor, + } + } + resp = scheduler.describe(app_id) + + assert resp is not None + mock_status.assert_called_once_with(real_job_name) + + +# ── Cancel / logs ───────────────────────────────────────────────────────────── + + +def test_cancel_existing(scheduler, executor): + with ( + mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs, + mock.patch.object(PyTorchJobExecutor, "cancel") as mock_cancel, + ): + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Running", + "job_name": "test-job", + "executor": executor, + } + } + scheduler._cancel_existing("test_exp___test_role___test-job") + mock_cancel.assert_called_once_with("test-job") + + +def test_log_iter_list(scheduler, executor): + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Running", + "job_name": "test-job", + "executor": executor, + } + } + executor.fetch_logs = MagicMock(return_value=["log line 1", "log line 2"]) + + lines = list(scheduler.log_iter("test_exp___test_role___test-job", "test_role")) + assert lines == ["log line 1", "log line 2"] + + +def test_log_iter_str(scheduler, executor): + with mock.patch("nemo_run.run.torchx_backend.schedulers.pytorchjob._get_job_dirs") as mock_dirs: + mock_dirs.return_value = { + "test_exp___test_role___test-job": { + "job_status": "Running", + "job_name": "test-job", + "executor": executor, + } + } + executor.fetch_logs = MagicMock(return_value="log line 1\nlog line 2") + + lines = list(scheduler.log_iter("test_exp___test_role___test-job", "test_role")) + assert "log line 1\n" in lines or "log line 1" in lines + + +# ── Persistence ─────────────────────────────────────────────────────────────── + + +def test_save_job_dir_new_file(executor, tmp_path): + from nemo_run.config import set_nemorun_home + + set_nemorun_home(str(tmp_path)) + + from nemo_run.run.torchx_backend.schedulers.pytorchjob import _get_job_dirs, _save_job_dir + + _save_job_dir("my_app_id", job_status="Created", executor=executor, job_name="my-job") + dirs = _get_job_dirs() + assert "my_app_id" in dirs + assert dirs["my_app_id"]["job_name"] == "my-job" + assert isinstance(dirs["my_app_id"]["executor"], PyTorchJobExecutor) + + +def test_save_job_dir_existing_file(executor, tmp_path): + from nemo_run.config import set_nemorun_home + + set_nemorun_home(str(tmp_path)) + + from nemo_run.run.torchx_backend.schedulers.pytorchjob import _get_job_dirs, _save_job_dir + + _save_job_dir("app_id_1", job_status="Created", executor=executor, job_name="job-1") + _save_job_dir("app_id_2", job_status="Running", executor=executor, job_name="job-2") + + dirs = _get_job_dirs() + assert "app_id_1" in dirs + assert "app_id_2" in dirs + + +def test_get_job_dirs_file_not_found(tmp_path): + from nemo_run.config import set_nemorun_home + + set_nemorun_home(str(tmp_path)) + + from nemo_run.run.torchx_backend.schedulers.pytorchjob import _get_job_dirs + + result = _get_job_dirs() + assert result == {} + + +# ── State map ───────────────────────────────────────────────────────────────── + + +def test_unknown_state_maps_to_pending(): + assert PYTORCHJOB_STATES[PyTorchJobState.UNKNOWN] == AppState.PENDING