diff --git a/examples/ml/eval_callback.py b/examples/ml/eval_callback.py index fdc68fa95..55c2a248e 100644 --- a/examples/ml/eval_callback.py +++ b/examples/ml/eval_callback.py @@ -31,13 +31,13 @@ from flyte.io import File image = flyte.Image.from_debian_base(name="lightning-eval").with_pip_packages( - "lightning==2.6.1", "flyteplugins-pytorch==2.0.2" + "lightning==2.6.1", "flyteplugins-pytorch==2.0.3" ) # Multi-node training: 2 nodes, 1 process per node train_env = flyte.TaskEnvironment( name="distributed-train", - resources=flyte.Resources(cpu=4, memory="25Gi", gpu="L4:1"), + resources=flyte.Resources(cpu=4, memory="25Gi", gpu="T4:1"), plugin_config=Elastic( nproc_per_node=1, nnodes=2, diff --git a/examples/plugins/torch_example.py b/examples/plugins/torch_example.py index 59c19cd8f..677141db2 100644 --- a/examples/plugins/torch_example.py +++ b/examples/plugins/torch_example.py @@ -5,28 +5,24 @@ import torch.nn as nn import torch.optim as optim from flyteplugins.pytorch.task import Elastic +from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader, DistributedSampler, TensorDataset import flyte -# Install flyteplugins-torch from the wheel for development. -# In production, you would just specify the package name and version. -# from flyte._image import DIST_FOLDER, PythonWheels -# image = flyte.Image.from_debian_base(name="torch").clone( -# addl_layer=PythonWheels(wheel_dir=DIST_FOLDER, package_name="flyteplugins-pytorch", pre=True) -# ) - -image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch") +image = flyte.Image.from_debian_base(name="torch").with_pip_packages( + "flyteplugins-pytorch", "flyteplugins-wandb" +) torch_env = flyte.TaskEnvironment( name="torch_env", - resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")), + resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi"), gpu="T4:1"), plugin_config=Elastic( nproc_per_node=1, - # if you want to do local testing set nnodes=1 nnodes=2, ), + secrets=flyte.Secret(key="NIELS_WANDB_API_KEY", as_env_var="WANDB_API_KEY"), image=image, ) @@ -45,9 +41,9 @@ def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataL Prepare a DataLoader with a DistributedSampler so each rank gets a shard of the dataset. """ - # Dummy dataset - x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) - y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]]) + n_samples = 100 + x_train = torch.randn(n_samples, 1) + y_train = 2.0 * x_train + 1.0 + 0.1 * torch.randn(n_samples, 1) dataset = TensorDataset(x_train, y_train) # Distributed-aware sampler @@ -76,8 +72,9 @@ def train_loop(epochs: int = 3) -> float: optimizer = optim.SGD(model.parameters(), lr=0.01) final_loss = 0.0 + wandb_run = get_wandb_run() - for _ in range(epochs): + for epoch in range(epochs): for x, y in dataloader: outputs = model(x) loss = criterion(outputs, y) @@ -87,12 +84,16 @@ def train_loop(epochs: int = 3) -> float: optimizer.step() final_loss = loss.item() + + if wandb_run: + wandb_run.log({"loss": final_loss, "epoch": epoch}) if torch.distributed.get_rank() == 0: print(f"Loss: {final_loss}") return final_loss +@wandb_init @torch_env.task def torch_distributed_train(epochs: int) -> typing.Optional[float]: """ @@ -106,6 +107,9 @@ def torch_distributed_train(epochs: int) -> typing.Optional[float]: if __name__ == "__main__": flyte.init_from_config() - run = flyte.with_runcontext(mode="remote").run(torch_distributed_train, epochs=3) + run = flyte.with_runcontext( + mode="remote", + custom_context=wandb_config(project="torch-distributed-training"), + ).run(torch_distributed_train, epochs=1_000_000) print("run name:", run.name) print("run url:", run.url) diff --git a/examples/scripts/hello.py b/examples/scripts/hello.py new file mode 100644 index 000000000..a065c7164 --- /dev/null +++ b/examples/scripts/hello.py @@ -0,0 +1,20 @@ +""" +Run with: + +```bash +flyte run --follow python-script hello.py --output-dir output +``` +""" + +import os + + +def main(): + print("Hello, world!") + os.makedirs("output", exist_ok=True) + with open("output/hello.txt", "w") as f: + f.write("Hello, file!") + + +if __name__ == "__main__": + main() diff --git a/src/flyte/_code_bundle/_utils.py b/src/flyte/_code_bundle/_utils.py index 5ebc45f06..18eef7f95 100644 --- a/src/flyte/_code_bundle/_utils.py +++ b/src/flyte/_code_bundle/_utils.py @@ -22,7 +22,7 @@ from ._ignore import Ignore, IgnoreGroup, StandardIgnore -CopyFiles = Literal["loaded_modules", "all", "none"] +CopyFiles = Literal["loaded_modules", "all", "none", "python_script"] def compress_scripts(source_path: str, destination: str, modules: List[ModuleType]): diff --git a/src/flyte/_internal/resolvers/script.py b/src/flyte/_internal/resolvers/script.py new file mode 100644 index 000000000..8216e36dc --- /dev/null +++ b/src/flyte/_internal/resolvers/script.py @@ -0,0 +1,48 @@ +"""Resolver for python-script tasks. + +Bundles a plain Python script into the code bundle and recreates +the ``execute_script`` task at runtime so the Flyte entrypoint can +run it without pickling. +""" + +from pathlib import Path +from typing import List, Optional + +from flyte._internal.resolvers.common import Resolver +from flyte._task import TaskTemplate + + +class ScriptTaskResolver(Resolver): + """Resolve a bundled Python script into an executable task. + + During serialization the resolver stores the script filename and timeout + as loader args. At runtime ``load_task`` recreates a lightweight task + that executes the script via ``subprocess``. + """ + + def __init__(self, script_name: str = "", output_dir: Optional[str] = None, timeout: int = 3600): + self._script_name = script_name + self._output_dir = output_dir + self._timeout = timeout + + @property + def import_path(self) -> str: + return "flyte._internal.resolvers.script.ScriptTaskResolver" + + def load_task(self, loader_args: List[str]) -> TaskTemplate: + args_iter = iter(loader_args) + parsed = dict(zip(args_iter, args_iter)) + script_name = parsed["script"] + output_dir = parsed.get("output_dir", None) + timeout = int(parsed.get("timeout", "3600")) + + from flyte._run_python_script import _build_script_runner_task + + return _build_script_runner_task(script_name, output_dir, timeout) + + def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]: + args = ["script", self._script_name] + if self._output_dir is not None: + args.extend(["output_dir", self._output_dir]) + args.extend(["timeout", str(self._timeout)]) + return args diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 322766fd3..3364c5120 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -116,10 +116,14 @@ def __init__( preserve_original_types: bool | None = None, debug: bool = False, _tracker: Any = None, + _bundle_relative_paths: tuple[str, ...] | None = None, + _bundle_from_dir: pathlib.Path | None = None, ): from flyte._tools import ipython_check self._tracker = _tracker + self._bundle_relative_paths = _bundle_relative_paths + self._bundle_from_dir = _bundle_from_dir init_config = _get_init_config() client = init_config.client if init_config else None if not force_mode and client is not None: @@ -223,16 +227,28 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to, ) - else: - if self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, + elif self._copy_files == "python_script": + from ._code_bundle.bundle import build_code_bundle_from_relative_paths + + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError( + "copy_style='python_script' requires _bundle_relative_paths and _bundle_from_dir" ) - else: - code_bundle = None + code_bundle = await build_code_bundle_from_relative_paths( + self._bundle_relative_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + ) + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + ) + else: + code_bundle = None if not self._disable_run_cache: _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] = _CacheValue( code_bundle=code_bundle, image_cache=image_cache @@ -469,16 +485,26 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to, ) + elif self._copy_files == "python_script": + from ._code_bundle.bundle import build_code_bundle_from_relative_paths + + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError("copy_style='python_script' requires _bundle_relative_paths and _bundle_from_dir") + code_bundle = await build_code_bundle_from_relative_paths( + self._bundle_relative_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + ) + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + ) else: - if self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, - ) - else: - code_bundle = None + code_bundle = None version = self._version or ( code_bundle.computed_version if code_bundle and code_bundle.computed_version else None diff --git a/src/flyte/_run_python_script.py b/src/flyte/_run_python_script.py index c2b664797..f84b25cb1 100644 --- a/src/flyte/_run_python_script.py +++ b/src/flyte/_run_python_script.py @@ -14,9 +14,11 @@ # class object at decoration time, not a deferred string. import pathlib +from dataclasses import dataclass from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import flyte.io from flyte.syncify import syncify if TYPE_CHECKING: @@ -24,37 +26,73 @@ from flyte.remote import Run +@dataclass +class PythonScriptOutput: + exit_code: int + stdout: str + stderr: str + output_dir: Optional[flyte.io.Dir] + + def _build_task( env: Any, + script_name: str, timeout: int, short_name: str, + output_dir: "Optional[str]" = None, + task_resolver: Any = None, ) -> Any: - """Build the execute_script task. + """Build the ``execute_script`` task for serialization. - Defined separately so that the ``File`` import is evaluated eagerly - and available as a real type when ``@env.task`` inspects annotations. + The *script_name* is captured via closure for local execution. When + running remotely the :class:`ScriptTaskResolver` recreates the task from + the loader args embedded in the container command, so the closure value + is not carried over the wire. """ - from flyte.io import File - task_timeout = timedelta(seconds=timeout) - @env.task(timeout=task_timeout, short_name=short_name) - async def execute_script(script_file: File, args: list, task_timeout: int) -> dict: + @env.task(timeout=task_timeout, short_name=short_name, task_resolver=task_resolver) + async def execute_script(args: list[str], task_timeout: int) -> PythonScriptOutput: """Execute a Python script on a remote machine.""" import subprocess import sys - local_path = await script_file.download() - cmd = [sys.executable, local_path, *args] - result = subprocess.run(cmd, text=True, check=False, timeout=task_timeout - 60) # noqa: ASYNC221 + cmd = [sys.executable, script_name, *args] + result = subprocess.run( # noqa: ASYNC221 + cmd, + capture_output=True, + text=True, + check=True, + timeout=task_timeout - 60, + ) - return { - "exit_code": result.returncode, - } + _dir: Optional[flyte.io.Dir] = None + if output_dir: + _dir = await flyte.io.Dir.from_local(output_dir) + + return PythonScriptOutput( + exit_code=result.returncode, + stdout=result.stdout[-1000:], + stderr=result.stderr[-1000:], + output_dir=_dir, + ) return execute_script +def _build_script_runner_task(script_name: str, output_dir: "Optional[str]" = None, timeout: int = 3600) -> Any: + """Build the ``execute_script`` task at runtime (called by :class:`ScriptTaskResolver`). + + Creates a minimal :class:`~flyte.TaskEnvironment` — only the function + signature matters here because the container already has the correct + image and resources. + """ + import flyte + + env = flyte.TaskEnvironment(name="python_script") + return _build_task(env, script_name, timeout, short_name=script_name, output_dir=output_dir) + + @syncify async def run_python_script( script: pathlib.Path, @@ -69,11 +107,15 @@ async def run_python_script( queue: "Optional[str]" = None, wait: bool = False, name: "Optional[str]" = None, + debug: bool = False, + output_dir: "Optional[str]" = None, ) -> "Run": """Package and run a Python script on a remote Flyte cluster. - Uploads the script via :class:`~flyte.io.File`, passes it as a typed input - to a Flyte task, and executes it remotely with the requested resources. + Bundles the script into a Flyte code bundle and executes it remotely + with the requested resources. Unlike ``interactive_mode`` (which + pickles the task), this approach uses a :class:`ScriptTaskResolver` + so the task can be properly debugged with ``debug=True``. Project and domain are read from the init config (set via ``flyte.init()`` or ``flyte.init_from_config()``), consistent with ``flyte.run()``. @@ -96,6 +138,9 @@ async def run_python_script( :param queue: Flyte queue / cluster override. :param wait: If True, block until execution completes before returning. :param name: Run name. If omitted, a random name is generated. + :param debug: If True, run the task as a VS Code debug task, starting a + code-server in the container so you can connect via the UI to + interactively debug/run the task. :return: A :class:`~flyte.remote.Run` handle for the remote execution. Example:: @@ -120,7 +165,8 @@ async def run_python_script( run = flyte.run_python_script(Path("analysis.py"), image=img) """ import flyte - from flyte.io import File + from flyte._internal.resolvers.script import ScriptTaskResolver + from flyte._run import _Runner script = pathlib.Path(script).resolve() if not script.exists(): @@ -153,20 +199,24 @@ async def run_python_script( env_kwargs["queue"] = queue env = flyte.TaskEnvironment(**env_kwargs) - # Build task (in a separate function so File annotation resolves correctly) + # Build task with the ScriptTaskResolver so the runner knows how to + # serialize and reload it without pickling. + resolver = ScriptTaskResolver(script.name, output_dir=output_dir, timeout=timeout) task_short_name = name or script.stem - execute_script = _build_task(env, timeout, short_name=task_short_name) - - script_file: File = await File.from_local(script) + execute_script = _build_task( + env, script.name, timeout, short_name=task_short_name, output_dir=output_dir, task_resolver=resolver + ) - runner = flyte.with_runcontext( - mode="remote", + runner = _Runner( + force_mode="remote", name=name, - interactive_mode=True, + debug=debug, + copy_style="python_script", + _bundle_relative_paths=(script.name,), + _bundle_from_dir=script.parent, ) run = await runner.run.aio( execute_script, - script_file=script_file, args=extra_args or [], task_timeout=timeout, ) diff --git a/src/flyte/_task.py b/src/flyte/_task.py index c21477cb2..05eb6fc53 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -470,6 +470,7 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R, F]): func: F plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration debuggable: bool = True + task_resolver: Optional[Any] = None def __post_init__(self): super().__post_init__() @@ -556,21 +557,23 @@ def container_args(self, serialize_context: SerializationContext) -> List[str]: if not serialize_context.code_bundle or not serialize_context.code_bundle.pkl: # If we do not have a code bundle, or if we have one, but it is not a pkl, we need to add the resolver + resolver = self.task_resolver + if resolver is None: + from flyte._internal.resolvers.default import DefaultTaskResolver - from flyte._internal.resolvers.default import DefaultTaskResolver + resolver = DefaultTaskResolver() if not serialize_context.root_dir: raise RuntimeSystemError( "SerializationError", "Root dir is required for default task resolver when no code bundle is provided.", ) - _task_resolver = DefaultTaskResolver() args = [ *args, *[ "--resolver", - _task_resolver.import_path, - *_task_resolver.loader_args(task=self, root_dir=serialize_context.root_dir), + resolver.import_path, + *resolver.loader_args(task=self, root_dir=serialize_context.root_dir), ], ] diff --git a/src/flyte/_task_environment.py b/src/flyte/_task_environment.py index 9bed35bd1..7c989ccd0 100644 --- a/src/flyte/_task_environment.py +++ b/src/flyte/_task_environment.py @@ -170,6 +170,7 @@ def task( queue: Optional[str] = None, triggers: Tuple[Trigger, ...] | Trigger = (), links: Tuple[Link, ...] | Link = (), + task_resolver: Any | None = None, ) -> Callable[[Callable[P, R]], AsyncFunctionTaskTemplate[P, R, Callable[P, R]]]: ... @overload @@ -195,6 +196,7 @@ def task( queue: Optional[str] = None, triggers: Tuple[Trigger, ...] | Trigger = (), links: Tuple[Link, ...] | Link = (), + task_resolver: Any | None = None, ) -> Callable[[F], AsyncFunctionTaskTemplate[P, R, F]] | AsyncFunctionTaskTemplate[P, R, F]: """ Decorate a function to be a task. @@ -277,6 +279,7 @@ def decorator(func: F) -> AsyncFunctionTaskTemplate[P, R, F]: interruptible=interruptible if interruptible is not None else self.interruptible, triggers=triggers if isinstance(triggers, tuple) else (triggers,), links=links if isinstance(links, tuple) else (links,), + task_resolver=task_resolver, ) self._tasks[task_name] = tmpl return tmpl diff --git a/src/flyte/cli/_run.py b/src/flyte/cli/_run.py index f100899f3..854daeccd 100644 --- a/src/flyte/cli/_run.py +++ b/src/flyte/cli/_run.py @@ -14,6 +14,7 @@ from .._code_bundle._utils import CopyFiles from .._task import TaskTemplate from ..remote import Run +from ..syncify import syncify from . import _common as common from ._params import to_click_option @@ -22,6 +23,7 @@ initialize_config = common.initialize_config +@syncify async def _render_debug_url(console, result: Run, config: common.CLIConfig) -> None: """Poll the run for the VS Code Debugger URL and print it.""" from flyte._debug.client import watch_for_vscode_url @@ -331,7 +333,7 @@ async def _render_remote_success(self, console, result, config): console.print(common.get_panel("Remote Run", run_info, config.output_format)) if self.run_args.debug: - await _render_debug_url(console, result, config) + await _render_debug_url.aio(console, result, config) if self.run_args.follow: from flyte._status import status @@ -554,7 +556,7 @@ async def _render_remote_success(self, console, result, config): console.print(common.get_panel("Remote Run", run_info, config.output_format)) if self.run_args.debug: - await _render_debug_url(console, result, config) + await _render_debug_url.aio(console, result, config) if self.run_args.follow: from flyte._status import status diff --git a/src/flyte/cli/_run_python_script.py b/src/flyte/cli/_run_python_script.py index 04a1d274d..b83628bc6 100644 --- a/src/flyte/cli/_run_python_script.py +++ b/src/flyte/cli/_run_python_script.py @@ -59,6 +59,12 @@ class PythonScriptCommand(CommandBase): help="Extra arguments passed to the script (comma-separated).", ) @click.option("--queue", type=str, default=None, help="Flyte queue / cluster override.") +@click.option( + "--output-dir", + type=str, + default=None, + help="Directory path (inside the container) to upload as output after the script finishes.", +) @click.pass_obj def python_script( cfg: common.CLIConfig, @@ -72,6 +78,7 @@ def python_script( timeout: int, extra_args: str | None, queue: str | None, + output_dir: str | None, ) -> None: """Run a Python script on a remote Flyte cluster. @@ -116,6 +123,7 @@ def python_script( name = run_args.name if run_args else None project = run_args.project if run_args else None domain = run_args.domain if run_args else None + debug = run_args.debug if run_args else False # Initialize flyte config (like prefetch does) initialize_config( @@ -150,6 +158,8 @@ def python_script( queue=queue, wait=False, name=name, + debug=debug, + output_dir=output_dir, ) url = run.url @@ -158,6 +168,11 @@ def python_script( f" Check the console for status at [link={url}]{url}[/link]" ) + if debug: + from flyte.cli._run import _render_debug_url + + _render_debug_url(console, run, cfg) + if follow: run.wait() try: diff --git a/src/flyte/io/_file.py b/src/flyte/io/_file.py index 1d1278f92..bc1e9c726 100644 --- a/src/flyte/io/_file.py +++ b/src/flyte/io/_file.py @@ -455,7 +455,7 @@ async def stream_read(f: File) -> str: # Fall back to aiofiles fs = storage.get_underlying_filesystem(path=self.path) if "file" in fs.protocol: - async with aiofiles.open(self.path, mode=mode, **kwargs) as f: + async with aiofiles.open(self.path, mode=mode, **kwargs) as f: # type: ignore[call-overload] yield f return raise diff --git a/tests/user_api/test_run_python_script.py b/tests/user_api/test_run_python_script.py index 16d554b85..f1b0d616d 100644 --- a/tests/user_api/test_run_python_script.py +++ b/tests/user_api/test_run_python_script.py @@ -23,27 +23,21 @@ def script(tmp_path): @pytest.fixture def mock_remote(): - """Mock File.from_local and flyte.with_runcontext so no real remote call happens.""" - mock_file = AsyncMock() + """Mock _Runner so no real remote call happens.""" mock_run = MagicMock() mock_runner = MagicMock() mock_runner.run.aio = AsyncMock(return_value=mock_run) mock_run.wait.aio = AsyncMock() - with ( - patch.object(flyte.io.File, "from_local", new_callable=AsyncMock, return_value=mock_file) as mock_from_local, - patch("flyte.with_runcontext", return_value=mock_runner) as mock_runcontext, - ): + with patch("flyte._run._Runner", return_value=mock_runner) as mock_runner_cls: yield { - "file": mock_file, - "from_local": mock_from_local, "run": mock_run, "runner": mock_runner, - "runcontext": mock_runcontext, + "runner_cls": mock_runner_cls, } -# --------------------------------------------------------------------------- +# ---------------------------------------------------------- ----------------- # _build_task # --------------------------------------------------------------------------- @@ -53,22 +47,22 @@ class TestBuildTask: def test_task_short_name(self): env = flyte.TaskEnvironment(name="test_env") - task = _build_task(env, timeout=3600, short_name="my_script") + task = _build_task(env, script_name="my_script.py", timeout=3600, short_name="my_script") assert task.short_name == "my_script" def test_task_short_name_custom(self): env = flyte.TaskEnvironment(name="test_env2") - task = _build_task(env, timeout=3600, short_name="custom_name") + task = _build_task(env, script_name="script.py", timeout=3600, short_name="custom_name") assert task.short_name == "custom_name" def test_task_timeout(self): env = flyte.TaskEnvironment(name="test_env3") - task = _build_task(env, timeout=7200, short_name="t") + task = _build_task(env, script_name="script.py", timeout=7200, short_name="t") assert task.timeout == timedelta(seconds=7200) def test_task_registered_in_env(self): env = flyte.TaskEnvironment(name="test_env4") - task = _build_task(env, timeout=3600, short_name="t") + task = _build_task(env, script_name="script.py", timeout=3600, short_name="t") assert task in env.tasks.values() @@ -126,25 +120,103 @@ def test_short_name_for_nested_script(self, tmp_path, mock_remote): # --------------------------------------------------------------------------- -# run_python_script -File.from_local usage +# run_python_script -code bundle # --------------------------------------------------------------------------- -class TestRunPythonScriptFileUpload: - """Tests that the script is uploaded via File.from_local.""" +class TestRunPythonScriptCodeBundle: + """Tests that the runner is configured with python_script copy_style.""" - def test_uses_file_from_local(self, script, mock_remote): - """Verify File.from_local is called with the resolved script path.""" + def test_runner_uses_python_script_copy_style(self, script, mock_remote): + """Verify _Runner is constructed with copy_style='python_script'.""" run_python_script(script) - mock_remote["from_local"].assert_awaited_once_with(script.resolve()) + mock_remote["runner_cls"].assert_called_once() + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["copy_style"] == "python_script" - def test_file_passed_to_runner(self, script, mock_remote): - """Verify the File object from from_local is passed to the runner.""" + def test_runner_bundle_relative_paths(self, script, mock_remote): + """Verify _Runner receives the script filename as bundle_relative_paths.""" run_python_script(script) - call_kwargs = mock_remote["runner"].run.aio.call_args[1] - assert call_kwargs["script_file"] is mock_remote["file"] + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["_bundle_relative_paths"] == (script.name,) + + def test_runner_bundle_from_dir(self, script, mock_remote): + """Verify _Runner receives the script's parent as bundle_from_dir.""" + run_python_script(script) + + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["_bundle_from_dir"] == script.resolve().parent + + def test_task_has_script_resolver(self, script, mock_remote): + """Verify the task has a ScriptTaskResolver attached.""" + from flyte._internal.resolvers.script import ScriptTaskResolver + + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert isinstance(task_arg.task_resolver, ScriptTaskResolver) + assert task_arg.task_resolver._script_name == script.name + + def test_resolver_output_dir_none_by_default(self, script, mock_remote): + """Verify the resolver has output_dir=None when not specified.""" + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._output_dir is None + + def test_resolver_output_dir_passed_through(self, script, mock_remote): + """Verify the resolver receives the output_dir value.""" + run_python_script(script, output_dir="/tmp/results") + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._output_dir == "/tmp/results" + + +# --------------------------------------------------------------------------- +# run_python_script -output_dir +# --------------------------------------------------------------------------- + + +class TestRunPythonScriptOutputDir: + """Tests that the output_dir parameter is propagated correctly.""" + + def test_output_dir_default_is_none(self, script, mock_remote): + """Without output_dir=, the resolver should have output_dir=None.""" + from flyte._internal.resolvers.script import ScriptTaskResolver + + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + resolver = task_arg.task_resolver + assert isinstance(resolver, ScriptTaskResolver) + assert resolver._output_dir is None + + def test_output_dir_passed_to_resolver(self, script, mock_remote): + """output_dir= should be stored on the resolver for serialization.""" + run_python_script(script, output_dir="/tmp/output") + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._output_dir == "/tmp/output" + + def test_resolver_loader_args_includes_output_dir(self): + """loader_args should include output_dir when set.""" + from flyte._internal.resolvers.script import ScriptTaskResolver + + resolver = ScriptTaskResolver("script.py", output_dir="/tmp/out", timeout=600) + args = resolver.loader_args(MagicMock()) + assert "output_dir" in args + idx = args.index("output_dir") + assert args[idx + 1] == "/tmp/out" + + def test_resolver_loader_args_excludes_output_dir_when_none(self): + """loader_args should not include output_dir when None.""" + from flyte._internal.resolvers.script import ScriptTaskResolver + + resolver = ScriptTaskResolver("script.py", timeout=600) + args = resolver.loader_args(MagicMock()) + assert "output_dir" not in args # --------------------------------------------------------------------------- @@ -275,20 +347,21 @@ def test_no_extra_args_defaults_to_empty_list(self, script, mock_remote): class TestRunPythonScriptRunContext: - """Tests that with_runcontext is called with correct parameters.""" + """Tests that _Runner is constructed with correct parameters.""" - def test_runcontext_mode_remote(self, script, mock_remote): + def test_runner_mode_remote(self, script, mock_remote): run_python_script(script) - mock_remote["runcontext"].assert_called_once_with( - mode="remote", - name=None, - interactive_mode=True, - ) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["force_mode"] == "remote" + assert call_kwargs["name"] is None + assert call_kwargs["debug"] is False - def test_runcontext_passes_name(self, script, mock_remote): + def test_runner_passes_name(self, script, mock_remote): run_python_script(script, name="my-run") - mock_remote["runcontext"].assert_called_once_with( - mode="remote", - name="my-run", - interactive_mode=True, - ) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["name"] == "my-run" + + def test_runner_passes_debug(self, script, mock_remote): + run_python_script(script, debug=True) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["debug"] is True