diff --git a/examples/ml/eval_callback.py b/examples/ml/eval_callback.py index fdc68fa95..55c2a248e 100644 --- a/examples/ml/eval_callback.py +++ b/examples/ml/eval_callback.py @@ -31,13 +31,13 @@ from flyte.io import File image = flyte.Image.from_debian_base(name="lightning-eval").with_pip_packages( - "lightning==2.6.1", "flyteplugins-pytorch==2.0.2" + "lightning==2.6.1", "flyteplugins-pytorch==2.0.3" ) # Multi-node training: 2 nodes, 1 process per node train_env = flyte.TaskEnvironment( name="distributed-train", - resources=flyte.Resources(cpu=4, memory="25Gi", gpu="L4:1"), + resources=flyte.Resources(cpu=4, memory="25Gi", gpu="T4:1"), plugin_config=Elastic( nproc_per_node=1, nnodes=2, diff --git a/examples/plugins/torch_example.py b/examples/plugins/torch_example.py index 59c19cd8f..98b605944 100644 --- a/examples/plugins/torch_example.py +++ b/examples/plugins/torch_example.py @@ -10,18 +10,11 @@ import flyte -# Install flyteplugins-torch from the wheel for development. -# In production, you would just specify the package name and version. -# from flyte._image import DIST_FOLDER, PythonWheels -# image = flyte.Image.from_debian_base(name="torch").clone( -# addl_layer=PythonWheels(wheel_dir=DIST_FOLDER, package_name="flyteplugins-pytorch", pre=True) -# ) - image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch") torch_env = flyte.TaskEnvironment( name="torch_env", - resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")), + resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi"), gpu="T4:1"), plugin_config=Elastic( nproc_per_node=1, # if you want to do local testing set nnodes=1 @@ -106,6 +99,6 @@ def torch_distributed_train(epochs: int) -> typing.Optional[float]: if __name__ == "__main__": flyte.init_from_config() - run = flyte.with_runcontext(mode="remote").run(torch_distributed_train, epochs=3) + run = flyte.with_runcontext(mode="remote").run(torch_distributed_train, epochs=1000) print("run name:", run.name) print("run url:", run.url) diff --git a/examples/scripts/hello.py b/examples/scripts/hello.py new file mode 100644 index 000000000..a065c7164 --- /dev/null +++ b/examples/scripts/hello.py @@ -0,0 +1,20 @@ +""" +Run with: + +```bash +flyte run --follow python-script hello.py --output-dir output +``` +""" + +import os + + +def main(): + print("Hello, world!") + os.makedirs("output", exist_ok=True) + with open("output/hello.txt", "w") as f: + f.write("Hello, file!") + + +if __name__ == "__main__": + main() diff --git a/src/flyte/_code_bundle/__init__.py b/src/flyte/_code_bundle/__init__.py index d7c4bcc33..f34a23456 100644 --- a/src/flyte/_code_bundle/__init__.py +++ b/src/flyte/_code_bundle/__init__.py @@ -1,8 +1,20 @@ from ._ignore import GitIgnore, IgnoreGroup, StandardIgnore from ._utils import CopyFiles -from .bundle import build_code_bundle, build_pkl_bundle, download_bundle +from .bundle import ( + build_code_bundle, + build_code_bundle_from_relative_paths, + build_pkl_bundle, + download_bundle, +) -__all__ = ["CopyFiles", "build_code_bundle", "build_pkl_bundle", "default_ignores", "download_bundle"] +__all__ = [ + "CopyFiles", + "build_code_bundle", + "build_code_bundle_from_relative_paths", + "build_pkl_bundle", + "default_ignores", + "download_bundle", +] default_ignores = [GitIgnore, StandardIgnore, IgnoreGroup] diff --git a/src/flyte/_code_bundle/_utils.py b/src/flyte/_code_bundle/_utils.py index 5ebc45f06..487691c4c 100644 --- a/src/flyte/_code_bundle/_utils.py +++ b/src/flyte/_code_bundle/_utils.py @@ -22,7 +22,7 @@ from ._ignore import Ignore, IgnoreGroup, StandardIgnore -CopyFiles = Literal["loaded_modules", "all", "none"] +CopyFiles = Literal["loaded_modules", "all", "none", "custom"] def compress_scripts(source_path: str, destination: str, modules: List[ModuleType]): diff --git a/src/flyte/_debug/vscode.py b/src/flyte/_debug/vscode.py index d1e93271f..9694f35b4 100644 --- a/src/flyte/_debug/vscode.py +++ b/src/flyte/_debug/vscode.py @@ -206,6 +206,12 @@ def prepare_launch_json(ctx: click.Context, pid: int): ctx.params["version"], "--run-base-dir", ctx.params["run_base_dir"], + "--raw-data-path", + ctx.params["raw_data_path"], + "--checkpoint-path", + ctx.params["checkpoint_path"], + "--prev-checkpoint", + ctx.params["prev_checkpoint"], "--name", name, "--run-name", diff --git a/src/flyte/_internal/resolvers/internal.py b/src/flyte/_internal/resolvers/internal.py new file mode 100644 index 000000000..5c22bb7d5 --- /dev/null +++ b/src/flyte/_internal/resolvers/internal.py @@ -0,0 +1,56 @@ +"""Generic resolver for internal Flyte tasks. + +Stores an import path to a task-builder function and arbitrary keyword +arguments. At runtime ``load_task`` dynamically imports the builder and +calls it with the stored kwargs, recreating a lightweight task without +pickling. This is the same mechanism used by ``run_python_script`` and +can be reused for prefetch, custom bundling, and other internal tasks. +""" + +import importlib +from pathlib import Path +from typing import Any, Dict, List, Optional + +from flyte._internal.resolvers.common import Resolver +from flyte._task import TaskTemplate + + +class InternalTaskResolver(Resolver): + """Resolve an internal task by dynamically importing its builder. + + During serialization the resolver stores: + + * ``task_builder`` - fully-qualified import path of a callable that + returns a :class:`TaskTemplate` (e.g. + ``"flyte._run_python_script._build_script_runner_task"``). + * Arbitrary keyword arguments forwarded to the builder. + + At runtime :meth:`load_task` re-imports the builder and calls it with + the stored kwargs. + """ + + def __init__(self, task_builder: str = "", **kwargs: Any): + self._task_builder = task_builder + self._kwargs = kwargs + + @property + def import_path(self) -> str: + return "flyte._internal.resolvers.internal.InternalTaskResolver" + + def load_task(self, loader_args: List[str]) -> TaskTemplate: + args_iter = iter(loader_args) + parsed: Dict[str, str] = dict(zip(args_iter, args_iter)) + + builder_path = parsed.pop("task_builder") + module_path, func_name = builder_path.rsplit(".", 1) + module = importlib.import_module(module_path) + builder = getattr(module, func_name) + + return builder(**parsed) + + def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]: + args = ["task_builder", self._task_builder] + for key, value in self._kwargs.items(): + if value is not None: + args.extend([key, str(value)]) + return args diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 322766fd3..70d770fe2 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -116,10 +116,14 @@ def __init__( preserve_original_types: bool | None = None, debug: bool = False, _tracker: Any = None, + _bundle_relative_paths: tuple[str, ...] | None = None, + _bundle_from_dir: pathlib.Path | None = None, ): from flyte._tools import ipython_check self._tracker = _tracker + self._bundle_relative_paths = _bundle_relative_paths + self._bundle_from_dir = _bundle_from_dir init_config = _get_init_config() client = init_config.client if init_config else None if not force_mode and client is not None: @@ -169,7 +173,7 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar from flyte.remote import Run from flyte.remote._task import LazyEntity, TaskDetails - from ._code_bundle import build_code_bundle, build_pkl_bundle + from ._code_bundle import build_code_bundle, build_code_bundle_from_relative_paths, build_pkl_bundle from ._deploy import build_images from ._internal.runtime.convert import convert_from_native_to_inputs from ._internal.runtime.task_serde import translate_task_to_wire @@ -223,16 +227,24 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to, ) + elif self._copy_files == "custom": + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") + code_bundle = await build_code_bundle_from_relative_paths( + self._bundle_relative_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + ) + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + ) else: - if self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, - ) - else: - code_bundle = None + code_bundle = None if not self._disable_run_cache: _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] = _CacheValue( code_bundle=code_bundle, image_cache=image_cache @@ -435,7 +447,7 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: over the longer term we will productize this. """ import flyte.report - from flyte._code_bundle import build_code_bundle, build_pkl_bundle + from flyte._code_bundle import build_code_bundle, build_code_bundle_from_relative_paths, build_pkl_bundle from flyte._deploy import build_images from flyte.models import RawDataPath from flyte.storage import ABFS, GCS, S3 @@ -469,16 +481,24 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: upload_to_controlplane=not self._dry_run, copy_bundle_to=self._copy_bundle_to, ) + elif self._copy_files == "custom": + if not self._bundle_relative_paths or not self._bundle_from_dir: + raise ValueError("copy_style='custom' requires _bundle_relative_paths and _bundle_from_dir") + code_bundle = await build_code_bundle_from_relative_paths( + self._bundle_relative_paths, + from_dir=self._bundle_from_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + ) + elif self._copy_files != "none": + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=self._dry_run, + copy_bundle_to=self._copy_bundle_to, + copy_style=self._copy_files, + ) else: - if self._copy_files != "none": - code_bundle = await build_code_bundle( - from_dir=cfg.root_dir, - dryrun=self._dry_run, - copy_bundle_to=self._copy_bundle_to, - copy_style=self._copy_files, - ) - else: - code_bundle = None + code_bundle = None version = self._version or ( code_bundle.computed_version if code_bundle and code_bundle.computed_version else None @@ -802,6 +822,8 @@ async def example_task(x: int, y: str) -> str: """ if mode == "hybrid" and not name and not run_base_dir: raise ValueError("Run name and run base dir are required for hybrid mode") + if copy_style == "custom": + raise ValueError("copy_style='custom' is not yet supported through with_runcontext.") if copy_style == "none" and not version: raise ValueError("Version is required when copy_style is 'none'") diff --git a/src/flyte/_run_python_script.py b/src/flyte/_run_python_script.py index c2b664797..449e55447 100644 --- a/src/flyte/_run_python_script.py +++ b/src/flyte/_run_python_script.py @@ -14,9 +14,11 @@ # class object at decoration time, not a deferred string. import pathlib +from dataclasses import dataclass from datetime import timedelta from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +import flyte.io from flyte.syncify import syncify if TYPE_CHECKING: @@ -24,35 +26,85 @@ from flyte.remote import Run +@dataclass +class PythonScriptOutput: + exit_code: int + stdout: str + output_dir: Optional[flyte.io.Dir] + + def _build_task( env: Any, + script_name: str, timeout: int, short_name: str, + output_dir: "Optional[str]" = None, + task_resolver: Any = None, ) -> Any: - """Build the execute_script task. + """Build the ``execute_script`` task for serialization. - Defined separately so that the ``File`` import is evaluated eagerly - and available as a real type when ``@env.task`` inspects annotations. + The *script_name* is captured via closure for local execution. When + running remotely the :class:`InternalTaskResolver` recreates the task from + the loader args embedded in the container command, so the closure value + is not carried over the wire. """ - from flyte.io import File - task_timeout = timedelta(seconds=timeout) - @env.task(timeout=task_timeout, short_name=short_name) - async def execute_script(script_file: File, args: list, task_timeout: int) -> dict: + @env.task(timeout=task_timeout, short_name=short_name, task_resolver=task_resolver) + async def execute_script(args: list[str], task_timeout: int) -> PythonScriptOutput: """Execute a Python script on a remote machine.""" + import os import subprocess import sys + import tempfile + + tail_bytes = 1000 + cmd = [sys.executable, script_name, *args] + + with tempfile.TemporaryFile(mode="w+") as out_f, tempfile.TemporaryFile(mode="w+") as err_f: + result = subprocess.run( # noqa: ASYNC221 + cmd, + stdout=out_f, + stderr=err_f, + check=False, + timeout=task_timeout - 60, + ) + + for f in (out_f, err_f): + f.seek(0, os.SEEK_END) + pos = f.tell() + f.seek(max(0, pos - tail_bytes)) + + stdout_tail = out_f.read() + stderr_tail = err_f.read() + + if result.returncode != 0: + raise RuntimeError(f"Script failed with exit code {result.returncode}, stderr: {stderr_tail}") + + _dir: Optional[flyte.io.Dir] = None + if output_dir: + _dir = await flyte.io.Dir.from_local(output_dir) + + return PythonScriptOutput( + exit_code=result.returncode, + stdout=stdout_tail, + output_dir=_dir, + ) - local_path = await script_file.download() - cmd = [sys.executable, local_path, *args] - result = subprocess.run(cmd, text=True, check=False, timeout=task_timeout - 60) # noqa: ASYNC221 + return execute_script - return { - "exit_code": result.returncode, - } - return execute_script +def _build_script_runner_task(script_name: str, output_dir: "Optional[str]" = None, timeout: str = "3600") -> Any: + """Build the ``execute_script`` task at runtime (called by :class:`InternalTaskResolver`). + + Creates a minimal :class:`~flyte.TaskEnvironment` — only the function + signature matters here because the container already has the correct + image and resources. + """ + import flyte + + env = flyte.TaskEnvironment(name="python_script") + return _build_task(env, script_name, int(timeout), short_name=script_name, output_dir=output_dir) @syncify @@ -69,11 +121,15 @@ async def run_python_script( queue: "Optional[str]" = None, wait: bool = False, name: "Optional[str]" = None, + debug: bool = False, + output_dir: "Optional[str]" = None, ) -> "Run": """Package and run a Python script on a remote Flyte cluster. - Uploads the script via :class:`~flyte.io.File`, passes it as a typed input - to a Flyte task, and executes it remotely with the requested resources. + Bundles the script into a Flyte code bundle and executes it remotely + with the requested resources. Unlike ``interactive_mode`` (which + pickles the task), this approach uses an :class:`InternalTaskResolver` + so the task can be properly debugged with ``debug=True``. Project and domain are read from the init config (set via ``flyte.init()`` or ``flyte.init_from_config()``), consistent with ``flyte.run()``. @@ -96,6 +152,9 @@ async def run_python_script( :param queue: Flyte queue / cluster override. :param wait: If True, block until execution completes before returning. :param name: Run name. If omitted, a random name is generated. + :param debug: If True, run the task as a VS Code debug task, starting a + code-server in the container so you can connect via the UI to + interactively debug/run the task. :return: A :class:`~flyte.remote.Run` handle for the remote execution. Example:: @@ -120,7 +179,8 @@ async def run_python_script( run = flyte.run_python_script(Path("analysis.py"), image=img) """ import flyte - from flyte.io import File + from flyte._internal.resolvers.internal import InternalTaskResolver + from flyte._run import _Runner script = pathlib.Path(script).resolve() if not script.exists(): @@ -153,20 +213,29 @@ async def run_python_script( env_kwargs["queue"] = queue env = flyte.TaskEnvironment(**env_kwargs) - # Build task (in a separate function so File annotation resolves correctly) + # Build task with the InternalTaskResolver so the runner knows how to + # serialize and reload it without pickling. + resolver = InternalTaskResolver( + "flyte._run_python_script._build_script_runner_task", + script_name=script.name, + output_dir=output_dir, + timeout=timeout, + ) task_short_name = name or script.stem - execute_script = _build_task(env, timeout, short_name=task_short_name) - - script_file: File = await File.from_local(script) + execute_script = _build_task( + env, script.name, timeout, short_name=task_short_name, output_dir=output_dir, task_resolver=resolver + ) - runner = flyte.with_runcontext( - mode="remote", + runner = _Runner( + force_mode="remote", name=name, - interactive_mode=True, + debug=debug, + copy_style="custom", + _bundle_relative_paths=(script.name,), + _bundle_from_dir=script.parent, ) run = await runner.run.aio( execute_script, - script_file=script_file, args=extra_args or [], task_timeout=timeout, ) diff --git a/src/flyte/_task.py b/src/flyte/_task.py index c21477cb2..05eb6fc53 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -470,6 +470,7 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R, F]): func: F plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration debuggable: bool = True + task_resolver: Optional[Any] = None def __post_init__(self): super().__post_init__() @@ -556,21 +557,23 @@ def container_args(self, serialize_context: SerializationContext) -> List[str]: if not serialize_context.code_bundle or not serialize_context.code_bundle.pkl: # If we do not have a code bundle, or if we have one, but it is not a pkl, we need to add the resolver + resolver = self.task_resolver + if resolver is None: + from flyte._internal.resolvers.default import DefaultTaskResolver - from flyte._internal.resolvers.default import DefaultTaskResolver + resolver = DefaultTaskResolver() if not serialize_context.root_dir: raise RuntimeSystemError( "SerializationError", "Root dir is required for default task resolver when no code bundle is provided.", ) - _task_resolver = DefaultTaskResolver() args = [ *args, *[ "--resolver", - _task_resolver.import_path, - *_task_resolver.loader_args(task=self, root_dir=serialize_context.root_dir), + resolver.import_path, + *resolver.loader_args(task=self, root_dir=serialize_context.root_dir), ], ] diff --git a/src/flyte/_task_environment.py b/src/flyte/_task_environment.py index 9bed35bd1..7c989ccd0 100644 --- a/src/flyte/_task_environment.py +++ b/src/flyte/_task_environment.py @@ -170,6 +170,7 @@ def task( queue: Optional[str] = None, triggers: Tuple[Trigger, ...] | Trigger = (), links: Tuple[Link, ...] | Link = (), + task_resolver: Any | None = None, ) -> Callable[[Callable[P, R]], AsyncFunctionTaskTemplate[P, R, Callable[P, R]]]: ... @overload @@ -195,6 +196,7 @@ def task( queue: Optional[str] = None, triggers: Tuple[Trigger, ...] | Trigger = (), links: Tuple[Link, ...] | Link = (), + task_resolver: Any | None = None, ) -> Callable[[F], AsyncFunctionTaskTemplate[P, R, F]] | AsyncFunctionTaskTemplate[P, R, F]: """ Decorate a function to be a task. @@ -277,6 +279,7 @@ def decorator(func: F) -> AsyncFunctionTaskTemplate[P, R, F]: interruptible=interruptible if interruptible is not None else self.interruptible, triggers=triggers if isinstance(triggers, tuple) else (triggers,), links=links if isinstance(links, tuple) else (links,), + task_resolver=task_resolver, ) self._tasks[task_name] = tmpl return tmpl diff --git a/src/flyte/cli/_run.py b/src/flyte/cli/_run.py index f100899f3..854daeccd 100644 --- a/src/flyte/cli/_run.py +++ b/src/flyte/cli/_run.py @@ -14,6 +14,7 @@ from .._code_bundle._utils import CopyFiles from .._task import TaskTemplate from ..remote import Run +from ..syncify import syncify from . import _common as common from ._params import to_click_option @@ -22,6 +23,7 @@ initialize_config = common.initialize_config +@syncify async def _render_debug_url(console, result: Run, config: common.CLIConfig) -> None: """Poll the run for the VS Code Debugger URL and print it.""" from flyte._debug.client import watch_for_vscode_url @@ -331,7 +333,7 @@ async def _render_remote_success(self, console, result, config): console.print(common.get_panel("Remote Run", run_info, config.output_format)) if self.run_args.debug: - await _render_debug_url(console, result, config) + await _render_debug_url.aio(console, result, config) if self.run_args.follow: from flyte._status import status @@ -554,7 +556,7 @@ async def _render_remote_success(self, console, result, config): console.print(common.get_panel("Remote Run", run_info, config.output_format)) if self.run_args.debug: - await _render_debug_url(console, result, config) + await _render_debug_url.aio(console, result, config) if self.run_args.follow: from flyte._status import status diff --git a/src/flyte/cli/_run_python_script.py b/src/flyte/cli/_run_python_script.py index 04a1d274d..b83628bc6 100644 --- a/src/flyte/cli/_run_python_script.py +++ b/src/flyte/cli/_run_python_script.py @@ -59,6 +59,12 @@ class PythonScriptCommand(CommandBase): help="Extra arguments passed to the script (comma-separated).", ) @click.option("--queue", type=str, default=None, help="Flyte queue / cluster override.") +@click.option( + "--output-dir", + type=str, + default=None, + help="Directory path (inside the container) to upload as output after the script finishes.", +) @click.pass_obj def python_script( cfg: common.CLIConfig, @@ -72,6 +78,7 @@ def python_script( timeout: int, extra_args: str | None, queue: str | None, + output_dir: str | None, ) -> None: """Run a Python script on a remote Flyte cluster. @@ -116,6 +123,7 @@ def python_script( name = run_args.name if run_args else None project = run_args.project if run_args else None domain = run_args.domain if run_args else None + debug = run_args.debug if run_args else False # Initialize flyte config (like prefetch does) initialize_config( @@ -150,6 +158,8 @@ def python_script( queue=queue, wait=False, name=name, + debug=debug, + output_dir=output_dir, ) url = run.url @@ -158,6 +168,11 @@ def python_script( f" Check the console for status at [link={url}]{url}[/link]" ) + if debug: + from flyte.cli._run import _render_debug_url + + _render_debug_url(console, run, cfg) + if follow: run.wait() try: diff --git a/src/flyte/io/_file.py b/src/flyte/io/_file.py index 1d1278f92..bc1e9c726 100644 --- a/src/flyte/io/_file.py +++ b/src/flyte/io/_file.py @@ -455,7 +455,7 @@ async def stream_read(f: File) -> str: # Fall back to aiofiles fs = storage.get_underlying_filesystem(path=self.path) if "file" in fs.protocol: - async with aiofiles.open(self.path, mode=mode, **kwargs) as f: + async with aiofiles.open(self.path, mode=mode, **kwargs) as f: # type: ignore[call-overload] yield f return raise diff --git a/tests/user_api/test_run_python_script.py b/tests/user_api/test_run_python_script.py index 16d554b85..84b68954e 100644 --- a/tests/user_api/test_run_python_script.py +++ b/tests/user_api/test_run_python_script.py @@ -23,27 +23,21 @@ def script(tmp_path): @pytest.fixture def mock_remote(): - """Mock File.from_local and flyte.with_runcontext so no real remote call happens.""" - mock_file = AsyncMock() + """Mock _Runner so no real remote call happens.""" mock_run = MagicMock() mock_runner = MagicMock() mock_runner.run.aio = AsyncMock(return_value=mock_run) mock_run.wait.aio = AsyncMock() - with ( - patch.object(flyte.io.File, "from_local", new_callable=AsyncMock, return_value=mock_file) as mock_from_local, - patch("flyte.with_runcontext", return_value=mock_runner) as mock_runcontext, - ): + with patch("flyte._run._Runner", return_value=mock_runner) as mock_runner_cls: yield { - "file": mock_file, - "from_local": mock_from_local, "run": mock_run, "runner": mock_runner, - "runcontext": mock_runcontext, + "runner_cls": mock_runner_cls, } -# --------------------------------------------------------------------------- +# ---------------------------------------------------------- ----------------- # _build_task # --------------------------------------------------------------------------- @@ -53,22 +47,22 @@ class TestBuildTask: def test_task_short_name(self): env = flyte.TaskEnvironment(name="test_env") - task = _build_task(env, timeout=3600, short_name="my_script") + task = _build_task(env, script_name="my_script.py", timeout=3600, short_name="my_script") assert task.short_name == "my_script" def test_task_short_name_custom(self): env = flyte.TaskEnvironment(name="test_env2") - task = _build_task(env, timeout=3600, short_name="custom_name") + task = _build_task(env, script_name="script.py", timeout=3600, short_name="custom_name") assert task.short_name == "custom_name" def test_task_timeout(self): env = flyte.TaskEnvironment(name="test_env3") - task = _build_task(env, timeout=7200, short_name="t") + task = _build_task(env, script_name="script.py", timeout=7200, short_name="t") assert task.timeout == timedelta(seconds=7200) def test_task_registered_in_env(self): env = flyte.TaskEnvironment(name="test_env4") - task = _build_task(env, timeout=3600, short_name="t") + task = _build_task(env, script_name="script.py", timeout=3600, short_name="t") assert task in env.tasks.values() @@ -126,25 +120,112 @@ def test_short_name_for_nested_script(self, tmp_path, mock_remote): # --------------------------------------------------------------------------- -# run_python_script -File.from_local usage +# run_python_script -code bundle # --------------------------------------------------------------------------- -class TestRunPythonScriptFileUpload: - """Tests that the script is uploaded via File.from_local.""" +class TestRunPythonScriptCodeBundle: + """Tests that the runner is configured with custom copy_style.""" - def test_uses_file_from_local(self, script, mock_remote): - """Verify File.from_local is called with the resolved script path.""" + def test_runner_uses_custom_copy_style(self, script, mock_remote): + """Verify _Runner is constructed with copy_style='custom'.""" run_python_script(script) - mock_remote["from_local"].assert_awaited_once_with(script.resolve()) + mock_remote["runner_cls"].assert_called_once() + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["copy_style"] == "custom" - def test_file_passed_to_runner(self, script, mock_remote): - """Verify the File object from from_local is passed to the runner.""" + def test_runner_bundle_relative_paths(self, script, mock_remote): + """Verify _Runner receives the script filename as bundle_relative_paths.""" run_python_script(script) - call_kwargs = mock_remote["runner"].run.aio.call_args[1] - assert call_kwargs["script_file"] is mock_remote["file"] + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["_bundle_relative_paths"] == (script.name,) + + def test_runner_bundle_from_dir(self, script, mock_remote): + """Verify _Runner receives the script's parent as bundle_from_dir.""" + run_python_script(script) + + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["_bundle_from_dir"] == script.resolve().parent + + def test_task_has_internal_resolver(self, script, mock_remote): + """Verify the task has an InternalTaskResolver attached.""" + from flyte._internal.resolvers.internal import InternalTaskResolver + + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert isinstance(task_arg.task_resolver, InternalTaskResolver) + assert task_arg.task_resolver._kwargs["script_name"] == script.name + + def test_resolver_output_dir_none_by_default(self, script, mock_remote): + """Verify the resolver has output_dir=None when not specified.""" + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._kwargs.get("output_dir") is None + + def test_resolver_output_dir_passed_through(self, script, mock_remote): + """Verify the resolver receives the output_dir value.""" + run_python_script(script, output_dir="/tmp/results") + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._kwargs["output_dir"] == "/tmp/results" + + +# --------------------------------------------------------------------------- +# run_python_script -output_dir +# --------------------------------------------------------------------------- + + +class TestRunPythonScriptOutputDir: + """Tests that the output_dir parameter is propagated correctly.""" + + def test_output_dir_default_is_none(self, script, mock_remote): + """Without output_dir=, the resolver should have output_dir=None.""" + from flyte._internal.resolvers.internal import InternalTaskResolver + + run_python_script(script) + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + resolver = task_arg.task_resolver + assert isinstance(resolver, InternalTaskResolver) + assert resolver._kwargs.get("output_dir") is None + + def test_output_dir_passed_to_resolver(self, script, mock_remote): + """output_dir= should be stored on the resolver for serialization.""" + run_python_script(script, output_dir="/tmp/output") + + task_arg = mock_remote["runner"].run.aio.call_args[0][0] + assert task_arg.task_resolver._kwargs["output_dir"] == "/tmp/output" + + def test_resolver_loader_args_includes_output_dir(self): + """loader_args should include output_dir when set.""" + from flyte._internal.resolvers.internal import InternalTaskResolver + + resolver = InternalTaskResolver( + "flyte._run_python_script._build_script_runner_task", + script_name="script.py", + output_dir="/tmp/out", + timeout=600, + ) + args = resolver.loader_args(MagicMock()) + assert "output_dir" in args + idx = args.index("output_dir") + assert args[idx + 1] == "/tmp/out" + + def test_resolver_loader_args_excludes_output_dir_when_none(self): + """loader_args should not include output_dir when None.""" + from flyte._internal.resolvers.internal import InternalTaskResolver + + resolver = InternalTaskResolver( + "flyte._run_python_script._build_script_runner_task", + script_name="script.py", + timeout=600, + ) + args = resolver.loader_args(MagicMock()) + assert "output_dir" not in args # --------------------------------------------------------------------------- @@ -275,20 +356,21 @@ def test_no_extra_args_defaults_to_empty_list(self, script, mock_remote): class TestRunPythonScriptRunContext: - """Tests that with_runcontext is called with correct parameters.""" + """Tests that _Runner is constructed with correct parameters.""" - def test_runcontext_mode_remote(self, script, mock_remote): + def test_runner_mode_remote(self, script, mock_remote): run_python_script(script) - mock_remote["runcontext"].assert_called_once_with( - mode="remote", - name=None, - interactive_mode=True, - ) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["force_mode"] == "remote" + assert call_kwargs["name"] is None + assert call_kwargs["debug"] is False - def test_runcontext_passes_name(self, script, mock_remote): + def test_runner_passes_name(self, script, mock_remote): run_python_script(script, name="my-run") - mock_remote["runcontext"].assert_called_once_with( - mode="remote", - name="my-run", - interactive_mode=True, - ) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["name"] == "my-run" + + def test_runner_passes_debug(self, script, mock_remote): + run_python_script(script, debug=True) + call_kwargs = mock_remote["runner_cls"].call_args[1] + assert call_kwargs["debug"] is True