Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/ml/eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
from flyte.io import File

image = flyte.Image.from_debian_base(name="lightning-eval").with_pip_packages(
"lightning==2.6.1", "flyteplugins-pytorch==2.0.2"
"lightning==2.6.1", "flyteplugins-pytorch==2.0.3"
)

# Multi-node training: 2 nodes, 1 process per node
train_env = flyte.TaskEnvironment(
name="distributed-train",
resources=flyte.Resources(cpu=4, memory="25Gi", gpu="L4:1"),
resources=flyte.Resources(cpu=4, memory="25Gi", gpu="T4:1"),
plugin_config=Elastic(
nproc_per_node=1,
nnodes=2,
Expand Down
34 changes: 19 additions & 15 deletions examples/plugins/torch_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

import flyte

# Install flyteplugins-torch from the wheel for development.
# In production, you would just specify the package name and version.
# from flyte._image import DIST_FOLDER, PythonWheels
# image = flyte.Image.from_debian_base(name="torch").clone(
# addl_layer=PythonWheels(wheel_dir=DIST_FOLDER, package_name="flyteplugins-pytorch", pre=True)
# )

image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch")
image = flyte.Image.from_debian_base(name="torch").with_pip_packages(
"flyteplugins-pytorch", "flyteplugins-wandb"
)

torch_env = flyte.TaskEnvironment(
name="torch_env",
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi"), gpu="T4:1"),
plugin_config=Elastic(
nproc_per_node=1,
# if you want to do local testing set nnodes=1
nnodes=2,
),
secrets=flyte.Secret(key="NIELS_WANDB_API_KEY", as_env_var="WANDB_API_KEY"),
image=image,
)

Expand All @@ -45,9 +41,9 @@ def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataL
Prepare a DataLoader with a DistributedSampler so each rank
gets a shard of the dataset.
"""
# Dummy dataset
x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
n_samples = 100
x_train = torch.randn(n_samples, 1)
y_train = 2.0 * x_train + 1.0 + 0.1 * torch.randn(n_samples, 1)
dataset = TensorDataset(x_train, y_train)

# Distributed-aware sampler
Expand Down Expand Up @@ -76,8 +72,9 @@ def train_loop(epochs: int = 3) -> float:
optimizer = optim.SGD(model.parameters(), lr=0.01)

final_loss = 0.0
wandb_run = get_wandb_run()

for _ in range(epochs):
for epoch in range(epochs):
for x, y in dataloader:
outputs = model(x)
loss = criterion(outputs, y)
Expand All @@ -87,12 +84,16 @@ def train_loop(epochs: int = 3) -> float:
optimizer.step()

final_loss = loss.item()

if wandb_run:
wandb_run.log({"loss": final_loss, "epoch": epoch})
if torch.distributed.get_rank() == 0:
print(f"Loss: {final_loss}")

return final_loss


@wandb_init
@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
"""
Expand All @@ -106,6 +107,9 @@ def torch_distributed_train(epochs: int) -> typing.Optional[float]:

if __name__ == "__main__":
flyte.init_from_config()
run = flyte.with_runcontext(mode="remote").run(torch_distributed_train, epochs=3)
run = flyte.with_runcontext(
mode="remote",
custom_context=wandb_config(project="torch-distributed-training"),
).run(torch_distributed_train, epochs=1_000_000)
print("run name:", run.name)
print("run url:", run.url)
20 changes: 20 additions & 0 deletions examples/scripts/hello.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Run with:

```bash
flyte run --follow python-script hello.py --output-dir output
```
"""

import os


def main():
print("Hello, world!")
os.makedirs("output", exist_ok=True)
with open("output/hello.txt", "w") as f:
f.write("Hello, file!")


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/flyte/_code_bundle/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._ignore import Ignore, IgnoreGroup, StandardIgnore

CopyFiles = Literal["loaded_modules", "all", "none"]
CopyFiles = Literal["loaded_modules", "all", "none", "python_script"]


def compress_scripts(source_path: str, destination: str, modules: List[ModuleType]):
Expand Down
48 changes: 48 additions & 0 deletions src/flyte/_internal/resolvers/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Resolver for python-script tasks.

Bundles a plain Python script into the code bundle and recreates
the ``execute_script`` task at runtime so the Flyte entrypoint can
run it without pickling.
"""

from pathlib import Path
from typing import List, Optional

from flyte._internal.resolvers.common import Resolver
from flyte._task import TaskTemplate


class ScriptTaskResolver(Resolver):
"""Resolve a bundled Python script into an executable task.

During serialization the resolver stores the script filename and timeout
as loader args. At runtime ``load_task`` recreates a lightweight task
that executes the script via ``subprocess``.
"""

def __init__(self, script_name: str = "", output_dir: Optional[str] = None, timeout: int = 3600):
self._script_name = script_name
self._output_dir = output_dir
self._timeout = timeout

@property
def import_path(self) -> str:
return "flyte._internal.resolvers.script.ScriptTaskResolver"

def load_task(self, loader_args: List[str]) -> TaskTemplate:
args_iter = iter(loader_args)
parsed = dict(zip(args_iter, args_iter))
script_name = parsed["script"]
output_dir = parsed.get("output_dir", None)
timeout = int(parsed.get("timeout", "3600"))

from flyte._run_python_script import _build_script_runner_task

return _build_script_runner_task(script_name, output_dir, timeout)

def loader_args(self, task: TaskTemplate, root_dir: Optional[Path] = None) -> List[str]:
args = ["script", self._script_name]
if self._output_dir is not None:
args.extend(["output_dir", self._output_dir])
args.extend(["timeout", str(self._timeout)])
return args
62 changes: 44 additions & 18 deletions src/flyte/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,14 @@ def __init__(
preserve_original_types: bool | None = None,
debug: bool = False,
_tracker: Any = None,
_bundle_relative_paths: tuple[str, ...] | None = None,
_bundle_from_dir: pathlib.Path | None = None,
):
from flyte._tools import ipython_check

self._tracker = _tracker
self._bundle_relative_paths = _bundle_relative_paths
self._bundle_from_dir = _bundle_from_dir
init_config = _get_init_config()
client = init_config.client if init_config else None
if not force_mode and client is not None:
Expand Down Expand Up @@ -223,16 +227,28 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar
upload_to_controlplane=not self._dry_run,
copy_bundle_to=self._copy_bundle_to,
)
else:
if self._copy_files != "none":
code_bundle = await build_code_bundle(
from_dir=cfg.root_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
copy_style=self._copy_files,
elif self._copy_files == "python_script":
from ._code_bundle.bundle import build_code_bundle_from_relative_paths

if not self._bundle_relative_paths or not self._bundle_from_dir:
raise ValueError(
"copy_style='python_script' requires _bundle_relative_paths and _bundle_from_dir"
)
else:
code_bundle = None
code_bundle = await build_code_bundle_from_relative_paths(
self._bundle_relative_paths,
from_dir=self._bundle_from_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
)
elif self._copy_files != "none":
code_bundle = await build_code_bundle(
from_dir=cfg.root_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
copy_style=self._copy_files,
)
else:
code_bundle = None
if not self._disable_run_cache:
_RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] = _CacheValue(
code_bundle=code_bundle, image_cache=image_cache
Expand Down Expand Up @@ -469,16 +485,26 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs:
upload_to_controlplane=not self._dry_run,
copy_bundle_to=self._copy_bundle_to,
)
elif self._copy_files == "python_script":
from ._code_bundle.bundle import build_code_bundle_from_relative_paths

if not self._bundle_relative_paths or not self._bundle_from_dir:
raise ValueError("copy_style='python_script' requires _bundle_relative_paths and _bundle_from_dir")
code_bundle = await build_code_bundle_from_relative_paths(
self._bundle_relative_paths,
from_dir=self._bundle_from_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
)
elif self._copy_files != "none":
code_bundle = await build_code_bundle(
from_dir=cfg.root_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
copy_style=self._copy_files,
)
else:
if self._copy_files != "none":
code_bundle = await build_code_bundle(
from_dir=cfg.root_dir,
dryrun=self._dry_run,
copy_bundle_to=self._copy_bundle_to,
copy_style=self._copy_files,
)
else:
code_bundle = None
code_bundle = None

version = self._version or (
code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
Expand Down
Loading
Loading