From ea036bd66ec247d0d86e218472921e9b24aa5598 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 27 Nov 2025 17:06:56 -0800 Subject: [PATCH 01/22] examples Signed-off-by: Yee Hing Tong --- README.md | 18 ++++++ examples/advanced/hybrid_mode.py | 81 ++++++++++++++++++++++++++ examples/advanced/remote_controller.py | 72 +++++++++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 examples/advanced/hybrid_mode.py create mode 100644 examples/advanced/remote_controller.py diff --git a/README.md b/README.md index 9f77f4d3a..c7df35c78 100644 --- a/README.md +++ b/README.md @@ -313,3 +313,21 @@ python maint_tools/build_default_image.py --registry ghcr.io/my-org --name my-fl ## 📄 License Flyte 2 is licensed under the [Apache 2.0 License](LICENSE). + +## Developing the Core Controller + +The following instructions are for helping to build the default multi-arch image. Each architecture needs a different wheel. Each wheel needs to be built by a different docker image. + +### Setup Builders +`cd` into `rs_controller` and run `make build-builders`. This will build the builder images once, so you can keep using them as the rust code changes. + +### Iteration Cycle +Make sure you have `CLOUD_REPO=/Users//go/src/github.com/unionai/cloud` exported and checked out to a branch that has the latest prost generated code. Delete this comment and update make target in the future if it gets merged/published. + +Then run `make build-wheels`. + +`cd` back up to the root folder of this project and proceed with +```bash +make dist +python maint_tools/build_default_image.py +``` \ No newline at end of file diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py new file mode 100644 index 000000000..49523134d --- /dev/null +++ b/examples/advanced/hybrid_mode.py @@ -0,0 +1,81 @@ +import asyncio +import os +from pathlib import Path +from typing import List + +import flyte +import flyte.storage +from flyte.storage import S3 + +env = flyte.TaskEnvironment(name="hello_world", cache="disable") + + +@env.task +async def say_hello_hybrid(data: str, lt: List[int]) -> str: + print(f"Hello, world! - {flyte.ctx().action}") + return f"Hello {data} {lt}" + + +@env.task +async def squared(i: int = 3) -> int: + print(flyte.ctx().action) + return i * i + + +@env.task +async def squared_2(i: int = 3) -> int: + print(flyte.ctx().action) + return i * i + + +@env.task +async def say_hello_hybrid_nested(data: str = "default string") -> str: + print(f"Hello, nested! - {flyte.ctx().action}") + coros = [] + for i in range(3): + coros.append(squared(i=i)) + + vals = await asyncio.gather(*coros) + return await say_hello_hybrid(data=data, lt=vals) + + +@env.task +async def hybrid_parent_placeholder(): + import sys + import time + + print(f"Hello, hybrid parent placeholder - Task Context: {flyte.ctx()}") + print(f"Run command: {sys.argv}") + print("Environment Variables:") + for k, value in sorted(os.environ.items()): + if k.startswith("FLYTE_") or k.startswith("_U"): # noqa: PIE810 + print(f"{k}: {value}") + + print("Sleeping for 24 hours to simulate a long-running task...", flush=True) + time.sleep(86400) # noqa: ASYNC251 + + +if __name__ == "__main__": + # Get current working directory + current_directory = Path(os.getcwd()) + # change to root directory of the project + os.chdir(current_directory.parent.parent) + config = S3.for_sandbox() + # config = flyte.storage.S3.auto() + flyte.init( + endpoint="dns:///localhost:8090", + insecure=True, + org="testorg", + project="testproject", + domain="development", + storage=config, + log_level=10, + ) + # Kick off a run of hybrid_parent_placeholder and fill in with kicked off things. + run_name = "rxmjfwt6nz2rkwzntrtl" + outputs = flyte.with_runcontext( + mode="hybrid", + name=run_name, + run_base_dir=f"s3://bucket/metadata/v2/testorg/testproject/development/{run_name}", + ).run(say_hello_hybrid_nested, data="hello world") + print("Output:", outputs) diff --git a/examples/advanced/remote_controller.py b/examples/advanced/remote_controller.py new file mode 100644 index 000000000..a1bfbafc0 --- /dev/null +++ b/examples/advanced/remote_controller.py @@ -0,0 +1,72 @@ +import asyncio + +# from cloud_mod.cloud_mod import cloudidl +# from cloud_mod.cloud_mod import Action +from pathlib import Path + +from flyte_controller_base import Action, BaseController, cloudidl + +from examples.advanced.hybrid_mode import say_hello_hybrid +from flyte._internal.imagebuild.image_builder import ImageCache +from flyte._internal.runtime.task_serde import translate_task_to_wire +from flyte.models import ( + CodeBundle, + SerializationContext, +) + +img_cache = ImageCache.from_transport( + "H4sIAAAAAAAC/wXBSQ6AIAwAwL/0TsG6hs8YlILEpUbFxBj/7swLaXWR+0VkzjvYF1y+BCzEaTwwic5bks0lJeepw/JcbPenxKJUt0FCM1CLnu+KVAwjd559g54M1aYtavi+H56TcPxgAAAA" +) +s_ctx = SerializationContext( + project="testproject", + domain="development", + org="testorg", + code_bundle=CodeBundle( + computed_version="605136feba679aeb1936677f4c5593f6", + tgz="s3://bucket/testproject/development/MBITN7V2M6NOWGJWM57UYVMT6Y======/fast0dc2ef669a983610a0b9793e974fb288.tar.gz", + ), + version="605136feba679aeb1936677f4c5593f6", + image_cache=img_cache, + root_dir=Path("/Users/ytong/go/src/github.com/unionai/unionv2"), +) +task_spec = translate_task_to_wire(say_hello_hybrid, s_ctx) +xxx = task_spec.SerializeToString() + +yyy = cloudidl.workflow.TaskSpec.decode(xxx) +print(yyy) + + +class MyRunner(BaseController): + ... + # play around with this + # def __init__(self, run_id: cloudidl.workflow.RunIdentifier): + # super().__new__(BaseController, run_id) + + +async def main(): + run_id = cloudidl.workflow.RunIdentifier( + org="testorg", domain="development", name="rxp79l5qjpmmdd84qg7j", project="testproject" + ) + + sub_action_id = cloudidl.workflow.ActionIdentifier(name="sub_action_3", run=run_id) + + action = Action.from_task( + sub_action_id=sub_action_id, + parent_action_name="a0", + group_data=None, + task_spec=yyy, + inputs_uri="s3://bucket/metadata/v2/testorg/testproject/development/rllmmzgh6v4xjc8pswc8/4jzwmmj06fnpql20rtlqz4aq2/inputs.pb", + run_output_base="s3://bucket/metadata/v2/testorg/testproject/development/rllmmzgh6v4xjc8pswc8", + cache_key=None, + ) + + runner = MyRunner(run_id=run_id) + + result = await runner.submit_action(action) + print("First submit done", flush=True) + breakpoint() + print(result) + + +if __name__ == "__main__": + asyncio.run(main()) From e9d80d94b5fb504d0fe2a8cd29165d5bacd425ef Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 27 Nov 2025 19:12:20 -0800 Subject: [PATCH 02/22] copy files over Signed-off-by: Yee Hing Tong --- examples/advanced/hybrid_mode.py | 21 +- rs_controller/.cargo/config.toml | 3 + rs_controller/.gitignore | 1 + rs_controller/Cargo.lock | 1706 +++++++++++++++++ rs_controller/Cargo.toml | 35 + rs_controller/Dockerfile.maturinx | 13 + rs_controller/Makefile | 43 + rs_controller/build.rs | 4 + rs_controller/pyproject.toml | 15 + rs_controller/src/action.rs | 209 ++ rs_controller/src/informer.rs | 314 +++ rs_controller/src/lib.rs | 508 +++++ rs_controller/temp_cargo_config.toml | 2 + src/flyte/_internal/controllers/__init__.py | 12 +- .../controllers/_local_controller.py | 14 +- .../_internal/controllers/remote/_action.py | 4 +- .../controllers/remote/_controller.py | 12 +- .../controllers/remote/_r_controller.py | 557 ++++++ src/flyte/_internal/runtime/convert.py | 7 +- src/flyte/_run.py | 5 +- src/flyte/models.py | 6 - 21 files changed, 3465 insertions(+), 26 deletions(-) create mode 100644 rs_controller/.cargo/config.toml create mode 100644 rs_controller/.gitignore create mode 100644 rs_controller/Cargo.lock create mode 100644 rs_controller/Cargo.toml create mode 100644 rs_controller/Dockerfile.maturinx create mode 100644 rs_controller/Makefile create mode 100644 rs_controller/build.rs create mode 100644 rs_controller/pyproject.toml create mode 100644 rs_controller/src/action.rs create mode 100644 rs_controller/src/informer.rs create mode 100644 rs_controller/src/lib.rs create mode 100644 rs_controller/temp_cargo_config.toml create mode 100644 src/flyte/_internal/controllers/remote/_r_controller.py diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py index 49523134d..ede8c85b5 100644 --- a/examples/advanced/hybrid_mode.py +++ b/examples/advanced/hybrid_mode.py @@ -60,17 +60,18 @@ async def hybrid_parent_placeholder(): current_directory = Path(os.getcwd()) # change to root directory of the project os.chdir(current_directory.parent.parent) - config = S3.for_sandbox() + # config = S3.for_sandbox() # config = flyte.storage.S3.auto() - flyte.init( - endpoint="dns:///localhost:8090", - insecure=True, - org="testorg", - project="testproject", - domain="development", - storage=config, - log_level=10, - ) + flyte.init_from_config("/Users/ytong/.flyte/config-k3d.yaml") + # flyte.init( + # endpoint="dns:///localhost:8090", + # insecure=True, + # org="testorg", + # project="testproject", + # domain="development", + # storage=config, + # log_level=10, + # ) # Kick off a run of hybrid_parent_placeholder and fill in with kicked off things. run_name = "rxmjfwt6nz2rkwzntrtl" outputs = flyte.with_runcontext( diff --git a/rs_controller/.cargo/config.toml b/rs_controller/.cargo/config.toml new file mode 100644 index 000000000..b39600ead --- /dev/null +++ b/rs_controller/.cargo/config.toml @@ -0,0 +1,3 @@ +# .cargo/config.toml +[build] +rustflags = ["-L", "/opt/homebrew/opt/python@3.12/Frameworks/Python.framework/Versions/3.12/lib"] \ No newline at end of file diff --git a/rs_controller/.gitignore b/rs_controller/.gitignore new file mode 100644 index 000000000..5bb4638f0 --- /dev/null +++ b/rs_controller/.gitignore @@ -0,0 +1 @@ +docker_cargo_cache/ diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock new file mode 100644 index 000000000..b43b6f4e3 --- /dev/null +++ b/rs_controller/Cargo.lock @@ -0,0 +1,1706 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "addr2line" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" + +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "async-trait" +version = "0.1.88" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "sync_wrapper", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", +] + +[[package]] +name = "backtrace" +version = "0.3.74" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +dependencies = [ + "addr2line", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", + "windows-targets", +] + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" + +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + +[[package]] +name = "cc" +version = "1.2.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "cloudidl" +version = "0.1.0" +dependencies = [ + "async-trait", + "futures", + "pbjson", + "pbjson-build", + "pbjson-types", + "prettyplease", + "prost 0.13.5", + "prost-build 0.14.1", + "prost-types 0.12.6", + "protobuf", + "pyo3", + "pyo3-async-runtimes", + "quote", + "regex", + "serde", + "syn", + "thiserror", + "tokio", + "tonic", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", +] + +[[package]] +name = "gimli" +version = "0.31.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" + +[[package]] +name = "h2" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.9.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.15.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "http" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-timeout" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" +dependencies = [ + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "hyper", + "libc", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +dependencies = [ + "equivalent", + "hashbrown 0.15.3", +] + +[[package]] +name = "indoc" +version = "2.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" + +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.172" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" + +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" + +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + +[[package]] +name = "miniz_oxide" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" +dependencies = [ + "adler2", +] + +[[package]] +name = "mio" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +dependencies = [ + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", + "windows-sys", +] + +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "object" +version = "0.36.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "parking_lot" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pbjson" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" +dependencies = [ + "base64 0.21.7", + "serde", +] + +[[package]] +name = "pbjson-build" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" +dependencies = [ + "heck", + "itertools 0.13.0", + "prost 0.13.5", + "prost-types 0.13.5", +] + +[[package]] +name = "pbjson-types" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" +dependencies = [ + "bytes", + "chrono", + "pbjson", + "pbjson-build", + "prost 0.13.5", + "prost-build 0.13.5", + "serde", +] + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset", + "indexmap 2.9.0", +] + +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +dependencies = [ + "bytes", + "prost-derive 0.12.6", +] + +[[package]] +name = "prost" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" +dependencies = [ + "bytes", + "prost-derive 0.13.5", +] + +[[package]] +name = "prost" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +dependencies = [ + "bytes", + "prost-derive 0.14.1", +] + +[[package]] +name = "prost-build" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" +dependencies = [ + "heck", + "itertools 0.13.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost 0.13.5", + "prost-types 0.13.5", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-build" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +dependencies = [ + "heck", + "itertools 0.13.0", + "log", + "multimap", + "once_cell", + "petgraph", + "prettyplease", + "prost 0.14.1", + "prost-types 0.14.1", + "regex", + "syn", + "tempfile", +] + +[[package]] +name = "prost-derive" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +dependencies = [ + "anyhow", + "itertools 0.12.1", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-derive" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" +dependencies = [ + "anyhow", + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-derive" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +dependencies = [ + "anyhow", + "itertools 0.13.0", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +dependencies = [ + "prost 0.12.6", +] + +[[package]] +name = "prost-types" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" +dependencies = [ + "prost 0.13.5", +] + +[[package]] +name = "prost-types" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +dependencies = [ + "prost 0.14.1", +] + +[[package]] +name = "protobuf" +version = "4.31.1-release" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7967deb2e74ba240cdcfbf447858848c9edaae6231027d31f4120d42243543a2" +dependencies = [ + "cc", + "paste", +] + +[[package]] +name = "pyo3" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-async-runtimes" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd0b83dc42f9d41f50d38180dad65f0c99763b65a3ff2a81bf351dd35a1df8bf" +dependencies = [ + "futures", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + +[[package]] +name = "pyo3-build-config" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", +] + +[[package]] +name = "redox_syscall" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + +[[package]] +name = "rustversion" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +dependencies = [ + "libc", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smallvec" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" + +[[package]] +name = "socket2" +version = "0.5.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.104" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + +[[package]] +name = "target-lexicon" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" + +[[package]] +name = "tempfile" +version = "3.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tokio" +version = "1.44.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tonic" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" +dependencies = [ + "async-stream", + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-timeout", + "hyper-util", + "percent-encoding", + "pin-project", + "prost 0.13.5", + "socket2", + "tokio", + "tokio-stream", + "tower 0.4.13", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" +dependencies = [ + "futures-core", + "futures-util", + "indexmap 1.9.3", + "pin-project", + "pin-project-lite", + "rand", + "slab", + "tokio", + "tokio-util", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "log", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unindent" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" + +[[package]] +name = "union_rust_controller" +version = "0.1.0" +dependencies = [ + "async-trait", + "cloudidl", + "futures", + "prost 0.13.5", + "prost-types 0.12.6", + "pyo3", + "pyo3-async-runtimes", + "pyo3-build-config", + "thiserror", + "tokio", + "tonic", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasi" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags", +] + +[[package]] +name = "zerocopy" +version = "0.8.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml new file mode 100644 index 000000000..08c250c15 --- /dev/null +++ b/rs_controller/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "union_rust_controller" +version = "0.1.0" +edition = "2021" + +[lib] +name = "union_rust_controller" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +cloudidl = { path = "../../cloud/gen/pb_rust" } +pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } +pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] } +tokio = { version = "1.0", features = ["full"] } +tonic = "0.12" +prost = { version = "0.13.5", features = ["std"] } +prost-types = { version = "0.12", features = ["std"] } +futures = "0.3" +tower = "0.4" +tower-http = { version = "0.5", features = ["trace"] } +tracing = "0.1" +tracing-subscriber = "0.3" +async-trait = "0.1" +thiserror = "1.0" +pyo3-build-config = "0.24.2" + +[build-dependencies] +pyo3 = { version = "0.24", features = [ + "extension-module", + "abi3-py310", +] } # Use your pyo3 version + +# Need to do this for some reason otherwise maturin develop fails horribly +# export RUSTFLAGS="-C link-arg=-undefined -C link-arg=dynamic_lookup" diff --git a/rs_controller/Dockerfile.maturinx b/rs_controller/Dockerfile.maturinx new file mode 100644 index 000000000..fc48de38e --- /dev/null +++ b/rs_controller/Dockerfile.maturinx @@ -0,0 +1,13 @@ +ARG ARCH=aarch64 +FROM quay.io/pypa/manylinux_2_28_${ARCH} + +# Install any extra system dependencies (beyond manylinux defaults; add yours if needed, e.g., for your Rust crate) +RUN yum install -y --enablerepo=powertools curl pkgconfig openssl-devel gcc-c++ make libffi-devel vim + +# Install Rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y +ENV PATH="/root/.cargo/bin:${PATH}" + +# Install Maturin using one of the available Pythons (e.g., 3.12; manylinux has 3.8–3.13+) +RUN /opt/python/cp310-cp310/bin/pip install maturin + diff --git a/rs_controller/Makefile b/rs_controller/Makefile new file mode 100644 index 000000000..95839d487 --- /dev/null +++ b/rs_controller/Makefile @@ -0,0 +1,43 @@ +# Makefile for building wheel-builder images and wheels using those images +# This is for making the main flyte base with_local_v2() image, which needs to +# be multi-arch to run on live clusters. + + +# Build both architecture-specific builder images +build-builders: build-arm64 build-amd64 + +# Build the arm64 builder image +build-arm64: + docker buildx build --platform linux/arm64 --build-arg ARCH=aarch64 -f Dockerfile.maturinx -t wheel-builder:arm64 . + +# Build the amd64 builder image (emulated on arm64 host) +build-amd64: + docker buildx build --platform linux/amd64 --build-arg ARCH=x86_64 -f Dockerfile.maturinx -t wheel-builder:amd64 . + +CARGO_CACHE_DIR := docker_cargo_cache +DIST_DIRS := dist + +dist-dirs: + mkdir -p $(DIST_DIRS) $(CARGO_CACHE_DIR) + +define BUILD_WHEELS_RECIPE +docker run --rm \ + -v $(PWD):/io \ + -v $(PWD)/docker_cargo_cache:/root/.cargo/registry \ + -v $(CLOUD_REPO):/cloud \ + -v $(PWD)/temp_cargo_config.toml:/io/.cargo/config.toml \ + wheel-builder:$(1) /bin/bash -c "\ + cd /io; \ + /opt/python/cp310-cp310/bin/maturin build --release --find-interpreter --out /io/dist/ \ + " +endef + +# Build wheels for arm64 (depends on dist-dirs) +build-wheels-arm64: dist-dirs + $(call BUILD_WHEELS_RECIPE,arm64) + +# Build wheels for amd64 (depends on dist-dirs) +build-wheels-amd64: dist-dirs + $(call BUILD_WHEELS_RECIPE,amd64) + +build-wheels: build-wheels-arm64 build-wheels-amd64 diff --git a/rs_controller/build.rs b/rs_controller/build.rs new file mode 100644 index 000000000..40462f970 --- /dev/null +++ b/rs_controller/build.rs @@ -0,0 +1,4 @@ +// build.rs +fn main() { + // pyo3_build_config::use_pyo3_cfgs(); +} \ No newline at end of file diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml new file mode 100644 index 000000000..5431a94d5 --- /dev/null +++ b/rs_controller/pyproject.toml @@ -0,0 +1,15 @@ +[build-system] +requires = ["maturin>=1.4,<2.0"] +build-backend = "maturin" + +[project] +name = "rust_controller" +version = "0.1.0" +description = "Rust controller for Union" +requires-python = ">=3.10" +classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] + +[tool.maturin] +module-name = "flyte_controller_base" +features = ["pyo3/extension-module"] + diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs new file mode 100644 index 000000000..0abd52792 --- /dev/null +++ b/rs_controller/src/action.rs @@ -0,0 +1,209 @@ +use pyo3::prelude::*; + +use cloudidl::{ + cloudidl::workflow::{ActionIdentifier, ActionUpdate, Phase, TaskSpec}, + flyteidl::core::ExecutionError, +}; +use tracing::debug; + +#[pyclass(eq, eq_int)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ActionType { + Task = 0, + Trace = 1, +} + +#[pyclass(dict, get_all, set_all)] +#[derive(Debug, Clone, PartialEq)] +pub struct Action { + pub action_id: ActionIdentifier, + pub parent_action_name: String, + pub action_type: ActionType, + pub friendly_name: Option, + pub group: Option, + pub task: Option, + pub inputs_uri: Option, + pub run_output_base: Option, + pub realized_outputs_uri: Option, + pub err: Option, + pub phase: Option, + pub started: bool, + pub retries: u32, + pub client_err: Option, // Changed from PyErr to String for serializability + pub cache_key: Option, +} + +impl Action { + pub fn get_run_name(&self) -> String { + match self.action_id.run.clone() { + Some(run_id) => run_id.name, + None => String::from("missing run name"), + } + } + + pub fn get_action_name(&self) -> String { + self.action_id.name.clone() + } + + pub fn set_client_err(&mut self, err: String) { + debug!( + "Setting client error on action {:?} to {}", + self.action_id, err + ); + self.client_err = Some(err); + } + + pub fn mark_cancelled(&mut self) { + debug!("Marking action {:?} as cancelled", self.action_id); + self.mark_started(); + self.phase = Some(Phase::Aborted); + } + + pub fn mark_started(&mut self) { + debug!("Marking action {:?} as started", self.action_id); + self.started = true; + // clear self.task in the future to save memory + } + + pub fn merge_update(&mut self, obj: &ActionUpdate) { + if let Ok(new_phase) = Phase::try_from(obj.phase) { + if self.phase.is_none() || self.phase != Some(new_phase) { + self.phase = Some(new_phase); + if obj.error.is_some() { + self.err = obj.error.clone(); + } + } + } + if !obj.output_uri.is_empty() { + self.realized_outputs_uri = Some(obj.output_uri.clone()); + } + self.started = true; + } + + pub fn new_from_update(parent_action_name: String, obj: ActionUpdate) -> Self { + let action_id = obj.action_id.unwrap(); + let phase = Phase::try_from(obj.phase).unwrap(); + Action { + action_id: action_id.clone(), + parent_action_name, + action_type: ActionType::Task, + friendly_name: None, + group: None, + task: None, + inputs_uri: None, + run_output_base: None, + realized_outputs_uri: Some(obj.output_uri), + err: obj.error, + phase: Some(phase), + started: true, + retries: 0, + client_err: None, + cache_key: None, + } + } + + pub fn is_action_terminal(&self) -> bool { + if let Some(phase) = &self.phase { + matches!( + phase, + Phase::Succeeded | Phase::Failed | Phase::Aborted | Phase::TimedOut + ) + } else { + false + } + } + + // action here is the submitted action, invoked by the informer's manual submit. + pub fn merge_from_submit(&mut self, action: &Action) { + self.run_output_base = action.run_output_base.clone(); + self.inputs_uri = action.inputs_uri.clone(); + self.group = action.group.clone(); + self.friendly_name = action.friendly_name.clone(); + + if !self.started { + self.task = action.task.clone(); + } + + self.cache_key = action.cache_key.clone(); + } +} + +#[pymethods] +impl Action { + #[staticmethod] + pub fn from_task( + sub_action_id: ActionIdentifier, + parent_action_name: String, + group_data: Option, + task_spec: TaskSpec, + inputs_uri: String, + run_output_base: String, + cache_key: Option, + ) -> Self { + debug!("Creating Action from task for ID {:?}", &sub_action_id); + Action { + action_id: sub_action_id, + parent_action_name, + action_type: ActionType::Task, + friendly_name: task_spec + .task_template + .as_ref() + .and_then(|tt| tt.id.as_ref().and_then(|id| Some(id.name.clone()))), + group: group_data, + task: Some(task_spec), + inputs_uri: Some(inputs_uri), + run_output_base: Some(run_output_base), + realized_outputs_uri: None, + err: None, + phase: Some(Phase::Unspecified), + started: false, + retries: 0, + client_err: None, + cache_key, + } + } + + /// This creates a new action for tracing purposes. It is used to track the execution of a trace + #[staticmethod] + pub fn from_trace( + parent_action_name: String, + action_id: ActionIdentifier, + friendly_name: String, + group_data: Option, + inputs_uri: String, + outputs_uri: String, + ) -> Self { + debug!("Creating Action from trace for ID {:?}", &action_id); + Action { + action_id, + parent_action_name, + action_type: ActionType::Trace, + friendly_name: Some(friendly_name), + group: group_data, + task: None, + inputs_uri: Some(inputs_uri), + run_output_base: None, + realized_outputs_uri: Some(outputs_uri), + err: None, + phase: Some(Phase::Succeeded), + started: true, + retries: 0, + client_err: None, + cache_key: None, + } + } + + #[getter(run_name)] + fn run_name(&self) -> String { + self.get_run_name() + } + + #[getter(name)] + fn name(&self) -> String { + self.get_action_name() + } + + fn has_error(&self) -> bool { + self.err.is_some() || self.client_err.is_some() + } +} diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs new file mode 100644 index 000000000..a09b83ee1 --- /dev/null +++ b/rs_controller/src/informer.rs @@ -0,0 +1,314 @@ +use crate::action::Action; +use crate::ControllerError; +use cloudidl::cloudidl::workflow::state_service_client::StateServiceClient; +use cloudidl::cloudidl::workflow::watch_request; +use cloudidl::cloudidl::workflow::watch_response::Message; +use cloudidl::cloudidl::workflow::{ActionIdentifier, RunIdentifier, WatchRequest, WatchResponse}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::select; +use tokio::sync::RwLock; +use tokio::sync::{mpsc, oneshot, Notify}; +use tokio::task::JoinHandle; +use tokio::time::sleep; +use tonic::transport::channel::Channel; +use tonic::transport::Endpoint; +use tracing::{debug, error, info, warn}; +use tracing_subscriber::fmt; + +#[derive(Clone, Debug)] +pub struct Informer { + client: StateServiceClient, + run_id: RunIdentifier, + action_cache: Arc>>, + parent_action_name: String, + shared_queue: mpsc::Sender, + ready: Arc, + completion_events: Arc>>>, +} + +impl Informer { + pub fn new( + client: StateServiceClient, + run_id: RunIdentifier, + parent_action_name: String, + shared_queue: mpsc::Sender, + ) -> Self { + Informer { + client, + run_id, + action_cache: Arc::new(RwLock::new(HashMap::new())), + parent_action_name, + shared_queue, + ready: Arc::new(Notify::new()), + completion_events: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn set_action_client_err(&self, action: &Action) -> Result<(), ControllerError> { + if let Some(client_err) = &action.client_err { + let mut cache = self.action_cache.write().await; + let action_name = action.action_id.name.clone(); + if let Some(action) = cache.get_mut(&action_name) { + action.set_client_err(client_err.clone()); + Ok(()) + } else { + Err(ControllerError::RuntimeError(format!( + "Action {} not found in cache", + action_name + ))) + } + } else { + Ok(()) + } + } + + async fn handle_watch_response( + &self, + response: WatchResponse, + ) -> Result, ControllerError> { + debug!( + "Informer for {:?}::{} processing incoming message {:?}", + self.run_id.name, self.parent_action_name, &response + ); + if let Some(msg) = response.message { + match msg { + Message::ControlMessage(_) => { + // Handle control messages if needed + debug!("Received sentinel for parent {}", self.parent_action_name); + self.ready.notify_one(); + Ok(None) + } + Message::ActionUpdate(action_update) => { + // Handle action updates + debug!("Received action update: {:?}", action_update.action_id); + let mut cache = self.action_cache.write().await; + let action_name = action_update + .action_id + .as_ref() + .map(|act_id| act_id.name.clone()) + .ok_or(ControllerError::RuntimeError(format!( + "Action update received without a name: {:?}", + action_update + )))?; + + if let Some(existing) = cache.get_mut(&action_name) { + existing.merge_update(&action_update); + + // Don't fire a completion event here either - successful return of this + // function should re-enqueue the action for processing, and the controller + // will detect and fire completion + } else { + debug!( + "Action update for {:?} not in cache, adding", + action_update.action_id + ); + let action_from_update = + Action::new_from_update(self.parent_action_name.clone(), action_update); + cache.insert(action_name.clone(), action_from_update); + + // don't fire completion events here because we may not have a completion event yet + // i.e. the submit that creates the completion event may not have fired yet, so just + // add to the cache for now. + } + + Ok(Some(cache.get(&action_name).unwrap().clone())) + } + } + } else { + Err(ControllerError::BadContext( + "No message in response".to_string(), + )) + } + } + + async fn watch_actions(&self) -> ControllerError { + let action_id = ActionIdentifier { + name: self.parent_action_name.clone(), + run: Some(self.run_id.clone()), + }; + let request = WatchRequest { + filter: Some(watch_request::Filter::ParentActionId(action_id)), + }; + + let mut stream = self + .client + .clone() + .watch(request) + .await; + + let mut stream = match stream { + Ok(s) => s.into_inner(), + Err(e) => { + error!("Failed to start watch stream: {:?}", e); + return ControllerError::from(e); + } + }; + + loop { + match stream.message().await { + Ok(Some(response)) => { + let handle_response = self.handle_watch_response(response).await; + match handle_response { + Ok(Some(action)) => { + match self.shared_queue.send(action).await { + Ok(_) => { + continue; + } + Err(e) => { + error!("Informer watch failed sending action back to shared queue: {:?}", e); + return ControllerError::RuntimeError(format!( + "Failed to send action to shared queue: {}", + e + )); + } + } + } + Ok(None) => { + debug!("Received None from handle_watch_response, continuing watch loop."); + } + Err(err) => { + // this should cascade up to the controller to restart the informer, and if there + // are too many informer restarts, the controller should fail + error!("Error in informer watch {:?}", err); + return err; + } + } + } + Ok(None) => { + debug!("Stream received empty message, maybe no more messages? Repeating watch loop."); + }, // Stream ended, exit loop + Err(e) => { + error!("Error receiving message from stream: {:?}", e); + return ControllerError::from(e); + } + } + + } + } + + async fn wait_ready_or_timeout(ready: Arc) -> Result<(), ControllerError> { + select! { + _ = ready.notified() => { + debug!("Ready sentinel ack'ed"); + Ok(()) + } + _ = sleep(Duration::from_millis(100)) => Err(ControllerError::SystemError("".to_string())) + } + } + + pub async fn start(informer: Arc) -> Result, ControllerError> { + let me = informer.clone(); + let ready = me.ready.clone(); + let _watch_handle = tokio::spawn(async move { + // handle errors later + me.watch_actions().await; + }); + + match Self::wait_ready_or_timeout(ready).await { + Ok(()) => Ok(_watch_handle), + Err(_) => { + warn!("Timed out waiting for sentinel"); + Ok(_watch_handle) + } + } + } + + pub async fn get_action(&self, action_name: String) -> Option { + let cache = self.action_cache.read().await; + cache.get(&action_name).cloned() + } + + pub async fn submit_action( + &self, + action: Action, + done_tx: oneshot::Sender<()>, + ) -> Result<(), ControllerError> { + let action_name = action.action_id.name.clone(); + + // Store the completion event sender + { + let mut completion_events = self.completion_events.write().await; + completion_events.insert(action_name.clone(), done_tx); + } + + // Add action to shared queue + self.shared_queue.send(action).await.map_err(|e| { + ControllerError::RuntimeError(format!("Failed to send action to shared queue: {}", e)) + })?; + + Ok(()) + } + + pub async fn fire_completion_event(&self, action_name: &str) -> Result<(), ControllerError> { + info!("Firing completion event for action: {}", action_name); + let mut completion_events = self.completion_events.write().await; + if let Some(done_tx) = completion_events.remove(action_name) { + done_tx.send(()).map_err(|_| { + ControllerError::RuntimeError(format!( + "Failed to send completion event for action: {}", + action_name + )) + })?; + } else { + error!( + "No completion event found for action---------------------: {}", + action_name, + ); + // Return error, which should cause informer to re-enqueue + return Err(ControllerError::RuntimeError(format!( + "No completion event found for action: {}. This may be because the informer is still starting up.", + action_name + ))); + } + Ok(()) + } +} + +async fn informer_main() { + // Create an informer but first create the shared_queue that will be shared between the + // Controller and the informer + let (tx, rx) = mpsc::channel::(64); + let endpoint = Endpoint::from_static("http://localhost:8090"); + let channel = endpoint.connect().await.unwrap(); + let client = StateServiceClient::new(channel); + + let run_id = RunIdentifier { + org: String::from("testorg"), + project: String::from("testproject"), + domain: String::from("development"), + name: String::from("qdtc266r2z8clscl2lj5"), + }; + + let informer = Arc::new(Informer::new(client, run_id, "a0".to_string(), tx.clone())); + + let watch_task = Informer::start(informer.clone()).await; + + println!("{:?}: {:?}", informer, watch_task); + // do creation and start of informer behind a once +} + +fn init_tracing() { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let subscriber = fmt() + .with_max_level(tracing::Level::DEBUG) + .with_test_writer() // so logs show in test output + .finish(); + tracing::subscriber::set_global_default(subscriber) + .expect("setting default subscriber failed"); + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_informer() { + init_tracing(); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(informer_main()); + } +} diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs new file mode 100644 index 000000000..ca1a34eea --- /dev/null +++ b/rs_controller/src/lib.rs @@ -0,0 +1,508 @@ +mod action; +mod informer; + +use std::sync::Arc; +use std::time::Duration; + +use futures::TryFutureExt; +use pyo3::prelude::*; +use tokio::sync::{mpsc, Mutex, Notify}; +use tracing::{debug, error, info, warn}; + +use thiserror::Error; + +use crate::action::{Action, ActionType}; +use crate::informer::Informer; + +use cloudidl::cloudidl::workflow::queue_service_client::QueueServiceClient; +use cloudidl::cloudidl::workflow::state_service_client::StateServiceClient; +use cloudidl::cloudidl::workflow::{ + enqueue_action_request, ActionIdentifier, EnqueueActionRequest, EnqueueActionResponse, + RunIdentifier, TaskAction, TaskIdentifier, +}; +use cloudidl::google; +use google::protobuf::StringValue; +use pyo3::exceptions; +use pyo3::types::PyAny; +use pyo3_async_runtimes::tokio::future_into_py; +use pyo3_async_runtimes::tokio::get_runtime; +use tokio::sync::{oneshot, OnceCell}; +use tokio::time::sleep; +use tonic::transport::Endpoint; +use tonic::Status; +use tracing_subscriber::FmtSubscriber; + +#[derive(Error, Debug)] +pub enum ControllerError { + #[error("Bad context: {0}")] + BadContext(String), + #[error("Runtime error: {0}")] + RuntimeError(String), + #[error("System error: {0}")] + SystemError(String), + #[error("gRPC error: {0}")] + GrpcError(#[from] tonic::Status), + #[error("Task error: {0}")] + TaskError(String), +} + +impl From for ControllerError { + fn from(err: tonic::transport::Error) -> Self { + ControllerError::SystemError(format!("Transport error: {:?}", err)) + } +} + +impl From for PyErr { + // can better map errors in the future + fn from(err: ControllerError) -> Self { + exceptions::PyRuntimeError::new_err(err.to_string()) + } +} + +struct CoreBaseController { + endpoint: String, + state_client: StateServiceClient, + queue_client: QueueServiceClient, + informer: OnceCell>, + shared_queue: mpsc::Sender, + rx_of_shared_queue: Arc>>, +} + +impl CoreBaseController { + pub fn try_new(endpoint: String) -> Result, ControllerError> { + info!("Creating CoreBaseController with endpoint {:?}", endpoint); + // play with taking str slice instead of String instead of intentionally leaking. + let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); + // shared queue + let (shared_tx, rx_of_shared_queue) = mpsc::channel::(64); + + let rt = get_runtime(); + let (state_client, queue_client) = rt.block_on(async { + // Need to update to with auth to read API key + let endpoint = Endpoint::from_static(&endpoint_static); + let channel = endpoint.connect().await.map_err(|e| ControllerError::from(e))?; + Ok::<_, ControllerError>(( + StateServiceClient::new(channel.clone()), + QueueServiceClient::new(channel), + )) + })?; + + let real_base_controller = CoreBaseController { + endpoint, + state_client, + queue_client, + informer: OnceCell::new(), + shared_queue: shared_tx, + rx_of_shared_queue: Arc::new(tokio::sync::Mutex::new(rx_of_shared_queue)), + }; + + let real_base_controller = Arc::new(real_base_controller); + // Start the background worker + let controller_clone = real_base_controller.clone(); + rt.spawn(async move { + controller_clone.bg_worker().await; + }); + Ok(real_base_controller) + } + + async fn bg_worker(&self) { + const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(100); + const MAX_RETRIES: u32 = 5; + + debug!( + "Launching core controller background task on thread {:?}", + std::thread::current().name() + ); + loop { + // Receive actions from shared queue + let mut rx = self.rx_of_shared_queue.lock().await; + match rx.recv().await { + Some(mut action) => { + let run_name = &action + .action_id + .run + .as_ref() + .map_or(String::from(""), |i| i.name.clone()); + debug!( + "Controller worker processing action {}::{}", + run_name, action.action_id.name + ); + + // Drop the mutex guard before processing + drop(rx); + + match self.handle_action(&mut action).await { + Ok(_) => {} + Err(e) => { + error!("Error in controller loop: {:?}", e); + // Handle backoff and retry logic + sleep(MIN_BACKOFF_ON_ERR).await; + action.retries += 1; + + if action.retries > MAX_RETRIES { + error!( + "Controller failed processing {}::{}, system retries {} crossed threshold {}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + ); + action.client_err = Some(format!( + "Controller failed {}::{}, system retries {} crossed threshold {}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + )); + + // Fire completion event for failed action + if let Some(informer) = self.informer.get() { + // todo: check these two errors + + // Before firing completion event, update the action in the + // informer, otherwise client_err will not be set. + let _ = informer.set_action_client_err(&action).await; + let _ = informer + .fire_completion_event(&action.action_id.name) + .await; + } else { + error!( + "Max retries hit for action but informer still not yet initialized for action: {}", + action.action_id.name + ); + } + } else { + // Re-queue the action for retry + info!( + "Re-queuing action {}::{} for retry, attempt {}/{}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + ); + if let Err(send_err) = self.shared_queue.send(action).await { + error!("Failed to re-queue action for retry: {}", send_err); + } + } + } + } + } + None => { + warn!("Shared queue channel closed, stopping bg_worker"); + break; + } + } + } + } + + async fn handle_action(&self, action: &mut Action) -> Result<(), ControllerError> { + if !action.started { + // Action not started, launch it + self.bg_launch(action).await?; + } else if action.is_action_terminal() { + // Action is terminal, fire completion event + if let Some(informer) = self.informer.get() { + debug!( + "handle action firing completion event for {:?}", + &action.action_id.name + ); + informer + .fire_completion_event(&action.action_id.name) + .await?; + } else { + error!( + "Informer not yet initialized for action: {}", + action.action_id.name + ); + return Err(ControllerError::BadContext(format!( + "Informer not initialized for action: {}. This may be because the informer is still starting up.", + action.action_id.name + ))); + } + } else { + // Action still in progress + debug!("Resource {} still in progress...", action.action_id.name); + } + Ok(()) + } + + async fn bg_launch(&self, action: &Action) -> Result<(), ControllerError> { + match self.launch_task(action).await { + Ok(_) => { + debug!("Successfully launched action: {}", action.action_id.name); + Ok(()) + } + Err(e) => { + error!( + "Failed to launch action: {}, error: {}", + action.action_id.name, e + ); + Err(ControllerError::RuntimeError(format!( + "Launch failed: {}", + e + ))) + } + } + } + + async fn cancel_action(&self, action: &mut Action) -> Result<(), ControllerError> { + if action.is_action_terminal() { + info!( + "Action {} is already terminal, no need to cancel.", + action.action_id.name + ); + return Ok(()); + } + + debug!("Cancelling action: {}", action.action_id.name); + action.mark_cancelled(); + + if let Some(informer) = self.informer.get() { + let _ = informer + .fire_completion_event(&action.action_id.name) + .await?; + } else { + debug!( + "Informer missing when trying to cancel action: {}", + action.action_id.name + ); + } + Ok(()) + } + + async fn get_action(&self, action_id: ActionIdentifier) -> Result { + if let Some(informer) = self.informer.get() { + let action_name = action_id.name.clone(); + match informer.get_action(action_name).await { + Some(action) => Ok(action), + None => Err(ControllerError::RuntimeError(format!( + "Action not found: {}", + action_id.name + ))), + } + } else { + Err(ControllerError::BadContext( + "Informer not initialized".to_string(), + )) + } + } + + fn create_enqueue_action_request( + &self, + action: &Action, + ) -> Result { + // todo-pr: handle trace action + let task_identifier = action + .task + .as_ref() + .and_then(|task| task.task_template.as_ref()) + .and_then(|task_template| task_template.id.as_ref()) + .and_then(|core_task_id| { + Some(TaskIdentifier { + version: core_task_id.version.clone(), + org: core_task_id.org.clone(), + project: core_task_id.project.clone(), + domain: core_task_id.domain.clone(), + name: core_task_id.name.clone(), + }) + }) + .ok_or(ControllerError::RuntimeError(format!( + "TaskIdentifier missing from Action {:?}", + action + )))?; + + let input_uri = action + .inputs_uri + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Inputs URI missing from Action {:?}", + action + )))?; + let run_output_base = + action + .run_output_base + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Run output base missing from Action {:?}", + action + )))?; + let group = action.group.clone().unwrap_or_default(); + let task_action = TaskAction { + id: Some(task_identifier), + spec: action.task.clone(), + cache_key: action + .cache_key + .as_ref() + .map(|ck| StringValue { value: ck.clone() }), + }; + + Ok(EnqueueActionRequest { + action_id: Some(action.action_id.clone()), + parent_action_name: Some(action.parent_action_name.clone()), + spec: Some(enqueue_action_request::Spec::Task(task_action)), + run_spec: None, + input_uri, + run_output_base, + group, + subject: String::default(), // Subject is not used in the current implementation + }) + } + + async fn launch_task(&self, action: &Action) -> Result { + if !action.started && action.task.is_some() { + let enqueue_request = self + .create_enqueue_action_request(action) + .expect("Failed to create EnqueueActionRequest"); + let mut client = self.queue_client.clone(); + // todo: tonic doesn't seem to have wait_for_ready, or maybe the .ready is already doing this. + let enqueue_result = client.enqueue_action(enqueue_request).await; + match enqueue_result { + Ok(response) => { + debug!("Successfully launched action: {:?}", action.action_id); + Ok(response.into_inner()) + } + Err(e) => { + if e.code() == tonic::Code::AlreadyExists { + info!( + "Action {} already exists, continuing to monitor.", + action.action_id.name + ); + Ok(EnqueueActionResponse {}) + } else { + error!( + "Failed to launch action: {:?}, backing off...", + action.action_id + ); + error!("Error details: {}", e); + // Handle backoff logic here + Err(e) + } + } + } + } else { + debug!( + "Action {} is already started or has no task, skipping launch.", + action.action_id.name + ); + Ok(EnqueueActionResponse {}) + } + } + + pub async fn _submit_action(&self, action: Action) -> Result { + let action_name = action.action_id.name.clone(); + let parent_action_name = action.parent_action_name.clone(); + // The first action that gets submitted determines the run_id that will be used. + // This is obviously not going to work, + + let run_id = action + .action_id + .run + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Run ID missing from submit action {}", + action_name.clone() + )))?; + let informer: &Arc = self + .informer // OnceCell> + .get_or_try_init(|| async move { + info!("Creating informer set to run_id {:?}", run_id); + let inf = Arc::new(Informer::new( + self.state_client.clone(), + run_id, + parent_action_name, + self.shared_queue.clone(), + )); + + Informer::start(inf.clone()).await?; + + // Using PyErr for now, but any errors coming from the informer will not really + // be py errs, will need to add and map later. + Ok::, ControllerError>(inf) + }) + .await?; + let (done_tx, done_rx) = oneshot::channel(); + informer.submit_action(action, done_tx).await?; + + done_rx.await.map_err(|_| { + ControllerError::BadContext(String::from("Failed to receive done signal from informer")) + })?; + debug!( + "Action {} complete, looking up final value and returning", + action_name + ); + + // get the action and return it + let final_action = informer.get_action(action_name).await; + final_action.ok_or(ControllerError::BadContext(String::from( + "Action not found after done", + ))) + } +} + +/// Base class for RemoteController to eventually inherit from +#[pyclass(subclass)] +struct BaseController(Arc); + +#[pymethods] +impl BaseController { + #[new] + #[pyo3(signature = (*, endpoint))] + fn new(endpoint: String) -> PyResult { + info!("Creating controller wrapper with endpoint {:?}", endpoint); + let core_base = CoreBaseController::try_new(endpoint)?; + Ok(BaseController(core_base)) + } + + /// `async def submit(self, action: Action) -> Action` + /// + /// Enqueue `action`. + fn submit_action<'py>(&self, py: Python<'py>, action: Action) -> PyResult> { + let real_base = self.0.clone(); + let py_fut = future_into_py(py, async move { + let action_id = action.action_id.clone(); + real_base._submit_action(action).await.map_err(|e| { + error!("Error submitting action {:?}: {:?}", action_id, e); + exceptions::PyRuntimeError::new_err(format!("Failed to submit action: {}", e)) + }) + }); + py_fut + } + + fn cancel_action<'py>(&self, py: Python<'py>, action: Action) -> PyResult> { + let real_base = self.0.clone(); + let mut a = action.clone(); + let py_fut = future_into_py(py, async move { + real_base.cancel_action(&mut a).await.map_err(|e| { + error!("Error cancelling action {:?}: {:?}", action.action_id, e); + exceptions::PyRuntimeError::new_err(format!("Failed to cancel action: {}", e)) + }) + }); + py_fut + } + + fn get_action<'py>( + &self, + py: Python<'py>, + action_id: ActionIdentifier, + ) -> PyResult> { + let real_base = self.0.clone(); + let py_fut = future_into_py(py, async move { + real_base.get_action(action_id.clone()).await.map_err(|e| { + error!("Error getting action {:?}: {:?}", action_id, e); + exceptions::PyRuntimeError::new_err(format!("Failed to cancel action: {}", e)) + }) + }); + py_fut + } +} + +use cloudidl::pymodules::cloud_mod; + +#[pymodule] +fn flyte_controller_base(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let subscriber = FmtSubscriber::builder() + .with_max_level(tracing::Level::DEBUG) + .finish(); + tracing::subscriber::set_global_default(subscriber) + .expect("Failed to set global tracing subscriber"); + }); + + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + cloud_mod(py, m)?; + Ok(()) +} diff --git a/rs_controller/temp_cargo_config.toml b/rs_controller/temp_cargo_config.toml new file mode 100644 index 000000000..1ea088bc1 --- /dev/null +++ b/rs_controller/temp_cargo_config.toml @@ -0,0 +1,2 @@ +[patch.crates-io] +cloudidl = { path = "/cloud/gen/pb_rust" } diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 88b02cacd..ea7969168 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: import concurrent.futures -ControllerType = Literal["local", "remote"] +ControllerType = Literal["local", "remote", "rust"] R = TypeVar("R") @@ -120,7 +120,15 @@ def create_controller( case "remote" | "hybrid": from flyte._internal.controllers.remote import create_remote_controller - controller = create_remote_controller(**kwargs) + # controller = create_remote_controller(**kwargs) + from flyte._internal.controllers.remote._r_controller import RemoteController + + controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) + case "rust": + # hybrid case, despite the case statement above, meant for local runs not inside docker + from flyte._internal.controllers.remote._r_controller import RemoteController + + controller = RemoteController(endpoint="http://localhost:8090", workers=10, max_system_retries=5) case _: raise ValueError(f"{ct} is not a valid controller type.") diff --git a/src/flyte/_internal/controllers/_local_controller.py b/src/flyte/_internal/controllers/_local_controller.py index d5de88012..9b9be43cb 100644 --- a/src/flyte/_internal/controllers/_local_controller.py +++ b/src/flyte/_internal/controllers/_local_controller.py @@ -4,7 +4,7 @@ import os import pathlib import threading -from typing import Any, Callable, Tuple, TypeVar +from typing import Any, Callable, Tuple, TypeVar, Protocol import flyte.errors from flyte._cache.cache import VersionParameters, cache_from_request @@ -22,6 +22,16 @@ R = TypeVar("R") +class ControllerProtocol(Protocol): + async def submit(self, _task: "TaskTemplate", *args, **kwargs) -> Any: ... + def submit_sync(self, _task: "TaskTemplate", *args, **kwargs) -> concurrent.futures.Future: ... + async def finalize_parent_action(self, action: "ActionID"): ... + async def get_action_outputs( + self, _interface: "NativeInterface", _func: Callable, *args, **kwargs + ) -> Tuple["TraceInfo", bool]: ... + async def record_trace(self, info: "TraceInfo"): ... + async def submit_task_ref(self, _task: "task_definition_pb2.TaskDetails", *args, **kwargs) -> Any: ... + class _TaskRunner: """A task runner that runs an asyncio event loop on a background thread.""" @@ -69,7 +79,7 @@ def get_run_future(self, coro: Any) -> concurrent.futures.Future: return fut -class LocalController: +class LocalController(ControllerProtocol): def __init__(self): logger.debug("LocalController init") self._runner_map: dict[str, _TaskRunner] = {} diff --git a/src/flyte/_internal/controllers/remote/_action.py b/src/flyte/_internal/controllers/remote/_action.py index 52b756bc7..5a1518161 100644 --- a/src/flyte/_internal/controllers/remote/_action.py +++ b/src/flyte/_internal/controllers/remote/_action.py @@ -28,7 +28,7 @@ class Action: parent_action_name: str type: ActionType = "task" # type of action, task or trace friendly_name: str | None = None - group: GroupData | None = None + group: str | None = None task: task_definition_pb2.TaskSpec | None = None trace: run_definition_pb2.TraceAction | None = None inputs_uri: str | None = None @@ -117,7 +117,7 @@ def from_task( cls, parent_action_name: str, sub_action_id: identifier_pb2.ActionIdentifier, - group_data: GroupData | None, + group_data: str | None, task_spec: task_definition_pb2.TaskSpec, inputs_uri: str, run_output_base: str, diff --git a/src/flyte/_internal/controllers/remote/_controller.py b/src/flyte/_internal/controllers/remote/_controller.py index 987c3caeb..bafb5e5e4 100644 --- a/src/flyte/_internal/controllers/remote/_controller.py +++ b/src/flyte/_internal/controllers/remote/_controller.py @@ -19,7 +19,7 @@ from flyte._context import internal_ctx from flyte._internal.controllers import TraceInfo from flyte._internal.controllers.remote._action import Action -from flyte._internal.controllers.remote._core import Controller +# from flyte._internal.controllers.remote._core import Controller from flyte._internal.controllers.remote._service_protocol import ClientSet from flyte._internal.runtime import convert, io from flyte._internal.runtime.task_serde import translate_task_to_wire @@ -30,6 +30,16 @@ from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext from flyte.remote._task import TaskDetails +# Import the Rust Controller instead of the Python one +# try: +# from flyte._internal.controllers.remote.rust_controller import Controller +# +# except ImportError: +# # Fallback to Python implementation during development +# from flyte._internal.controllers.remote._core import Controller + +from flyte._internal.controllers.remote.rust_controller import Controller + R = TypeVar("R") MAX_TRACE_BYTES = MAX_INLINE_IO_BYTES diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py new file mode 100644 index 000000000..008f47134 --- /dev/null +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -0,0 +1,557 @@ +from __future__ import annotations + +import asyncio +import concurrent.futures +import os +import threading +from asyncio import Event +from collections import defaultdict +from collections.abc import Callable +from pathlib import Path +from typing import Any, AsyncIterable, DefaultDict, Tuple, TypeVar + +from flyte_controller_base import Action, BaseController, cloudidl + +import flyte +import flyte.errors +import flyte.storage as storage +import flyte.types as types +from flyte._code_bundle import build_pkl_bundle +from flyte._context import internal_ctx +from flyte._internal.controllers import TraceInfo +from flyte._internal.runtime import convert, io +from flyte._internal.runtime.task_serde import translate_task_to_wire +from flyte._logging import logger +from flyte._protos.workflow import run_definition_pb2, task_definition_pb2 +from flyte._task import TaskTemplate +from flyte._utils.helpers import _selector_policy +from flyte.models import ActionID, NativeInterface, SerializationContext + +R = TypeVar("R") + + +async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None: + """ + Upload inputs to the specified URI with error handling. + + Args: + serialized_inputs: The serialized inputs to upload + inputs_uri: The destination URI + + Raises: + RuntimeSystemError: If the upload fails + """ + try: + # TODO Add retry decorator to this + await storage.put_stream(serialized_inputs, to_path=inputs_uri) + except Exception as e: + logger.exception("Failed to upload inputs") + raise flyte.errors.RuntimeSystemError(type(e).__name__, str(e)) from e + + +async def handle_action_failure(action: Action, task_name: str) -> Exception: + """ + Handle action failure by loading error details or raising a RuntimeSystemError. + + Args: + action: The updated action + task_name: The name of the task + + Raises: + Exception: The converted native exception or RuntimeSystemError + """ + err = action.err or action.client_err + if not err and action.phase == 6: # PHASE_FAILED + logger.error(f"Server reported failure for action {action.name}, checking error file.") + try: + error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1") + err = await io.load_error(error_path) + except Exception as e: + logger.exception("Failed to load error file", e) + err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}") + else: + logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}") + + exc = convert.convert_error_to_native(err) + if not exc: + return flyte.errors.RuntimeSystemError("UnableToConvertError", f"Error in task {task_name}: {err}") + return exc + + +async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any: + """ + Load outputs from the given URI and convert them to native format. + + Args: + iface: The Native interface + realized_outputs_uri: The URI where outputs are stored + + Returns: + The converted native outputs + """ + outputs_file_path = io.outputs_path(realized_outputs_uri) + outputs = await io.load_outputs(outputs_file_path) + return await convert.convert_outputs_to_native(iface, outputs) + + +def unique_action_name(action_id: ActionID) -> str: + return f"{action_id.name}_{action_id.run_name}" + + +class RemoteController(BaseController): + """ + This a specialized controller that wraps the core controller and performs IO, serialization and deserialization + """ + + def __new__( + cls, + endpoint: str, + workers: int, + max_system_retries: int, + default_parent_concurrency: int = 100, + ): + return super().__new__(cls, endpoint=endpoint) + + def __init__( + self, + endpoint: str, + workers: int, + max_system_retries: int, + default_parent_concurrency: int = 100, + ): + """ """ + self._default_parent_concurrency = default_parent_concurrency + self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict( + lambda: asyncio.Semaphore(default_parent_concurrency) + ) + self._parent_action_task_call_sequence: DefaultDict[str, DefaultDict[int, int]] = defaultdict( + lambda: defaultdict(int) + ) + self._submit_loop: asyncio.AbstractEventLoop | None = None + self._submit_thread: threading.Thread | None = None + + def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> int: + """ + Generate a task call sequence for the given task object and action ID. + This is used to track the number of times a task is called within an action. + """ + current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)] + current_task_id = id(task_obj) + v = current_action_sequencer[current_task_id] + new_seq = v + 1 + current_action_sequencer[current_task_id] = new_seq + return new_seq + + async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any: + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + current_action_id = tctx.action + + # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks + # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run. + code_bundle = tctx.code_bundle + + if code_bundle and code_bundle.pkl: + logger.debug(f"Building new pkl bundle for task {_task.name}") + code_bundle = await build_pkl_bundle( + _task, + upload_to_controlplane=False, + upload_from_dataplane_base_path=tctx.run_base_dir, + ) + + inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs) + + root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd() + # Don't set output path in sec context because node executor will set it + new_serialization_context = SerializationContext( + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + code_bundle=code_bundle, + version=tctx.version, + # supplied version. + # input_path=inputs_uri, + image_cache=tctx.compiled_image_cache, + root_dir=root_dir, + ) + + task_spec = translate_task_to_wire(_task, new_serialization_context) + inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs) + sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( + tctx, task_spec, inputs_hash, _task_call_seq + ) + + serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) + inputs_uri = io.inputs_path(sub_action_output_path) + await upload_inputs_with_retry(serialized_inputs, inputs_uri) + + md = task_spec.task_template.metadata + ignored_input_vars = [] + if len(md.cache_ignore_input_vars) > 0: + ignored_input_vars = list(md.cache_ignore_input_vars) + cache_key = None + if task_spec.task_template.metadata and task_spec.task_template.metadata.discoverable: + discovery_version = task_spec.task_template.metadata.discovery_version + cache_key = convert.generate_cache_key_hash( + _task.name, + inputs_hash, + task_spec.task_template.interface, + discovery_version, + ignored_input_vars, + inputs.proto_inputs, + ) + + # Clear to free memory + serialized_inputs = None # type: ignore + inputs_hash = None # type: ignore + + translated_task_spec = self._translate_task_spec_new_proto(task_spec) + action = Action.from_task( + sub_action_id=self._get_action_identifier_new_proto( + action_name=sub_action_id.name, + run_name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), + parent_action_name=current_action_id.name, + group_data=str(tctx.group_data) if tctx.group_data else None, + task_spec=translated_task_spec, + inputs_uri=inputs_uri, + run_output_base=tctx.run_base_dir, + cache_key=cache_key, + ) + + try: + logger.info( + f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], " + f"task:[{_task.name}], action:[{action.name}]" + ) + n = await self.submit_action(action) + logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!") + except asyncio.CancelledError: + # If the action is cancelled, we need to cancel the action on the server as well + logger.info(f"Action {action.action_id.name} cancelled, cancelling on server") + await self.cancel_action(action) + raise + + if n.has_error() or n.phase == 6: # failed + exc = await handle_action_failure(action, _task.name) + raise exc + + if _task.native_interface.outputs: + if not n.realized_outputs_uri: + raise flyte.errors.RuntimeSystemError( + "RuntimeError", + f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", + ) + return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri) + return None + + async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any: + """ + Submit a task to the remote controller.This creates a new action on the queue service. + """ + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + current_action_id = tctx.action + task_call_seq = self.generate_task_call_sequence(_task, current_action_id) + async with self._parent_action_semaphore[unique_action_name(current_action_id)]: + return await self._submit(task_call_seq, _task, *args, **kwargs) + + def _sync_thread_loop_runner(self) -> None: + """This method runs the event loop and should be invoked in a separate thread.""" + + loop = self._submit_loop + assert loop is not None + try: + loop.run_forever() + finally: + loop.close() + + def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future: + """ + # todo-pr: unclear if this will work. this calls submit on another thread, which then calls submit_action + This function creates a cached thread and loop for the purpose of calling the submit method synchronously, + returning a concurrent Future that can be awaited. There's no need for a lock because this function itself is + single threaded and non-async. This pattern here is basically the trivial/degenerate case of the thread pool + in the LocalController. + Please see additional comments in protocol. + + :param _task: + :param args: + :param kwargs: + :return: + """ + if self._submit_thread is None: + # Please see LocalController for the general implementation of this pattern. + def exc_handler(loop, context): + logger.error(f"Remote controller submit sync loop caught exception in {loop}: {context}") + + with _selector_policy(): + self._submit_loop = asyncio.new_event_loop() + self._submit_loop.set_exception_handler(exc_handler) + + self._submit_thread = threading.Thread( + name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner + ) + self._submit_thread.start() + + coro = self.submit(_task, *args, **kwargs) + assert self._submit_loop is not None, "Submit loop should always have been initialized by now" + fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop) + return fut + + # will be implemented in the future, should await for errors coming from the rust layer + async def watch_for_errors(self): + e = Event() + await e.wait() + + async def stop(self): + """ + Stop the controller. Incomplete, needs to gracefully shut down the rust controller as well. + """ + if self._submit_loop is not None: + self._submit_loop.stop() + if self._submit_thread is not None: + self._submit_thread.join() + self._submit_loop = None + self._submit_thread = None + logger.info("RemoteController stopped.") + + async def finalize_parent_action(self, action_id: ActionID): + """ + This method is invoked when the parent action is finished. It will finalize the run and upload the outputs + to the control plane. + """ + # todo-pr: implement any cleanup + # translate the ActionID python object to something handleable in pyo3 + # await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name) + self._parent_action_semaphore.pop(unique_action_name(action_id), None) + self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None) + + async def get_action_outputs( + self, _interface: NativeInterface, _func: Callable, *args, **kwargs + ) -> Tuple[TraceInfo, bool]: + """ + This method returns the outputs of the action, if it is available. + If not available it raises a NotFoundError. + :param _interface: NativeInterface + :param _func: Function name + :param args: Arguments + :param kwargs: Keyword arguments + :return: + """ + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + current_action_id = tctx.action + + func_name = _func.__name__ + invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id) + inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs) + serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) + + sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( + tctx, func_name, serialized_inputs, invoke_seq_num + ) + + inputs_uri = io.inputs_path(sub_action_output_path) + await upload_inputs_with_retry(serialized_inputs, inputs_uri) + # Clear to free memory + serialized_inputs = None # type: ignore + + sub_action_id = self._get_action_identifier_new_proto( + action_name=sub_action_id.name, + run_name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ) + prev_action = await self.get_action( + sub_action_id, + current_action_id.name, + ) + + if prev_action is None: + return TraceInfo(sub_action_id, _interface, inputs_uri), False + + if prev_action.phase == 6: # failed + if prev_action.has_error(): + exc = convert.convert_error_to_native(prev_action.err) + return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True + else: + logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!") + elif prev_action.realized_outputs_uri is not None: + outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri) + o = await io.load_outputs(outputs_file_path) + outputs = await convert.convert_outputs_to_native(_interface, o) + return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True + + return TraceInfo(sub_action_id, _interface, inputs_uri), False + + @staticmethod + def _get_action_identifier_new_proto( + *, action_name: str, run_name: str, project: str, domain: str, org: str + ) -> cloudidl.workflow.ActionIdentifier: + return cloudidl.workflow.ActionIdentifier( + name=action_name, + run=cloudidl.workflow.RunIdentifier( + name=run_name, + project=project, + domain=domain, + org=org, + ), + ) + + @staticmethod + def _translate_task_spec_new_proto(task_spec: task_definition_pb2.TaskSpec) -> cloudidl.workflow.TaskSpec: + task_spec_bytes = task_spec.SerializeToString() + new_task_spec = cloudidl.workflow.TaskSpec.decode(task_spec_bytes) + return new_task_spec + + async def record_trace(self, info: TraceInfo): + """ + Record a trace action. This is used to record the trace of the action and should be called when the action + :param info: + :return: + """ + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + + current_action_id = tctx.action + current_output_path = tctx.output_path + sub_run_output_path = storage.join(current_output_path, info.action.name) + + if info.interface.has_outputs(): + outputs_file_path: str = "" + if info.output: + outputs = await convert.convert_from_native_to_outputs(info.output, info.interface) + outputs_file_path = io.outputs_path(sub_run_output_path) + await io.upload_outputs(outputs, outputs_file_path) + elif info.error: + err = convert.convert_from_native_to_error(info.error) + error_path = io.error_path(sub_run_output_path) + await io.upload_error(err.err, error_path) + else: + raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error") + + action_id = self._get_action_identifier_new_proto( + action_name=info.action.name, + run_name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ) + trace_action = Action.from_trace( + parent_action_name=current_action_id.name, + action_id=action_id, + inputs_uri=info.inputs_path, + outputs_uri=outputs_file_path, + friendly_name=info.name, + group_data=str(tctx.group_data) if tctx.group_data else None, + ) + try: + logger.info( + f"Submitting Trace action Run:[{trace_action.run_name}, Parent:[{trace_action.parent_action_name}]," + f" Trace fn:[{info.name}], action:[{info.action.name}]" + ) + await self.submit_action(trace_action) + logger.info(f"Trace Action for [{info.name}] action id: {info.action.name}, completed!") + except asyncio.CancelledError: + # If the action is cancelled, we need to cancel the action on the server as well + raise + + async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any: + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + current_action_id = tctx.action + task_name = _task.spec.task_template.id.name + + invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id) + + native_interface = types.guess_interface( + _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs + ) + inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs) + inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs) + sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( + tctx, task_name, inputs_hash, invoke_seq_num + ) + + serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) + inputs_uri = io.inputs_path(sub_action_output_path) + await upload_inputs_with_retry(serialized_inputs, inputs_uri) + # cache key - task name, task signature, inputs, cache version + cache_key = None + md = _task.spec.task_template.metadata + ignored_input_vars = [] + if len(md.cache_ignore_input_vars) > 0: + ignored_input_vars = list(md.cache_ignore_input_vars) + if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable: + discovery_version = _task.spec.task_template.metadata.discovery_version + cache_key = convert.generate_cache_key_hash( + task_name, + inputs_hash, + _task.spec.task_template.interface, + discovery_version, + ignored_input_vars, + inputs.proto_inputs, + ) + + # Clear to free memory + serialized_inputs = None # type: ignore + inputs_hash = None # type: ignore + + sub_action_id = self._get_action_identifier_new_proto( + action_name=sub_action_id.name, + run_name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ) + translated_task_spec = self._translate_task_spec_new_proto(_task.spec) + action = Action.from_task( + sub_action_id=sub_action_id, + parent_action_name=current_action_id.name, + group_data=str(tctx.group_data) if tctx.group_data else None, + task_spec=translated_task_spec, + inputs_uri=inputs_uri, + run_output_base=tctx.run_base_dir, + cache_key=cache_key, + ) + + try: + logger.info( + f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], " + f"task:[{task_name}], action:[{action.name}]" + ) + n = await self.submit_action(action) + logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!") + except asyncio.CancelledError: + # If the action is cancelled, we need to cancel the action on the server as well + logger.info(f"Action {action.action_id.name} cancelled, cancelling on server") + await self.cancel_action(action) + raise + + if n.has_error() or n.phase == 6: # failed + exc = await handle_action_failure(action, task_name) + raise exc + + if native_interface.outputs: + if not n.realized_outputs_uri: + raise flyte.errors.RuntimeSystemError( + "RuntimeError", + f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", + ) + return await load_and_convert_outputs(native_interface, n.realized_outputs_uri) + return None diff --git a/src/flyte/_internal/runtime/convert.py b/src/flyte/_internal/runtime/convert.py index fbc98a0ac..aa0705a5a 100644 --- a/src/flyte/_internal/runtime/convert.py +++ b/src/flyte/_internal/runtime/convert.py @@ -234,7 +234,12 @@ async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs return tuple(kwargs[k] for k in interface.outputs.keys()) -def convert_error_to_native(err: execution_pb2.ExecutionError | Exception | Error) -> Exception | None: +from flyte_controller_base import cloudidl + + +def convert_error_to_native( + err: execution_pb2.ExecutionError | Exception | Error | cloudidl.workflow.ExecutionError, +) -> Exception | None: if not err: return None diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 7391207c3..b4c2178f3 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -422,7 +422,8 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: run_name = self._name random_id = str(uuid.uuid4())[:6] - controller = create_controller("remote", endpoint="localhost:8090", insecure=True) + # controller = create_controller("remote", endpoint="localhost:8090", insecure=True) + controller = create_controller("rust", endpoint="localhost:8090", insecure=True) action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org) inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs) @@ -452,7 +453,7 @@ async def _run_task() -> Tuple[Any, Optional[Exception]]: checkpoints=checkpoints, code_bundle=code_bundle, output_path=output_path, - version=version if version else "na", + version=version, # if version else "na", raw_data_path=raw_data_path_obj, compiled_image_cache=image_cache, run_base_dir=run_base_dir, diff --git a/src/flyte/models.py b/src/flyte/models.py index 33316ee96..5a8a6ec38 100644 --- a/src/flyte/models.py +++ b/src/flyte/models.py @@ -179,12 +179,6 @@ def get_random_remote_path(self, file_name: Optional[str] = None) -> str: return remote_path -@rich.repr.auto -@dataclass(frozen=True) -class GroupData: - name: str - - @rich.repr.auto @dataclass(frozen=True, kw_only=True) class TaskContext: From e34e7684fc87bf3e3578922b8a64f837e4c6c037 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 27 Nov 2025 23:16:29 -0800 Subject: [PATCH 03/22] remove cloudidl and get it compiling Signed-off-by: Yee Hing Tong --- README.md | 9 +- rs_controller/Cargo.lock | 186 +++++++++--------- rs_controller/Cargo.toml | 12 +- rs_controller/Makefile | 6 + rs_controller/pyproject.toml | 3 +- rs_controller/src/action.rs | 16 +- rs_controller/src/informer.rs | 49 +++-- rs_controller/src/lib.rs | 26 +-- rs_controller/temp_cargo_config.toml | 2 - .../_internal/controllers/remote/_action.py | 1 + .../controllers/remote/_controller.py | 12 +- 11 files changed, 169 insertions(+), 153 deletions(-) delete mode 100644 rs_controller/temp_cargo_config.toml diff --git a/README.md b/README.md index c7df35c78..f8b0a7589 100644 --- a/README.md +++ b/README.md @@ -316,6 +316,13 @@ Flyte 2 is licensed under the [Apache 2.0 License](LICENSE). ## Developing the Core Controller +Create a separate virtual environment for the Rust contoller inside the rs_controller folder. The reason for this is +because the rust controller should be a separate pypi package. The reason it should be a separate pypi package is that +including it into the main SDK as a core component means the entire build toolchain for the SDK will need to become +rust/maturin based. We should probably move to this model in the future though. + +Keep important dependencies the same though, namely flyteidl2. + The following instructions are for helping to build the default multi-arch image. Each architecture needs a different wheel. Each wheel needs to be built by a different docker image. ### Setup Builders @@ -330,4 +337,4 @@ Then run `make build-wheels`. ```bash make dist python maint_tools/build_default_image.py -``` \ No newline at end of file +``` diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index b43b6f4e3..22ec3e40e 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -19,9 +19,9 @@ checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" [[package]] name = "aho-corasick" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" dependencies = [ "memchr", ] @@ -165,10 +165,11 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.29" +version = "1.2.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c1599538de2394445747c8cf7935946e3cc27e9625f889d979bfb2aaf569362" +checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" dependencies = [ + "find-msvc-tools", "shlex", ] @@ -180,42 +181,13 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ "num-traits", ] -[[package]] -name = "cloudidl" -version = "0.1.0" -dependencies = [ - "async-trait", - "futures", - "pbjson", - "pbjson-build", - "pbjson-types", - "prettyplease", - "prost 0.13.5", - "prost-build 0.14.1", - "prost-types 0.12.6", - "protobuf", - "pyo3", - "pyo3-async-runtimes", - "quote", - "regex", - "serde", - "syn", - "thiserror", - "tokio", - "tonic", - "tower 0.4.13", - "tower-http", - "tracing", - "tracing-subscriber", -] - [[package]] name = "either" version = "1.15.0" @@ -230,9 +202,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.13" +version = "0.3.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", "windows-sys", @@ -244,12 +216,70 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "find-msvc-tools" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a3076410a55c90011c298b04d0cfa770b00fa04e1e3c97d3f6c9de105a03844" + [[package]] name = "fixedbitset" version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" +[[package]] +name = "flyte_core" +version = "0.1.0" +dependencies = [ + "async-trait", + "flyteidl2", + "futures", + "prost 0.13.5", + "prost-types 0.12.6", + "pyo3", + "pyo3-async-runtimes", + "pyo3-build-config", + "thiserror", + "tokio", + "tonic", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "flyteidl2" +version = "2.0.0-alpha9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841c8fd5b71ff4cfbb2c70135aeaf51d02c847d853d5b333a2e4cd512d9a863c" +dependencies = [ + "async-trait", + "futures", + "pbjson", + "pbjson-build", + "pbjson-types", + "prettyplease", + "prost 0.13.5", + "prost-build 0.14.1", + "prost-types 0.12.6", + "protobuf", + "pyo3", + "pyo3-async-runtimes", + "quote", + "regex", + "serde", + "syn", + "thiserror", + "tokio", + "tonic", + "tower 0.4.13", + "tower-http", + "tracing", + "tracing-subscriber", +] + [[package]] name = "fnv" version = "1.0.7" @@ -353,19 +383,19 @@ checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", ] [[package]] name = "getrandom" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", "libc", "r-efi", - "wasi 0.14.2+wasi-0.2.4", + "wasip2", ] [[package]] @@ -576,9 +606,9 @@ checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "linux-raw-sys" -version = "0.9.4" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" [[package]] name = "lock_api" @@ -639,7 +669,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi", "windows-sys", ] @@ -820,9 +850,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.35" +version = "0.2.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061c1221631e079b26479d25bbf2275bfe5917ae8419cd7e34f13bfc2aa7539a" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" dependencies = [ "proc-macro2", "syn", @@ -1115,9 +1145,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.11.1" +version = "1.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" dependencies = [ "aho-corasick", "memchr", @@ -1127,9 +1157,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.9" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" dependencies = [ "aho-corasick", "memchr", @@ -1138,9 +1168,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "rustc-demangle" @@ -1150,9 +1180,9 @@ checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" [[package]] name = "rustix" -version = "1.0.7" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" dependencies = [ "bitflags", "errno", @@ -1244,9 +1274,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.104" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -1267,12 +1297,12 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "tempfile" -version = "3.20.0" +version = "3.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" dependencies = [ "fastrand", - "getrandom 0.3.3", + "getrandom 0.3.4", "once_cell", "rustix", "windows-sys", @@ -1530,27 +1560,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" -[[package]] -name = "union_rust_controller" -version = "0.1.0" -dependencies = [ - "async-trait", - "cloudidl", - "futures", - "prost 0.13.5", - "prost-types 0.12.6", - "pyo3", - "pyo3-async-runtimes", - "pyo3-build-config", - "thiserror", - "tokio", - "tonic", - "tower 0.4.13", - "tower-http", - "tracing", - "tracing-subscriber", -] - [[package]] name = "valuable" version = "0.1.1" @@ -1573,12 +1582,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "wasi" -version = "0.14.2+wasi-0.2.4" +name = "wasip2" +version = "1.0.1+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" dependencies = [ - "wit-bindgen-rt", + "wit-bindgen", ] [[package]] @@ -1677,13 +1686,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] -name = "wit-bindgen-rt" -version = "0.39.0" +name = "wit-bindgen" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" -dependencies = [ - "bitflags", -] +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" [[package]] name = "zerocopy" diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index 08c250c15..8120bc422 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "union_rust_controller" +name = "flyte_core" version = "0.1.0" edition = "2021" [lib] -name = "union_rust_controller" +name = "flyte_controller_base" path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -cloudidl = { path = "../../cloud/gen/pb_rust" } +# cloudidl = { path = "../../cloud/gen/pb_rust" } pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] } tokio = { version = "1.0", features = ["full"] } @@ -24,12 +24,10 @@ tracing-subscriber = "0.3" async-trait = "0.1" thiserror = "1.0" pyo3-build-config = "0.24.2" +flyteidl2 = "2.0.0-alpha14" [build-dependencies] -pyo3 = { version = "0.24", features = [ - "extension-module", - "abi3-py310", -] } # Use your pyo3 version +pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } # Need to do this for some reason otherwise maturin develop fails horribly # export RUSTFLAGS="-C link-arg=-undefined -C link-arg=dynamic_lookup" diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 95839d487..f1c5d219e 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -41,3 +41,9 @@ build-wheels-amd64: dist-dirs $(call BUILD_WHEELS_RECIPE,amd64) build-wheels: build-wheels-arm64 build-wheels-amd64 + + +# This is for Mac users, since the other targets won't build macos wheels (only local arch so probably arm64) +build-wheel-local: dist-dirs + python -m build --wheel --outdir dist + diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 5431a94d5..1e768aa67 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.4,<2.0"] build-backend = "maturin" [project] -name = "rust_controller" +name = "flyte_controller_base" version = "0.1.0" description = "Rust controller for Union" requires-python = ">=3.10" @@ -12,4 +12,3 @@ classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] [tool.maturin] module-name = "flyte_controller_base" features = ["pyo3/extension-module"] - diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 0abd52792..29bf03572 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -1,9 +1,12 @@ use pyo3::prelude::*; -use cloudidl::{ - cloudidl::workflow::{ActionIdentifier, ActionUpdate, Phase, TaskSpec}, - flyteidl::core::ExecutionError, -}; +use flyteidl2::flyteidl::common::ActionIdentifier; +use flyteidl2::flyteidl::workflow::ActionUpdate; + +use flyteidl2::flyteidl::core::ExecutionError; +use flyteidl2::flyteidl::task::TaskSpec; +use flyteidl2::flyteidl::workflow::Phase; + use tracing::debug; #[pyclass(eq, eq_int)] @@ -31,6 +34,7 @@ pub struct Action { pub retries: u32, pub client_err: Option, // Changed from PyErr to String for serializability pub cache_key: Option, + pub queue: Option, } impl Action { @@ -99,6 +103,7 @@ impl Action { retries: 0, client_err: None, cache_key: None, + queue: None, } } @@ -139,6 +144,7 @@ impl Action { inputs_uri: String, run_output_base: String, cache_key: Option, + queue: Option, ) -> Self { debug!("Creating Action from task for ID {:?}", &sub_action_id); Action { @@ -160,6 +166,7 @@ impl Action { retries: 0, client_err: None, cache_key, + queue, } } @@ -190,6 +197,7 @@ impl Action { retries: 0, client_err: None, cache_key: None, + queue: None, } } diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index a09b83ee1..d385b3009 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -1,9 +1,13 @@ use crate::action::Action; use crate::ControllerError; -use cloudidl::cloudidl::workflow::state_service_client::StateServiceClient; -use cloudidl::cloudidl::workflow::watch_request; -use cloudidl::cloudidl::workflow::watch_response::Message; -use cloudidl::cloudidl::workflow::{ActionIdentifier, RunIdentifier, WatchRequest, WatchResponse}; + +use flyteidl2::flyteidl::common::ActionIdentifier; +use flyteidl2::flyteidl::common::RunIdentifier; +use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::workflow::{ + watch_request, watch_response::Message, WatchRequest, WatchResponse, +}; + use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -132,11 +136,7 @@ impl Informer { filter: Some(watch_request::Filter::ParentActionId(action_id)), }; - let mut stream = self - .client - .clone() - .watch(request) - .await; + let mut stream = self.client.clone().watch(request).await; let mut stream = match stream { Ok(s) => s.into_inner(), @@ -151,22 +151,22 @@ impl Informer { Ok(Some(response)) => { let handle_response = self.handle_watch_response(response).await; match handle_response { - Ok(Some(action)) => { - match self.shared_queue.send(action).await { - Ok(_) => { - continue; - } - Err(e) => { - error!("Informer watch failed sending action back to shared queue: {:?}", e); - return ControllerError::RuntimeError(format!( - "Failed to send action to shared queue: {}", - e - )); - } + Ok(Some(action)) => match self.shared_queue.send(action).await { + Ok(_) => { + continue; } - } + Err(e) => { + error!("Informer watch failed sending action back to shared queue: {:?}", e); + return ControllerError::RuntimeError(format!( + "Failed to send action to shared queue: {}", + e + )); + } + }, Ok(None) => { - debug!("Received None from handle_watch_response, continuing watch loop."); + debug!( + "Received None from handle_watch_response, continuing watch loop." + ); } Err(err) => { // this should cascade up to the controller to restart the informer, and if there @@ -178,13 +178,12 @@ impl Informer { } Ok(None) => { debug!("Stream received empty message, maybe no more messages? Repeating watch loop."); - }, // Stream ended, exit loop + } // Stream ended, exit loop Err(e) => { error!("Error receiving message from stream: {:?}", e); return ControllerError::from(e); } } - } } diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index ca1a34eea..433f73d5d 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -6,7 +6,7 @@ use std::time::Duration; use futures::TryFutureExt; use pyo3::prelude::*; -use tokio::sync::{mpsc, Mutex, Notify}; +use tokio::sync::mpsc; use tracing::{debug, error, info, warn}; use thiserror::Error; @@ -14,13 +14,14 @@ use thiserror::Error; use crate::action::{Action, ActionType}; use crate::informer::Informer; -use cloudidl::cloudidl::workflow::queue_service_client::QueueServiceClient; -use cloudidl::cloudidl::workflow::state_service_client::StateServiceClient; -use cloudidl::cloudidl::workflow::{ - enqueue_action_request, ActionIdentifier, EnqueueActionRequest, EnqueueActionResponse, - RunIdentifier, TaskAction, TaskIdentifier, -}; -use cloudidl::google; +use flyteidl2::flyteidl::common::ActionIdentifier; +use flyteidl2::flyteidl::task::TaskIdentifier; +use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::workflow::{EnqueueActionRequest, EnqueueActionResponse, TaskAction}; + +use flyteidl2::flyteidl::workflow::enqueue_action_request; +use flyteidl2::flyteidl::workflow::queue_service_client::QueueServiceClient; +use flyteidl2::google; use google::protobuf::StringValue; use pyo3::exceptions; use pyo3::types::PyAny; @@ -80,7 +81,10 @@ impl CoreBaseController { let (state_client, queue_client) = rt.block_on(async { // Need to update to with auth to read API key let endpoint = Endpoint::from_static(&endpoint_static); - let channel = endpoint.connect().await.map_err(|e| ControllerError::from(e))?; + let channel = endpoint + .connect() + .await + .map_err(|e| ControllerError::from(e))?; Ok::<_, ControllerError>(( StateServiceClient::new(channel.clone()), QueueServiceClient::new(channel), @@ -325,6 +329,7 @@ impl CoreBaseController { .cache_key .as_ref() .map(|ck| StringValue { value: ck.clone() }), + cluster: action.queue.clone().unwrap_or("".to_string()), }; Ok(EnqueueActionRequest { @@ -487,7 +492,7 @@ impl BaseController { } } -use cloudidl::pymodules::cloud_mod; +// use cloudidl::pymodules::cloud_mod; #[pymodule] fn flyte_controller_base(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { @@ -503,6 +508,5 @@ fn flyte_controller_base(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<() m.add_class::()?; m.add_class::()?; m.add_class::()?; - cloud_mod(py, m)?; Ok(()) } diff --git a/rs_controller/temp_cargo_config.toml b/rs_controller/temp_cargo_config.toml deleted file mode 100644 index 1ea088bc1..000000000 --- a/rs_controller/temp_cargo_config.toml +++ /dev/null @@ -1,2 +0,0 @@ -[patch.crates-io] -cloudidl = { path = "/cloud/gen/pb_rust" } diff --git a/src/flyte/_internal/controllers/remote/_action.py b/src/flyte/_internal/controllers/remote/_action.py index 5a1518161..d5cd4a5d3 100644 --- a/src/flyte/_internal/controllers/remote/_action.py +++ b/src/flyte/_internal/controllers/remote/_action.py @@ -17,6 +17,7 @@ ActionType = Literal["task", "trace"] +# This class should be deleted following move to pyo3. @dataclass class Action: """ diff --git a/src/flyte/_internal/controllers/remote/_controller.py b/src/flyte/_internal/controllers/remote/_controller.py index bafb5e5e4..987c3caeb 100644 --- a/src/flyte/_internal/controllers/remote/_controller.py +++ b/src/flyte/_internal/controllers/remote/_controller.py @@ -19,7 +19,7 @@ from flyte._context import internal_ctx from flyte._internal.controllers import TraceInfo from flyte._internal.controllers.remote._action import Action -# from flyte._internal.controllers.remote._core import Controller +from flyte._internal.controllers.remote._core import Controller from flyte._internal.controllers.remote._service_protocol import ClientSet from flyte._internal.runtime import convert, io from flyte._internal.runtime.task_serde import translate_task_to_wire @@ -30,16 +30,6 @@ from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext from flyte.remote._task import TaskDetails -# Import the Rust Controller instead of the Python one -# try: -# from flyte._internal.controllers.remote.rust_controller import Controller -# -# except ImportError: -# # Fallback to Python implementation during development -# from flyte._internal.controllers.remote._core import Controller - -from flyte._internal.controllers.remote.rust_controller import Controller - R = TypeVar("R") MAX_TRACE_BYTES = MAX_INLINE_IO_BYTES From 98a2414bd9486451f3f7ab697111c81104ccbfc5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 27 Nov 2025 23:31:11 -0800 Subject: [PATCH 04/22] wheels Signed-off-by: Yee Hing Tong --- rs_controller/Makefile | 1 - 1 file changed, 1 deletion(-) diff --git a/rs_controller/Makefile b/rs_controller/Makefile index f1c5d219e..5ef3427a9 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -25,7 +25,6 @@ docker run --rm \ -v $(PWD):/io \ -v $(PWD)/docker_cargo_cache:/root/.cargo/registry \ -v $(CLOUD_REPO):/cloud \ - -v $(PWD)/temp_cargo_config.toml:/io/.cargo/config.toml \ wheel-builder:$(1) /bin/bash -c "\ cd /io; \ /opt/python/cp310-cp310/bin/maturin build --release --find-interpreter --out /io/dist/ \ From d3af2f90db38fcf5eac1f0fd030b875a2d81d77e Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sun, 30 Nov 2025 19:10:22 -0800 Subject: [PATCH 05/22] add basic trace support to action, untested Signed-off-by: Yee Hing Tong --- rs_controller/src/action.rs | 54 +++++++++++++++++++++++++++++++++---- rs_controller/src/lib.rs | 11 +++----- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 29bf03572..0ef735bf0 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -1,10 +1,11 @@ +use flyteidl2::google::protobuf::Timestamp; use pyo3::prelude::*; use flyteidl2::flyteidl::common::ActionIdentifier; -use flyteidl2::flyteidl::workflow::ActionUpdate; +use flyteidl2::flyteidl::workflow::{ActionUpdate, TraceAction}; -use flyteidl2::flyteidl::core::ExecutionError; -use flyteidl2::flyteidl::task::TaskSpec; +use flyteidl2::flyteidl::core::{ExecutionError, TypedInterface}; +use flyteidl2::flyteidl::task::{OutputReferences, TaskSpec, TraceSpec}; use flyteidl2::flyteidl::workflow::Phase; use tracing::debug; @@ -35,6 +36,7 @@ pub struct Action { pub client_err: Option, // Changed from PyErr to String for serializability pub cache_key: Option, pub queue: Option, + pub trace: Option, } impl Action { @@ -104,6 +106,7 @@ impl Action { client_err: None, cache_key: None, queue: None, + trace: None, } } @@ -167,6 +170,7 @@ impl Action { client_err: None, cache_key, queue, + trace: None, } } @@ -179,8 +183,47 @@ impl Action { group_data: Option, inputs_uri: String, outputs_uri: String, + start_time: f64, // Unix timestamp in seconds with fractional seconds + end_time: f64, // Unix timestamp in seconds with fractional seconds + run_output_base: String, + report_uri: Option, + typed_interface: Option, ) -> Self { debug!("Creating Action from trace for ID {:?}", &action_id); + let trace_spec = TraceSpec { + interface: typed_interface, + }; + let start_secs = start_time.floor() as i64; + let start_nanos = ((start_time - start_time.floor()) * 1e9) as i32; + + let end_secs = end_time.floor() as i64; + let end_nanos = ((end_time - end_time.floor()) * 1e9) as i32; + + // TraceAction expects an optional OutputReferences - let's only include it if there's something to include + let output_references = if report_uri.is_some() || !outputs_uri.is_empty() { + Some(OutputReferences { + output_uri: outputs_uri.clone(), + report_uri: report_uri.clone().unwrap_or("".to_string()), + }) + } else { + None + }; + + let trace_action = TraceAction { + name: friendly_name.clone(), + phase: Phase::Succeeded.into(), + start_time: Some(Timestamp { + seconds: start_secs, + nanos: start_nanos, + }), + end_time: Some(Timestamp { + seconds: end_secs, + nanos: end_nanos, + }), + outputs: output_references, + spec: Some(trace_spec), + }; + Action { action_id, parent_action_name, @@ -189,15 +232,16 @@ impl Action { group: group_data, task: None, inputs_uri: Some(inputs_uri), - run_output_base: None, + run_output_base: Some(run_output_base), realized_outputs_uri: Some(outputs_uri), + phase: Phase::Succeeded.into(), err: None, - phase: Some(Phase::Succeeded), started: true, retries: 0, client_err: None, cache_key: None, queue: None, + trace: Some(trace_action), } } diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index 433f73d5d..a0404fe96 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -1,3 +1,5 @@ +#![allow(clippy::too_many_arguments)] + mod action; mod informer; @@ -61,7 +63,6 @@ impl From for PyErr { } struct CoreBaseController { - endpoint: String, state_client: StateServiceClient, queue_client: QueueServiceClient, informer: OnceCell>, @@ -80,11 +81,8 @@ impl CoreBaseController { let rt = get_runtime(); let (state_client, queue_client) = rt.block_on(async { // Need to update to with auth to read API key - let endpoint = Endpoint::from_static(&endpoint_static); - let channel = endpoint - .connect() - .await - .map_err(|e| ControllerError::from(e))?; + let endpoint = Endpoint::from_static(endpoint_static); + let channel = endpoint.connect().await.map_err(ControllerError::from)?; Ok::<_, ControllerError>(( StateServiceClient::new(channel.clone()), QueueServiceClient::new(channel), @@ -92,7 +90,6 @@ impl CoreBaseController { })?; let real_base_controller = CoreBaseController { - endpoint, state_client, queue_client, informer: OnceCell::new(), From 90d90c9aa4a60f28238b9b1a3beef511917c596d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sun, 30 Nov 2025 23:58:53 -0800 Subject: [PATCH 06/22] at least temporarily replace groupdata with just a string, update controller to match current remote controller Signed-off-by: Yee Hing Tong --- examples/advanced/hybrid_mode.py | 1 - examples/advanced/remote_controller.py | 1 - examples/context/custom_context.py | 3 + rs_controller/src/action.rs | 2 +- src/flyte/_context.py | 4 +- src/flyte/_group.py | 3 +- src/flyte/_internal/controllers/__init__.py | 2 +- .../controllers/_local_controller.py | 5 +- .../_internal/controllers/remote/_action.py | 4 +- .../_internal/controllers/remote/_core.py | 2 +- .../controllers/remote/_r_controller.py | 280 ++++++++++-------- src/flyte/_internal/runtime/convert.py | 7 +- src/flyte/models.py | 2 +- 13 files changed, 180 insertions(+), 136 deletions(-) diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py index ede8c85b5..618a41aef 100644 --- a/examples/advanced/hybrid_mode.py +++ b/examples/advanced/hybrid_mode.py @@ -5,7 +5,6 @@ import flyte import flyte.storage -from flyte.storage import S3 env = flyte.TaskEnvironment(name="hello_world", cache="disable") diff --git a/examples/advanced/remote_controller.py b/examples/advanced/remote_controller.py index a1bfbafc0..fa4b93e5d 100644 --- a/examples/advanced/remote_controller.py +++ b/examples/advanced/remote_controller.py @@ -64,7 +64,6 @@ async def main(): result = await runner.submit_action(action) print("First submit done", flush=True) - breakpoint() print(result) diff --git a/examples/context/custom_context.py b/examples/context/custom_context.py index ec6b0c45f..5e1292403 100644 --- a/examples/context/custom_context.py +++ b/examples/context/custom_context.py @@ -2,16 +2,19 @@ env = flyte.TaskEnvironment("custom-context-example") + @env.task async def leaf_task() -> str: # Reads run-level context print("leaf sees:", flyte.ctx().custom_context) return flyte.ctx().custom_context.get("trace_id") + @env.task async def root() -> str: return await leaf_task() + if __name__ == "__main__": flyte.init_from_config() # Base context for the entire run diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 0ef735bf0..76b6d45e6 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -143,7 +143,7 @@ impl Action { sub_action_id: ActionIdentifier, parent_action_name: String, group_data: Option, - task_spec: TaskSpec, + task_spec: TaskSpec, // document what this error is inputs_uri: String, run_output_base: String, cache_key: Option, diff --git a/src/flyte/_context.py b/src/flyte/_context.py index 20689560f..a2a4116c5 100644 --- a/src/flyte/_context.py +++ b/src/flyte/_context.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar from flyte._logging import logger -from flyte.models import GroupData, RawDataPath, TaskContext +from flyte.models import RawDataPath, TaskContext if TYPE_CHECKING: from flyte.report import Report @@ -26,7 +26,7 @@ class ContextData: will be None. """ - group_data: Optional[GroupData] = None + group_data: Optional[str] = None task_context: Optional[TaskContext] = None raw_data_path: Optional[RawDataPath] = None diff --git a/src/flyte/_group.py b/src/flyte/_group.py index 8fa8d4260..4fc2bae49 100644 --- a/src/flyte/_group.py +++ b/src/flyte/_group.py @@ -1,7 +1,6 @@ from contextlib import contextmanager from ._context import internal_ctx -from .models import GroupData @contextmanager @@ -26,7 +25,7 @@ async def my_task(): yield return tctx = ctx.data.task_context - new_tctx = tctx.replace(group_data=GroupData(name)) + new_tctx = tctx.replace(group_data=name) with ctx.replace_task_context(new_tctx): yield # Exit the context and restore the previous context diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index ea7969168..600b85485 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -118,7 +118,7 @@ def create_controller( controller = LocalController() case "remote" | "hybrid": - from flyte._internal.controllers.remote import create_remote_controller + # from flyte._internal.controllers.remote import create_remote_controller # controller = create_remote_controller(**kwargs) from flyte._internal.controllers.remote._r_controller import RemoteController diff --git a/src/flyte/_internal/controllers/_local_controller.py b/src/flyte/_internal/controllers/_local_controller.py index 9b9be43cb..69ae4ad80 100644 --- a/src/flyte/_internal/controllers/_local_controller.py +++ b/src/flyte/_internal/controllers/_local_controller.py @@ -4,7 +4,9 @@ import os import pathlib import threading -from typing import Any, Callable, Tuple, TypeVar, Protocol +from typing import Any, Callable, Protocol, Tuple, TypeVar + +from flyteidl2.task import task_definition_pb2 import flyte.errors from flyte._cache.cache import VersionParameters, cache_from_request @@ -22,6 +24,7 @@ R = TypeVar("R") + class ControllerProtocol(Protocol): async def submit(self, _task: "TaskTemplate", *args, **kwargs) -> Any: ... def submit_sync(self, _task: "TaskTemplate", *args, **kwargs) -> concurrent.futures.Future: ... diff --git a/src/flyte/_internal/controllers/remote/_action.py b/src/flyte/_internal/controllers/remote/_action.py index d5cd4a5d3..9c9ebb9d2 100644 --- a/src/flyte/_internal/controllers/remote/_action.py +++ b/src/flyte/_internal/controllers/remote/_action.py @@ -12,8 +12,6 @@ ) from google.protobuf import timestamp_pb2 -from flyte.models import GroupData - ActionType = Literal["task", "trace"] @@ -166,7 +164,7 @@ def from_trace( parent_action_name: str, action_id: identifier_pb2.ActionIdentifier, friendly_name: str, - group_data: GroupData | None, + group_data: str | None, inputs_uri: str, outputs_uri: str, start_time: float, # Unix timestamp in seconds with fractional seconds diff --git a/src/flyte/_internal/controllers/remote/_core.py b/src/flyte/_internal/controllers/remote/_core.py index f442d824e..7e6ff6449 100644 --- a/src/flyte/_internal/controllers/remote/_core.py +++ b/src/flyte/_internal/controllers/remote/_core.py @@ -364,7 +364,7 @@ async def _bg_launch(self, action: Action): trace=trace, input_uri=action.inputs_uri, run_output_base=action.run_output_base, - group=action.group.name if action.group else None, + group=action.group if action.group else None, # Subject is not used in the current implementation ), wait_for_ready=True, diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 008f47134..0f5c074af 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -8,39 +8,49 @@ from collections import defaultdict from collections.abc import Callable from pathlib import Path -from typing import Any, AsyncIterable, DefaultDict, Tuple, TypeVar +from typing import Any, DefaultDict, Tuple, TypeVar -from flyte_controller_base import Action, BaseController, cloudidl +from flyte_controller_base import Action, BaseController +from flyteidl2.common import identifier_pb2 +from flyteidl2.workflow import run_definition_pb2 import flyte import flyte.errors import flyte.storage as storage -import flyte.types as types from flyte._code_bundle import build_pkl_bundle from flyte._context import internal_ctx from flyte._internal.controllers import TraceInfo from flyte._internal.runtime import convert, io from flyte._internal.runtime.task_serde import translate_task_to_wire +from flyte._internal.runtime.types_serde import transform_native_to_typed_interface from flyte._logging import logger -from flyte._protos.workflow import run_definition_pb2, task_definition_pb2 from flyte._task import TaskTemplate from flyte._utils.helpers import _selector_policy -from flyte.models import ActionID, NativeInterface, SerializationContext +from flyte.models import MAX_INLINE_IO_BYTES, ActionID, NativeInterface, SerializationContext +from flyte.remote._task import TaskDetails R = TypeVar("R") +MAX_TRACE_BYTES = MAX_INLINE_IO_BYTES -async def upload_inputs_with_retry(serialized_inputs: AsyncIterable[bytes] | bytes, inputs_uri: str) -> None: + +async def upload_inputs_with_retry(serialized_inputs: bytes, inputs_uri: str, max_bytes: int) -> None: """ Upload inputs to the specified URI with error handling. Args: serialized_inputs: The serialized inputs to upload inputs_uri: The destination URI + max_bytes: Maximum number of bytes to read from the input stream Raises: RuntimeSystemError: If the upload fails """ + if len(serialized_inputs) > max_bytes: + raise flyte.errors.InlineIOMaxBytesBreached( + f"Inputs exceed max_bytes limit of {max_bytes / 1024 / 1024} MB," + f" actual size: {len(serialized_inputs) / 1024 / 1024} MB" + ) try: # TODO Add retry decorator to this await storage.put_stream(serialized_inputs, to_path=inputs_uri) @@ -61,7 +71,7 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception: Exception: The converted native exception or RuntimeSystemError """ err = action.err or action.client_err - if not err and action.phase == 6: # PHASE_FAILED + if not err and action.phase == run_definition_pb2.PHASE_FAILED: logger.error(f"Server reported failure for action {action.name}, checking error file.") try: error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1") @@ -78,19 +88,20 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception: return exc -async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str) -> Any: +async def load_and_convert_outputs(iface: NativeInterface, realized_outputs_uri: str, max_bytes: int) -> Any: """ Load outputs from the given URI and convert them to native format. Args: iface: The Native interface realized_outputs_uri: The URI where outputs are stored + max_bytes: Maximum number of bytes to read from the output file Returns: The converted native outputs """ outputs_file_path = io.outputs_path(realized_outputs_uri) - outputs = await io.load_outputs(outputs_file_path) + outputs = await io.load_outputs(outputs_file_path, max_bytes=max_bytes) return await convert.convert_outputs_to_native(iface, outputs) @@ -106,20 +117,18 @@ class RemoteController(BaseController): def __new__( cls, endpoint: str, - workers: int, - max_system_retries: int, - default_parent_concurrency: int = 100, + workers: int = 20, + max_system_retries: int = 10, ): return super().__new__(cls, endpoint=endpoint) def __init__( self, endpoint: str, - workers: int, - max_system_retries: int, - default_parent_concurrency: int = 100, + workers: int = 20, + max_system_retries: int = 10, ): - """ """ + default_parent_concurrency = int(os.getenv("_F_P_CNC", "1000")) self._default_parent_concurrency = default_parent_concurrency self._parent_action_semaphore: DefaultDict[str, asyncio.Semaphore] = defaultdict( lambda: asyncio.Semaphore(default_parent_concurrency) @@ -135,11 +144,18 @@ def generate_task_call_sequence(self, task_obj: object, action_id: ActionID) -> Generate a task call sequence for the given task object and action ID. This is used to track the number of times a task is called within an action. """ - current_action_sequencer = self._parent_action_task_call_sequence[unique_action_name(action_id)] + uniq = unique_action_name(action_id) + current_action_sequencer = self._parent_action_task_call_sequence[uniq] current_task_id = id(task_obj) v = current_action_sequencer[current_task_id] new_seq = v + 1 current_action_sequencer[current_task_id] = new_seq + name = "" + if hasattr(task_obj, "__name__"): + name = task_obj.__name__ + elif hasattr(task_obj, "name"): + name = task_obj.name + logger.info(f"For action {uniq}, task {name} call sequence is {new_seq}") return new_seq async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwargs) -> Any: @@ -153,7 +169,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run. code_bundle = tctx.code_bundle - if code_bundle and code_bundle.pkl: + if tctx.interactive_mode or (code_bundle and code_bundle.pkl): logger.debug(f"Building new pkl bundle for task {_task.name}") code_bundle = await build_pkl_bundle( _task, @@ -182,10 +198,11 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( tctx, task_spec, inputs_hash, _task_call_seq ) + logger.info(f"Sub action {sub_action_id} output path {sub_action_output_path}") serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) inputs_uri = io.inputs_path(sub_action_output_path) - await upload_inputs_with_retry(serialized_inputs, inputs_uri) + await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=_task.max_inline_io_bytes) md = task_spec.task_template.metadata ignored_input_vars = [] @@ -207,21 +224,23 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg serialized_inputs = None # type: ignore inputs_hash = None # type: ignore - translated_task_spec = self._translate_task_spec_new_proto(task_spec) action = Action.from_task( - sub_action_id=self._get_action_identifier_new_proto( - action_name=sub_action_id.name, - run_name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, + sub_action_id=identifier_pb2.ActionIdentifier( + name=sub_action_id.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), ), parent_action_name=current_action_id.name, group_data=str(tctx.group_data) if tctx.group_data else None, - task_spec=translated_task_spec, + task_spec=task_spec, inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, cache_key=cache_key, + queue=_task.queue, ) try: @@ -237,7 +256,22 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg await self.cancel_action(action) raise - if n.has_error() or n.phase == 6: # failed + # If the action is aborted, we should abort the controller as well + if n.phase == run_definition_pb2.PHASE_ABORTED: + logger.warning(f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}") + raise flyte.errors.RunAbortedError( + f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}" + ) + + if n.phase == run_definition_pb2.PHASE_TIMED_OUT: + logger.warning( + f"Action {n.action_id.name} timed out, raising timeout exception Action {current_action_id.name}" + ) + raise flyte.errors.TaskTimeoutError( + f"Action {n.action_id.name} timed out, raising exception in current Action {current_action_id.name}" + ) + + if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED: exc = await handle_action_failure(action, _task.name) raise exc @@ -247,7 +281,9 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg "RuntimeError", f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", ) - return await load_and_convert_outputs(_task.native_interface, n.realized_outputs_uri) + return await load_and_convert_outputs( + _task.native_interface, n.realized_outputs_uri, max_bytes=_task.max_inline_io_bytes + ) return None async def submit(self, _task: TaskTemplate, *args, **kwargs) -> Any: @@ -297,7 +333,9 @@ def exc_handler(loop, context): self._submit_loop.set_exception_handler(exc_handler) self._submit_thread = threading.Thread( - name=f"remote-controller-{os.getpid()}-submitter", daemon=True, target=self._sync_thread_loop_runner + name=f"remote-controller-{os.getpid()}-submitter", + daemon=True, + target=self._sync_thread_loop_runner, ) self._submit_thread.start() @@ -330,6 +368,13 @@ async def finalize_parent_action(self, action_id: ActionID): """ # todo-pr: implement any cleanup # translate the ActionID python object to something handleable in pyo3 + # will need to do this after we have multiple informers. + # run_id = identifier_pb2.RunIdentifier( + # name=action_id.run_name, + # project=action_id.project, + # domain=action_id.domain, + # org=action_id.org, + # ) # await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name) self._parent_action_semaphore.pop(unique_action_name(action_id), None) self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None) @@ -356,64 +401,49 @@ async def get_action_outputs( invoke_seq_num = self.generate_task_call_sequence(_func, current_action_id) inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs) serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) + inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs) sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( - tctx, func_name, serialized_inputs, invoke_seq_num + tctx, func_name, inputs_hash, invoke_seq_num ) inputs_uri = io.inputs_path(sub_action_output_path) - await upload_inputs_with_retry(serialized_inputs, inputs_uri) + await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=MAX_TRACE_BYTES) # Clear to free memory serialized_inputs = None # type: ignore - sub_action_id = self._get_action_identifier_new_proto( - action_name=sub_action_id.name, - run_name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, + sub_action_id_pb = identifier_pb2.ActionIdentifier( + name=sub_action_id.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), ) prev_action = await self.get_action( - sub_action_id, + sub_action_id_pb, current_action_id.name, ) if prev_action is None: - return TraceInfo(sub_action_id, _interface, inputs_uri), False + return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False - if prev_action.phase == 6: # failed + if prev_action.phase == run_definition_pb2.PHASE_FAILED: if prev_action.has_error(): exc = convert.convert_error_to_native(prev_action.err) - return TraceInfo(sub_action_id, _interface, inputs_uri, error=exc), True + return ( + TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc), + True, + ) else: logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!") elif prev_action.realized_outputs_uri is not None: - outputs_file_path = io.outputs_path(prev_action.realized_outputs_uri) - o = await io.load_outputs(outputs_file_path) + o = await io.load_outputs(prev_action.realized_outputs_uri, max_bytes=MAX_TRACE_BYTES) outputs = await convert.convert_outputs_to_native(_interface, o) - return TraceInfo(sub_action_id, _interface, inputs_uri, output=outputs), True - - return TraceInfo(sub_action_id, _interface, inputs_uri), False - - @staticmethod - def _get_action_identifier_new_proto( - *, action_name: str, run_name: str, project: str, domain: str, org: str - ) -> cloudidl.workflow.ActionIdentifier: - return cloudidl.workflow.ActionIdentifier( - name=action_name, - run=cloudidl.workflow.RunIdentifier( - name=run_name, - project=project, - domain=domain, - org=org, - ), - ) + return TraceInfo(func_name, sub_action_id, _interface, inputs_uri, output=outputs), True - @staticmethod - def _translate_task_spec_new_proto(task_spec: task_definition_pb2.TaskSpec) -> cloudidl.workflow.TaskSpec: - task_spec_bytes = task_spec.SerializeToString() - new_task_spec = cloudidl.workflow.TaskSpec.decode(task_spec_bytes) - return new_task_spec + return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False async def record_trace(self, info: TraceInfo): """ @@ -427,40 +457,47 @@ async def record_trace(self, info: TraceInfo): raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") current_action_id = tctx.action - current_output_path = tctx.output_path - sub_run_output_path = storage.join(current_output_path, info.action.name) + sub_run_output_path = storage.join(tctx.run_base_dir, info.action.name) + outputs_file_path: str = "" if info.interface.has_outputs(): - outputs_file_path: str = "" - if info.output: - outputs = await convert.convert_from_native_to_outputs(info.output, info.interface) - outputs_file_path = io.outputs_path(sub_run_output_path) - await io.upload_outputs(outputs, outputs_file_path) - elif info.error: + if info.error: err = convert.convert_from_native_to_error(info.error) - error_path = io.error_path(sub_run_output_path) - await io.upload_error(err.err, error_path) + await io.upload_error(err.err, sub_run_output_path) else: - raise flyte.errors.RuntimeSystemError("BadTraceInfo", "Trace info does not have output or error") + outputs = await convert.convert_from_native_to_outputs(info.output, info.interface) + outputs_file_path = io.outputs_path(sub_run_output_path) + await io.upload_outputs(outputs, sub_run_output_path, max_bytes=MAX_TRACE_BYTES) - action_id = self._get_action_identifier_new_proto( - action_name=info.action.name, - run_name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, - ) - trace_action = Action.from_trace( - parent_action_name=current_action_id.name, - action_id=action_id, - inputs_uri=info.inputs_path, - outputs_uri=outputs_file_path, - friendly_name=info.name, - group_data=str(tctx.group_data) if tctx.group_data else None, - ) + typed_interface = transform_native_to_typed_interface(info.interface) + + trace_action = Action.from_trace( + parent_action_name=current_action_id.name, + action_id=identifier_pb2.ActionIdentifier( + name=info.action.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), + ), + inputs_uri=info.inputs_path, + outputs_uri=outputs_file_path, + friendly_name=info.name, + group_data=tctx.group_data, + run_output_base=tctx.run_base_dir, + start_time=info.start_time, + end_time=info.end_time, + typed_interface=typed_interface if typed_interface else None, + ) + + async with self._parent_action_semaphore[unique_action_name(current_action_id)]: + # todo: remove the noop try catch try: logger.info( - f"Submitting Trace action Run:[{trace_action.run_name}, Parent:[{trace_action.parent_action_name}]," + f"Submitting Trace action Run:[{trace_action.run_name}," + f" Parent:[{trace_action.parent_action_name}]," f" Trace fn:[{info.name}], action:[{info.action.name}]" ) await self.submit_action(trace_action) @@ -469,19 +506,17 @@ async def record_trace(self, info: TraceInfo): # If the action is cancelled, we need to cancel the action on the server as well raise - async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any: + async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, **kwargs) -> Any: ctx = internal_ctx() tctx = ctx.data.task_context if tctx is None: raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") current_action_id = tctx.action - task_name = _task.spec.task_template.id.name + task_name = _task.name - invoke_seq_num = self.generate_task_call_sequence(_task, current_action_id) + native_interface = _task.interface + pb_interface = _task.pb2.spec.task_template.interface - native_interface = types.guess_interface( - _task.spec.task_template.interface, default_inputs=_task.spec.default_inputs - ) inputs = await convert.convert_from_native_to_inputs(native_interface, *args, **kwargs) inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs) sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( @@ -490,19 +525,19 @@ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, * serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) inputs_uri = io.inputs_path(sub_action_output_path) - await upload_inputs_with_retry(serialized_inputs, inputs_uri) + await upload_inputs_with_retry(serialized_inputs, inputs_uri, _task.max_inline_io_bytes) # cache key - task name, task signature, inputs, cache version cache_key = None - md = _task.spec.task_template.metadata + md = _task.pb2.spec.task_template.metadata ignored_input_vars = [] if len(md.cache_ignore_input_vars) > 0: ignored_input_vars = list(md.cache_ignore_input_vars) - if _task.spec.task_template.metadata and _task.spec.task_template.metadata.discoverable: - discovery_version = _task.spec.task_template.metadata.discovery_version + if md and md.discoverable: + discovery_version = md.discovery_version cache_key = convert.generate_cache_key_hash( task_name, inputs_hash, - _task.spec.task_template.interface, + pb_interface, discovery_version, ignored_input_vars, inputs.proto_inputs, @@ -512,22 +547,23 @@ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, * serialized_inputs = None # type: ignore inputs_hash = None # type: ignore - sub_action_id = self._get_action_identifier_new_proto( - action_name=sub_action_id.name, - run_name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, - ) - translated_task_spec = self._translate_task_spec_new_proto(_task.spec) action = Action.from_task( - sub_action_id=sub_action_id, + sub_action_id=identifier_pb2.ActionIdentifier( + name=sub_action_id.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), + ), parent_action_name=current_action_id.name, - group_data=str(tctx.group_data) if tctx.group_data else None, - task_spec=translated_task_spec, + group_data=tctx.group_data, + task_spec=_task.pb2.spec, inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, cache_key=cache_key, + queue=None, ) try: @@ -543,7 +579,7 @@ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, * await self.cancel_action(action) raise - if n.has_error() or n.phase == 6: # failed + if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED: exc = await handle_action_failure(action, task_name) raise exc @@ -553,5 +589,15 @@ async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, * "RuntimeError", f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", ) - return await load_and_convert_outputs(native_interface, n.realized_outputs_uri) + return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes) return None + + async def submit_task_ref(self, _task: TaskDetails, *args, **kwargs) -> Any: + ctx = internal_ctx() + tctx = ctx.data.task_context + if tctx is None: + raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") + current_action_id = tctx.action + task_call_seq = self.generate_task_call_sequence(_task, current_action_id) + async with self._parent_action_semaphore[unique_action_name(current_action_id)]: + return await self._submit_task_ref(task_call_seq, _task, *args, **kwargs) diff --git a/src/flyte/_internal/runtime/convert.py b/src/flyte/_internal/runtime/convert.py index aa0705a5a..1496565cc 100644 --- a/src/flyte/_internal/runtime/convert.py +++ b/src/flyte/_internal/runtime/convert.py @@ -234,11 +234,8 @@ async def convert_outputs_to_native(interface: NativeInterface, outputs: Outputs return tuple(kwargs[k] for k in interface.outputs.keys()) -from flyte_controller_base import cloudidl - - def convert_error_to_native( - err: execution_pb2.ExecutionError | Exception | Error | cloudidl.workflow.ExecutionError, + err: execution_pb2.ExecutionError | Exception | Error, ) -> Exception | None: if not err: return None @@ -484,7 +481,7 @@ def generate_sub_action_id_and_output_path( sub_action_id = current_action_id.new_sub_action_from( task_hash=task_hash, input_hash=inputs_hash, - group=tctx.group_data.name if tctx.group_data else None, + group=tctx.group_data if tctx.group_data else None, task_call_seq=invoke_seq, ) sub_run_output_path = storage.join(current_output_path, sub_action_id.name) diff --git a/src/flyte/models.py b/src/flyte/models.py index 5a8a6ec38..01f4d0650 100644 --- a/src/flyte/models.py +++ b/src/flyte/models.py @@ -200,7 +200,7 @@ class TaskContext: output_path: str run_base_dir: str report: Report - group_data: GroupData | None = None + group_data: str | None = None checkpoints: Checkpoints | None = None code_bundle: CodeBundle | None = None compiled_image_cache: ImageCache | None = None From 3c51e783489febe36cbace15b2318881e8eeba5d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 1 Dec 2025 00:43:51 -0800 Subject: [PATCH 07/22] update for regular hybrid mode Signed-off-by: Yee Hing Tong --- examples/advanced/hybrid_mode.py | 21 ++++++--------------- src/flyte/_internal/controllers/__init__.py | 10 +++++----- src/flyte/_run.py | 6 +++--- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py index 618a41aef..343d0ee17 100644 --- a/examples/advanced/hybrid_mode.py +++ b/examples/advanced/hybrid_mode.py @@ -5,6 +5,7 @@ import flyte import flyte.storage +from flyte.storage import S3 env = flyte.TaskEnvironment(name="hello_world", cache="disable") @@ -57,22 +58,12 @@ async def hybrid_parent_placeholder(): if __name__ == "__main__": # Get current working directory current_directory = Path(os.getcwd()) - # change to root directory of the project - os.chdir(current_directory.parent.parent) - # config = S3.for_sandbox() - # config = flyte.storage.S3.auto() - flyte.init_from_config("/Users/ytong/.flyte/config-k3d.yaml") - # flyte.init( - # endpoint="dns:///localhost:8090", - # insecure=True, - # org="testorg", - # project="testproject", - # domain="development", - # storage=config, - # log_level=10, - # ) + repo_root = current_directory.parent.parent + s3_sandbox = S3.for_sandbox() + flyte.init_from_config("/Users/ytong/.flyte/config-k3d.yaml", root_dir=repo_root, storage=s3_sandbox) + # Kick off a run of hybrid_parent_placeholder and fill in with kicked off things. - run_name = "rxmjfwt6nz2rkwzntrtl" + run_name = "rbddrnv8pd9lslpzw89w" outputs = flyte.with_runcontext( mode="hybrid", name=run_name, diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 600b85485..43acf162b 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -118,12 +118,12 @@ def create_controller( controller = LocalController() case "remote" | "hybrid": - # from flyte._internal.controllers.remote import create_remote_controller + from flyte._internal.controllers.remote import create_remote_controller - # controller = create_remote_controller(**kwargs) - from flyte._internal.controllers.remote._r_controller import RemoteController - - controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) + controller = create_remote_controller(**kwargs) + # from flyte._internal.controllers.remote._r_controller import RemoteController + # + # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) case "rust": # hybrid case, despite the case statement above, meant for local runs not inside docker from flyte._internal.controllers.remote._r_controller import RemoteController diff --git a/src/flyte/_run.py b/src/flyte/_run.py index b4c2178f3..b12f71968 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -69,7 +69,7 @@ async def _get_code_bundle_for_run(name: str) -> CodeBundle | None: run = await Run.get.aio(name=name) if run: run_details = await run.details.aio() - spec = run_details.action_details.pb2.resolved_task_spec + spec = run_details.action_details.pb2.task return extract_code_bundle(spec) return None @@ -422,8 +422,8 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: run_name = self._name random_id = str(uuid.uuid4())[:6] - # controller = create_controller("remote", endpoint="localhost:8090", insecure=True) - controller = create_controller("rust", endpoint="localhost:8090", insecure=True) + controller = create_controller("remote", endpoint="localhost:8090", insecure=True) + # controller = create_controller("rust", endpoint="localhost:8090", insecure=True) action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org) inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs) From f6bb34264438e0b022e5763e8727c51dde039ea3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 1 Dec 2025 01:49:28 -0800 Subject: [PATCH 08/22] readme and alternate between original and rs controller in hybrid Signed-off-by: Yee Hing Tong --- README.md | 5 +++++ src/flyte/_run.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f8b0a7589..9af082633 100644 --- a/README.md +++ b/README.md @@ -338,3 +338,8 @@ Then run `make build-wheels`. make dist python maint_tools/build_default_image.py ``` + +To install the wheel locally for testing, use the following command with your venv active. +```bash +uv pip install --find-links ./rs_controller/dist --no-index --force-reinstall flyte_controller_base +``` diff --git a/src/flyte/_run.py b/src/flyte/_run.py index b12f71968..0712619af 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -422,8 +422,8 @@ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: run_name = self._name random_id = str(uuid.uuid4())[:6] - controller = create_controller("remote", endpoint="localhost:8090", insecure=True) - # controller = create_controller("rust", endpoint="localhost:8090", insecure=True) + # controller = create_controller("remote", endpoint="localhost:8090", insecure=True) + controller = create_controller("rust", endpoint="localhost:8090", insecure=True) action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org) inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs) From 251d47358221fbc77afdd0885fb4169f3f43c1e3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 1 Dec 2025 02:18:13 -0800 Subject: [PATCH 09/22] pin idl to 14 because of semver lexical ordering and hybrid mode Signed-off-by: Yee Hing Tong --- examples/advanced/hybrid_mode.py | 2 +- rs_controller/Cargo.lock | 4 +- rs_controller/Cargo.toml | 2 +- rs_controller/src/action.rs | 39 ++++++++--- .../controllers/remote/_r_controller.py | 69 +++++++++++-------- 5 files changed, 72 insertions(+), 44 deletions(-) diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py index 343d0ee17..73b7b579e 100644 --- a/examples/advanced/hybrid_mode.py +++ b/examples/advanced/hybrid_mode.py @@ -63,7 +63,7 @@ async def hybrid_parent_placeholder(): flyte.init_from_config("/Users/ytong/.flyte/config-k3d.yaml", root_dir=repo_root, storage=s3_sandbox) # Kick off a run of hybrid_parent_placeholder and fill in with kicked off things. - run_name = "rbddrnv8pd9lslpzw89w" + run_name = "rt26xx54p886brkhcns2" outputs = flyte.with_runcontext( mode="hybrid", name=run_name, diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index 22ec3e40e..e7561527b 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -251,9 +251,9 @@ dependencies = [ [[package]] name = "flyteidl2" -version = "2.0.0-alpha9" +version = "2.0.0-alpha14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841c8fd5b71ff4cfbb2c70135aeaf51d02c847d853d5b333a2e4cd512d9a863c" +checksum = "5da488fa330fc42fc4c2daa03f10e8bcbc69d302673036952e7f56c92718d13e" dependencies = [ "async-trait", "futures", diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index 8120bc422..d1d508e4f 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -24,7 +24,7 @@ tracing-subscriber = "0.3" async-trait = "0.1" thiserror = "1.0" pyo3-build-config = "0.24.2" -flyteidl2 = "2.0.0-alpha14" +flyteidl2 = "=2.0.0-alpha14" [build-dependencies] pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 76b6d45e6..795156390 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -1,4 +1,5 @@ use flyteidl2::google::protobuf::Timestamp; +use prost::Message; use pyo3::prelude::*; use flyteidl2::flyteidl::common::ActionIdentifier; @@ -140,17 +141,24 @@ impl Action { impl Action { #[staticmethod] pub fn from_task( - sub_action_id: ActionIdentifier, + sub_action_id_bytes: &[u8], parent_action_name: String, group_data: Option, - task_spec: TaskSpec, // document what this error is + task_spec_bytes: &[u8], inputs_uri: String, run_output_base: String, cache_key: Option, queue: Option, - ) -> Self { + ) -> PyResult { + // Deserialize bytes to Rust protobuf types since Python and Rust have different generated protobufs + let sub_action_id = ActionIdentifier::decode(sub_action_id_bytes) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode ActionIdentifier: {}", e)))?; + + let task_spec = TaskSpec::decode(task_spec_bytes) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode TaskSpec: {}", e)))?; + debug!("Creating Action from task for ID {:?}", &sub_action_id); - Action { + Ok(Action { action_id: sub_action_id, parent_action_name, action_type: ActionType::Task, @@ -171,14 +179,14 @@ impl Action { cache_key, queue, trace: None, - } + }) } /// This creates a new action for tracing purposes. It is used to track the execution of a trace #[staticmethod] pub fn from_trace( parent_action_name: String, - action_id: ActionIdentifier, + action_id_bytes: &[u8], friendly_name: String, group_data: Option, inputs_uri: String, @@ -187,8 +195,19 @@ impl Action { end_time: f64, // Unix timestamp in seconds with fractional seconds run_output_base: String, report_uri: Option, - typed_interface: Option, - ) -> Self { + typed_interface_bytes: Option<&[u8]>, + ) -> PyResult { + // Deserialize bytes to Rust protobuf types + let action_id = ActionIdentifier::decode(action_id_bytes) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode ActionIdentifier: {}", e)))?; + + let typed_interface = if let Some(bytes) = typed_interface_bytes { + Some(TypedInterface::decode(bytes) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode TypedInterface: {}", e)))?) + } else { + None + }; + debug!("Creating Action from trace for ID {:?}", &action_id); let trace_spec = TraceSpec { interface: typed_interface, @@ -224,7 +243,7 @@ impl Action { spec: Some(trace_spec), }; - Action { + Ok(Action { action_id, parent_action_name, action_type: ActionType::Trace, @@ -242,7 +261,7 @@ impl Action { cache_key: None, queue: None, trace: Some(trace_action), - } + }) } #[getter(run_name)] diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 0f5c074af..865396dc2 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -224,19 +224,22 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg serialized_inputs = None # type: ignore inputs_hash = None # type: ignore - action = Action.from_task( - sub_action_id=identifier_pb2.ActionIdentifier( - name=sub_action_id.name, - run=identifier_pb2.RunIdentifier( - name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, - ), + # Serialize protobuf objects to bytes for Rust interop + sub_action_id_pb = identifier_pb2.ActionIdentifier( + name=sub_action_id.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, ), + ) + + action = Action.from_task( + sub_action_id_bytes=sub_action_id_pb.SerializeToString(), parent_action_name=current_action_id.name, group_data=str(tctx.group_data) if tctx.group_data else None, - task_spec=task_spec, + task_spec_bytes=task_spec.SerializeToString(), inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, cache_key=cache_key, @@ -471,17 +474,20 @@ async def record_trace(self, info: TraceInfo): typed_interface = transform_native_to_typed_interface(info.interface) + # Serialize protobuf objects to bytes for Rust interop + action_id_pb = identifier_pb2.ActionIdentifier( + name=info.action.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, + ), + ) + trace_action = Action.from_trace( parent_action_name=current_action_id.name, - action_id=identifier_pb2.ActionIdentifier( - name=info.action.name, - run=identifier_pb2.RunIdentifier( - name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, - ), - ), + action_id_bytes=action_id_pb.SerializeToString(), inputs_uri=info.inputs_path, outputs_uri=outputs_file_path, friendly_name=info.name, @@ -489,7 +495,7 @@ async def record_trace(self, info: TraceInfo): run_output_base=tctx.run_base_dir, start_time=info.start_time, end_time=info.end_time, - typed_interface=typed_interface if typed_interface else None, + typed_interface_bytes=typed_interface.SerializeToString() if typed_interface else None, ) async with self._parent_action_semaphore[unique_action_name(current_action_id)]: @@ -547,19 +553,22 @@ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, serialized_inputs = None # type: ignore inputs_hash = None # type: ignore - action = Action.from_task( - sub_action_id=identifier_pb2.ActionIdentifier( - name=sub_action_id.name, - run=identifier_pb2.RunIdentifier( - name=current_action_id.run_name, - project=current_action_id.project, - domain=current_action_id.domain, - org=current_action_id.org, - ), + # Serialize protobuf objects to bytes for Rust interop + sub_action_id_pb = identifier_pb2.ActionIdentifier( + name=sub_action_id.name, + run=identifier_pb2.RunIdentifier( + name=current_action_id.run_name, + project=current_action_id.project, + domain=current_action_id.domain, + org=current_action_id.org, ), + ) + + action = Action.from_task( + sub_action_id_bytes=sub_action_id_pb.SerializeToString(), parent_action_name=current_action_id.name, group_data=tctx.group_data, - task_spec=_task.pb2.spec, + task_spec_bytes=_task.pb2.spec.SerializeToString(), inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, cache_key=cache_key, From 4fa9a8aa6e5485597e4fcbe5040c55cc10cbc566 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 3 Dec 2025 22:20:10 -0800 Subject: [PATCH 10/22] pr into pr ignore (#376) Signed-off-by: Yee Hing Tong --- README.md | 1 + examples/basics/devbox_one.py | 8 + rs_controller/AUTH_IMPLEMENTATION.md | 130 ++ rs_controller/Cargo.lock | 1088 ++++++++++++++++- rs_controller/Cargo.toml | 17 +- rs_controller/Makefile | 9 +- rs_controller/build.rs | 2 +- rs_controller/pyproject.toml | 6 +- rs_controller/src/action.rs | 53 +- rs_controller/src/auth/client_credentials.rs | 205 ++++ rs_controller/src/auth/config.rs | 55 + rs_controller/src/auth/errors.rs | 23 + rs_controller/src/auth/middleware.rs | 125 ++ rs_controller/src/auth/mod.rs | 11 + rs_controller/src/auth/token_client.rs | 96 ++ rs_controller/src/bin/test_auth.rs | 65 + rs_controller/src/informer.rs | 12 +- rs_controller/src/lib.rs | 504 +++++++- rs_controller/src/lib_auth.rs | 5 + rs_controller/src/proto/flyteidl.service.rs | 78 ++ .../src/proto/flyteidl.service.tonic.rs | 409 +++++++ rs_controller/src/proto/mod.rs | 12 + rs_controller/test_auth_direct.py | 25 + rs_controller/test_auth_simple.py | 98 ++ rs_controller/try_rust_controller.py | 108 ++ src/flyte/_internal/controllers/__init__.py | 12 +- .../controllers/remote/_r_controller.py | 84 +- 27 files changed, 3134 insertions(+), 107 deletions(-) create mode 100644 rs_controller/AUTH_IMPLEMENTATION.md create mode 100644 rs_controller/src/auth/client_credentials.rs create mode 100644 rs_controller/src/auth/config.rs create mode 100644 rs_controller/src/auth/errors.rs create mode 100644 rs_controller/src/auth/middleware.rs create mode 100644 rs_controller/src/auth/mod.rs create mode 100644 rs_controller/src/auth/token_client.rs create mode 100644 rs_controller/src/bin/test_auth.rs create mode 100644 rs_controller/src/lib_auth.rs create mode 100644 rs_controller/src/proto/flyteidl.service.rs create mode 100644 rs_controller/src/proto/flyteidl.service.tonic.rs create mode 100644 rs_controller/src/proto/mod.rs create mode 100644 rs_controller/test_auth_direct.py create mode 100644 rs_controller/test_auth_simple.py create mode 100644 rs_controller/try_rust_controller.py diff --git a/README.md b/README.md index 9af082633..41e85914b 100644 --- a/README.md +++ b/README.md @@ -343,3 +343,4 @@ To install the wheel locally for testing, use the following command with your ve ```bash uv pip install --find-links ./rs_controller/dist --no-index --force-reinstall flyte_controller_base ``` +Repeat this process to iterate - build new wheels, force reinstall the controller package. diff --git a/examples/basics/devbox_one.py b/examples/basics/devbox_one.py index 5a619b28a..22cd573a4 100644 --- a/examples/basics/devbox_one.py +++ b/examples/basics/devbox_one.py @@ -1,12 +1,20 @@ import asyncio import logging +from pathlib import Path from typing import List import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) env = flyte.TaskEnvironment( name="hello_world", resources=flyte.Resources(cpu=1, memory="1Gi"), + image=rs_controller_image, ) diff --git a/rs_controller/AUTH_IMPLEMENTATION.md b/rs_controller/AUTH_IMPLEMENTATION.md new file mode 100644 index 000000000..eff0c8a4c --- /dev/null +++ b/rs_controller/AUTH_IMPLEMENTATION.md @@ -0,0 +1,130 @@ +# Rust Authentication Implementation for Flyte + +## Summary + +I've implemented client credentials OAuth2 authentication for the Rust gRPC clients, modeled after the Python implementation. The implementation includes: + +1. **Auth Module Structure** (`src/auth/`) + - `config.rs` - Auth configuration and helper traits + - `token_client.rs` - OAuth2 token retrieval logic + - `client_credentials.rs` - Client credentials authenticator with token caching + - `interceptor.rs` - gRPC interceptor for adding auth headers and handling 401s + +2. **Proto Module** (`src/proto/`) + - Organized generated protobuf files from v1 Flyte IDL + - Includes `AuthMetadataService` for fetching OAuth2 metadata + +3. **Key Features** + - Automatic token fetching on first request + - Token caching with expiration tracking + - Automatic refresh on 401/Unauthenticated errors + - Thread-safe credential management using RwLock + - Retry logic with automatic credential refresh + +## Current Status + +**Implemented but not fully compiling** - There are compilation issues with the generated proto files: + +1. Some proto files have `#[derive(Copy)]` on structs with non-Copy fields (String) +2. There may be missing prost-types features needed for Timestamp handling +3. Some module visibility issues to resolve + +## How It Works + +### Authentication Flow + +``` +1. Client creates AuthConfig with endpoint, client_id, client_secret +2. ClientCredentialsAuthenticator is created +3. On first gRPC call: + a. Authenticator fetches OAuth2 metadata from AuthMetadataService + b. Calls token endpoint with client credentials + c. Caches the access token with expiration time +4. AuthInterceptor adds "Bearer {token}" to request metadata +5. If request returns 401: + a. Interceptor triggers credential refresh + b. Retries the request with new token +``` + +### Usage Example + +```rust +use flyte_controller_base::auth::{AuthConfig, AuthInterceptor, ClientCredentialsAuthenticator}; + +// Create auth config +let auth_config = AuthConfig { + endpoint: "dns:///flyte.example.com:443".to_string(), + client_id: "your_client_id".to_string(), + client_secret: "your_secret".to_string(), + scopes: None, + audience: None, +}; + +// Create authenticator +let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + +// Connect to endpoint +let channel = Endpoint::from_shared(endpoint)? + .connect() + .await?; + +// Create auth interceptor +let auth_interceptor = AuthInterceptor::new(authenticator, channel.clone()); + +// Make authenticated calls using the with_auth! macro +let response = with_auth!( + auth_interceptor, + my_client, + my_method, + request +)?; +``` + +## Files Created/Modified + +### New Files +- `src/auth/mod.rs` - Auth module exports +- `src/auth/config.rs` - Configuration types +- `src/auth/token_client.rs` - OAuth2 token client +- `src/auth/client_credentials.rs` - Authenticator implementation +- `src/auth/interceptor.rs` - gRPC interceptor with retry logic +- `src/proto/mod.rs` - Proto module organization +- `src/lib_auth.rs` - Re-exports for external use +- `examples/simple_auth_test.rs` - Test script +- `examples/auth_test.rs` - Example with actual API calls + +### Modified Files +- `Cargo.toml` - Added dependencies (reqwest, serde, base64, urlencoding) +- `src/lib.rs` - Added auth and proto modules + +## Next Steps to Fix Compilation + +1. **Fix proto file issues:** + - Remove `Copy` derives from structs with String fields in generated files + - OR regenerate the proto files with correct options + - OR use only the minimal auth-related protos + +2. **Check prost-types dependency:** + ```toml + prost-types = { version = "0.12", features = ["std"] } + ``` + May need to match the prost version exactly. + +3. **Simplest fix:** Extract just the `AuthMetadataService` related types into a minimal hand-written proto module (I started this in `src/proto/auth_service.rs`) + +## Testing + +Once compilation is fixed, test with: + +```bash +FLYTE_ENDPOINT=dns:///your-endpoint:443 \ +FLYTE_CLIENT_ID=your_id \ +FLYTE_CLIENT_SECRET=your_secret \ +cargo run --example simple_auth_test +``` + +## References + +- Python implementation: `/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/src/flyte/remote/_client/auth/` +- Flytekit PR #2416: https://github.com/flyteorg/flytekit/pull/2416/files +- Proto definitions: https://github.com/flyteorg/flyte/tree/v2/flyteidl2 diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index e7561527b..f4016dffd 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -157,6 +157,12 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +[[package]] +name = "bumpalo" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" + [[package]] name = "bytes" version = "1.10.1" @@ -179,6 +185,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.42" @@ -188,12 +200,58 @@ dependencies = [ "num-traits", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -207,7 +265,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -233,20 +291,29 @@ name = "flyte_core" version = "0.1.0" dependencies = [ "async-trait", + "base64 0.22.1", + "bytes", "flyteidl2", "futures", + "http", + "http-body-util", "prost 0.13.5", - "prost-types 0.12.6", + "prost-types 0.13.5", "pyo3", "pyo3-async-runtimes", "pyo3-build-config", - "thiserror", + "reqwest", + "serde", + "serde_json", + "thiserror 1.0.69", "tokio", "tonic", "tower 0.4.13", - "tower-http", + "tower-http 0.5.2", "tracing", "tracing-subscriber", + "url", + "urlencoding", ] [[package]] @@ -271,11 +338,11 @@ dependencies = [ "regex", "serde", "syn", - "thiserror", + "thiserror 1.0.69", "tokio", "tonic", "tower 0.4.13", - "tower-http", + "tower-http 0.5.2", "tracing", "tracing-subscriber", ] @@ -286,6 +353,30 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + [[package]] name = "futures" version = "0.3.31" @@ -382,8 +473,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -393,9 +486,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -508,6 +603,23 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + [[package]] name = "hyper-timeout" version = "0.5.2" @@ -521,12 +633,29 @@ dependencies = [ "tower-service", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" dependencies = [ + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -534,12 +663,118 @@ dependencies = [ "http", "http-body", "hyper", + "ipnet", "libc", + "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", +] + +[[package]] +name = "icu_collections" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" + +[[package]] +name = "icu_properties" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" + +[[package]] +name = "icu_provider" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", ] [[package]] @@ -568,6 +803,22 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" +[[package]] +name = "ipnet" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" + +[[package]] +name = "iri-string" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f867b9d1d896b67beb18518eda36fdb77a32ea590de864f1325b294a6d14397" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "itertools" version = "0.12.1" @@ -592,6 +843,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" +[[package]] +name = "js-sys" +version = "0.3.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -610,6 +871,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "litemap" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" + [[package]] name = "lock_api" version = "0.4.12" @@ -626,6 +893,12 @@ version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "matchit" version = "0.7.3" @@ -670,7 +943,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "wasi", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -679,6 +952,23 @@ version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -713,6 +1003,50 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "overload" version = "0.1.1" @@ -833,12 +1167,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "portable-atomic" version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "potential_utf" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77" +dependencies = [ + "zerovec", +] + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -1089,6 +1438,61 @@ dependencies = [ "syn", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2", + "thiserror 2.0.17", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.17", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.40" @@ -1111,8 +1515,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", ] [[package]] @@ -1122,7 +1536,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -1134,6 +1558,15 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "redox_syscall" version = "0.5.12" @@ -1173,36 +1606,208 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] -name = "rustc-demangle" -version = "0.1.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" - -[[package]] -name = "rustix" -version = "1.1.2" +name = "reqwest" +version = "0.12.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" dependencies = [ - "bitflags", - "errno", - "libc", - "linux-raw-sys", - "windows-sys", -] - -[[package]] + "base64 0.22.1", + "bytes", + "encoding_rs", + "futures-core", + "h2", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-tls", + "hyper-util", + "js-sys", + "log", + "mime", + "native-tls", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-native-tls", + "tokio-rustls", + "tower 0.5.2", + "tower-http 0.6.7", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-demangle" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.5.1", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] name = "rustversion" version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3297343eaf830f66ede390ea39da1d462b6b0c1b000f420d0a83f898bbbe6ef" +dependencies = [ + "bitflags", + "core-foundation 0.10.1", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.219" @@ -1223,6 +1828,30 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.143" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -1269,9 +1898,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.111" @@ -1288,6 +1929,41 @@ name = "sync_wrapper" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags", + "core-foundation 0.9.4", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] [[package]] name = "target-lexicon" @@ -1305,7 +1981,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -1314,7 +1990,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl 2.0.17", ] [[package]] @@ -1328,6 +2013,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -1338,6 +2034,31 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tinystr" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.44.2" @@ -1353,7 +2074,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1367,6 +2088,26 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -1412,8 +2153,11 @@ dependencies = [ "percent-encoding", "pin-project", "prost 0.13.5", + "rustls-native-certs", + "rustls-pemfile", "socket2", "tokio", + "tokio-rustls", "tokio-stream", "tower 0.4.13", "tower-layer", @@ -1432,7 +2176,7 @@ dependencies = [ "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand", + "rand 0.8.5", "slab", "tokio", "tokio-util", @@ -1451,6 +2195,7 @@ dependencies = [ "futures-util", "pin-project-lite", "sync_wrapper", + "tokio", "tower-layer", "tower-service", ] @@ -1472,6 +2217,24 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +dependencies = [ + "bitflags", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower 0.5.2", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -1560,12 +2323,47 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "valuable" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "want" version = "0.3.1" @@ -1590,6 +2388,93 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasm-bindgen" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" +dependencies = [ + "cfg-if", + "js-sys", + "once_cell", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "web-sys" +version = "0.3.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "webpki-roots" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2878ef029c47c6e8cf779119f20fcf52bde7ad42a731b2a304bc221df17571e" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1612,6 +2497,47 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-registry" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a9ed28765efc97bbc954883f4e6796c33a06546ebafacbabee9696967499e" +dependencies = [ + "windows-link 0.1.3", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link 0.1.3", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link 0.1.3", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -1621,6 +2547,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link 0.2.1", +] + [[package]] name = "windows-targets" version = "0.52.6" @@ -1691,6 +2626,35 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "writeable" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9" + +[[package]] +name = "yoke" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.8.25" @@ -1710,3 +2674,63 @@ dependencies = [ "quote", "syn", ] + +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index d1d508e4f..7be07429a 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -6,16 +6,16 @@ edition = "2021" [lib] name = "flyte_controller_base" path = "src/lib.rs" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] # cloudidl = { path = "../../cloud/gen/pb_rust" } pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] } tokio = { version = "1.0", features = ["full"] } -tonic = "0.12" -prost = { version = "0.13.5", features = ["std"] } -prost-types = { version = "0.12", features = ["std"] } +tonic = { version = "0.12", features = ["tls", "tls-native-roots"] } +prost = { version = "0.13", features = ["std"] } +prost-types = { version = "0.13", features = ["std"] } futures = "0.3" tower = "0.4" tower-http = { version = "0.5", features = ["trace"] } @@ -25,6 +25,15 @@ async-trait = "0.1" thiserror = "1.0" pyo3-build-config = "0.24.2" flyteidl2 = "=2.0.0-alpha14" +reqwest = { version = "0.12", features = ["json", "rustls-tls"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +base64 = "0.22" +url = "2.5" +http = "1.0" +urlencoding = "2.1" +bytes = "1.0" +http-body-util = "0.1" [build-dependencies] pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 5ef3427a9..ee3ed9401 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -16,6 +16,8 @@ build-amd64: CARGO_CACHE_DIR := docker_cargo_cache DIST_DIRS := dist +# Derive version from git tags, convert to PEP 440 format +VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | sed -E 's/^v//; s/(-[0-9]+)-g([0-9a-f]+)/.dev\1+g\2/; s/\.dev-/\.dev/; s/-dirty/.dirty/') dist-dirs: mkdir -p $(DIST_DIRS) $(CARGO_CACHE_DIR) @@ -27,6 +29,7 @@ docker run --rm \ -v $(CLOUD_REPO):/cloud \ wheel-builder:$(1) /bin/bash -c "\ cd /io; \ + sed -i 's/^version = .*/version = \"$(VERSION)\"/' pyproject.toml; \ /opt/python/cp310-cp310/bin/maturin build --release --find-interpreter --out /io/dist/ \ " endef @@ -39,10 +42,14 @@ build-wheels-arm64: dist-dirs build-wheels-amd64: dist-dirs $(call BUILD_WHEELS_RECIPE,amd64) -build-wheels: build-wheels-arm64 build-wheels-amd64 +clean_dist: + rm -rf $(DIST_DIRS)/*whl +build-wheels: clean_dist build-wheels-arm64 build-wheels-amd64 # This is for Mac users, since the other targets won't build macos wheels (only local arch so probably arm64) build-wheel-local: dist-dirs + @echo "Building version: $(VERSION)" + sed -i.bak 's/^version = .*/version = "$(VERSION)"/' pyproject.toml && rm pyproject.toml.bak python -m build --wheel --outdir dist diff --git a/rs_controller/build.rs b/rs_controller/build.rs index 40462f970..5c7492443 100644 --- a/rs_controller/build.rs +++ b/rs_controller/build.rs @@ -1,4 +1,4 @@ // build.rs fn main() { // pyo3_build_config::use_pyo3_cfgs(); -} \ No newline at end of file +} diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 1e768aa67..f572ca789 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "0.1.0" +version = "2.0.0b33.dev33+g3d028ba" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] @@ -12,3 +12,7 @@ classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] [tool.maturin] module-name = "flyte_controller_base" features = ["pyo3/extension-module"] + +[tool.ruff] +line-length = 120 +ignore = ["E501"] diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 795156390..f2e9095d4 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -151,11 +151,16 @@ impl Action { queue: Option, ) -> PyResult { // Deserialize bytes to Rust protobuf types since Python and Rust have different generated protobufs - let sub_action_id = ActionIdentifier::decode(sub_action_id_bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode ActionIdentifier: {}", e)))?; + let sub_action_id = ActionIdentifier::decode(sub_action_id_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to decode ActionIdentifier: {}", + e + )) + })?; - let task_spec = TaskSpec::decode(task_spec_bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode TaskSpec: {}", e)))?; + let task_spec = TaskSpec::decode(task_spec_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Failed to decode TaskSpec: {}", e)) + })?; debug!("Creating Action from task for ID {:?}", &sub_action_id); Ok(Action { @@ -198,12 +203,20 @@ impl Action { typed_interface_bytes: Option<&[u8]>, ) -> PyResult { // Deserialize bytes to Rust protobuf types - let action_id = ActionIdentifier::decode(action_id_bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode ActionIdentifier: {}", e)))?; + let action_id = ActionIdentifier::decode(action_id_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to decode ActionIdentifier: {}", + e + )) + })?; let typed_interface = if let Some(bytes) = typed_interface_bytes { - Some(TypedInterface::decode(bytes) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Failed to decode TypedInterface: {}", e)))?) + Some(TypedInterface::decode(bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to decode TypedInterface: {}", + e + )) + })?) } else { None }; @@ -277,4 +290,28 @@ impl Action { fn has_error(&self) -> bool { self.err.is_some() || self.client_err.is_some() } + + /// Get action_id as serialized bytes for Python interop + #[getter] + fn action_id_bytes(&self) -> PyResult> { + Ok(self.action_id.encode_to_vec()) + } + + /// Get err as serialized bytes for Python interop (returns None if no error) + #[getter] + fn err_bytes(&self) -> Option> { + self.err.as_ref().map(|e| e.encode_to_vec()) + } + + /// Get task as serialized bytes for Python interop (returns None if no task) + #[getter] + fn task_bytes(&self) -> Option> { + self.task.as_ref().map(|t| t.encode_to_vec()) + } + + /// Get phase as i32 for Python interop (returns None if no phase) + #[getter] + fn phase_value(&self) -> Option { + self.phase.map(|p| p as i32) + } } diff --git a/rs_controller/src/auth/client_credentials.rs b/rs_controller/src/auth/client_credentials.rs new file mode 100644 index 000000000..c0795571e --- /dev/null +++ b/rs_controller/src/auth/client_credentials.rs @@ -0,0 +1,205 @@ +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +use tokio::sync::RwLock; +use tonic::transport::Channel; +use tracing::{debug, info}; + +use super::config::{AuthConfig, ClientConfigExt}; +use super::errors::TokenError; +use super::token_client::{self, GrantType, TokenResponse}; +use crate::proto::{ + AuthMetadataServiceClient, OAuth2MetadataRequest, OAuth2MetadataResponse, + PublicClientAuthConfigRequest, PublicClientAuthConfigResponse, +}; + +/// Stored credentials with expiration tracking +#[derive(Debug, Clone)] +pub struct Credentials { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: SystemTime, +} + +impl Credentials { + fn from_token_response(response: TokenResponse) -> Self { + let expires_at = SystemTime::now() + Duration::from_secs(response.expires_in as u64); + Self { + access_token: response.access_token, + refresh_token: response.refresh_token, + expires_at, + } + } + + fn is_expired(&self) -> bool { + // Consider expired if less than 60 seconds remaining + SystemTime::now() + Duration::from_secs(60) >= self.expires_at + } +} + +/// Client credentials authenticator +#[derive(Debug)] +pub struct ClientCredentialsAuthenticator { + config: AuthConfig, + credentials: Arc>>, + client_config: Arc>>, + oauth2_metadata: Arc>>, +} + +impl ClientCredentialsAuthenticator { + pub fn new(config: AuthConfig) -> Self { + Self { + config, + credentials: Arc::new(RwLock::new(None)), + client_config: Arc::new(RwLock::new(None)), + oauth2_metadata: Arc::new(RwLock::new(None)), + } + } + + /// Get the client configuration from the auth metadata service + async fn fetch_client_config( + &self, + channel: Channel, + ) -> Result { + let mut client = AuthMetadataServiceClient::new(channel.clone()); + let request = tonic::Request::new(PublicClientAuthConfigRequest {}); + + let response = client + .get_public_client_config(request) + .await + .map_err(|e| TokenError::AuthError(format!("Failed to get client config: {}", e)))?; + + Ok(response.into_inner()) + } + + /// Get the OAuth2 metadata from the auth metadata service + async fn fetch_oauth2_metadata( + &self, + channel: Channel, + ) -> Result { + let mut client = AuthMetadataServiceClient::new(channel); + let request = tonic::Request::new(OAuth2MetadataRequest {}); + + let response = client + .get_o_auth2_metadata(request) + .await + .map_err(|e| TokenError::AuthError(format!("Failed to get OAuth2 metadata: {}", e)))?; + + Ok(response.into_inner()) + } + + /// Refresh credentials using client credentials flow + async fn refresh_credentials_internal( + &self, + channel: Channel, + ) -> Result { + tracing::info!("🔄 refresh_credentials_internal: Starting..."); + // First, get the client config if we don't have it (cached) + let client_config = { + let config_lock = self.client_config.read().await; + if let Some(cfg) = config_lock.as_ref() { + tracing::info!("🔄 Using cached client_config"); + cfg.clone() + } else { + drop(config_lock); + tracing::info!("🔄 Fetching client_config from auth service..."); + let cfg = self.fetch_client_config(channel.clone()).await?; + tracing::info!("🔄 Got client_config response"); + let mut config_lock = self.client_config.write().await; + *config_lock = Some(cfg.clone()); + cfg + } + }; + + // Get OAuth2 metadata to find the token endpoint (cached) + let oauth2_metadata = { + let metadata_lock = self.oauth2_metadata.read().await; + if let Some(metadata) = metadata_lock.as_ref() { + metadata.clone() + } else { + drop(metadata_lock); + let metadata = self.fetch_oauth2_metadata(channel).await?; + let mut metadata_lock = self.oauth2_metadata.write().await; + *metadata_lock = Some(metadata.clone()); + metadata + } + }; + + debug!( + "Client credentials flow with client_id: {}", + self.config.client_id + ); + + // Request the token + let token_response = token_client::get_token( + &oauth2_metadata.token_endpoint, + &self.config.client_id, + &self.config.client_secret, + Some(client_config.scopes.as_slice()), + Some(client_config.audience.as_str()), + GrantType::ClientCredentials, + ) + .await?; + + info!( + "Retrieved new token, expires in {} seconds", + token_response.expires_in + ); + + Ok(Credentials::from_token_response(token_response)) + } + + /// Get current credentials, refreshing if necessary + pub async fn get_credentials(&self, channel: Channel) -> Result { + tracing::info!("🔐 get_credentials: Starting..."); + // Check if we have valid credentials + { + tracing::info!("🔐 get_credentials: Acquiring read lock..."); + let creds_lock = self.credentials.read().await; + tracing::info!("🔐 get_credentials: Got read lock"); + if let Some(creds) = creds_lock.as_ref() { + if !creds.is_expired() { + return Ok(creds.clone()); + } + } + } + tracing::info!("🔐 get_credentials: Need to refresh, acquiring write lock..."); + + // Need to refresh - acquire write lock + let mut creds_lock = self.credentials.write().await; + tracing::info!( + "🔐 get_credentials: Got write lock, calling refresh_credentials_internal..." + ); + + // Double-check after acquiring write lock (another thread might have refreshed) + if let Some(creds) = creds_lock.as_ref() { + if !creds.is_expired() { + return Ok(creds.clone()); + } + } + + // Refresh the credentials + let new_creds = self.refresh_credentials_internal(channel).await?; + *creds_lock = Some(new_creds.clone()); + + Ok(new_creds) + } + + /// Force a refresh of credentials + pub async fn refresh_credentials(&self, channel: Channel) -> Result<(), TokenError> { + let new_creds = self.refresh_credentials_internal(channel).await?; + let mut creds_lock = self.credentials.write().await; + *creds_lock = Some(new_creds); + Ok(()) + } + + /// Get the header key to use for authentication + pub async fn get_header_key(&self) -> String { + let config_lock = self.client_config.read().await; + if let Some(cfg) = config_lock.as_ref() { + // get rid of this + cfg.header_key().to_string() + } else { + "authorization".to_string() + } + } +} diff --git a/rs_controller/src/auth/config.rs b/rs_controller/src/auth/config.rs new file mode 100644 index 000000000..3c9046355 --- /dev/null +++ b/rs_controller/src/auth/config.rs @@ -0,0 +1,55 @@ +use crate::auth::errors::AuthConfigError; +use base64::{engine, Engine}; + +/// Configuration for authentication +#[derive(Debug, Clone)] +pub struct AuthConfig { + pub endpoint: String, + pub client_id: String, + pub client_secret: String, +} + +/// Extension trait to add helper methods to the proto-generated PublicClientAuthConfigResponse +pub trait ClientConfigExt { + fn header_key(&self) -> &str; +} + +// todo: get rid of this +impl ClientConfigExt for crate::proto::PublicClientAuthConfigResponse { + fn header_key(&self) -> &str { + if self.authorization_metadata_key.is_empty() { + "authorization" + } else { + &self.authorization_metadata_key + } + } +} + +impl AuthConfig { + pub fn new_from_api_key(api_key: &str) -> Result { + let decoded = engine::general_purpose::STANDARD.decode(api_key)?; + let api_key_str = String::from_utf8(decoded)?; + let split: Vec<_> = api_key_str.split(':').collect(); + + if split.len() != 4 { + return Err(AuthConfigError::InvalidFormat(split.len())); + } + + let parts: [String; 4] = split + .into_iter() + .map(String::from) + .collect::>() + .try_into() + .unwrap(); + let [endpoint, client_id, client_secret, _org] = parts; + + // the api key comes back just with the domain, we add https:// to it for rust rather than dns:/// + let endpoint = "https://".to_owned() + &endpoint; + + Ok(AuthConfig { + endpoint, + client_id, + client_secret, + }) + } +} diff --git a/rs_controller/src/auth/errors.rs b/rs_controller/src/auth/errors.rs new file mode 100644 index 000000000..c73d9c6e8 --- /dev/null +++ b/rs_controller/src/auth/errors.rs @@ -0,0 +1,23 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum AuthConfigError { + #[error("Failed to decode base64: {0}")] + Base64DecodeError(#[from] base64::DecodeError), + + #[error("Invalid API key format: expected 4 colon-separated parts, got {0}")] + InvalidFormat(usize), + + #[error("Invalid UTF-8 in decoded API key")] + InvalidUtf8(#[from] std::string::FromUtf8Error), +} + +#[derive(Error, Debug)] +pub enum TokenError { + #[error("HTTP error: {0}")] + HttpError(#[from] reqwest::Error), + #[error("Authentication error: {0}")] + AuthError(String), + #[error("JSON error: {0}")] + JsonError(#[from] serde_json::Error), +} diff --git a/rs_controller/src/auth/middleware.rs b/rs_controller/src/auth/middleware.rs new file mode 100644 index 000000000..8aca41bb3 --- /dev/null +++ b/rs_controller/src/auth/middleware.rs @@ -0,0 +1,125 @@ +use std::sync::Arc; +use std::task::{Context, Poll}; +use tonic::body::BoxBody; +use tonic::transport::Channel; +use tower::{Layer, Service, ServiceExt}; +use tracing::{error, warn}; + +use super::client_credentials::ClientCredentialsAuthenticator; + +/// Tower layer that adds authentication to gRPC requests +#[derive(Clone)] +pub struct AuthLayer { + authenticator: Arc, + channel: Channel, +} + +impl AuthLayer { + pub fn new(authenticator: Arc, channel: Channel) -> Self { + Self { + authenticator, + channel, + } + } +} + +impl Layer for AuthLayer { + type Service = AuthService; + + fn layer(&self, inner: S) -> Self::Service { + AuthService { + inner, + authenticator: self.authenticator.clone(), + channel: self.channel.clone(), + } + } +} + +/// Tower service that intercepts requests to add authentication +#[derive(Clone)] +pub struct AuthService { + inner: S, + authenticator: Arc, + channel: Channel, +} + +impl std::fmt::Debug for AuthService { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthService") + .field("inner", &self.inner) + .field("authenticator", &self.authenticator) + .field("channel", &"") + .finish() + } +} + +impl Service> for AuthService +where + S: Service, Response = http::Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into> + Send, +{ + type Response = S::Response; + type Error = S::Error; + type Future = std::pin::Pin< + Box> + Send>, + >; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut request: http::Request) -> Self::Future { + let authenticator = self.authenticator.clone(); + let channel = self.channel.clone(); + let mut inner = self.inner.clone(); + + Box::pin(async move { + // Get credentials and add auth header + match authenticator.get_credentials(channel.clone()).await { + Ok(creds) => { + let header_key = authenticator.get_header_key().await; + let token_value = format!("Bearer {}", creds.access_token); + + warn!("Adding auth header: {}", header_key); + + // Insert the authorization header + if let Ok(header_value) = token_value.parse::() { + request + .headers_mut() + .insert(http::header::AUTHORIZATION, header_value); + } else { + warn!("Failed to parse auth token as header value"); + } + } + Err(e) => { + warn!("Failed to get credentials: {}", e); + // Continue without auth - let the server reject it + } + } + + if let Err(e) = inner.ready().await { + error!("Inner service failed to become ready!!!"); + // Return the error from the inner service's ready check + return Err(e); + } + + // Make the request + let result = inner.call(request).await; + + // Check for 401/Unauthenticated and refresh credentials for next time + if let Ok(ref response) = result { + if response.status() == http::StatusCode::UNAUTHORIZED { + warn!("Got 401, refreshing credentials for next request"); + + // Refresh credentials in background so next request will have fresh creds + if let Err(e) = authenticator.refresh_credentials(channel.clone()).await { + warn!("Failed to refresh credentials: {}", e); + } + } + } + + result + }) + } +} diff --git a/rs_controller/src/auth/mod.rs b/rs_controller/src/auth/mod.rs new file mode 100644 index 000000000..0418ad58d --- /dev/null +++ b/rs_controller/src/auth/mod.rs @@ -0,0 +1,11 @@ +mod client_credentials; +mod config; +mod errors; +mod middleware; +mod token_client; + +pub use client_credentials::{ClientCredentialsAuthenticator, Credentials}; +pub use config::{AuthConfig, ClientConfigExt}; +pub use errors::{AuthConfigError, TokenError}; +pub use middleware::{AuthLayer, AuthService}; +pub use token_client::{get_token, GrantType}; diff --git a/rs_controller/src/auth/token_client.rs b/rs_controller/src/auth/token_client.rs new file mode 100644 index 000000000..83477f3e7 --- /dev/null +++ b/rs_controller/src/auth/token_client.rs @@ -0,0 +1,96 @@ +use crate::auth::errors::TokenError; +use base64::{engine::general_purpose, Engine as _}; +use reqwest; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use tracing::debug; + +#[derive(Debug, Clone, Copy)] +pub enum GrantType { + ClientCredentials, +} + +impl GrantType { + fn as_str(&self) -> &'static str { + match self { + GrantType::ClientCredentials => "client_credentials", + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct TokenResponse { + pub access_token: String, + #[serde(default)] + pub refresh_token: Option, + pub expires_in: i64, + #[serde(default)] + pub token_type: String, +} + +/// Creates a basic authorization header from client ID and secret +pub fn get_basic_authorization_header(client_id: &str, client_secret: &str) -> String { + let encoded_secret = urlencoding::encode(client_secret); + let concatenated = format!("{}:{}", client_id, encoded_secret); + let encoded = general_purpose::STANDARD.encode(concatenated.as_bytes()); + format!("Basic {}", encoded) +} + +/// Retrieves an access token from the token endpoint +pub async fn get_token( + token_endpoint: &str, + client_id: &str, + client_secret: &str, + scopes: Option<&[String]>, + audience: Option<&str>, + grant_type: GrantType, +) -> Result { + let client = reqwest::Client::new(); + + let authorization_header = get_basic_authorization_header(client_id, client_secret); + + let mut body = HashMap::new(); + body.insert("grant_type", grant_type.as_str().to_string()); + + if let Some(scopes) = scopes { + let scope_str = scopes.join(" "); + body.insert("scope", scope_str); + } + + if let Some(aud) = audience { + body.insert("audience", aud.to_string()); + } + + debug!( + "Requesting token from {} with grant_type {}", + token_endpoint, + grant_type.as_str() + ); + + let response = client + .post(token_endpoint) + .header("Authorization", authorization_header) + .header("Cache-Control", "no-cache") + .header("Accept", "application/json") + .header("Content-Type", "application/x-www-form-urlencoded") + .form(&body) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + return Err(TokenError::AuthError(format!( + "Token request failed with status {}: {}", + status, error_text + ))); + } + + let token_response: TokenResponse = response.json().await?; + debug!( + "Retrieved new token, expires in {} seconds", + token_response.expires_in + ); + + Ok(token_response) +} diff --git a/rs_controller/src/bin/test_auth.rs b/rs_controller/src/bin/test_auth.rs new file mode 100644 index 000000000..183481e6c --- /dev/null +++ b/rs_controller/src/bin/test_auth.rs @@ -0,0 +1,65 @@ +/// Standalone authentication test binary +/// +/// Usage: +/// FLYTE_ENDPOINT=dns:///your-endpoint:443 \ +/// FLYTE_CLIENT_ID=your_id \ +/// FLYTE_CLIENT_SECRET=your_secret \ +/// cargo run --bin test_auth +use flyte_controller_base::auth::{AuthConfig, ClientCredentialsAuthenticator}; +use std::env; +use std::sync::Arc; +use tonic::transport::Endpoint; +use tracing_subscriber; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + println!("=== Flyte Client Credentials Authentication Test ===\n"); + + let endpoint = + env::var("FLYTE_ENDPOINT").unwrap_or_else(|_| "dns:///localhost:8089".to_string()); + let client_id = env::var("FLYTE_CLIENT_ID").expect("FLYTE_CLIENT_ID must be set"); + let client_secret = env::var("FLYTE_CLIENT_SECRET").expect("FLYTE_CLIENT_SECRET must be set"); + + println!("Endpoint: {}", endpoint); + println!("Client ID: {}\n", client_id); + + let auth_config = AuthConfig { + endpoint: endpoint.clone(), + client_id, + client_secret, + }; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + + println!("Connecting to endpoint..."); + let channel = Endpoint::from_shared(endpoint)?.connect().await?; + println!("✓ Connected\n"); + + println!("Fetching OAuth2 metadata and retrieving access token..."); + match authenticator.get_credentials(channel.clone()).await { + Ok(creds) => { + println!("✓ Successfully obtained access token!"); + let preview = &creds.access_token[..20.min(creds.access_token.len())]; + println!(" Token (first 20 chars): {}...", preview); + println!(" Expires at: {:?}\n", creds.expires_at); + + println!("Testing cached credential retrieval..."); + match authenticator.get_credentials(channel).await { + Ok(_) => println!("✓ Successfully retrieved cached credentials\n"), + Err(e) => eprintln!("✗ Failed: {}\n", e), + } + + println!("=== Test Complete ==="); + println!("Authentication is working correctly!"); + Ok(()) + } + Err(e) => { + eprintln!("✗ Failed to obtain access token: {}", e); + Err(e.into()) + } + } +} diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index d385b3009..97f049331 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -1,5 +1,6 @@ use crate::action::Action; use crate::ControllerError; +use crate::StateClient; use flyteidl2::flyteidl::common::ActionIdentifier; use flyteidl2::flyteidl::common::RunIdentifier; @@ -23,7 +24,7 @@ use tracing_subscriber::fmt; #[derive(Clone, Debug)] pub struct Informer { - client: StateServiceClient, + client: StateClient, run_id: RunIdentifier, action_cache: Arc>>, parent_action_name: String, @@ -34,7 +35,7 @@ pub struct Informer { impl Informer { pub fn new( - client: StateServiceClient, + client: StateClient, run_id: RunIdentifier, parent_action_name: String, shared_queue: mpsc::Sender, @@ -280,7 +281,12 @@ async fn informer_main() { name: String::from("qdtc266r2z8clscl2lj5"), }; - let informer = Arc::new(Informer::new(client, run_id, "a0".to_string(), tx.clone())); + let informer = Arc::new(Informer::new( + StateClient::Plain(client), + run_id, + "a0".to_string(), + tx.clone(), + )); let watch_task = Informer::start(informer.clone()).await; diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index a0404fe96..5820139eb 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -1,14 +1,19 @@ #![allow(clippy::too_many_arguments)] mod action; +pub mod auth; // Public for use in other crates mod informer; +pub mod proto; // Public for use in other crates +use std::default; +use std::sync::mpsc::channel; use std::sync::Arc; use std::time::Duration; use futures::TryFutureExt; use pyo3::prelude::*; use tokio::sync::mpsc; +use tower::ServiceExt; use tracing::{debug, error, info, warn}; use thiserror::Error; @@ -16,25 +21,65 @@ use thiserror::Error; use crate::action::{Action, ActionType}; use crate::informer::Informer; -use flyteidl2::flyteidl::common::ActionIdentifier; +use crate::auth::{AuthConfig, AuthConfigError, AuthLayer, ClientCredentialsAuthenticator}; +use flyteidl2::flyteidl::common::{ActionIdentifier, ProjectIdentifier}; +use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; use flyteidl2::flyteidl::task::TaskIdentifier; -use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; -use flyteidl2::flyteidl::workflow::{EnqueueActionRequest, EnqueueActionResponse, TaskAction}; - +use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; use flyteidl2::flyteidl::workflow::enqueue_action_request; use flyteidl2::flyteidl::workflow::queue_service_client::QueueServiceClient; +use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::workflow::{ + EnqueueActionRequest, EnqueueActionResponse, TaskAction, WatchRequest, WatchResponse, +}; use flyteidl2::google; use google::protobuf::StringValue; use pyo3::exceptions; use pyo3::types::PyAny; use pyo3_async_runtimes::tokio::future_into_py; use pyo3_async_runtimes::tokio::get_runtime; -use tokio::sync::{oneshot, OnceCell}; +use std::sync::OnceLock; +use tokio::sync::oneshot; +use tokio::sync::OnceCell; use tokio::time::sleep; -use tonic::transport::Endpoint; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; use tonic::Status; use tracing_subscriber::FmtSubscriber; +// Fetches Amazon root CA certificate from Amazon Trust Services +async fn fetch_amazon_root_ca() -> Result { + // Amazon Root CA 1 - the main root used by AWS services + let url = "https://www.amazontrust.com/repository/AmazonRootCA1.pem"; + + let response = reqwest::get(url) + .await + .map_err(|e| ControllerError::SystemError(format!("Failed to fetch certificate: {}", e)))?; + + let cert_pem = response + .text() + .await + .map_err(|e| ControllerError::SystemError(format!("Failed to read certificate: {}", e)))?; + + Ok(Certificate::from_pem(cert_pem)) +} + +// Helper to create TLS-configured endpoint with Amazon CA certificate +// todo: when we resolve the pem issue, also remove the need to have both inputs which are basically the same +async fn create_tls_endpoint(url: &'static str, domain: &str) -> Result { + // Fetch Amazon root CA dynamically + let cert = fetch_amazon_root_ca().await?; + + let tls_config = ClientTlsConfig::new() + .domain_name(domain) + .ca_certificate(cert); + + let endpoint = Endpoint::from_static(url) + .tls_config(tls_config) + .map_err(|e| ControllerError::SystemError(format!("TLS config error: {}", e)))?; + + Ok(endpoint) +} + #[derive(Error, Debug)] pub enum ControllerError { #[error("Bad context: {0}")] @@ -62,37 +107,178 @@ impl From for PyErr { } } +impl From for PyErr { + fn from(err: AuthConfigError) -> Self { + exceptions::PyRuntimeError::new_err(err.to_string()) + } +} + +enum ChannelType { + Plain(tonic::transport::Channel), + Authenticated(crate::auth::AuthService), +} + +#[derive(Clone, Debug)] +pub enum StateClient { + Plain(StateServiceClient), + Authenticated(StateServiceClient>), +} + +impl StateClient { + pub async fn watch( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result>, tonic::Status> { + match self { + StateClient::Plain(client) => client.watch(request).await, + StateClient::Authenticated(client) => client.watch(request).await, + } + } +} + +#[derive(Clone, Debug)] +pub enum QueueClient { + Plain(QueueServiceClient), + Authenticated(QueueServiceClient>), +} + +impl QueueClient { + pub async fn enqueue_action( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result, tonic::Status> { + match self { + QueueClient::Plain(client) => client.enqueue_action(request).await, + QueueClient::Authenticated(client) => client.enqueue_action(request).await, + } + } +} + struct CoreBaseController { - state_client: StateServiceClient, - queue_client: QueueServiceClient, + channel: ChannelType, informer: OnceCell>, + state_client_cache: OnceLock, + queue_client_cache: OnceLock, shared_queue: mpsc::Sender, rx_of_shared_queue: Arc>>, } impl CoreBaseController { - pub fn try_new(endpoint: String) -> Result, ControllerError> { - info!("Creating CoreBaseController with endpoint {:?}", endpoint); - // play with taking str slice instead of String instead of intentionally leaking. - let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); + // Helper methods to get cached clients (constructed once, reused thereafter) + fn state_client(&self) -> StateClient { + self.state_client_cache + .get_or_init(|| match &self.channel { + ChannelType::Plain(ch) => StateClient::Plain(StateServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + StateClient::Authenticated(StateServiceClient::new(ch.clone())) + } + }) + .clone() + } + + fn queue_client(&self) -> QueueClient { + self.queue_client_cache + .get_or_init(|| match &self.channel { + ChannelType::Plain(ch) => QueueClient::Plain(QueueServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + QueueClient::Authenticated(QueueServiceClient::new(ch.clone())) + } + }) + .clone() + } + + pub fn new_with_auth() -> Result, ControllerError> { + use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; + use tower::ServiceBuilder; + + info!("Creating CoreBaseController from _UNION_EAGER_API_KEY env var (with auth)"); + // Read from env var and use auth + let api_key = std::env::var("_UNION_EAGER_API_KEY").map_err(|_| { + ControllerError::SystemError( + "_UNION_EAGER_API_KEY env var must be provided".to_string(), + ) + })?; + let auth_config = AuthConfig::new_from_api_key(&api_key).expect("Bad api key"); + let endpoint_url = auth_config.endpoint.clone(); + + let endpoint_static: &'static str = + Box::leak(Box::new(endpoint_url.clone().into_boxed_str())); // shared queue let (shared_tx, rx_of_shared_queue) = mpsc::channel::(64); let rt = get_runtime(); - let (state_client, queue_client) = rt.block_on(async { - // Need to update to with auth to read API key - let endpoint = Endpoint::from_static(endpoint_static); + let channel = rt.block_on(async { + // todo: escape hatch for localhost + // Strip "https://" to get just the hostname for TLS config + let domain = endpoint_url.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError( + "Endpoint must start with https:// when using auth".to_string(), + ) + })?; + + // Create TLS-configured endpoint + let endpoint = create_tls_endpoint(endpoint_static, domain).await?; let channel = endpoint.connect().await.map_err(ControllerError::from)?; - Ok::<_, ControllerError>(( - StateServiceClient::new(channel.clone()), - QueueServiceClient::new(channel), - )) + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config.clone())); + let auth_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + Ok::<_, ControllerError>(ChannelType::Authenticated(auth_channel)) + })?; + + let real_base_controller = CoreBaseController { + channel, + informer: OnceCell::new(), + state_client_cache: OnceLock::new(), + queue_client_cache: OnceLock::new(), + shared_queue: shared_tx, + rx_of_shared_queue: Arc::new(tokio::sync::Mutex::new(rx_of_shared_queue)), + }; + + let real_base_controller = Arc::new(real_base_controller); + // Start the background worker + let controller_clone = real_base_controller.clone(); + rt.spawn(async move { + controller_clone.bg_worker().await; + }); + Ok(real_base_controller) + } + + pub fn new_without_auth(endpoint: String) -> Result, ControllerError> { + let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); + // shared queue + let (shared_tx, rx_of_shared_queue) = mpsc::channel::(64); + + let rt = get_runtime(); + let channel = rt.block_on(async { + let chan = if endpoint.starts_with("http://") { + let endpoint = Endpoint::from_static(endpoint_static); + endpoint.connect().await.map_err(ControllerError::from)? + } else if endpoint.starts_with("https://") { + // Strip "https://" to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + + // Create TLS-configured endpoint + let endpoint = create_tls_endpoint(endpoint_static, domain).await?; + endpoint.connect().await.map_err(ControllerError::from)? + } else { + return Err(ControllerError::SystemError(format!( + "Malformed endpoint {}", + endpoint + ))); + }; + Ok::<_, ControllerError>(ChannelType::Plain(chan)) })?; let real_base_controller = CoreBaseController { - state_client, - queue_client, + channel, informer: OnceCell::new(), + state_client_cache: OnceLock::new(), + queue_client_cache: OnceLock::new(), shared_queue: shared_tx, rx_of_shared_queue: Arc::new(tokio::sync::Mutex::new(rx_of_shared_queue)), }; @@ -346,7 +532,7 @@ impl CoreBaseController { let enqueue_request = self .create_enqueue_action_request(action) .expect("Failed to create EnqueueActionRequest"); - let mut client = self.queue_client.clone(); + let mut client = self.queue_client(); // todo: tonic doesn't seem to have wait_for_ready, or maybe the .ready is already doing this. let enqueue_result = client.enqueue_action(enqueue_request).await; match enqueue_result { @@ -400,7 +586,7 @@ impl CoreBaseController { .get_or_try_init(|| async move { info!("Creating informer set to run_id {:?}", run_id); let inf = Arc::new(Informer::new( - self.state_client.clone(), + self.state_client(), run_id, parent_action_name, self.shared_queue.clone(), @@ -439,13 +625,275 @@ struct BaseController(Arc); #[pymethods] impl BaseController { #[new] - #[pyo3(signature = (*, endpoint))] - fn new(endpoint: String) -> PyResult { - info!("Creating controller wrapper with endpoint {:?}", endpoint); - let core_base = CoreBaseController::try_new(endpoint)?; + #[pyo3(signature = (*, endpoint=None))] + fn new(endpoint: Option) -> PyResult { + let core_base = if let Some(ep) = endpoint { + info!("Creating controller wrapper with endpoint {:?}", ep); + CoreBaseController::new_without_auth(ep)? + } else { + info!("Creating controller wrapper from _UNION_EAGER_API_KEY env var"); + CoreBaseController::new_with_auth()? + }; Ok(BaseController(core_base)) } + #[staticmethod] + fn try_list_tasks(py: Python<'_>) -> PyResult> { + future_into_py(py, async move { + use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; + use flyteidl2::flyteidl::common::ListRequest; + use flyteidl2::flyteidl::common::ProjectIdentifier; + use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; + use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; + use tonic::Code; + use tower::ServiceBuilder; + + let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { + warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); + String::new() + }); + + let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; + let endpoint = auth_config.endpoint.clone(); + let static_endpoint = endpoint.clone().leak(); + // Strip "https://" (8 chars) to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + let endpoint = create_tls_endpoint(static_endpoint, domain).await?; + let channel = endpoint.connect().await.map_err(ControllerError::from)?; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + + let auth_handling_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + let mut task_client = TaskServiceClient::new(auth_handling_channel); + + let list_request_base = ListRequest { + limit: 100, + ..Default::default() + }; + let req = ListTasksRequest { + request: Some(list_request_base), + known_filters: vec![], + scope_by: Some(list_tasks_request::ScopeBy::ProjectId(ProjectIdentifier { + organization: "demo".to_string(), + domain: "development".to_string(), + name: "flytesnacks".to_string(), + })), + }; + + let mut attempts = 0; + let final_result = loop { + let result = task_client.list_tasks(req.clone()).await; + match result { + Ok(response) => { + println!("Success: {:?}", response.into_inner()); + break Ok(true); + } + Err(status) if status.code() == Code::Unauthenticated && attempts < 1 => { + attempts += 1; + continue; + } + Err(status) => { + eprintln!("Error calling gRPC: {}", status); + break Err(exceptions::PyRuntimeError::new_err(format!( + "gRPC error: {}", + status + ))); + } + } + }; + warn!("Finished try_list_tasks with result {:?}", final_result); + final_result + }) + } + + #[staticmethod] + fn try_watch(py: Python<'_>) -> PyResult> { + future_into_py(py, async move { + use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; + use flyteidl2::flyteidl::common::ActionIdentifier; + use flyteidl2::flyteidl::common::RunIdentifier; + use flyteidl2::flyteidl::workflow::watch_request::Filter; + use flyteidl2::flyteidl::workflow::WatchRequest; + use std::time::Duration; + use tokio::time::sleep; + use tower::ServiceBuilder; + + info!("Starting watch example with authentication and retry..."); + + // Read in the api key which gives us the endpoint to connect to as well as the credentials + let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { + warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); + String::new() + }); + + let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; + let endpoint = auth_config.endpoint.clone(); + let static_endpoint = endpoint.clone().leak(); + // Strip "https://" (8 chars) to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + let endpoint = create_tls_endpoint(static_endpoint, domain).await?; + let channel = endpoint.connect().await.map_err(ControllerError::from)?; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + + // Wrap channel with auth layer - ALL calls now automatically authenticated! + let auth_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + let mut client = StateServiceClient::new(auth_channel); + + // Watch configuration (matching Python example) + let run_id = RunIdentifier { + org: "demo".to_string(), + project: "flytesnacks".to_string(), + domain: "development".to_string(), + name: "r57jklb4mw4k6bkb2p88".to_string(), + }; + let parent_action_name = "a0".to_string(); + + // Retry parameters (matching Python defaults) + let min_watch_backoff = Duration::from_secs(1); + let max_watch_backoff = Duration::from_secs(30); + let max_watch_retries = 10; + + // Watch loop with retry logic (following Python _informer.py pattern) + let mut retries = 0; + let mut message_count = 0; + + while retries < max_watch_retries { + if retries >= 1 { + warn!("Watch retrying, attempt {}/{}", retries, max_watch_retries); + } + + // Create watch request + let request = WatchRequest { + filter: Some(Filter::ParentActionId(ActionIdentifier { + name: parent_action_name.clone(), + run: Some(run_id.clone()), + })), + }; + + // Establish the watch stream + // The outer retry loop handles failures, middleware handles auth refresh + let stream_result = client.watch(request.clone()).await; + + match stream_result { + Ok(response) => { + info!("Successfully established watch stream"); + let mut stream = response.into_inner(); + + // Process messages from the stream + loop { + match stream.message().await { + Ok(Some(watch_response)) => { + // Successfully received a message - reset retry counter + retries = 0; + message_count += 1; + + // Process the message (enum with ActionUpdate or ControlMessage) + use flyteidl2::flyteidl::workflow::watch_response::Message; + match &watch_response.message { + Some(Message::ControlMessage(control_msg)) => { + if control_msg.sentinel { + info!( + "Received Sentinel for parent action: {}", + parent_action_name + ); + } + } + Some(Message::ActionUpdate(action_update)) => { + info!( + "Received action update for: {} (phase: {:?})", + action_update + .action_id + .as_ref() + .map(|id| id.name.as_str()) + .unwrap_or("unknown"), + action_update.phase + ); + + if !action_update.output_uri.is_empty() { + info!("Output URI: {}", action_update.output_uri); + } + + if action_update.phase == 4 { + // PHASE_FAILED + if action_update.error.is_some() { + error!( + "Action failed with error: {:?}", + action_update.error + ); + } + } + } + None => { + warn!("Received empty watch response"); + } + } + + // For demo purposes, exit after receiving a few messages + if message_count >= 50 { + info!("Received {} messages, exiting demo", message_count); + return Ok(true); + } + } + Ok(None) => { + warn!("Watch stream ended gracefully"); + break; // Stream ended, retry + } + Err(status) => { + error!("Error receiving message from watch stream: {}", status); + + // Check if it's an auth error + if status.code() == tonic::Code::Unauthenticated { + warn!("Unauthenticated error - credentials will be refreshed on retry"); + } + + break; // Break inner loop to retry + } + } + } + } + Err(status) => { + error!("Failed to establish watch stream: {}", status); + + if status.code() == tonic::Code::Unauthenticated { + warn!("Unauthenticated error - credentials will be refreshed on retry"); + } + } + } + + // Increment retry counter and apply exponential backoff + retries += 1; + if retries < max_watch_retries { + let backoff = min_watch_backoff + .saturating_mul(2_u32.pow(retries as u32)) + .min(max_watch_backoff); + warn!("Watch failed, retrying in {:?}...", backoff); + sleep(backoff).await; + } + } + + // Exceeded max retries + error!( + "Watch failure retries crossed threshold {}/{}, exiting!", + retries, max_watch_retries + ); + Err(exceptions::PyRuntimeError::new_err(format!( + "Max watch retries ({}) exceeded", + max_watch_retries + ))) + }) + } + /// `async def submit(self, action: Action) -> Action` /// /// Enqueue `action`. @@ -489,8 +937,6 @@ impl BaseController { } } -// use cloudidl::pymodules::cloud_mod; - #[pymodule] fn flyte_controller_base(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { static INIT: std::sync::Once = std::sync::Once::new(); diff --git a/rs_controller/src/lib_auth.rs b/rs_controller/src/lib_auth.rs new file mode 100644 index 000000000..bdd33b1db --- /dev/null +++ b/rs_controller/src/lib_auth.rs @@ -0,0 +1,5 @@ +// Re-export auth module for use in examples and external crates +pub mod auth; +pub mod proto; + +pub use auth::{AuthConfig, AuthInterceptor, ClientCredentialsAuthenticator, Credentials}; diff --git a/rs_controller/src/proto/flyteidl.service.rs b/rs_controller/src/proto/flyteidl.service.rs new file mode 100644 index 000000000..bcd69cb84 --- /dev/null +++ b/rs_controller/src/proto/flyteidl.service.rs @@ -0,0 +1,78 @@ +// @generated +// This file is @generated by prost-build. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct OAuth2MetadataRequest {} +/// OAuth2MetadataResponse defines an RFC-Compliant response for /.well-known/oauth-authorization-server metadata +/// as defined in +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct OAuth2MetadataResponse { + /// Defines the issuer string in all JWT tokens this server issues. The issuer can be admin itself or an external + /// issuer. + #[prost(string, tag = "1")] + pub issuer: ::prost::alloc::string::String, + /// URL of the authorization server's authorization endpoint \[RFC6749\]. This is REQUIRED unless no grant types are + /// supported that use the authorization endpoint. + #[prost(string, tag = "2")] + pub authorization_endpoint: ::prost::alloc::string::String, + /// URL of the authorization server's token endpoint \[RFC6749\]. + #[prost(string, tag = "3")] + pub token_endpoint: ::prost::alloc::string::String, + /// Array containing a list of the OAuth 2.0 response_type values that this authorization server supports. + #[prost(string, repeated, tag = "4")] + pub response_types_supported: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// JSON array containing a list of the OAuth 2.0 \[RFC6749\] scope values that this authorization server supports. + #[prost(string, repeated, tag = "5")] + pub scopes_supported: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// JSON array containing a list of client authentication methods supported by this token endpoint. + #[prost(string, repeated, tag = "6")] + pub token_endpoint_auth_methods_supported: + ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// URL of the authorization server's JWK Set \[JWK\] document. The referenced document contains the signing key(s) the + /// client uses to validate signatures from the authorization server. + #[prost(string, tag = "7")] + pub jwks_uri: ::prost::alloc::string::String, + /// JSON array containing a list of Proof Key for Code Exchange (PKCE) \[RFC7636\] code challenge methods supported by + /// this authorization server. + #[prost(string, repeated, tag = "8")] + pub code_challenge_methods_supported: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// JSON array containing a list of the OAuth 2.0 grant type values that this authorization server supports. + #[prost(string, repeated, tag = "9")] + pub grant_types_supported: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// URL of the authorization server's device authorization endpoint, as defined in Section 3.1 of \[RFC8628\] + #[prost(string, tag = "10")] + pub device_authorization_endpoint: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] +pub struct PublicClientAuthConfigRequest {} +/// FlyteClientResponse encapsulates public information that flyte clients (CLIs... etc.) can use to authenticate users. +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PublicClientAuthConfigResponse { + /// client_id to use when initiating OAuth2 authorization requests. + #[prost(string, tag = "1")] + pub client_id: ::prost::alloc::string::String, + /// redirect uri to use when initiating OAuth2 authorization requests. + #[prost(string, tag = "2")] + pub redirect_uri: ::prost::alloc::string::String, + /// scopes to request when initiating OAuth2 authorization requests. + #[prost(string, repeated, tag = "3")] + pub scopes: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + /// Authorization Header to use when passing Access Tokens to the server. If not provided, the client should use the + /// default http `Authorization` header. + #[prost(string, tag = "4")] + pub authorization_metadata_key: ::prost::alloc::string::String, + /// ServiceHttpEndpoint points to the http endpoint for the backend. If empty, clients can assume the endpoint used + /// to configure the gRPC connection can be used for the http one respecting the insecure flag to choose between + /// SSL or no SSL connections. + #[prost(string, tag = "5")] + pub service_http_endpoint: ::prost::alloc::string::String, + /// audience to use when initiating OAuth2 authorization requests. + #[prost(string, tag = "6")] + pub audience: ::prost::alloc::string::String, +} + +include!("flyteidl.service.tonic.rs"); +// @@protoc_insertion_point(module) diff --git a/rs_controller/src/proto/flyteidl.service.tonic.rs b/rs_controller/src/proto/flyteidl.service.tonic.rs new file mode 100644 index 000000000..cc2cb4b92 --- /dev/null +++ b/rs_controller/src/proto/flyteidl.service.tonic.rs @@ -0,0 +1,409 @@ +// @generated +/// Generated client implementations. +pub mod auth_metadata_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + /** The following defines an RPC service that is also served over HTTP via grpc-gateway. + Standard response codes for both are defined here: https://github.com/grpc-ecosystem/grpc-gateway/blob/master/runtime/errors.go + RPCs defined in this service must be anonymously accessible. +*/ + #[derive(Debug, Clone)] + pub struct AuthMetadataServiceClient { + inner: tonic::client::Grpc, + } + impl AuthMetadataServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl AuthMetadataServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> AuthMetadataServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + AuthMetadataServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + /** Anonymously accessible. Retrieves local or external oauth authorization server metadata. +*/ + pub async fn get_o_auth2_metadata( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flyteidl.service.AuthMetadataService/GetOAuth2Metadata", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "flyteidl.service.AuthMetadataService", + "GetOAuth2Metadata", + ), + ); + self.inner.unary(req, path, codec).await + } + /** Anonymously accessible. Retrieves the client information clients should use when initiating OAuth2 authorization + requests. +*/ + pub async fn get_public_client_config( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flyteidl.service.AuthMetadataService/GetPublicClientConfig", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert( + GrpcMethod::new( + "flyteidl.service.AuthMetadataService", + "GetPublicClientConfig", + ), + ); + self.inner.unary(req, path, codec).await + } + } +} +/// Generated server implementations. +pub mod auth_metadata_service_server { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + /// Generated trait containing gRPC methods that should be implemented for use with AuthMetadataServiceServer. + #[async_trait] + pub trait AuthMetadataService: Send + Sync + 'static { + /** Anonymously accessible. Retrieves local or external oauth authorization server metadata. +*/ + async fn get_o_auth2_metadata( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + /** Anonymously accessible. Retrieves the client information clients should use when initiating OAuth2 authorization + requests. +*/ + async fn get_public_client_config( + &self, + request: tonic::Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + >; + } + /** The following defines an RPC service that is also served over HTTP via grpc-gateway. + Standard response codes for both are defined here: https://github.com/grpc-ecosystem/grpc-gateway/blob/master/runtime/errors.go + RPCs defined in this service must be anonymously accessible. +*/ + #[derive(Debug)] + pub struct AuthMetadataServiceServer { + inner: _Inner, + accept_compression_encodings: EnabledCompressionEncodings, + send_compression_encodings: EnabledCompressionEncodings, + max_decoding_message_size: Option, + max_encoding_message_size: Option, + } + struct _Inner(Arc); + impl AuthMetadataServiceServer { + pub fn new(inner: T) -> Self { + Self::from_arc(Arc::new(inner)) + } + pub fn from_arc(inner: Arc) -> Self { + let inner = _Inner(inner); + Self { + inner, + accept_compression_encodings: Default::default(), + send_compression_encodings: Default::default(), + max_decoding_message_size: None, + max_encoding_message_size: None, + } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> InterceptedService + where + F: tonic::service::Interceptor, + { + InterceptedService::new(Self::new(inner), interceptor) + } + /// Enable decompressing requests with the given encoding. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.accept_compression_encodings.enable(encoding); + self + } + /// Compress responses with the given encoding, if the client supports it. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.send_compression_encodings.enable(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.max_decoding_message_size = Some(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.max_encoding_message_size = Some(limit); + self + } + } + impl tonic::codegen::Service> for AuthMetadataServiceServer + where + T: AuthMetadataService, + B: Body + Send + 'static, + B::Error: Into + Send + 'static, + { + type Response = http::Response; + type Error = std::convert::Infallible; + type Future = BoxFuture; + fn poll_ready( + &mut self, + _cx: &mut Context<'_>, + ) -> Poll> { + Poll::Ready(Ok(())) + } + fn call(&mut self, req: http::Request) -> Self::Future { + let inner = self.inner.clone(); + match req.uri().path() { + "/flyteidl.service.AuthMetadataService/GetOAuth2Metadata" => { + #[allow(non_camel_case_types)] + struct GetOAuth2MetadataSvc(pub Arc); + impl< + T: AuthMetadataService, + > tonic::server::UnaryService + for GetOAuth2MetadataSvc { + type Response = super::OAuth2MetadataResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_o_auth2_metadata( + &inner, + request, + ) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetOAuth2MetadataSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + "/flyteidl.service.AuthMetadataService/GetPublicClientConfig" => { + #[allow(non_camel_case_types)] + struct GetPublicClientConfigSvc(pub Arc); + impl< + T: AuthMetadataService, + > tonic::server::UnaryService + for GetPublicClientConfigSvc { + type Response = super::PublicClientAuthConfigResponse; + type Future = BoxFuture< + tonic::Response, + tonic::Status, + >; + fn call( + &mut self, + request: tonic::Request, + ) -> Self::Future { + let inner = Arc::clone(&self.0); + let fut = async move { + ::get_public_client_config( + &inner, + request, + ) + .await + }; + Box::pin(fut) + } + } + let accept_compression_encodings = self.accept_compression_encodings; + let send_compression_encodings = self.send_compression_encodings; + let max_decoding_message_size = self.max_decoding_message_size; + let max_encoding_message_size = self.max_encoding_message_size; + let inner = self.inner.clone(); + let fut = async move { + let inner = inner.0; + let method = GetPublicClientConfigSvc(inner); + let codec = tonic::codec::ProstCodec::default(); + let mut grpc = tonic::server::Grpc::new(codec) + .apply_compression_config( + accept_compression_encodings, + send_compression_encodings, + ) + .apply_max_message_size_config( + max_decoding_message_size, + max_encoding_message_size, + ); + let res = grpc.unary(method, req).await; + Ok(res) + }; + Box::pin(fut) + } + _ => { + Box::pin(async move { + Ok( + http::Response::builder() + .status(200) + .header("grpc-status", "12") + .header("content-type", "application/grpc") + .body(empty_body()) + .unwrap(), + ) + }) + } + } + } + } + impl Clone for AuthMetadataServiceServer { + fn clone(&self) -> Self { + let inner = self.inner.clone(); + Self { + inner, + accept_compression_encodings: self.accept_compression_encodings, + send_compression_encodings: self.send_compression_encodings, + max_decoding_message_size: self.max_decoding_message_size, + max_encoding_message_size: self.max_encoding_message_size, + } + } + } + impl Clone for _Inner { + fn clone(&self) -> Self { + Self(Arc::clone(&self.0)) + } + } + impl std::fmt::Debug for _Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } + } + impl tonic::server::NamedService + for AuthMetadataServiceServer { + const NAME: &'static str = "flyteidl.service.AuthMetadataService"; + } +} diff --git a/rs_controller/src/proto/mod.rs b/rs_controller/src/proto/mod.rs new file mode 100644 index 000000000..6419d5808 --- /dev/null +++ b/rs_controller/src/proto/mod.rs @@ -0,0 +1,12 @@ +// Generated protobuf files for Flyte IDL +// Only including the files needed for authentication and basic operations + +#[path = "flyteidl.service.rs"] +pub mod service; + +// Re-export the auth-related types and services for convenience +pub use service::auth_metadata_service_client::AuthMetadataServiceClient; +pub use service::OAuth2MetadataRequest; +pub use service::OAuth2MetadataResponse; +pub use service::PublicClientAuthConfigRequest; +pub use service::PublicClientAuthConfigResponse; diff --git a/rs_controller/test_auth_direct.py b/rs_controller/test_auth_direct.py new file mode 100644 index 000000000..996e5098f --- /dev/null +++ b/rs_controller/test_auth_direct.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +Direct test of auth metadata service without middleware +""" + +import asyncio +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def test_auth_service_direct(): + """Test calling auth metadata service directly""" + from flyte_controller_base import BaseController + + # This will try to call the auth service + logger.info("Testing direct auth service call...") + result = await BaseController.try_list_tasks() + logger.info(f"Result: {result}") + return result + + +if __name__ == "__main__": + asyncio.run(test_auth_service_direct()) diff --git a/rs_controller/test_auth_simple.py b/rs_controller/test_auth_simple.py new file mode 100644 index 000000000..b342fafc4 --- /dev/null +++ b/rs_controller/test_auth_simple.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Simple test script for Rust controller auth functionality. + +Usage: + export CLIENT_SECRET="your-secret-here" + python test_auth_simple.py +""" + +import asyncio +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +async def test_list_tasks(): + """Test unary gRPC call with auth (list tasks endpoint)""" + from flyte_controller_base import BaseController + + logger.info("=" * 60) + logger.info("Testing unary gRPC call: list_tasks") + logger.info("=" * 60) + + try: + logger.info("About to call try_list_tasks...") + result = await BaseController.try_list_tasks() + logger.info(f"Got result: {result}") + + if result: + logger.info("✅ list_tasks test PASSED") + else: + logger.warning("⚠️ list_tasks test returned False") + + return result + except Exception as e: + logger.error(f"❌ list_tasks test FAILED: {e}", exc_info=True) + return False + + +async def test_watch(): + """Test streaming gRPC call with auth (watch endpoint)""" + from flyte_controller_base import BaseController + + logger.info("\n" + "=" * 60) + logger.info("Testing streaming gRPC call: watch") + logger.info("=" * 60) + + try: + result = await BaseController.try_watch() + + if result: + logger.info("✅ watch test PASSED") + else: + logger.warning("⚠️ watch test returned False") + + return result + except Exception as e: + logger.error(f"❌ watch test FAILED: {e}", exc_info=True) + return False + + +async def main(): + """Run all tests""" + import os + + logger.info("Starting Rust controller authentication tests") + logger.info(f"CLIENT_SECRET set: {'Yes' if os.getenv('EAGER_API_KEY') else 'No (will use empty string)'}") + + # Test 1: Unary call (list tasks) + result1 = await test_list_tasks() + print(result1) + + # Test 2: Streaming call (watch) + result2 = await test_watch() + print(result2) + + # # Summary + # logger.info("\n" + "=" * 60) # logger.info("Test Summary") + # logger.info("=" * 60) + # logger.info(f"list_tasks (unary): {'✅ PASSED' if result1 else '❌ FAILED'}") + # logger.info(f"watch (streaming): {'✅ PASSED' if result2 else '❌ FAILED'}") + # + # # Exit code + # if result1 and result2: + # logger.info("\n🎉 All tests passed!") + # sys.exit(0) + # else: + # logger.error("\n💥 Some tests failed") + # sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/rs_controller/try_rust_controller.py b/rs_controller/try_rust_controller.py new file mode 100644 index 000000000..5b8490aac --- /dev/null +++ b/rs_controller/try_rust_controller.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +""" +Test script for Rust controller auth functionality. + +This script tests both unary (list_tasks) and streaming (watch) gRPC calls +with authentication and retry logic. + +Usage: + export CLIENT_SECRET="your-secret-here" + python try_rust_controller.py +""" + +import asyncio +import logging +import sys + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + + +def test_list_tasks(): + """Test unary gRPC call with auth (list tasks endpoint)""" + logger.info("=" * 60) + logger.info("Testing unary gRPC call: list_tasks") + logger.info("=" * 60) + + try: + from flyte_controller_base import BaseController + + # Run the async test + result = asyncio.run(BaseController.try_list_tasks()) + + if result: + logger.info("✅ list_tasks test PASSED") + else: + logger.warning("⚠️ list_tasks test returned False") + + return result + + except Exception as e: + logger.error(f"❌ list_tasks test FAILED: {e}", exc_info=True) + return False + + +def test_watch(): + """Test streaming gRPC call with auth (watch endpoint)""" + logger.info("\n" + "=" * 60) + logger.info("Testing streaming gRPC call: watch") + logger.info("=" * 60) + + try: + from flyte_controller_base import BaseController + + # Run the async test + result = asyncio.run(BaseController.try_watch()) + + if result: + logger.info("✅ watch test PASSED") + else: + logger.warning("⚠️ watch test returned False") + + return result + + except Exception as e: + logger.error(f"❌ watch test FAILED: {e}", exc_info=True) + return False + + +def main(): + """Run all tests""" + import os + + logger.info("Starting Rust controller authentication tests") + logger.info(f"CLIENT_SECRET env var set: {'Yes' if os.getenv('CLIENT_SECRET') else 'No (will use empty string)'}") + + results = [] + + # Test 1: Unary call (list tasks) + results.append(("list_tasks (unary)", test_list_tasks())) + + # Test 2: Streaming call (watch) + results.append(("watch (streaming)", test_watch())) + + # Summary + logger.info("\n" + "=" * 60) + logger.info("Test Summary") + logger.info("=" * 60) + + for test_name, result in results: + status = "✅ PASSED" if result else "❌ FAILED" + logger.info(f"{test_name}: {status}") + + # Exit code + all_passed = all(result for _, result in results) + if all_passed: + logger.info("\n🎉 All tests passed!") + sys.exit(0) + else: + logger.error("\n💥 Some tests failed") + sys.exit(1) + + +if __name__ == "__main__": + test_list_tasks() diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 43acf162b..670a46351 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -118,12 +118,14 @@ def create_controller( controller = LocalController() case "remote" | "hybrid": - from flyte._internal.controllers.remote import create_remote_controller - - controller = create_remote_controller(**kwargs) - # from flyte._internal.controllers.remote._r_controller import RemoteController + # from flyte._internal.controllers.remote import create_remote_controller # - # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) + # controller = create_remote_controller(**kwargs) + from flyte._internal.controllers.remote._r_controller import RemoteController + + # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, + # max_system_retries=5) + controller = RemoteController(workers=10, max_system_retries=5) case "rust": # hybrid case, despite the case statement above, meant for local runs not inside docker from flyte._internal.controllers.remote._r_controller import RemoteController diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 865396dc2..d101cf544 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -70,17 +70,32 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception: Raises: Exception: The converted native exception or RuntimeSystemError """ - err = action.err or action.client_err - if not err and action.phase == run_definition_pb2.PHASE_FAILED: + # Deserialize err from bytes if present + from flyteidl2.core import execution_pb2 + + err = None + if action.err_bytes: + err_pb = execution_pb2.ExecutionError() + err_pb.ParseFromString(action.err_bytes) + err = err_pb + + err = err or action.client_err + if not err and action.phase_value == run_definition_pb2.PHASE_FAILED: logger.error(f"Server reported failure for action {action.name}, checking error file.") try: - error_path = io.error_path(f"{action.run_output_base}/{action.action_id.name}/1") + # Deserialize action_id to get the name + action_id_pb = identifier_pb2.ActionIdentifier() + action_id_pb.ParseFromString(action.action_id_bytes) + error_path = io.error_path(f"{action.run_output_base}/{action_id_pb.name}/1") err = await io.load_error(error_path) except Exception as e: logger.exception("Failed to load error file", e) err = flyte.errors.RuntimeSystemError(type(e).__name__, f"Failed to load error file: {e}") else: - logger.error(f"Server reported failure for action {action.action_id.name}, error: {err}") + # Deserialize action_id to get the name for logging + action_id_pb = identifier_pb2.ActionIdentifier() + action_id_pb.ParseFromString(action.action_id_bytes) + logger.error(f"Server reported failure for action {action_id_pb.name}, error: {err}") exc = convert.convert_error_to_native(err) if not exc: @@ -116,15 +131,16 @@ class RemoteController(BaseController): def __new__( cls, - endpoint: str, + endpoint: str | None = None, workers: int = 20, max_system_retries: int = 10, ): + # No endpoint means must have the api key env var return super().__new__(cls, endpoint=endpoint) def __init__( self, - endpoint: str, + endpoint: str | None = None, workers: int = 20, max_system_retries: int = 10, ): @@ -255,34 +271,44 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!") except asyncio.CancelledError: # If the action is cancelled, we need to cancel the action on the server as well - logger.info(f"Action {action.action_id.name} cancelled, cancelling on server") + action_id_pb = identifier_pb2.ActionIdentifier() + action_id_pb.ParseFromString(action.action_id_bytes) + logger.info(f"Action {action_id_pb.name} cancelled, cancelling on server") await self.cancel_action(action) raise # If the action is aborted, we should abort the controller as well - if n.phase == run_definition_pb2.PHASE_ABORTED: - logger.warning(f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}") + if n.phase_value == run_definition_pb2.PHASE_ABORTED: + n_action_id_pb = identifier_pb2.ActionIdentifier() + n_action_id_pb.ParseFromString(n.action_id_bytes) + logger.warning( + f"Action {n_action_id_pb.name} was aborted, aborting current Action {current_action_id.name}" + ) raise flyte.errors.RunAbortedError( - f"Action {n.action_id.name} was aborted, aborting current Action {current_action_id.name}" + f"Action {n_action_id_pb.name} was aborted, aborting current Action {current_action_id.name}" ) - if n.phase == run_definition_pb2.PHASE_TIMED_OUT: + if n.phase_value == run_definition_pb2.PHASE_TIMED_OUT: + n_action_id_pb = identifier_pb2.ActionIdentifier() + n_action_id_pb.ParseFromString(n.action_id_bytes) logger.warning( - f"Action {n.action_id.name} timed out, raising timeout exception Action {current_action_id.name}" + f"Action {n_action_id_pb.name} timed out, raising timeout exception Action {current_action_id.name}" ) raise flyte.errors.TaskTimeoutError( - f"Action {n.action_id.name} timed out, raising exception in current Action {current_action_id.name}" + f"Action {n_action_id_pb.name} timed out, raising exception in current Action {current_action_id.name}" ) - if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED: - exc = await handle_action_failure(action, _task.name) + if n.has_error() or n.phase_value == run_definition_pb2.PHASE_FAILED: + exc = await handle_action_failure(n, _task.name) raise exc if _task.native_interface.outputs: if not n.realized_outputs_uri: + n_action_id_pb = identifier_pb2.ActionIdentifier() + n_action_id_pb.ParseFromString(n.action_id_bytes) raise flyte.errors.RuntimeSystemError( "RuntimeError", - f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", + f"Task {n_action_id_pb.name} did not return an output path, but the task has outputs defined.", ) return await load_and_convert_outputs( _task.native_interface, n.realized_outputs_uri, max_bytes=_task.max_inline_io_bytes @@ -432,15 +458,23 @@ async def get_action_outputs( if prev_action is None: return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False - if prev_action.phase == run_definition_pb2.PHASE_FAILED: + if prev_action.phase_value == run_definition_pb2.PHASE_FAILED: if prev_action.has_error(): - exc = convert.convert_error_to_native(prev_action.err) + # Deserialize err from bytes + from flyteidl2.core import execution_pb2 + + err_pb = execution_pb2.ExecutionError() + err_pb.ParseFromString(prev_action.err_bytes) + exc = convert.convert_error_to_native(err_pb) return ( TraceInfo(func_name, sub_action_id, _interface, inputs_uri, error=exc), True, ) else: - logger.warning(f"Action {prev_action.action_id.name} failed, but no error was found, re-running trace!") + # Deserialize action_id for logging + prev_action_id_pb = identifier_pb2.ActionIdentifier() + prev_action_id_pb.ParseFromString(prev_action.action_id_bytes) + logger.warning(f"Action {prev_action_id_pb.name} failed, but no error was found, re-running trace!") elif prev_action.realized_outputs_uri is not None: o = await io.load_outputs(prev_action.realized_outputs_uri, max_bytes=MAX_TRACE_BYTES) outputs = await convert.convert_outputs_to_native(_interface, o) @@ -584,19 +618,23 @@ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, logger.info(f"Action for task [{task_name}] action id: {action.name}, completed!") except asyncio.CancelledError: # If the action is cancelled, we need to cancel the action on the server as well - logger.info(f"Action {action.action_id.name} cancelled, cancelling on server") + action_id_pb = identifier_pb2.ActionIdentifier() + action_id_pb.ParseFromString(action.action_id_bytes) + logger.info(f"Action {action_id_pb.name} cancelled, cancelling on server") await self.cancel_action(action) raise - if n.has_error() or n.phase == run_definition_pb2.PHASE_FAILED: - exc = await handle_action_failure(action, task_name) + if n.has_error() or n.phase_value == run_definition_pb2.PHASE_FAILED: + exc = await handle_action_failure(n, task_name) raise exc if native_interface.outputs: if not n.realized_outputs_uri: + n_action_id_pb = identifier_pb2.ActionIdentifier() + n_action_id_pb.ParseFromString(n.action_id_bytes) raise flyte.errors.RuntimeSystemError( "RuntimeError", - f"Task {n.action_id.name} did not return an output path, but the task has outputs defined.", + f"Task {n_action_id_pb.name} did not return an output path, but the task has outputs defined.", ) return await load_and_convert_outputs(native_interface, n.realized_outputs_uri, _task.max_inline_io_bytes) return None From 2d5d693128e005e1d113e9d92fd83e44729bd4cd Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sat, 13 Dec 2025 00:07:22 +0800 Subject: [PATCH 11/22] ignore - pr into pr (#423) * Split out controller into core.rs and lib.rs. core.rs should be pure Rust and not depend on Python at all. * Update cargo.toml with two features - one that turns on pyo3/auto-initialize for building Rust crates (rlib), which allows us to cargo run binary files, and one that turns on pyo3/extension-module, which lets the manylinux image build wheels. * Temporarily adding in local flyte v2 idl - will remove next pr. * Update Phase -> ActionPhase. * Move the test functions in the Python BaseController class into separate test binaries now that we have rlib. (can delete these in the future) * Add an informer cache to support multiple root runs. * Hook up finalize parent action, which exposed the fact that Action::merge_from_submit was not being called, added. * Move controller errors out to a new file. * Update readme for setup docs a bit. --------- Signed-off-by: Yee Hing Tong --- .gitignore | 1 + README.md | 39 +- examples/advanced/hybrid_mode.py | 2 +- rs_controller/.cargo/config.toml | 3 - rs_controller/AUTH_IMPLEMENTATION.md | 130 --- rs_controller/Cargo.lock | 6 +- rs_controller/Cargo.toml | 26 +- rs_controller/Makefile | 1 + rs_controller/pyproject.toml | 8 +- rs_controller/src/action.rs | 58 +- rs_controller/src/bin/test_controller.rs | 39 + rs_controller/src/bin/try_list_tasks.rs | 83 ++ rs_controller/src/bin/try_watch.rs | 191 ++++ rs_controller/src/core.rs | 617 ++++++++++++ rs_controller/src/error.rs | 39 + rs_controller/src/informer.rs | 423 +++++--- rs_controller/src/lib.rs | 915 ++---------------- rs_controller/uv.lock | 8 + src/flyte/_internal/controllers/__init__.py | 17 +- .../controllers/remote/_r_controller.py | 15 +- 20 files changed, 1469 insertions(+), 1152 deletions(-) delete mode 100644 rs_controller/.cargo/config.toml delete mode 100644 rs_controller/AUTH_IMPLEMENTATION.md create mode 100644 rs_controller/src/bin/test_controller.rs create mode 100644 rs_controller/src/bin/try_list_tasks.rs create mode 100644 rs_controller/src/bin/try_watch.rs create mode 100644 rs_controller/src/core.rs create mode 100644 rs_controller/src/error.rs create mode 100644 rs_controller/uv.lock diff --git a/.gitignore b/.gitignore index 1acc3e124..3f8488175 100644 --- a/.gitignore +++ b/.gitignore @@ -190,3 +190,4 @@ examples/remote_example.py config.yaml .claude/ tests/flyte/internal/bin/flyte-inputs.json +.cargo/ diff --git a/README.md b/README.md index 41e85914b..125ff6bf5 100644 --- a/README.md +++ b/README.md @@ -329,9 +329,9 @@ The following instructions are for helping to build the default multi-arch image `cd` into `rs_controller` and run `make build-builders`. This will build the builder images once, so you can keep using them as the rust code changes. ### Iteration Cycle -Make sure you have `CLOUD_REPO=/Users//go/src/github.com/unionai/cloud` exported and checked out to a branch that has the latest prost generated code. Delete this comment and update make target in the future if it gets merged/published. - -Then run `make build-wheels`. +Run `make build-wheels` to actually build the multi-arch wheels. This command should probably be updated to build all three, +currently it only builds for linux/amd64 and linux/arm64... the `make build-wheel-local` command builds a macosx wheel, +unclear what the difference is between that and the arm64 one, and unclear if both are present, which one pip chooses. `cd` back up to the root folder of this project and proceed with ```bash @@ -341,6 +341,37 @@ python maint_tools/build_default_image.py To install the wheel locally for testing, use the following command with your venv active. ```bash -uv pip install --find-links ./rs_controller/dist --no-index --force-reinstall flyte_controller_base +uv pip install --find-links ./rs_controller/dist --no-index --force-reinstall --no-deps flyte_controller_base ``` Repeat this process to iterate - build new wheels, force reinstall the controller package. + +### Build Configuration Summary + +In order to support both Rust crate publication and Python wheel distribution, we have +to sometimes use and sometimes not use the 'pyo3/extension-module' feature. To do this, this +project's Cargo.toml itself can toggle this on and off. + + [features] + default = ["pyo3/auto-initialize"] # For Rust crate users (links to libpython) + extension-module = ["pyo3/extension-module"] # For Python wheels (no libpython linking) + +The cargo file contains + + # Cargo.toml + [lib] + crate-type = ["rlib", "cdylib"] # Support both Rust and Python usage + +When using 'default', 'auto-initialize' is turned on, which requires linking to libpython, which exists on local Mac so +this works nicely. It is not available in manylinux however, so trying to build with this feature in a manylinux docker +image will fail. But that's okay, because the purpose of the manylinux container is to build wheels, +and for wheels, we need the 'extension-module' feature, which disables linking to libpython. + +The key insight: auto-initialize is for embedding Python in Rust (needs libpython), while +extension-module is for extending Python with Rust (must NOT link libpython for portability). + +This setup makes it possible to build wheels and also run Rust binaries with `cargo run --bin`. + +(not sure if this is needed) + # pyproject.toml + [tool.maturin] + features = ["extension-module"] # Tells maturin to use extension-module feature diff --git a/examples/advanced/hybrid_mode.py b/examples/advanced/hybrid_mode.py index 73b7b579e..db8af829e 100644 --- a/examples/advanced/hybrid_mode.py +++ b/examples/advanced/hybrid_mode.py @@ -63,7 +63,7 @@ async def hybrid_parent_placeholder(): flyte.init_from_config("/Users/ytong/.flyte/config-k3d.yaml", root_dir=repo_root, storage=s3_sandbox) # Kick off a run of hybrid_parent_placeholder and fill in with kicked off things. - run_name = "rt26xx54p886brkhcns2" + run_name = "r9sfvk6plj7gld7fds6f" outputs = flyte.with_runcontext( mode="hybrid", name=run_name, diff --git a/rs_controller/.cargo/config.toml b/rs_controller/.cargo/config.toml deleted file mode 100644 index b39600ead..000000000 --- a/rs_controller/.cargo/config.toml +++ /dev/null @@ -1,3 +0,0 @@ -# .cargo/config.toml -[build] -rustflags = ["-L", "/opt/homebrew/opt/python@3.12/Frameworks/Python.framework/Versions/3.12/lib"] \ No newline at end of file diff --git a/rs_controller/AUTH_IMPLEMENTATION.md b/rs_controller/AUTH_IMPLEMENTATION.md deleted file mode 100644 index eff0c8a4c..000000000 --- a/rs_controller/AUTH_IMPLEMENTATION.md +++ /dev/null @@ -1,130 +0,0 @@ -# Rust Authentication Implementation for Flyte - -## Summary - -I've implemented client credentials OAuth2 authentication for the Rust gRPC clients, modeled after the Python implementation. The implementation includes: - -1. **Auth Module Structure** (`src/auth/`) - - `config.rs` - Auth configuration and helper traits - - `token_client.rs` - OAuth2 token retrieval logic - - `client_credentials.rs` - Client credentials authenticator with token caching - - `interceptor.rs` - gRPC interceptor for adding auth headers and handling 401s - -2. **Proto Module** (`src/proto/`) - - Organized generated protobuf files from v1 Flyte IDL - - Includes `AuthMetadataService` for fetching OAuth2 metadata - -3. **Key Features** - - Automatic token fetching on first request - - Token caching with expiration tracking - - Automatic refresh on 401/Unauthenticated errors - - Thread-safe credential management using RwLock - - Retry logic with automatic credential refresh - -## Current Status - -**Implemented but not fully compiling** - There are compilation issues with the generated proto files: - -1. Some proto files have `#[derive(Copy)]` on structs with non-Copy fields (String) -2. There may be missing prost-types features needed for Timestamp handling -3. Some module visibility issues to resolve - -## How It Works - -### Authentication Flow - -``` -1. Client creates AuthConfig with endpoint, client_id, client_secret -2. ClientCredentialsAuthenticator is created -3. On first gRPC call: - a. Authenticator fetches OAuth2 metadata from AuthMetadataService - b. Calls token endpoint with client credentials - c. Caches the access token with expiration time -4. AuthInterceptor adds "Bearer {token}" to request metadata -5. If request returns 401: - a. Interceptor triggers credential refresh - b. Retries the request with new token -``` - -### Usage Example - -```rust -use flyte_controller_base::auth::{AuthConfig, AuthInterceptor, ClientCredentialsAuthenticator}; - -// Create auth config -let auth_config = AuthConfig { - endpoint: "dns:///flyte.example.com:443".to_string(), - client_id: "your_client_id".to_string(), - client_secret: "your_secret".to_string(), - scopes: None, - audience: None, -}; - -// Create authenticator -let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); - -// Connect to endpoint -let channel = Endpoint::from_shared(endpoint)? - .connect() - .await?; - -// Create auth interceptor -let auth_interceptor = AuthInterceptor::new(authenticator, channel.clone()); - -// Make authenticated calls using the with_auth! macro -let response = with_auth!( - auth_interceptor, - my_client, - my_method, - request -)?; -``` - -## Files Created/Modified - -### New Files -- `src/auth/mod.rs` - Auth module exports -- `src/auth/config.rs` - Configuration types -- `src/auth/token_client.rs` - OAuth2 token client -- `src/auth/client_credentials.rs` - Authenticator implementation -- `src/auth/interceptor.rs` - gRPC interceptor with retry logic -- `src/proto/mod.rs` - Proto module organization -- `src/lib_auth.rs` - Re-exports for external use -- `examples/simple_auth_test.rs` - Test script -- `examples/auth_test.rs` - Example with actual API calls - -### Modified Files -- `Cargo.toml` - Added dependencies (reqwest, serde, base64, urlencoding) -- `src/lib.rs` - Added auth and proto modules - -## Next Steps to Fix Compilation - -1. **Fix proto file issues:** - - Remove `Copy` derives from structs with String fields in generated files - - OR regenerate the proto files with correct options - - OR use only the minimal auth-related protos - -2. **Check prost-types dependency:** - ```toml - prost-types = { version = "0.12", features = ["std"] } - ``` - May need to match the prost version exactly. - -3. **Simplest fix:** Extract just the `AuthMetadataService` related types into a minimal hand-written proto module (I started this in `src/proto/auth_service.rs`) - -## Testing - -Once compilation is fixed, test with: - -```bash -FLYTE_ENDPOINT=dns:///your-endpoint:443 \ -FLYTE_CLIENT_ID=your_id \ -FLYTE_CLIENT_SECRET=your_secret \ -cargo run --example simple_auth_test -``` - -## References - -- Python implementation: `/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/src/flyte/remote/_client/auth/` -- Flytekit PR #2416: https://github.com/flyteorg/flytekit/pull/2416/files -- Proto definitions: https://github.com/flyteorg/flyte/tree/v2/flyteidl2 diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index f4016dffd..d72a2f3ac 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -301,12 +301,12 @@ dependencies = [ "prost-types 0.13.5", "pyo3", "pyo3-async-runtimes", - "pyo3-build-config", "reqwest", "serde", "serde_json", "thiserror 1.0.69", "tokio", + "tokio-util", "tonic", "tower 0.4.13", "tower-http 0.5.2", @@ -318,9 +318,7 @@ dependencies = [ [[package]] name = "flyteidl2" -version = "2.0.0-alpha14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5da488fa330fc42fc4c2daa03f10e8bcbc69d302673036952e7f56c92718d13e" +version = "0.1.0" dependencies = [ "async-trait", "futures", diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index 7be07429a..c851c35eb 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -6,13 +6,22 @@ edition = "2021" [lib] name = "flyte_controller_base" path = "src/lib.rs" -crate-type = ["cdylib", "rlib"] +# Default to rlib for Rust users and binaries +# Maturin will override to ["cdylib"] when building wheels +crate-type = ["rlib", "cdylib"] + +[features] +# Default features for Rust crate users (includes auto-initialize for binaries) +default = ["pyo3/auto-initialize"] +# Extension module feature for Python wheels (no auto-initialize, no linking) +extension-module = ["pyo3/extension-module"] [dependencies] -# cloudidl = { path = "../../cloud/gen/pb_rust" } -pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } +# Use abi3 for stable API, auto-initialize only when not building extension-module +pyo3 = { version = "0.24", features = ["abi3-py310"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"] } tokio = { version = "1.0", features = ["full"] } +tokio-util = "0.7" tonic = { version = "0.12", features = ["tls", "tls-native-roots"] } prost = { version = "0.13", features = ["std"] } prost-types = { version = "0.13", features = ["std"] } @@ -23,8 +32,9 @@ tracing = "0.1" tracing-subscriber = "0.3" async-trait = "0.1" thiserror = "1.0" -pyo3-build-config = "0.24.2" -flyteidl2 = "=2.0.0-alpha14" +# Using local flyteidl2 without extension-module +flyteidl2 = { path = "/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust" } +# flyteidl2 = "=2.0.0-alpha14" reqwest = { version = "0.12", features = ["json", "rustls-tls"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -34,9 +44,3 @@ http = "1.0" urlencoding = "2.1" bytes = "1.0" http-body-util = "0.1" - -[build-dependencies] -pyo3 = { version = "0.24", features = ["extension-module", "abi3-py310"] } - -# Need to do this for some reason otherwise maturin develop fails horribly -# export RUSTFLAGS="-C link-arg=-undefined -C link-arg=dynamic_lookup" diff --git a/rs_controller/Makefile b/rs_controller/Makefile index ee3ed9401..7864bf29f 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -27,6 +27,7 @@ docker run --rm \ -v $(PWD):/io \ -v $(PWD)/docker_cargo_cache:/root/.cargo/registry \ -v $(CLOUD_REPO):/cloud \ + -v /Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust:/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust \ wheel-builder:$(1) /bin/bash -c "\ cd /io; \ sed -i 's/^version = .*/version = \"$(VERSION)\"/' pyproject.toml; \ diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index f572ca789..37ef3b58b 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,14 +4,16 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b33.dev33+g3d028ba" +version = "2.0.0b33.dev35+g4f8cb912.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] - +dependencies = [ + "flyteidl2>=2.0.0a14", +] [tool.maturin] module-name = "flyte_controller_base" -features = ["pyo3/extension-module"] +features = ["extension-module"] [tool.ruff] line-length = 120 diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index f2e9095d4..4f207d4ee 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -2,13 +2,12 @@ use flyteidl2::google::protobuf::Timestamp; use prost::Message; use pyo3::prelude::*; -use flyteidl2::flyteidl::common::ActionIdentifier; +use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; use flyteidl2::flyteidl::workflow::{ActionUpdate, TraceAction}; +use flyteidl2::flyteidl::common::ActionPhase; use flyteidl2::flyteidl::core::{ExecutionError, TypedInterface}; use flyteidl2::flyteidl::task::{OutputReferences, TaskSpec, TraceSpec}; -use flyteidl2::flyteidl::workflow::Phase; - use tracing::debug; #[pyclass(eq, eq_int)] @@ -31,7 +30,7 @@ pub struct Action { pub run_output_base: Option, pub realized_outputs_uri: Option, pub err: Option, - pub phase: Option, + pub phase: Option, pub started: bool, pub retries: u32, pub client_err: Option, // Changed from PyErr to String for serializability @@ -42,10 +41,36 @@ pub struct Action { impl Action { pub fn get_run_name(&self) -> String { - match self.action_id.run.clone() { - Some(run_id) => run_id.name, - None => String::from("missing run name"), - } + let run_name = self + .action_id + .run + .as_ref() + .expect("Action ID missing run") + .name + .clone(); + assert!(!run_name.is_empty()); + run_name + } + + pub fn get_run_identifier(&self) -> RunIdentifier { + self.action_id + .run + .as_ref() + .expect("Action ID missing run") + .clone() + } + + pub fn get_full_name(&self) -> String { + format!( + "{}:{}", + &self + .action_id + .run + .as_ref() + .expect("Action ID missing run") + .name, + self.action_id.name + ) } pub fn get_action_name(&self) -> String { @@ -63,7 +88,7 @@ impl Action { pub fn mark_cancelled(&mut self) { debug!("Marking action {:?} as cancelled", self.action_id); self.mark_started(); - self.phase = Some(Phase::Aborted); + self.phase = Some(ActionPhase::Aborted); } pub fn mark_started(&mut self) { @@ -73,7 +98,7 @@ impl Action { } pub fn merge_update(&mut self, obj: &ActionUpdate) { - if let Ok(new_phase) = Phase::try_from(obj.phase) { + if let Ok(new_phase) = ActionPhase::try_from(obj.phase) { if self.phase.is_none() || self.phase != Some(new_phase) { self.phase = Some(new_phase); if obj.error.is_some() { @@ -89,7 +114,7 @@ impl Action { pub fn new_from_update(parent_action_name: String, obj: ActionUpdate) -> Self { let action_id = obj.action_id.unwrap(); - let phase = Phase::try_from(obj.phase).unwrap(); + let phase = ActionPhase::try_from(obj.phase).unwrap(); Action { action_id: action_id.clone(), parent_action_name, @@ -115,7 +140,10 @@ impl Action { if let Some(phase) = &self.phase { matches!( phase, - Phase::Succeeded | Phase::Failed | Phase::Aborted | Phase::TimedOut + ActionPhase::Succeeded + | ActionPhase::Failed + | ActionPhase::Aborted + | ActionPhase::TimedOut ) } else { false @@ -177,7 +205,7 @@ impl Action { run_output_base: Some(run_output_base), realized_outputs_uri: None, err: None, - phase: Some(Phase::Unspecified), + phase: Some(ActionPhase::Unspecified), started: false, retries: 0, client_err: None, @@ -243,7 +271,7 @@ impl Action { let trace_action = TraceAction { name: friendly_name.clone(), - phase: Phase::Succeeded.into(), + phase: ActionPhase::Succeeded.into(), start_time: Some(Timestamp { seconds: start_secs, nanos: start_nanos, @@ -266,7 +294,7 @@ impl Action { inputs_uri: Some(inputs_uri), run_output_base: Some(run_output_base), realized_outputs_uri: Some(outputs_uri), - phase: Phase::Succeeded.into(), + phase: ActionPhase::Succeeded.into(), err: None, started: true, retries: 0, diff --git a/rs_controller/src/bin/test_controller.rs b/rs_controller/src/bin/test_controller.rs new file mode 100644 index 000000000..6672ffa00 --- /dev/null +++ b/rs_controller/src/bin/test_controller.rs @@ -0,0 +1,39 @@ +/// Usage: +/// _UNION_EAGER_API_KEY=your_api_key cargo run --bin test_controller +/// +/// Or without auth: +/// cargo run --bin test_controller -- http://localhost:8089 +use flyte_controller_base::core::CoreBaseController; +use std::env; +use tracing_subscriber; + +fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + println!("=== Flyte Core Controller Test ===\n"); + + // Try to create a controller + let controller = if let Ok(api_key) = env::var("_UNION_EAGER_API_KEY") { + println!("Using auth from _UNION_EAGER_API_KEY"); + // Set the env var back since CoreBaseController::new_with_auth reads it + env::set_var("_UNION_EAGER_API_KEY", api_key); + CoreBaseController::new_with_auth()? + } else { + let endpoint = env::args() + .nth(1) + .unwrap_or_else(|| "http://localhost:8090".to_string()); + println!("Using endpoint: {}", endpoint); + CoreBaseController::new_without_auth(endpoint)? + }; + + println!("✓ Successfully created CoreBaseController!"); + println!("✓ This proves that:"); + println!(" - The core module is accessible from binaries"); + println!(" - The Action type (with #[pyclass]) can be used"); + println!(" - No PyO3 linking errors occur"); + println!("\n=== Test Complete ==="); + + Ok(()) +} diff --git a/rs_controller/src/bin/try_list_tasks.rs b/rs_controller/src/bin/try_list_tasks.rs new file mode 100644 index 000000000..b9a5483db --- /dev/null +++ b/rs_controller/src/bin/try_list_tasks.rs @@ -0,0 +1,83 @@ +/// Test binary to list tasks from the Flyte API +/// +/// Usage: +/// _UNION_EAGER_API_KEY=your_api_key cargo run --bin try_list_tasks +use std::sync::Arc; +use tower::ServiceBuilder; +use tracing::warn; + +use flyte_controller_base::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; +use flyte_controller_base::error::ControllerError; + +use flyteidl2::flyteidl::common::{ListRequest, ProjectIdentifier}; +use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; +use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; +use tonic::Code; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { + warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); + String::new() + }); + + let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; + let endpoint = auth_config.endpoint.clone(); + let static_endpoint = endpoint.clone().leak(); + // Strip "https://" (8 chars) to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + let endpoint = + flyte_controller_base::core::create_tls_endpoint(static_endpoint, domain).await?; + let channel = endpoint.connect().await.map_err(ControllerError::from)?; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + + let auth_handling_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + let mut task_client = TaskServiceClient::new(auth_handling_channel); + + let list_request_base = ListRequest { + limit: 100, + ..Default::default() + }; + + let req = ListTasksRequest { + request: Some(list_request_base), + known_filters: vec![], + scope_by: Some(list_tasks_request::ScopeBy::ProjectId(ProjectIdentifier { + organization: "demo".to_string(), + domain: "development".to_string(), + name: "flytesnacks".to_string(), + })), + }; + + let mut attempts = 0; + let final_result: Result = loop { + let result = task_client.list_tasks(req.clone()).await; + match result { + Ok(response) => { + println!("Success: {:?}", response.into_inner()); + break Ok(true); + } + Err(status) if status.code() == Code::Unauthenticated && attempts < 1 => { + attempts += 1; + continue; + } + Err(status) => { + eprintln!("Error calling gRPC: {}", status); + break Err(format!("gRPC error: {}", status).into()); + } + } + }; + warn!("Finished try_list_tasks with result {:?}", final_result); + final_result?; + Ok(()) +} diff --git a/rs_controller/src/bin/try_watch.rs b/rs_controller/src/bin/try_watch.rs new file mode 100644 index 000000000..9d6a282ce --- /dev/null +++ b/rs_controller/src/bin/try_watch.rs @@ -0,0 +1,191 @@ +/// Test binary to watch action updates from the Flyte API +/// +/// Usage: +/// _UNION_EAGER_API_KEY=your_api_key cargo run --bin try_watch +use std::sync::Arc; +use std::time::Duration; +use tokio::time::sleep; +use tower::ServiceBuilder; +use tracing::{error, info, warn}; + +use flyte_controller_base::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; +use flyte_controller_base::error::ControllerError; + +use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; +use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::workflow::watch_request::Filter; +use flyteidl2::flyteidl::workflow::WatchRequest; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::INFO) + .init(); + + info!("Starting watch example with authentication and retry..."); + + // Read in the api key which gives us the endpoint to connect to as well as the credentials + let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { + warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); + String::new() + }); + + let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; + let endpoint = auth_config.endpoint.clone(); + let static_endpoint = endpoint.clone().leak(); + // Strip "https://" (8 chars) to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + let endpoint = + flyte_controller_base::core::create_tls_endpoint(static_endpoint, domain).await?; + let channel = endpoint.connect().await.map_err(ControllerError::from)?; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); + + // Wrap channel with auth layer - ALL calls now automatically authenticated! + let auth_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + let mut client = StateServiceClient::new(auth_channel); + + // Watch configuration (matching Python example) + let run_id = RunIdentifier { + org: "demo".to_string(), + project: "flytesnacks".to_string(), + domain: "development".to_string(), + name: "rz8bf5zksgxsmrzcrkx4".to_string(), + }; + let parent_action_name = "a0".to_string(); + + // Retry parameters (matching Python defaults) + let min_watch_backoff = Duration::from_secs(1); + let max_watch_backoff = Duration::from_secs(30); + let max_watch_retries = 10; + + // Watch loop with retry logic (following Python _informer.py pattern) + let mut retries = 0; + let mut message_count = 0; + + while retries < max_watch_retries { + if retries >= 1 { + warn!("Watch retrying, attempt {}/{}", retries, max_watch_retries); + } + + // Create watch request + let request = WatchRequest { + filter: Some(Filter::ParentActionId(ActionIdentifier { + name: parent_action_name.clone(), + run: Some(run_id.clone()), + })), + }; + + // Establish the watch stream + // The outer retry loop handles failures, middleware handles auth refresh + let stream_result = client.watch(request.clone()).await; + + match stream_result { + Ok(response) => { + info!("Successfully established watch stream"); + let mut stream = response.into_inner(); + + // Process messages from the stream + loop { + match stream.message().await { + Ok(Some(watch_response)) => { + // Successfully received a message - reset retry counter + retries = 0; + message_count += 1; + + // Process the message (enum with ActionUpdate or ControlMessage) + use flyteidl2::flyteidl::workflow::watch_response::Message; + match &watch_response.message { + Some(Message::ControlMessage(control_msg)) => { + if control_msg.sentinel { + info!( + "Received Sentinel for parent action: {}", + parent_action_name + ); + } + } + Some(Message::ActionUpdate(action_update)) => { + info!( + "Received action update for: {} (phase: {:?})", + action_update + .action_id + .as_ref() + .map(|id| id.name.as_str()) + .unwrap_or("unknown"), + action_update.phase + ); + + if !action_update.output_uri.is_empty() { + info!("Output URI: {}", action_update.output_uri); + } + + if action_update.phase == 4 { + // PHASE_FAILED + if action_update.error.is_some() { + error!( + "Action failed with error: {:?}", + action_update.error + ); + } + } + } + None => { + warn!("Received empty watch response"); + } + } + + // For demo purposes, exit after receiving a few messages + if message_count >= 50 { + info!("Received {} messages, exiting demo", message_count); + return Ok(()); + } + } + Ok(None) => { + warn!("Watch stream ended gracefully"); + break; // Stream ended, retry + } + Err(status) => { + error!("Error receiving message from watch stream: {}", status); + + // Check if it's an auth error + if status.code() == tonic::Code::Unauthenticated { + warn!("Unauthenticated error - credentials will be refreshed on retry"); + } + + break; // Break inner loop to retry + } + } + } + } + Err(status) => { + error!("Failed to establish watch stream: {}", status); + + if status.code() == tonic::Code::Unauthenticated { + warn!("Unauthenticated error - credentials will be refreshed on retry"); + } + } + } + + // Increment retry counter and apply exponential backoff + retries += 1; + if retries < max_watch_retries { + let backoff = min_watch_backoff + .saturating_mul(2_u32.pow(retries as u32)) + .min(max_watch_backoff); + warn!("Watch failed, retrying in {:?}...", backoff); + sleep(backoff).await; + } + } + + // Exceeded max retries + error!( + "Watch failure retries crossed threshold {}/{}, exiting!", + retries, max_watch_retries + ); + Err(format!("Max watch retries ({}) exceeded", max_watch_retries).into()) +} diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs new file mode 100644 index 000000000..417640cb4 --- /dev/null +++ b/rs_controller/src/core.rs @@ -0,0 +1,617 @@ +//! Core controller implementation - Pure Rust, no PyO3 dependencies +//! This module can be used by both Python bindings and standalone Rust binaries + +use std::sync::Arc; +use std::sync::OnceLock; +use std::time::Duration; + +use pyo3_async_runtimes::tokio::get_runtime; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::sync::OnceCell; +use tokio::time::sleep; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; +use tonic::Status; +use tower::ServiceBuilder; +use tracing::{debug, error, info, warn}; + +use crate::action::{Action, ActionType}; +use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; +use crate::error::{ControllerError, InformerError}; +use crate::informer::{Informer, InformerCache}; +use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; +use flyteidl2::flyteidl::task::TaskIdentifier; +use flyteidl2::flyteidl::workflow::enqueue_action_request; +use flyteidl2::flyteidl::workflow::queue_service_client::QueueServiceClient; +use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::workflow::{ + EnqueueActionRequest, EnqueueActionResponse, TaskAction, WatchRequest, WatchResponse, +}; +use flyteidl2::google; +use google::protobuf::StringValue; + +// Fetches Amazon root CA certificate from Amazon Trust Services +pub async fn fetch_amazon_root_ca() -> Result { + // Amazon Root CA 1 - the main root used by AWS services + let url = "https://www.amazontrust.com/repository/AmazonRootCA1.pem"; + + let response = reqwest::get(url) + .await + .map_err(|e| ControllerError::SystemError(format!("Failed to fetch certificate: {}", e)))?; + + let cert_pem = response + .text() + .await + .map_err(|e| ControllerError::SystemError(format!("Failed to read certificate: {}", e)))?; + + Ok(Certificate::from_pem(cert_pem)) +} + +// Helper to create TLS-configured endpoint with Amazon CA certificate +// todo: when we resolve the pem issue, also remove the need to have both inputs which are basically the same +pub async fn create_tls_endpoint( + url: &'static str, + domain: &str, +) -> Result { + // Fetch Amazon root CA dynamically + let cert = fetch_amazon_root_ca().await?; + + let tls_config = ClientTlsConfig::new() + .domain_name(domain) + .ca_certificate(cert); + + let endpoint = Endpoint::from_static(url) + .tls_config(tls_config) + .map_err(|e| ControllerError::SystemError(format!("TLS config error: {}", e)))? + .keep_alive_while_idle(true); + + Ok(endpoint) +} + +enum ChannelType { + Plain(tonic::transport::Channel), + Authenticated(crate::auth::AuthService), +} + +#[derive(Clone, Debug)] +pub enum StateClient { + Plain(StateServiceClient), + Authenticated(StateServiceClient>), +} + +impl StateClient { + pub async fn watch( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result>, tonic::Status> { + match self { + StateClient::Plain(client) => client.watch(request).await, + StateClient::Authenticated(client) => client.watch(request).await, + } + } +} + +#[derive(Clone, Debug)] +pub enum QueueClient { + Plain(QueueServiceClient), + Authenticated(QueueServiceClient>), +} + +impl QueueClient { + pub async fn enqueue_action( + &mut self, + request: impl tonic::IntoRequest, + ) -> Result, tonic::Status> { + match self { + QueueClient::Plain(client) => client.enqueue_action(request).await, + QueueClient::Authenticated(client) => client.enqueue_action(request).await, + } + } +} + +pub struct CoreBaseController { + channel: ChannelType, + informer_cache: InformerCache, + state_client: StateClient, + queue_client: QueueClient, + shared_queue: mpsc::Sender, + shared_queue_rx: Arc>>, + failure_rx: mpsc::Receiver, +} + +impl CoreBaseController { + pub fn new_with_auth() -> Result, ControllerError> { + info!("Creating CoreBaseController from _UNION_EAGER_API_KEY env var (with auth)"); + // Read from env var and use auth + let api_key = std::env::var("_UNION_EAGER_API_KEY").map_err(|_| { + ControllerError::SystemError( + "_UNION_EAGER_API_KEY env var must be provided".to_string(), + ) + })?; + let auth_config = AuthConfig::new_from_api_key(&api_key)?; + let endpoint_url = auth_config.endpoint.clone(); + + let endpoint_static: &'static str = + Box::leak(Box::new(endpoint_url.clone().into_boxed_str())); + // shared queue + let (shared_tx, shared_queue_rx) = mpsc::channel::(64); + + let rt = get_runtime(); + let channel = rt.block_on(async { + // todo: escape hatch for localhost + // Strip "https://" to get just the hostname for TLS config + let domain = endpoint_url.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError( + "Endpoint must start with https:// when using auth".to_string(), + ) + })?; + + // Create TLS-configured endpoint + let endpoint = create_tls_endpoint(endpoint_static, domain).await?; + let channel = endpoint.connect().await.map_err(ControllerError::from)?; + + let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config.clone())); + let auth_channel = ServiceBuilder::new() + .layer(AuthLayer::new(authenticator, channel.clone())) + .service(channel); + + Ok::<_, ControllerError>(ChannelType::Authenticated(auth_channel)) + })?; + + let (failure_tx, failure_rx) = mpsc::channel::(10); + + let state_client = match &channel { + ChannelType::Plain(ch) => StateClient::Plain(StateServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + StateClient::Authenticated(StateServiceClient::new(ch.clone())) + } + }; + + let queue_client = match &channel { + ChannelType::Plain(ch) => QueueClient::Plain(QueueServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + QueueClient::Authenticated(QueueServiceClient::new(ch.clone())) + } + }; + + let informer_cache = + InformerCache::new(state_client.clone(), shared_tx.clone(), failure_tx); + + let real_base_controller = CoreBaseController { + channel, + informer_cache, + state_client, + queue_client, + shared_queue: shared_tx, + shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), + failure_rx, + }; + + let real_base_controller = Arc::new(real_base_controller); + // Start the background worker + let controller_clone = real_base_controller.clone(); + rt.spawn(async move { + controller_clone.bg_worker().await; + }); + Ok(real_base_controller) + } + + pub fn new_without_auth(endpoint: String) -> Result, ControllerError> { + let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); + // shared queue + let (shared_tx, shared_queue_rx) = mpsc::channel::(64); + + let rt = get_runtime(); + let channel = rt.block_on(async { + let chan = if endpoint.starts_with("http://") { + let endpoint = Endpoint::from_static(endpoint_static) + .keep_alive_while_idle(true); + endpoint.connect().await.map_err(ControllerError::from)? + } else if endpoint.starts_with("https://") { + // Strip "https://" to get just the hostname for TLS config + let domain = endpoint.strip_prefix("https://").ok_or_else(|| { + ControllerError::SystemError("Endpoint must start with https://".to_string()) + })?; + + // Create TLS-configured endpoint + let endpoint = create_tls_endpoint(endpoint_static, domain).await?; + endpoint.connect().await.map_err(ControllerError::from)? + } else { + return Err(ControllerError::SystemError(format!( + "Malformed endpoint {}", + endpoint + ))); + }; + Ok::<_, ControllerError>(ChannelType::Plain(chan)) + })?; + + let (failure_tx, failure_rx) = mpsc::channel::(10); + + let state_client = match &channel { + ChannelType::Plain(ch) => StateClient::Plain(StateServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + StateClient::Authenticated(StateServiceClient::new(ch.clone())) + } + }; + + let queue_client = match &channel { + ChannelType::Plain(ch) => QueueClient::Plain(QueueServiceClient::new(ch.clone())), + ChannelType::Authenticated(ch) => { + QueueClient::Authenticated(QueueServiceClient::new(ch.clone())) + } + }; + + let informer_cache = + InformerCache::new(state_client.clone(), shared_tx.clone(), failure_tx); + + let real_base_controller = CoreBaseController { + channel, + informer_cache, + state_client, + queue_client, + shared_queue: shared_tx, + shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), + failure_rx, + }; + + let real_base_controller = Arc::new(real_base_controller); + // Start the background worker + let controller_clone = real_base_controller.clone(); + rt.spawn(async move { + controller_clone.bg_worker().await; + }); + Ok(real_base_controller) + } + + async fn bg_worker(&self) { + const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(100); + const MAX_RETRIES: u32 = 5; + + debug!( + "Launching core controller background task on thread {:?}", + std::thread::current().name() + ); + loop { + // Receive actions from shared queue + let mut rx = self.shared_queue_rx.lock().await; + match rx.recv().await { + Some(mut action) => { + let run_name = &action + .action_id + .run + .as_ref() + .map_or(String::from(""), |i| i.name.clone()); + debug!( + "Controller worker processing action {}::{}", + run_name, action.action_id.name + ); + + // Drop the mutex guard before processing + drop(rx); + + match self.handle_action(&mut action).await { + Ok(_) => {} + // Add handling here for new slow down error + Err(e) => { + error!("Error in controller loop: {:?}", e); + // Handle backoff and retry logic + sleep(MIN_BACKOFF_ON_ERR).await; + action.retries += 1; + + if action.retries > MAX_RETRIES { + error!( + "Controller failed processing {}::{}, system retries {} crossed threshold {}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + ); + action.client_err = Some(format!( + "Controller failed {}::{}, system retries {} crossed threshold {}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + )); + + // Fire completion event for failed action + let opt_informer = self + .informer_cache + .get(&action.get_run_identifier(), &action.parent_action_name) + .await; + if let Some(informer) = opt_informer { + // todo: check these two errors + + // Before firing completion event, update the action in the + // informer, otherwise client_err will not be set. + let _ = informer.set_action_client_err(&action).await; + let _ = informer + .fire_completion_event(&action.action_id.name) + .await; + } else { + error!( + "Max retries hit for action but informer missing: {:?}", + action.action_id + ); + } + } else { + // Re-queue the action for retry + info!( + "Re-queuing action {}::{} for retry, attempt {}/{}", + run_name, action.action_id.name, action.retries, MAX_RETRIES + ); + if let Err(send_err) = self.shared_queue.send(action).await { + error!("Failed to re-queue action for retry: {}", send_err); + } + } + } + } + } + None => { + warn!("Shared queue channel closed, stopping bg_worker"); + break; + } + } + } + } + + async fn handle_action(&self, action: &mut Action) -> Result<(), ControllerError> { + if !action.started { + // Action not started, launch it + warn!("Action is not started, launching action {:?}", action); + self.bg_launch(action).await?; + } else if action.is_action_terminal() { + // Action is terminal, fire completion event + if let Some(arc_informer) = self + .informer_cache + .get(&action.get_run_identifier(), &action.parent_action_name) + .await + { + debug!( + "handle action firing completion event for {:?}", + &action.action_id.name + ); + arc_informer + .fire_completion_event(&action.action_id.name) + .await?; + } else { + error!( + "Unable to find informer to fire completion event for action: {}", + action.get_full_name(), + ); + return Err(ControllerError::BadContext(format!( + "Informer missing for action: {} while handling.", + action.get_full_name() + ))); + } + } else { + // Action still in progress + debug!("Resource {} still in progress...", action.action_id.name); + } + Ok(()) + } + + async fn bg_launch(&self, action: &Action) -> Result<(), ControllerError> { + match self.launch_task(action).await { + Ok(_) => { + debug!("Successfully launched action: {}", action.action_id.name); + Ok(()) + } + Err(e) => { + error!( + "Failed to launch action: {}, error: {}", + action.action_id.name, e + ); + Err(ControllerError::RuntimeError(format!( + "Launch failed: {}", + e + ))) + } + } + } + + pub async fn cancel_action(&self, action: &mut Action) -> Result<(), ControllerError> { + if action.is_action_terminal() { + info!( + "Action {} is already terminal, no need to cancel.", + action.action_id.name + ); + return Ok(()); + } + + debug!("Cancelling action: {}", action.action_id.name); + action.mark_cancelled(); + + if let Some(informer) = self + .informer_cache + .get(&action.get_run_identifier(), &action.parent_action_name) + .await + { + let _ = informer + .fire_completion_event(&action.action_id.name) + .await?; + } else { + debug!( + "Informer missing when trying to cancel action: {}", + action.action_id.name + ); + } + Ok(()) + } + + pub async fn get_action( + &self, + action_id: ActionIdentifier, + parent_action_name: &str, + ) -> Result { + let run = action_id + .run + .as_ref() + .ok_or(ControllerError::RuntimeError(format!( + "Action {:?} doesn't have a run, can't get action", + action_id + )))?; + if let Some(informer) = self.informer_cache.get(run, parent_action_name).await { + let action_name = action_id.name.clone(); + match informer.get_action(action_name).await { + Some(action) => Ok(action), + None => Err(ControllerError::RuntimeError(format!( + "Action not found getting from action_id: {:?}", + action_id + ))), + } + } else { + Err(ControllerError::BadContext( + "Informer not initialized".to_string(), + )) + } + } + + fn create_enqueue_action_request( + &self, + action: &Action, + ) -> Result { + // todo-pr: handle trace action + let task_identifier = action + .task + .as_ref() + .and_then(|task| task.task_template.as_ref()) + .and_then(|task_template| task_template.id.as_ref()) + .and_then(|core_task_id| { + Some(TaskIdentifier { + version: core_task_id.version.clone(), + org: core_task_id.org.clone(), + project: core_task_id.project.clone(), + domain: core_task_id.domain.clone(), + name: core_task_id.name.clone(), + }) + }) + .ok_or(ControllerError::RuntimeError(format!( + "TaskIdentifier missing from Action {:?}", + action + )))?; + + let input_uri = action + .inputs_uri + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Inputs URI missing from Action {:?}", + action + )))?; + let run_output_base = + action + .run_output_base + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Run output base missing from Action {:?}", + action + )))?; + let group = action.group.clone().unwrap_or_default(); + let task_action = TaskAction { + id: Some(task_identifier), + spec: action.task.clone(), + cache_key: action + .cache_key + .as_ref() + .map(|ck| StringValue { value: ck.clone() }), + cluster: action.queue.clone().unwrap_or("".to_string()), + }; + + Ok(EnqueueActionRequest { + action_id: Some(action.action_id.clone()), + parent_action_name: Some(action.parent_action_name.clone()), + spec: Some(enqueue_action_request::Spec::Task(task_action)), + run_spec: None, + input_uri, + run_output_base, + group, + subject: String::default(), // Subject is not used in the current implementation + }) + } + + async fn launch_task(&self, action: &Action) -> Result { + if !action.started && action.task.is_some() { + let enqueue_request = self + .create_enqueue_action_request(action) + .expect("Failed to create EnqueueActionRequest"); + let mut client = self.queue_client.clone(); + // todo: tonic doesn't seem to have wait_for_ready, or maybe the .ready is already doing this. + let enqueue_result = client.enqueue_action(enqueue_request).await; + // Add logic from resiliency pr here, return certain errors, but change others to be a specific slowdown error. + match enqueue_result { + Ok(response) => { + debug!("Successfully enqueued action: {:?}", action.action_id); + Ok(response.into_inner()) + } + Err(e) => { + if e.code() == tonic::Code::AlreadyExists { + info!( + "Action {} already exists, continuing to monitor.", + action.action_id.name + ); + Ok(EnqueueActionResponse {}) + } else { + error!( + "Failed to launch action: {:?}, backing off...", + action.action_id + ); + error!("Error details: {}", e); + // Handle backoff logic here + Err(e) + } + } + } + } else { + debug!( + "Action {} is already started or has no task, skipping launch.", + action.action_id.name + ); + Ok(EnqueueActionResponse {}) + } + } + + pub async fn submit_action(&self, action: Action) -> Result { + let action_name = action.action_id.name.clone(); + // The first action that gets submitted determines the run_id that will be used. + // This is obviously not going to work, + + let run_id = action + .action_id + .run + .clone() + .ok_or(ControllerError::RuntimeError(format!( + "Run ID missing from submit action {}", + action_name.clone() + )))?; + info!("Creating informer set to run_id {:?}", run_id); + let informer: Arc = self + .informer_cache + .get_or_create_informer(&action.get_run_identifier(), &action.parent_action_name) + .await; + let (done_tx, done_rx) = oneshot::channel(); + informer.submit_action(action, done_tx).await?; + + done_rx.await.map_err(|_| { + ControllerError::BadContext(String::from("Failed to receive done signal from informer")) + })?; + debug!( + "Action {} complete, looking up final value and returning", + action_name + ); + + // get the action and return it + let final_action = informer.get_action(action_name).await; + final_action.ok_or(ControllerError::BadContext(String::from( + "Action not found after done", + ))) + } + + pub async fn finalize_parent_action(&self, run_id: &RunIdentifier, parent_action_name: &str) { + let opt_informer = self.informer_cache.remove(run_id, parent_action_name).await; + match opt_informer { + Some(informer) => { + informer.stop().await; + } + None => { + warn!( + "No informer found when finalizing parent action {}", + parent_action_name + ); + } + } + } +} diff --git a/rs_controller/src/error.rs b/rs_controller/src/error.rs new file mode 100644 index 000000000..6bf8c73e8 --- /dev/null +++ b/rs_controller/src/error.rs @@ -0,0 +1,39 @@ +use thiserror::Error; + +use crate::auth::AuthConfigError; + +#[derive(Error, Debug)] +pub enum ControllerError { + #[error("Bad context: {0}")] + BadContext(String), + #[error("Runtime error: {0}")] + RuntimeError(String), + #[error("System error: {0}")] + SystemError(String), + #[error("gRPC error: {0}")] + GrpcError(#[from] tonic::Status), + #[error("Task error: {0}")] + TaskError(String), +} + +impl From for ControllerError { + fn from(err: tonic::transport::Error) -> Self { + ControllerError::SystemError(format!("Transport error: {:?}", err)) + } +} + +impl From for ControllerError { + fn from(err: AuthConfigError) -> Self { + ControllerError::SystemError(err.to_string()) + } +} + +#[derive(Error, Debug)] +pub enum InformerError { + #[error("Informer watch failed for run {run_name}, parent action {parent_action_name}: {error_message}")] + WatchFailed { + run_name: String, + parent_action_name: String, + error_message: String, + }, +} diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index 97f049331..7195e920e 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -1,6 +1,7 @@ use crate::action::Action; -use crate::ControllerError; -use crate::StateClient; +use crate::core::StateClient; +use crate::error::{ControllerError, InformerError}; +use tokio_util::sync::CancellationToken; use flyteidl2::flyteidl::common::ActionIdentifier; use flyteidl2::flyteidl::common::RunIdentifier; @@ -9,16 +10,18 @@ use flyteidl2::flyteidl::workflow::{ watch_request, watch_response::Message, WatchRequest, WatchResponse, }; +use pyo3_async_runtimes::tokio::run; use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; use tokio::select; use tokio::sync::RwLock; use tokio::sync::{mpsc, oneshot, Notify}; -use tokio::task::JoinHandle; use tokio::time::sleep; use tonic::transport::channel::Channel; use tonic::transport::Endpoint; +use tracing::log::Level::Info; use tracing::{debug, error, info, warn}; use tracing_subscriber::fmt; @@ -30,7 +33,10 @@ pub struct Informer { parent_action_name: String, shared_queue: mpsc::Sender, ready: Arc, + is_ready: Arc, completion_events: Arc>>>, + cancellation_token: CancellationToken, + watch_handle: Arc>>>, } impl Informer { @@ -47,7 +53,10 @@ impl Informer { parent_action_name, shared_queue, ready: Arc::new(Notify::new()), + is_ready: Arc::new(AtomicBool::new(false)), completion_events: Arc::new(RwLock::new(HashMap::new())), + cancellation_token: CancellationToken::new(), + watch_handle: Arc::new(RwLock::new(None)), } } @@ -82,7 +91,8 @@ impl Informer { Message::ControlMessage(_) => { // Handle control messages if needed debug!("Received sentinel for parent {}", self.parent_action_name); - self.ready.notify_one(); + self.is_ready.store(true, Ordering::Release); + self.ready.notify_waiters(); Ok(None) } Message::ActionUpdate(action_update) => { @@ -128,7 +138,7 @@ impl Informer { } } - async fn watch_actions(&self) -> ControllerError { + async fn watch_actions(&self) -> Result<(), ControllerError> { let action_id = ActionIdentifier { name: self.parent_action_name.clone(), run: Some(self.run_id.clone()), @@ -137,80 +147,62 @@ impl Informer { filter: Some(watch_request::Filter::ParentActionId(action_id)), }; - let mut stream = self.client.clone().watch(request).await; + let stream = self.client.clone().watch(request).await; let mut stream = match stream { Ok(s) => s.into_inner(), Err(e) => { error!("Failed to start watch stream: {:?}", e); - return ControllerError::from(e); + return Err(ControllerError::from(e)); } }; loop { - match stream.message().await { - Ok(Some(response)) => { - let handle_response = self.handle_watch_response(response).await; - match handle_response { - Ok(Some(action)) => match self.shared_queue.send(action).await { - Ok(_) => { - continue; - } - Err(e) => { - error!("Informer watch failed sending action back to shared queue: {:?}", e); - return ControllerError::RuntimeError(format!( - "Failed to send action to shared queue: {}", - e - )); + select! { + _ = self.cancellation_token.cancelled() => { + warn!("Cancellation token got - exiting from watch_actions: {}", self.parent_action_name); + return Ok(()) + } + + result = stream.message() => { + match result { + Ok(Some(response)) => { + let handle_response = self.handle_watch_response(response).await; + match handle_response { + Ok(Some(action)) => match self.shared_queue.send(action).await { + Ok(_) => { + continue; + } + Err(e) => { + error!("Informer watch failed sending action back to shared queue: {:?}", e); + return Err(ControllerError::RuntimeError(format!( + "Failed to send action to shared queue: {}", + e + ))); + } + }, + Ok(None) => { + debug!( + "Received None from handle_watch_response, continuing watch loop." + ); + } + Err(err) => { + // this should cascade up to the controller to restart the informer, and if there + // are too many informer restarts, the controller should fail + error!("Error in informer watch {:?}", err); + return Err(err); + } } - }, - Ok(None) => { - debug!( - "Received None from handle_watch_response, continuing watch loop." - ); } - Err(err) => { - // this should cascade up to the controller to restart the informer, and if there - // are too many informer restarts, the controller should fail - error!("Error in informer watch {:?}", err); - return err; + Ok(None) => { + debug!("Stream received empty message, maybe no more messages? Repeating watch loop."); + } // Stream ended, exit loop + Err(e) => { + error!("Error receiving message from stream: {:?}", e); + return Err(ControllerError::from(e)); } } } - Ok(None) => { - debug!("Stream received empty message, maybe no more messages? Repeating watch loop."); - } // Stream ended, exit loop - Err(e) => { - error!("Error receiving message from stream: {:?}", e); - return ControllerError::from(e); - } - } - } - } - - async fn wait_ready_or_timeout(ready: Arc) -> Result<(), ControllerError> { - select! { - _ = ready.notified() => { - debug!("Ready sentinel ack'ed"); - Ok(()) - } - _ = sleep(Duration::from_millis(100)) => Err(ControllerError::SystemError("".to_string())) - } - } - - pub async fn start(informer: Arc) -> Result, ControllerError> { - let me = informer.clone(); - let ready = me.ready.clone(); - let _watch_handle = tokio::spawn(async move { - // handle errors later - me.watch_actions().await; - }); - - match Self::wait_ready_or_timeout(ready).await { - Ok(()) => Ok(_watch_handle), - Err(_) => { - warn!("Timed out waiting for sentinel"); - Ok(_watch_handle) } } } @@ -227,14 +219,29 @@ impl Informer { ) -> Result<(), ControllerError> { let action_name = action.action_id.name.clone(); + let merged_action = { + let mut cache = self.action_cache.write().await; + let cached_action = cache.get_mut(&action_name); + if let Some(some_action) = cached_action { + warn!("Submitting action {} and it's already in the cache!!! Existing {:?} <<<--->>> New: {:?}", action_name, some_action, action); + some_action.merge_from_submit(&action); + some_action.clone() + } else { + // don't need to write anything. return the original + action + } + }; + warn!("Merged action: ===> {} {:?}", action_name, merged_action); + // Store the completion event sender { let mut completion_events = self.completion_events.write().await; completion_events.insert(action_name.clone(), done_tx); + warn!("---------> Adding completion event in submit action {:?}", action_name); } // Add action to shared queue - self.shared_queue.send(action).await.map_err(|e| { + self.shared_queue.send(merged_action).await.map_err(|e| { ControllerError::RuntimeError(format!("Failed to send action to shared queue: {}", e)) })?; @@ -252,64 +259,264 @@ impl Informer { )) })?; } else { - error!( + warn!( "No completion event found for action---------------------: {}", action_name, ); - // Return error, which should cause informer to re-enqueue - return Err(ControllerError::RuntimeError(format!( - "No completion event found for action: {}. This may be because the informer is still starting up.", - action_name - ))); + // Maybe the action hasn't started yet. + return Ok(()) } Ok(()) } + + pub async fn stop(&self) { + self.cancellation_token.cancel(); + if let Some(handle) = self.watch_handle.write().await.take() { + warn!("Awaiting taken handle"); + let _ = handle.await; + warn!("Taken handle finished..."); + } else { + warn!("No handle to take ------------------------"); + } + warn!("Stopped informer {:?}", self.parent_action_name); + } } -async fn informer_main() { - // Create an informer but first create the shared_queue that will be shared between the - // Controller and the informer - let (tx, rx) = mpsc::channel::(64); - let endpoint = Endpoint::from_static("http://localhost:8090"); - let channel = endpoint.connect().await.unwrap(); - let client = StateServiceClient::new(channel); - - let run_id = RunIdentifier { - org: String::from("testorg"), - project: String::from("testproject"), - domain: String::from("development"), - name: String::from("qdtc266r2z8clscl2lj5"), - }; - - let informer = Arc::new(Informer::new( - StateClient::Plain(client), - run_id, - "a0".to_string(), - tx.clone(), - )); - - let watch_task = Informer::start(informer.clone()).await; - - println!("{:?}: {:?}", informer, watch_task); - // do creation and start of informer behind a once +pub struct InformerCache { + cache: Arc>>>, + client: StateClient, + shared_queue: mpsc::Sender, + failure_tx: mpsc::Sender, } -fn init_tracing() { - static INIT: std::sync::Once = std::sync::Once::new(); - INIT.call_once(|| { - let subscriber = fmt() - .with_max_level(tracing::Level::DEBUG) - .with_test_writer() // so logs show in test output - .finish(); - tracing::subscriber::set_global_default(subscriber) - .expect("setting default subscriber failed"); - }); +impl InformerCache { + pub fn new( + client: StateClient, + shared_queue: mpsc::Sender, + failure_tx: mpsc::Sender, + ) -> Self { + Self { + cache: Arc::new(RwLock::new(HashMap::new())), + client, + shared_queue, + failure_tx, + } + } + + fn mkname(run_name: &str, parent_action_name: &str) -> String { + format!("{}.{}", run_name, parent_action_name) + } + + pub async fn get_or_create_informer( + &self, + run_id: &RunIdentifier, + parent_action_name: &str, + ) -> Arc { + let informer_name = Self::mkname(&run_id.name, parent_action_name); + info!(">>> get_or_create_informer called for: {}", informer_name); + let timeout = Duration::from_millis(100); + + // Check if exists (with read lock) + { + debug!("Acquiring read lock to check cache for: {}", informer_name); + let map = self.cache.read().await; + debug!("Read lock acquired, checking cache..."); + if let Some(informer) = map.get(&informer_name) { + info!("CACHE HIT: Found existing informer for: {}", informer_name); + let arc_informer = Arc::clone(informer); + // Release read lock before waiting + drop(map); + debug!("Read lock released, waiting for ready..."); + Self::wait_for_ready(&arc_informer, timeout).await; + info!("<<< Returning existing informer for: {}", informer_name); + return arc_informer; + } + debug!("CACHE MISS: Informer not found in cache: {}", informer_name); + } + + // Create new informer (with write lock) + debug!("Acquiring write lock to create informer for: {}", informer_name); + let mut map = self.cache.write().await; + info!("Write lock acquired for: {}", informer_name); + + // Double-check it wasn't created while we were waiting for write lock + if let Some(informer) = map.get(&informer_name) { + info!("RACE: Informer was created while waiting for write lock: {}", informer_name); + let arc_informer = Arc::clone(informer); + drop(map); + debug!("Write lock released after race condition"); + Self::wait_for_ready(&arc_informer, timeout).await; + info!("<<< Returning race-created informer for: {}", informer_name); + return arc_informer; + } + + // Create and add to cache + info!("CREATING new informer for: {}", informer_name); + let informer = Arc::new(Informer::new( + self.client.clone(), + run_id.clone(), + parent_action_name.to_string(), + self.shared_queue.clone(), + )); + debug!("Informer object created, inserting into cache..."); + map.insert(informer_name.clone(), Arc::clone(&informer)); + info!("Informer inserted into cache: {}", informer_name); + + // Release write lock before starting (starting involves waiting) + drop(map); + debug!("Write lock released for: {}", informer_name); + + let me = Arc::clone(&informer); + let failure_tx = self.failure_tx.clone(); + + info!("Spawning watch task for: {}", informer_name); + let _watch_handle = tokio::spawn(async move { + debug!("Watch task started for: {}", me.parent_action_name); + let watch_actions_result = me.watch_actions().await; + + // If there are errors with the watch then notify the channel + if watch_actions_result.is_err() { + let err = watch_actions_result.err().unwrap(); + error!( + "Informer watch_actions failed for run {}, parent action {}: {:?}", + me.run_id.name, me.parent_action_name, err + ); + + let failure = InformerError::WatchFailed { + run_name: me.run_id.name.clone(), + parent_action_name: me.parent_action_name.clone(), + error_message: err.to_string(), + }; + + if let Err(e) = failure_tx.send(failure).await { + error!("Failed to send informer failure event: {:?}", e); + } + } else { + info!("Informer watch_actions completed successfully for {}", me.run_id.name); + } + }); + + // save the value and ignore the returned reference. + debug!("Acquiring write lock to save watch handle for: {}", informer_name); + let _ = informer.watch_handle.write().await.insert(_watch_handle); + info!("Watch handle saved for: {}", informer_name); + + // Optimistically wait for ready (sentinel) with timeout + debug!("Waiting for informer to be ready: {}", informer_name); + Self::wait_for_ready(&informer, timeout).await; + + info!("<<< Returning newly created informer for: {}", informer_name); + informer + } + + pub async fn get( + &self, + run_id: &RunIdentifier, + parent_action_name: &str, + ) -> Option> { + let informer_name = InformerCache::mkname(&run_id.name, parent_action_name); + debug!("InformerCache::get called for: {}", informer_name); + let map = self.cache.read().await; + let opt_informer = map.get(&informer_name).cloned(); + if opt_informer.is_some() { + debug!("InformerCache::get - found: {}", informer_name); + } else { + debug!("InformerCache::get - not found: {}", informer_name); + } + opt_informer + } + + /// Wait for informer to be ready with a timeout. If timeout occurs, set ready anyway + /// and log a warning - this is optimistic, assuming the informer will become ready eventually. + /// Once ready has been set, future calls return immediately without waiting. + async fn wait_for_ready(informer: &Arc, timeout: Duration) { + debug!("wait_for_ready called for: {}", informer.parent_action_name); + + // Subscribe to notifications first, before checking ready + // This ensures we don't miss a notification that happens between the check and the wait + let ready_fut = informer.ready.notified(); + + // Quick check - if already ready, return immediately + if informer.is_ready.load(Ordering::Acquire) { + info!("Informer already ready for: {}", informer.parent_action_name); + return; + } + + debug!("Waiting for ready signal with timeout {:?}...", timeout); + // Otherwise wait with timeout + match tokio::time::timeout(timeout, ready_fut).await { + Ok(_) => { + info!("Informer ready signal received for: {}", informer.parent_action_name); + } + Err(_) => { + warn!( + "Informer ready TIMEOUT after {:?} for {}:{} - continuing optimistically", + timeout, informer.run_id.name, informer.parent_action_name + ); + // Set ready anyway so future calls don't wait + informer.is_ready.store(true, Ordering::Release); + } + } + } + + pub async fn remove( + &self, + run_id: &RunIdentifier, + parent_action_name: &str, + ) -> Option> { + let informer_name = InformerCache::mkname(&run_id.name, parent_action_name); + info!("InformerCache::remove called for: {}", informer_name); + let mut map = self.cache.write().await; + let opt_informer = map.remove(&informer_name); + if opt_informer.is_some() { + info!("InformerCache::remove - removed: {}", informer_name); + } else { + warn!("InformerCache::remove - not found: {}", informer_name); + } + opt_informer + } } + #[cfg(test)] mod tests { use super::*; + async fn informer_main() { + // Create an informer but first create the shared_queue that will be shared between the + // Controller and the informer + let (tx, _rx) = mpsc::channel::(64); + let endpoint = Endpoint::from_static("http://localhost:8090"); + let channel = endpoint.connect().await.unwrap(); + let client = StateServiceClient::new(channel); + + let run_id = RunIdentifier { + org: String::from("testorg"), + project: String::from("testproject"), + domain: String::from("development"), + name: String::from("rchn685b8jgwtvz4k795"), + }; + let (failure_tx, _failure_rx) = mpsc::channel::(1); + + let informer_cache = InformerCache::new(StateClient::Plain(client), tx.clone(), failure_tx); + let informer = informer_cache.get_or_create_informer(&run_id, "a0").await; + + println!("{:?}", informer); + } + + fn init_tracing() { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let subscriber = fmt() + .with_max_level(tracing::Level::DEBUG) + .with_test_writer() // so logs show in test output + .finish(); + tracing::subscriber::set_global_default(subscriber) + .expect("setting default subscriber failed"); + }); + } + + // cargo test --lib informer::tests:test_informer -- --nocapture --show-output #[test] fn test_informer() { init_tracing(); diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index 5820139eb..b2fdcf3ce 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -1,623 +1,49 @@ #![allow(clippy::too_many_arguments)] -mod action; -pub mod auth; // Public for use in other crates +// Core modules - public for use by binaries and other crates +pub mod action; +pub mod auth; +pub mod core; +pub mod error; mod informer; -pub mod proto; // Public for use in other crates +pub mod proto; -use std::default; -use std::sync::mpsc::channel; +// Python bindings - thin wrappers around core types use std::sync::Arc; use std::time::Duration; -use futures::TryFutureExt; +use pyo3::exceptions; use pyo3::prelude::*; -use tokio::sync::mpsc; -use tower::ServiceExt; -use tracing::{debug, error, info, warn}; - -use thiserror::Error; +use pyo3::types::PyAny; +use pyo3_async_runtimes::tokio::future_into_py; +use tower::ServiceBuilder; +use tracing::{error, info, warn}; +use tracing_subscriber::FmtSubscriber; use crate::action::{Action, ActionType}; -use crate::informer::Informer; - -use crate::auth::{AuthConfig, AuthConfigError, AuthLayer, ClientCredentialsAuthenticator}; -use flyteidl2::flyteidl::common::{ActionIdentifier, ProjectIdentifier}; +use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; +use crate::core::CoreBaseController; +use crate::error::ControllerError; +use flyteidl2::flyteidl::common::{ActionIdentifier, ProjectIdentifier, RunIdentifier}; use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; -use flyteidl2::flyteidl::task::TaskIdentifier; use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; -use flyteidl2::flyteidl::workflow::enqueue_action_request; -use flyteidl2::flyteidl::workflow::queue_service_client::QueueServiceClient; use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; -use flyteidl2::flyteidl::workflow::{ - EnqueueActionRequest, EnqueueActionResponse, TaskAction, WatchRequest, WatchResponse, -}; -use flyteidl2::google; -use google::protobuf::StringValue; -use pyo3::exceptions; -use pyo3::types::PyAny; -use pyo3_async_runtimes::tokio::future_into_py; -use pyo3_async_runtimes::tokio::get_runtime; -use std::sync::OnceLock; -use tokio::sync::oneshot; -use tokio::sync::OnceCell; -use tokio::time::sleep; -use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; -use tonic::Status; -use tracing_subscriber::FmtSubscriber; - -// Fetches Amazon root CA certificate from Amazon Trust Services -async fn fetch_amazon_root_ca() -> Result { - // Amazon Root CA 1 - the main root used by AWS services - let url = "https://www.amazontrust.com/repository/AmazonRootCA1.pem"; - - let response = reqwest::get(url) - .await - .map_err(|e| ControllerError::SystemError(format!("Failed to fetch certificate: {}", e)))?; - - let cert_pem = response - .text() - .await - .map_err(|e| ControllerError::SystemError(format!("Failed to read certificate: {}", e)))?; - - Ok(Certificate::from_pem(cert_pem)) -} - -// Helper to create TLS-configured endpoint with Amazon CA certificate -// todo: when we resolve the pem issue, also remove the need to have both inputs which are basically the same -async fn create_tls_endpoint(url: &'static str, domain: &str) -> Result { - // Fetch Amazon root CA dynamically - let cert = fetch_amazon_root_ca().await?; - - let tls_config = ClientTlsConfig::new() - .domain_name(domain) - .ca_certificate(cert); - - let endpoint = Endpoint::from_static(url) - .tls_config(tls_config) - .map_err(|e| ControllerError::SystemError(format!("TLS config error: {}", e)))?; - - Ok(endpoint) -} - -#[derive(Error, Debug)] -pub enum ControllerError { - #[error("Bad context: {0}")] - BadContext(String), - #[error("Runtime error: {0}")] - RuntimeError(String), - #[error("System error: {0}")] - SystemError(String), - #[error("gRPC error: {0}")] - GrpcError(#[from] tonic::Status), - #[error("Task error: {0}")] - TaskError(String), -} - -impl From for ControllerError { - fn from(err: tonic::transport::Error) -> Self { - ControllerError::SystemError(format!("Transport error: {:?}", err)) - } -} +use prost::Message; +use tonic::transport::Endpoint; +// Python error conversions impl From for PyErr { - // can better map errors in the future fn from(err: ControllerError) -> Self { exceptions::PyRuntimeError::new_err(err.to_string()) } } -impl From for PyErr { - fn from(err: AuthConfigError) -> Self { +impl From for PyErr { + fn from(err: crate::auth::AuthConfigError) -> Self { exceptions::PyRuntimeError::new_err(err.to_string()) } } -enum ChannelType { - Plain(tonic::transport::Channel), - Authenticated(crate::auth::AuthService), -} - -#[derive(Clone, Debug)] -pub enum StateClient { - Plain(StateServiceClient), - Authenticated(StateServiceClient>), -} - -impl StateClient { - pub async fn watch( - &mut self, - request: impl tonic::IntoRequest, - ) -> Result>, tonic::Status> { - match self { - StateClient::Plain(client) => client.watch(request).await, - StateClient::Authenticated(client) => client.watch(request).await, - } - } -} - -#[derive(Clone, Debug)] -pub enum QueueClient { - Plain(QueueServiceClient), - Authenticated(QueueServiceClient>), -} - -impl QueueClient { - pub async fn enqueue_action( - &mut self, - request: impl tonic::IntoRequest, - ) -> Result, tonic::Status> { - match self { - QueueClient::Plain(client) => client.enqueue_action(request).await, - QueueClient::Authenticated(client) => client.enqueue_action(request).await, - } - } -} - -struct CoreBaseController { - channel: ChannelType, - informer: OnceCell>, - state_client_cache: OnceLock, - queue_client_cache: OnceLock, - shared_queue: mpsc::Sender, - rx_of_shared_queue: Arc>>, -} - -impl CoreBaseController { - // Helper methods to get cached clients (constructed once, reused thereafter) - fn state_client(&self) -> StateClient { - self.state_client_cache - .get_or_init(|| match &self.channel { - ChannelType::Plain(ch) => StateClient::Plain(StateServiceClient::new(ch.clone())), - ChannelType::Authenticated(ch) => { - StateClient::Authenticated(StateServiceClient::new(ch.clone())) - } - }) - .clone() - } - - fn queue_client(&self) -> QueueClient { - self.queue_client_cache - .get_or_init(|| match &self.channel { - ChannelType::Plain(ch) => QueueClient::Plain(QueueServiceClient::new(ch.clone())), - ChannelType::Authenticated(ch) => { - QueueClient::Authenticated(QueueServiceClient::new(ch.clone())) - } - }) - .clone() - } - - pub fn new_with_auth() -> Result, ControllerError> { - use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; - use tower::ServiceBuilder; - - info!("Creating CoreBaseController from _UNION_EAGER_API_KEY env var (with auth)"); - // Read from env var and use auth - let api_key = std::env::var("_UNION_EAGER_API_KEY").map_err(|_| { - ControllerError::SystemError( - "_UNION_EAGER_API_KEY env var must be provided".to_string(), - ) - })?; - let auth_config = AuthConfig::new_from_api_key(&api_key).expect("Bad api key"); - let endpoint_url = auth_config.endpoint.clone(); - - let endpoint_static: &'static str = - Box::leak(Box::new(endpoint_url.clone().into_boxed_str())); - // shared queue - let (shared_tx, rx_of_shared_queue) = mpsc::channel::(64); - - let rt = get_runtime(); - let channel = rt.block_on(async { - // todo: escape hatch for localhost - // Strip "https://" to get just the hostname for TLS config - let domain = endpoint_url.strip_prefix("https://").ok_or_else(|| { - ControllerError::SystemError( - "Endpoint must start with https:// when using auth".to_string(), - ) - })?; - - // Create TLS-configured endpoint - let endpoint = create_tls_endpoint(endpoint_static, domain).await?; - let channel = endpoint.connect().await.map_err(ControllerError::from)?; - - let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config.clone())); - let auth_channel = ServiceBuilder::new() - .layer(AuthLayer::new(authenticator, channel.clone())) - .service(channel); - - Ok::<_, ControllerError>(ChannelType::Authenticated(auth_channel)) - })?; - - let real_base_controller = CoreBaseController { - channel, - informer: OnceCell::new(), - state_client_cache: OnceLock::new(), - queue_client_cache: OnceLock::new(), - shared_queue: shared_tx, - rx_of_shared_queue: Arc::new(tokio::sync::Mutex::new(rx_of_shared_queue)), - }; - - let real_base_controller = Arc::new(real_base_controller); - // Start the background worker - let controller_clone = real_base_controller.clone(); - rt.spawn(async move { - controller_clone.bg_worker().await; - }); - Ok(real_base_controller) - } - - pub fn new_without_auth(endpoint: String) -> Result, ControllerError> { - let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); - // shared queue - let (shared_tx, rx_of_shared_queue) = mpsc::channel::(64); - - let rt = get_runtime(); - let channel = rt.block_on(async { - let chan = if endpoint.starts_with("http://") { - let endpoint = Endpoint::from_static(endpoint_static); - endpoint.connect().await.map_err(ControllerError::from)? - } else if endpoint.starts_with("https://") { - // Strip "https://" to get just the hostname for TLS config - let domain = endpoint.strip_prefix("https://").ok_or_else(|| { - ControllerError::SystemError("Endpoint must start with https://".to_string()) - })?; - - // Create TLS-configured endpoint - let endpoint = create_tls_endpoint(endpoint_static, domain).await?; - endpoint.connect().await.map_err(ControllerError::from)? - } else { - return Err(ControllerError::SystemError(format!( - "Malformed endpoint {}", - endpoint - ))); - }; - Ok::<_, ControllerError>(ChannelType::Plain(chan)) - })?; - - let real_base_controller = CoreBaseController { - channel, - informer: OnceCell::new(), - state_client_cache: OnceLock::new(), - queue_client_cache: OnceLock::new(), - shared_queue: shared_tx, - rx_of_shared_queue: Arc::new(tokio::sync::Mutex::new(rx_of_shared_queue)), - }; - - let real_base_controller = Arc::new(real_base_controller); - // Start the background worker - let controller_clone = real_base_controller.clone(); - rt.spawn(async move { - controller_clone.bg_worker().await; - }); - Ok(real_base_controller) - } - - async fn bg_worker(&self) { - const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(100); - const MAX_RETRIES: u32 = 5; - - debug!( - "Launching core controller background task on thread {:?}", - std::thread::current().name() - ); - loop { - // Receive actions from shared queue - let mut rx = self.rx_of_shared_queue.lock().await; - match rx.recv().await { - Some(mut action) => { - let run_name = &action - .action_id - .run - .as_ref() - .map_or(String::from(""), |i| i.name.clone()); - debug!( - "Controller worker processing action {}::{}", - run_name, action.action_id.name - ); - - // Drop the mutex guard before processing - drop(rx); - - match self.handle_action(&mut action).await { - Ok(_) => {} - Err(e) => { - error!("Error in controller loop: {:?}", e); - // Handle backoff and retry logic - sleep(MIN_BACKOFF_ON_ERR).await; - action.retries += 1; - - if action.retries > MAX_RETRIES { - error!( - "Controller failed processing {}::{}, system retries {} crossed threshold {}", - run_name, action.action_id.name, action.retries, MAX_RETRIES - ); - action.client_err = Some(format!( - "Controller failed {}::{}, system retries {} crossed threshold {}", - run_name, action.action_id.name, action.retries, MAX_RETRIES - )); - - // Fire completion event for failed action - if let Some(informer) = self.informer.get() { - // todo: check these two errors - - // Before firing completion event, update the action in the - // informer, otherwise client_err will not be set. - let _ = informer.set_action_client_err(&action).await; - let _ = informer - .fire_completion_event(&action.action_id.name) - .await; - } else { - error!( - "Max retries hit for action but informer still not yet initialized for action: {}", - action.action_id.name - ); - } - } else { - // Re-queue the action for retry - info!( - "Re-queuing action {}::{} for retry, attempt {}/{}", - run_name, action.action_id.name, action.retries, MAX_RETRIES - ); - if let Err(send_err) = self.shared_queue.send(action).await { - error!("Failed to re-queue action for retry: {}", send_err); - } - } - } - } - } - None => { - warn!("Shared queue channel closed, stopping bg_worker"); - break; - } - } - } - } - - async fn handle_action(&self, action: &mut Action) -> Result<(), ControllerError> { - if !action.started { - // Action not started, launch it - self.bg_launch(action).await?; - } else if action.is_action_terminal() { - // Action is terminal, fire completion event - if let Some(informer) = self.informer.get() { - debug!( - "handle action firing completion event for {:?}", - &action.action_id.name - ); - informer - .fire_completion_event(&action.action_id.name) - .await?; - } else { - error!( - "Informer not yet initialized for action: {}", - action.action_id.name - ); - return Err(ControllerError::BadContext(format!( - "Informer not initialized for action: {}. This may be because the informer is still starting up.", - action.action_id.name - ))); - } - } else { - // Action still in progress - debug!("Resource {} still in progress...", action.action_id.name); - } - Ok(()) - } - - async fn bg_launch(&self, action: &Action) -> Result<(), ControllerError> { - match self.launch_task(action).await { - Ok(_) => { - debug!("Successfully launched action: {}", action.action_id.name); - Ok(()) - } - Err(e) => { - error!( - "Failed to launch action: {}, error: {}", - action.action_id.name, e - ); - Err(ControllerError::RuntimeError(format!( - "Launch failed: {}", - e - ))) - } - } - } - - async fn cancel_action(&self, action: &mut Action) -> Result<(), ControllerError> { - if action.is_action_terminal() { - info!( - "Action {} is already terminal, no need to cancel.", - action.action_id.name - ); - return Ok(()); - } - - debug!("Cancelling action: {}", action.action_id.name); - action.mark_cancelled(); - - if let Some(informer) = self.informer.get() { - let _ = informer - .fire_completion_event(&action.action_id.name) - .await?; - } else { - debug!( - "Informer missing when trying to cancel action: {}", - action.action_id.name - ); - } - Ok(()) - } - - async fn get_action(&self, action_id: ActionIdentifier) -> Result { - if let Some(informer) = self.informer.get() { - let action_name = action_id.name.clone(); - match informer.get_action(action_name).await { - Some(action) => Ok(action), - None => Err(ControllerError::RuntimeError(format!( - "Action not found: {}", - action_id.name - ))), - } - } else { - Err(ControllerError::BadContext( - "Informer not initialized".to_string(), - )) - } - } - - fn create_enqueue_action_request( - &self, - action: &Action, - ) -> Result { - // todo-pr: handle trace action - let task_identifier = action - .task - .as_ref() - .and_then(|task| task.task_template.as_ref()) - .and_then(|task_template| task_template.id.as_ref()) - .and_then(|core_task_id| { - Some(TaskIdentifier { - version: core_task_id.version.clone(), - org: core_task_id.org.clone(), - project: core_task_id.project.clone(), - domain: core_task_id.domain.clone(), - name: core_task_id.name.clone(), - }) - }) - .ok_or(ControllerError::RuntimeError(format!( - "TaskIdentifier missing from Action {:?}", - action - )))?; - - let input_uri = action - .inputs_uri - .clone() - .ok_or(ControllerError::RuntimeError(format!( - "Inputs URI missing from Action {:?}", - action - )))?; - let run_output_base = - action - .run_output_base - .clone() - .ok_or(ControllerError::RuntimeError(format!( - "Run output base missing from Action {:?}", - action - )))?; - let group = action.group.clone().unwrap_or_default(); - let task_action = TaskAction { - id: Some(task_identifier), - spec: action.task.clone(), - cache_key: action - .cache_key - .as_ref() - .map(|ck| StringValue { value: ck.clone() }), - cluster: action.queue.clone().unwrap_or("".to_string()), - }; - - Ok(EnqueueActionRequest { - action_id: Some(action.action_id.clone()), - parent_action_name: Some(action.parent_action_name.clone()), - spec: Some(enqueue_action_request::Spec::Task(task_action)), - run_spec: None, - input_uri, - run_output_base, - group, - subject: String::default(), // Subject is not used in the current implementation - }) - } - - async fn launch_task(&self, action: &Action) -> Result { - if !action.started && action.task.is_some() { - let enqueue_request = self - .create_enqueue_action_request(action) - .expect("Failed to create EnqueueActionRequest"); - let mut client = self.queue_client(); - // todo: tonic doesn't seem to have wait_for_ready, or maybe the .ready is already doing this. - let enqueue_result = client.enqueue_action(enqueue_request).await; - match enqueue_result { - Ok(response) => { - debug!("Successfully launched action: {:?}", action.action_id); - Ok(response.into_inner()) - } - Err(e) => { - if e.code() == tonic::Code::AlreadyExists { - info!( - "Action {} already exists, continuing to monitor.", - action.action_id.name - ); - Ok(EnqueueActionResponse {}) - } else { - error!( - "Failed to launch action: {:?}, backing off...", - action.action_id - ); - error!("Error details: {}", e); - // Handle backoff logic here - Err(e) - } - } - } - } else { - debug!( - "Action {} is already started or has no task, skipping launch.", - action.action_id.name - ); - Ok(EnqueueActionResponse {}) - } - } - - pub async fn _submit_action(&self, action: Action) -> Result { - let action_name = action.action_id.name.clone(); - let parent_action_name = action.parent_action_name.clone(); - // The first action that gets submitted determines the run_id that will be used. - // This is obviously not going to work, - - let run_id = action - .action_id - .run - .clone() - .ok_or(ControllerError::RuntimeError(format!( - "Run ID missing from submit action {}", - action_name.clone() - )))?; - let informer: &Arc = self - .informer // OnceCell> - .get_or_try_init(|| async move { - info!("Creating informer set to run_id {:?}", run_id); - let inf = Arc::new(Informer::new( - self.state_client(), - run_id, - parent_action_name, - self.shared_queue.clone(), - )); - - Informer::start(inf.clone()).await?; - - // Using PyErr for now, but any errors coming from the informer will not really - // be py errs, will need to add and map later. - Ok::, ControllerError>(inf) - }) - .await?; - let (done_tx, done_rx) = oneshot::channel(); - informer.submit_action(action, done_tx).await?; - - done_rx.await.map_err(|_| { - ControllerError::BadContext(String::from("Failed to receive done signal from informer")) - })?; - debug!( - "Action {} complete, looking up final value and returning", - action_name - ); - - // get the action and return it - let final_action = informer.get_action(action_name).await; - final_action.ok_or(ControllerError::BadContext(String::from( - "Action not found after done", - ))) - } -} - /// Base class for RemoteController to eventually inherit from #[pyclass(subclass)] struct BaseController(Arc); @@ -637,263 +63,6 @@ impl BaseController { Ok(BaseController(core_base)) } - #[staticmethod] - fn try_list_tasks(py: Python<'_>) -> PyResult> { - future_into_py(py, async move { - use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; - use flyteidl2::flyteidl::common::ListRequest; - use flyteidl2::flyteidl::common::ProjectIdentifier; - use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; - use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; - use tonic::Code; - use tower::ServiceBuilder; - - let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { - warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); - String::new() - }); - - let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; - let endpoint = auth_config.endpoint.clone(); - let static_endpoint = endpoint.clone().leak(); - // Strip "https://" (8 chars) to get just the hostname for TLS config - let domain = endpoint.strip_prefix("https://").ok_or_else(|| { - ControllerError::SystemError("Endpoint must start with https://".to_string()) - })?; - let endpoint = create_tls_endpoint(static_endpoint, domain).await?; - let channel = endpoint.connect().await.map_err(ControllerError::from)?; - - let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); - - let auth_handling_channel = ServiceBuilder::new() - .layer(AuthLayer::new(authenticator, channel.clone())) - .service(channel); - - let mut task_client = TaskServiceClient::new(auth_handling_channel); - - let list_request_base = ListRequest { - limit: 100, - ..Default::default() - }; - let req = ListTasksRequest { - request: Some(list_request_base), - known_filters: vec![], - scope_by: Some(list_tasks_request::ScopeBy::ProjectId(ProjectIdentifier { - organization: "demo".to_string(), - domain: "development".to_string(), - name: "flytesnacks".to_string(), - })), - }; - - let mut attempts = 0; - let final_result = loop { - let result = task_client.list_tasks(req.clone()).await; - match result { - Ok(response) => { - println!("Success: {:?}", response.into_inner()); - break Ok(true); - } - Err(status) if status.code() == Code::Unauthenticated && attempts < 1 => { - attempts += 1; - continue; - } - Err(status) => { - eprintln!("Error calling gRPC: {}", status); - break Err(exceptions::PyRuntimeError::new_err(format!( - "gRPC error: {}", - status - ))); - } - } - }; - warn!("Finished try_list_tasks with result {:?}", final_result); - final_result - }) - } - - #[staticmethod] - fn try_watch(py: Python<'_>) -> PyResult> { - future_into_py(py, async move { - use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; - use flyteidl2::flyteidl::common::ActionIdentifier; - use flyteidl2::flyteidl::common::RunIdentifier; - use flyteidl2::flyteidl::workflow::watch_request::Filter; - use flyteidl2::flyteidl::workflow::WatchRequest; - use std::time::Duration; - use tokio::time::sleep; - use tower::ServiceBuilder; - - info!("Starting watch example with authentication and retry..."); - - // Read in the api key which gives us the endpoint to connect to as well as the credentials - let api_key = std::env::var("_UNION_EAGER_API_KEY").unwrap_or_else(|_| { - warn!("_UNION_EAGER_API_KEY env var not set, using empty string"); - String::new() - }); - - let auth_config = AuthConfig::new_from_api_key(api_key.as_str())?; - let endpoint = auth_config.endpoint.clone(); - let static_endpoint = endpoint.clone().leak(); - // Strip "https://" (8 chars) to get just the hostname for TLS config - let domain = endpoint.strip_prefix("https://").ok_or_else(|| { - ControllerError::SystemError("Endpoint must start with https://".to_string()) - })?; - let endpoint = create_tls_endpoint(static_endpoint, domain).await?; - let channel = endpoint.connect().await.map_err(ControllerError::from)?; - - let authenticator = Arc::new(ClientCredentialsAuthenticator::new(auth_config)); - - // Wrap channel with auth layer - ALL calls now automatically authenticated! - let auth_channel = ServiceBuilder::new() - .layer(AuthLayer::new(authenticator, channel.clone())) - .service(channel); - - let mut client = StateServiceClient::new(auth_channel); - - // Watch configuration (matching Python example) - let run_id = RunIdentifier { - org: "demo".to_string(), - project: "flytesnacks".to_string(), - domain: "development".to_string(), - name: "r57jklb4mw4k6bkb2p88".to_string(), - }; - let parent_action_name = "a0".to_string(); - - // Retry parameters (matching Python defaults) - let min_watch_backoff = Duration::from_secs(1); - let max_watch_backoff = Duration::from_secs(30); - let max_watch_retries = 10; - - // Watch loop with retry logic (following Python _informer.py pattern) - let mut retries = 0; - let mut message_count = 0; - - while retries < max_watch_retries { - if retries >= 1 { - warn!("Watch retrying, attempt {}/{}", retries, max_watch_retries); - } - - // Create watch request - let request = WatchRequest { - filter: Some(Filter::ParentActionId(ActionIdentifier { - name: parent_action_name.clone(), - run: Some(run_id.clone()), - })), - }; - - // Establish the watch stream - // The outer retry loop handles failures, middleware handles auth refresh - let stream_result = client.watch(request.clone()).await; - - match stream_result { - Ok(response) => { - info!("Successfully established watch stream"); - let mut stream = response.into_inner(); - - // Process messages from the stream - loop { - match stream.message().await { - Ok(Some(watch_response)) => { - // Successfully received a message - reset retry counter - retries = 0; - message_count += 1; - - // Process the message (enum with ActionUpdate or ControlMessage) - use flyteidl2::flyteidl::workflow::watch_response::Message; - match &watch_response.message { - Some(Message::ControlMessage(control_msg)) => { - if control_msg.sentinel { - info!( - "Received Sentinel for parent action: {}", - parent_action_name - ); - } - } - Some(Message::ActionUpdate(action_update)) => { - info!( - "Received action update for: {} (phase: {:?})", - action_update - .action_id - .as_ref() - .map(|id| id.name.as_str()) - .unwrap_or("unknown"), - action_update.phase - ); - - if !action_update.output_uri.is_empty() { - info!("Output URI: {}", action_update.output_uri); - } - - if action_update.phase == 4 { - // PHASE_FAILED - if action_update.error.is_some() { - error!( - "Action failed with error: {:?}", - action_update.error - ); - } - } - } - None => { - warn!("Received empty watch response"); - } - } - - // For demo purposes, exit after receiving a few messages - if message_count >= 50 { - info!("Received {} messages, exiting demo", message_count); - return Ok(true); - } - } - Ok(None) => { - warn!("Watch stream ended gracefully"); - break; // Stream ended, retry - } - Err(status) => { - error!("Error receiving message from watch stream: {}", status); - - // Check if it's an auth error - if status.code() == tonic::Code::Unauthenticated { - warn!("Unauthenticated error - credentials will be refreshed on retry"); - } - - break; // Break inner loop to retry - } - } - } - } - Err(status) => { - error!("Failed to establish watch stream: {}", status); - - if status.code() == tonic::Code::Unauthenticated { - warn!("Unauthenticated error - credentials will be refreshed on retry"); - } - } - } - - // Increment retry counter and apply exponential backoff - retries += 1; - if retries < max_watch_retries { - let backoff = min_watch_backoff - .saturating_mul(2_u32.pow(retries as u32)) - .min(max_watch_backoff); - warn!("Watch failed, retrying in {:?}...", backoff); - sleep(backoff).await; - } - } - - // Exceeded max retries - error!( - "Watch failure retries crossed threshold {}/{}, exiting!", - retries, max_watch_retries - ); - Err(exceptions::PyRuntimeError::new_err(format!( - "Max watch retries ({}) exceeded", - max_watch_retries - ))) - }) - } - /// `async def submit(self, action: Action) -> Action` /// /// Enqueue `action`. @@ -901,7 +70,7 @@ impl BaseController { let real_base = self.0.clone(); let py_fut = future_into_py(py, async move { let action_id = action.action_id.clone(); - real_base._submit_action(action).await.map_err(|e| { + real_base.submit_action(action).await.map_err(|e| { error!("Error submitting action {:?}: {:?}", action_id, e); exceptions::PyRuntimeError::new_err(format!("Failed to submit action: {}", e)) }) @@ -925,20 +94,48 @@ impl BaseController { &self, py: Python<'py>, action_id: ActionIdentifier, + parent_action_name: String, ) -> PyResult> { let real_base = self.0.clone(); let py_fut = future_into_py(py, async move { - real_base.get_action(action_id.clone()).await.map_err(|e| { - error!("Error getting action {:?}: {:?}", action_id, e); - exceptions::PyRuntimeError::new_err(format!("Failed to cancel action: {}", e)) - }) + real_base + .get_action(action_id.clone(), parent_action_name.as_str()) + .await + .map_err(|e| { + error!("Error getting action {:?}: {:?}", action_id, e); + exceptions::PyRuntimeError::new_err(format!("Failed to cancel action: {}", e)) + }) + }); + py_fut + } + + fn finalize_parent_action<'py>( + &self, + py: Python<'py>, + // run_id: RunIdentifier, + run_id_bytes: &[u8], + parent_action_name: &str, + ) -> PyResult> { + let base = self.0.clone(); + let parent_action_string = parent_action_name.to_string(); + let run_id = RunIdentifier::decode(run_id_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to decode RunIdentifier: {}", + e + )) + })?; + let py_fut = future_into_py(py, async move { + base.finalize_parent_action(&run_id, &parent_action_string) + .await; + warn!("Parent action finalize: {}", parent_action_string); + Ok(()) }); py_fut } } #[pymodule] -fn flyte_controller_base(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { +fn flyte_controller_base(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { static INIT: std::sync::Once = std::sync::Once::new(); INIT.call_once(|| { let subscriber = FmtSubscriber::builder() diff --git a/rs_controller/uv.lock b/rs_controller/uv.lock new file mode 100644 index 000000000..641e23575 --- /dev/null +++ b/rs_controller/uv.lock @@ -0,0 +1,8 @@ +version = 1 +revision = 3 +requires-python = ">=3.10" + +[[package]] +name = "flyte-controller-base" +version = "2.0.0b33.dev33+g3d028ba" +source = { editable = "." } diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 670a46351..5789ad257 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -117,15 +117,20 @@ def create_controller( from ._local_controller import LocalController controller = LocalController() - case "remote" | "hybrid": - # from flyte._internal.controllers.remote import create_remote_controller - # - # controller = create_remote_controller(**kwargs) - from flyte._internal.controllers.remote._r_controller import RemoteController + case "remote": + from flyte._internal.controllers.remote import create_remote_controller + + controller = create_remote_controller(**kwargs) + # from flyte._internal.controllers.remote._r_controller import RemoteController # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, # max_system_retries=5) - controller = RemoteController(workers=10, max_system_retries=5) + # controller = RemoteController(workers=10, max_system_retries=5) + case "hybrid": + from flyte._internal.controllers.remote._r_controller import RemoteController + + controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) + # controller = RemoteController(workers=10, max_system_retries=5) case "rust": # hybrid case, despite the case statement above, meant for local runs not inside docker from flyte._internal.controllers.remote._r_controller import RemoteController diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index d101cf544..73dfdd0a4 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -395,16 +395,15 @@ async def finalize_parent_action(self, action_id: ActionID): This method is invoked when the parent action is finished. It will finalize the run and upload the outputs to the control plane. """ - # todo-pr: implement any cleanup # translate the ActionID python object to something handleable in pyo3 # will need to do this after we have multiple informers. - # run_id = identifier_pb2.RunIdentifier( - # name=action_id.run_name, - # project=action_id.project, - # domain=action_id.domain, - # org=action_id.org, - # ) - # await super()._finalize_parent_action(run_id=run_id, parent_action_name=action_id.name) + run_id = identifier_pb2.RunIdentifier( + name=action_id.run_name, + project=action_id.project, + domain=action_id.domain, + org=action_id.org, + ) + await super().finalize_parent_action(run_id_bytes=run_id.SerializeToString(), parent_action_name=action_id.name) self._parent_action_semaphore.pop(unique_action_name(action_id), None) self._parent_action_task_call_sequence.pop(unique_action_name(action_id), None) From 8246b2774c8d94572b7e42d5ba21327b495bba15 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sun, 14 Dec 2025 10:48:25 +0800 Subject: [PATCH 12/22] ignore pr into pr (#424) Merged main into ctrl-rs. This PR is needed to get things working again. --------- Signed-off-by: Yee Hing Tong --- rs_controller/Cargo.lock | 4 +++- rs_controller/Cargo.toml | 6 +++--- rs_controller/Makefile | 4 +++- rs_controller/pyproject.toml | 2 +- src/flyte/_bin/runtime.py | 2 +- src/flyte/_internal/controllers/__init__.py | 3 ++- .../_internal/controllers/remote/_r_controller.py | 14 +++++++------- 7 files changed, 20 insertions(+), 15 deletions(-) diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index d72a2f3ac..90239fcfe 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -318,7 +318,9 @@ dependencies = [ [[package]] name = "flyteidl2" -version = "0.1.0" +version = "2.0.0-alpha15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5abd4c0481acd1132137e34f92c5e4e1aedf86aa49c8f770283349cfc6a618" dependencies = [ "async-trait", "futures", diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index c851c35eb..00758a655 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -32,9 +32,9 @@ tracing = "0.1" tracing-subscriber = "0.3" async-trait = "0.1" thiserror = "1.0" -# Using local flyteidl2 without extension-module -flyteidl2 = { path = "/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust" } -# flyteidl2 = "=2.0.0-alpha14" +# Uncomment this if you need to use local flyteidl2 +#flyteidl2 = { path = "/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust" } +flyteidl2 = "=2.0.0-alpha15" reqwest = { version = "0.12", features = ["json", "rustls-tls"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 7864bf29f..5bdec2a93 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -22,12 +22,14 @@ VERSION := $(shell git describe --tags --always --dirty 2>/dev/null | se dist-dirs: mkdir -p $(DIST_DIRS) $(CARGO_CACHE_DIR) +# Add the below to use local flyteidl2 +# -v /Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust:/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust + define BUILD_WHEELS_RECIPE docker run --rm \ -v $(PWD):/io \ -v $(PWD)/docker_cargo_cache:/root/.cargo/registry \ -v $(CLOUD_REPO):/cloud \ - -v /Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust:/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust \ wheel-builder:$(1) /bin/bash -c "\ cd /io; \ sed -i 's/^version = .*/version = \"$(VERSION)\"/' pyproject.toml; \ diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 37ef3b58b..eb3fbbd8c 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b33.dev35+g4f8cb912.dirty" +version = "2.0.0b36.dev16+g8af195c9.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] diff --git a/src/flyte/_bin/runtime.py b/src/flyte/_bin/runtime.py index 851801fa4..68372ff0e 100644 --- a/src/flyte/_bin/runtime.py +++ b/src/flyte/_bin/runtime.py @@ -132,7 +132,7 @@ def main( if tgz or pkl: bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version) # Controller is created with the same kwargs as init, so that it can be used to run tasks - controller = create_controller(ct="remote", **controller_kwargs) + controller = create_controller(ct="rust", **controller_kwargs) ic = ImageCache.from_transport(image_cache) if image_cache else None diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 5789ad257..398a68cdd 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -135,7 +135,8 @@ def create_controller( # hybrid case, despite the case statement above, meant for local runs not inside docker from flyte._internal.controllers.remote._r_controller import RemoteController - controller = RemoteController(endpoint="http://localhost:8090", workers=10, max_system_retries=5) + # controller = RemoteController(endpoint="http://localhost:8090", workers=10, max_system_retries=5) + controller = RemoteController(workers=10, max_system_retries=5) case _: raise ValueError(f"{ct} is not a valid controller type.") diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 73dfdd0a4..f6b46a979 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -11,7 +11,7 @@ from typing import Any, DefaultDict, Tuple, TypeVar from flyte_controller_base import Action, BaseController -from flyteidl2.common import identifier_pb2 +from flyteidl2.common import identifier_pb2, phase_pb2 from flyteidl2.workflow import run_definition_pb2 import flyte @@ -80,7 +80,7 @@ async def handle_action_failure(action: Action, task_name: str) -> Exception: err = err_pb err = err or action.client_err - if not err and action.phase_value == run_definition_pb2.PHASE_FAILED: + if not err and action.phase_value == phase_pb2.ACTION_PHASE_FAILED: logger.error(f"Server reported failure for action {action.name}, checking error file.") try: # Deserialize action_id to get the name @@ -278,7 +278,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg raise # If the action is aborted, we should abort the controller as well - if n.phase_value == run_definition_pb2.PHASE_ABORTED: + if n.phase_value == phase_pb2.ACTION_PHASE_ABORTED: n_action_id_pb = identifier_pb2.ActionIdentifier() n_action_id_pb.ParseFromString(n.action_id_bytes) logger.warning( @@ -288,7 +288,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg f"Action {n_action_id_pb.name} was aborted, aborting current Action {current_action_id.name}" ) - if n.phase_value == run_definition_pb2.PHASE_TIMED_OUT: + if n.phase_value == phase_pb2.ACTION_PHASE_TIMED_OUT: n_action_id_pb = identifier_pb2.ActionIdentifier() n_action_id_pb.ParseFromString(n.action_id_bytes) logger.warning( @@ -298,7 +298,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg f"Action {n_action_id_pb.name} timed out, raising exception in current Action {current_action_id.name}" ) - if n.has_error() or n.phase_value == run_definition_pb2.PHASE_FAILED: + if n.has_error() or n.phase_value == phase_pb2.ACTION_PHASE_FAILED: exc = await handle_action_failure(n, _task.name) raise exc @@ -457,7 +457,7 @@ async def get_action_outputs( if prev_action is None: return TraceInfo(func_name, sub_action_id, _interface, inputs_uri), False - if prev_action.phase_value == run_definition_pb2.PHASE_FAILED: + if prev_action.phase_value == phase_pb2.ACTION_PHASE_FAILED: if prev_action.has_error(): # Deserialize err from bytes from flyteidl2.core import execution_pb2 @@ -623,7 +623,7 @@ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, await self.cancel_action(action) raise - if n.has_error() or n.phase_value == run_definition_pb2.PHASE_FAILED: + if n.has_error() or n.phase_value == phase_pb2.ACTION_PHASE_FAILED: exc = await handle_action_failure(n, task_name) raise exc From 2b17e424cf6cedb41bc240a362ea68bfea0d4488 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Sat, 13 Dec 2025 18:49:14 -0800 Subject: [PATCH 13/22] remove rs_controller from gitignore Signed-off-by: Yee Hing Tong --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index ebbbf02b5..efb0f6f99 100644 --- a/.gitignore +++ b/.gitignore @@ -9,9 +9,6 @@ __pycache__/ # C extensions *.so -# Temporary -rs_controller/ - # Distribution / packaging .Python build/ From 6e63f23bc631d1fd4ad46a8a024d3586568701e6 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 16 Dec 2025 10:07:44 +0800 Subject: [PATCH 14/22] pr into ctrl-rs - error handling (#427) * Update informer errors to use its own errors. * Add watch_for_errors. * Store the singular controller worker into a field so it can be raced in watch_for_errors. * Also race the informer error channel. * Some cleanup: remove unused fields from core base controller. * Adding a timeout for now to thread.join in _r_controller's stop. Otherwise sync tasks were hanging. needs further investigation. --------- Signed-off-by: Yee Hing Tong --- examples/advanced/cancel_tasks.py | 13 +- examples/basics/hello.py | 16 ++ examples/stress/crash_recovery_trace.py | 12 ++ examples/stress/large_dir_io.py | 12 ++ examples/stress/large_file_io.py | 12 ++ examples/stress/long_recovery.py | 12 +- rs_controller/pyproject.toml | 2 +- rs_controller/src/core.rs | 118 ++++++++++--- rs_controller/src/error.rs | 20 ++- rs_controller/src/informer.rs | 160 ++++++++++++++---- rs_controller/src/lib.rs | 32 +++- .../controllers/remote/_r_controller.py | 9 +- 12 files changed, 344 insertions(+), 74 deletions(-) diff --git a/examples/advanced/cancel_tasks.py b/examples/advanced/cancel_tasks.py index c776b7c7e..64ca870f3 100644 --- a/examples/advanced/cancel_tasks.py +++ b/examples/advanced/cancel_tasks.py @@ -2,7 +2,18 @@ import flyte.errors -env = flyte.TaskEnvironment("cancel") +from pathlib import Path + +import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + + +env = flyte.TaskEnvironment("cancel", image=rs_controller_image) @env.task diff --git a/examples/basics/hello.py b/examples/basics/hello.py index 46095d0ad..226590d45 100644 --- a/examples/basics/hello.py +++ b/examples/basics/hello.py @@ -1,9 +1,18 @@ import flyte +from flyte._image import PythonWheels +from pathlib import Path + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + # TaskEnvironments provide a simple way of grouping configuration used by tasks (more later). env = flyte.TaskEnvironment( name="hello_world", resources=flyte.Resources(memory="250Mi"), + image=rs_controller_image, ) @@ -26,6 +35,13 @@ def main(x_list: list[int]) -> float: return y_mean +@env.task +def main2(x_list: list[int]) -> float: + y = fn(x_list[0]) + print(f"y = {y}!!!", flush=True) + return float(y) + + if __name__ == "__main__": flyte.init_from_config() # establish remote connection from within your script. run = flyte.run(main, x_list=list(range(10))) # run remotely inline and pass data. diff --git a/examples/stress/crash_recovery_trace.py b/examples/stress/crash_recovery_trace.py index 1650b92e2..6ad140177 100644 --- a/examples/stress/crash_recovery_trace.py +++ b/examples/stress/crash_recovery_trace.py @@ -3,9 +3,21 @@ import flyte import flyte.errors +from pathlib import Path + +import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + + env = flyte.TaskEnvironment( name="crash_recovery_trace", resources=flyte.Resources(memory="250Mi", cpu=1), + image=rs_controller_image ) diff --git a/examples/stress/large_dir_io.py b/examples/stress/large_dir_io.py index 5c7005446..c692f7e05 100644 --- a/examples/stress/large_dir_io.py +++ b/examples/stress/large_dir_io.py @@ -10,12 +10,24 @@ import flyte.io import flyte.storage +from pathlib import Path + +import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + + env = flyte.TaskEnvironment( "large_dir_io", resources=flyte.Resources( cpu=4, memory="16Gi", ), + image=rs_controller_image, ) diff --git a/examples/stress/large_file_io.py b/examples/stress/large_file_io.py index 09575bfca..2a6f15537 100644 --- a/examples/stress/large_file_io.py +++ b/examples/stress/large_file_io.py @@ -9,12 +9,24 @@ import flyte.io import flyte.storage +from pathlib import Path + +import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + + env = flyte.TaskEnvironment( "large_file_io", resources=flyte.Resources( cpu=4, memory="16Gi", ), + image=rs_controller_image, ) diff --git a/examples/stress/long_recovery.py b/examples/stress/long_recovery.py index 348787da6..ae9f731b7 100644 --- a/examples/stress/long_recovery.py +++ b/examples/stress/long_recovery.py @@ -5,7 +5,17 @@ import flyte import flyte.errors -env = flyte.TaskEnvironment(name="controller_stressor") +from pathlib import Path + +import flyte +from flyte._image import PythonWheels + +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") +base = flyte.Image.from_debian_base() +rs_controller_image = base.clone(addl_layer=wheel_layer) + +env = flyte.TaskEnvironment(name="controller_stressor", image=rs_controller_image) @env.task diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index eb3fbbd8c..73f549785 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b36.dev16+g8af195c9.dirty" +version = "2.0.0b36.dev23+gdb40f621.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 417640cb4..31b534aa3 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -2,20 +2,18 @@ //! This module can be used by both Python bindings and standalone Rust binaries use std::sync::Arc; -use std::sync::OnceLock; use std::time::Duration; use pyo3_async_runtimes::tokio::get_runtime; use tokio::sync::mpsc; use tokio::sync::oneshot; -use tokio::sync::OnceCell; use tokio::time::sleep; use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; use tonic::Status; use tower::ServiceBuilder; use tracing::{debug, error, info, warn}; -use crate::action::{Action, ActionType}; +use crate::action::Action; use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; use crate::error::{ControllerError, InformerError}; use crate::informer::{Informer, InformerCache}; @@ -110,13 +108,12 @@ impl QueueClient { } pub struct CoreBaseController { - channel: ChannelType, informer_cache: InformerCache, - state_client: StateClient, queue_client: QueueClient, shared_queue: mpsc::Sender, shared_queue_rx: Arc>>, - failure_rx: mpsc::Receiver, + failure_rx: Arc>>>, + bg_worker_handle: Arc>>>, } impl CoreBaseController { @@ -178,21 +175,24 @@ impl CoreBaseController { InformerCache::new(state_client.clone(), shared_tx.clone(), failure_tx); let real_base_controller = CoreBaseController { - channel, informer_cache, - state_client, queue_client, shared_queue: shared_tx, shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), - failure_rx, + failure_rx: Arc::new(std::sync::Mutex::new(Some(failure_rx))), + bg_worker_handle: Arc::new(std::sync::Mutex::new(None)), }; let real_base_controller = Arc::new(real_base_controller); // Start the background worker let controller_clone = real_base_controller.clone(); - rt.spawn(async move { + let handle = rt.spawn(async move { controller_clone.bg_worker().await; }); + + // Store the handle + *real_base_controller.bg_worker_handle.lock().unwrap() = Some(handle); + Ok(real_base_controller) } @@ -204,8 +204,7 @@ impl CoreBaseController { let rt = get_runtime(); let channel = rt.block_on(async { let chan = if endpoint.starts_with("http://") { - let endpoint = Endpoint::from_static(endpoint_static) - .keep_alive_while_idle(true); + let endpoint = Endpoint::from_static(endpoint_static).keep_alive_while_idle(true); endpoint.connect().await.map_err(ControllerError::from)? } else if endpoint.starts_with("https://") { // Strip "https://" to get just the hostname for TLS config @@ -245,21 +244,24 @@ impl CoreBaseController { InformerCache::new(state_client.clone(), shared_tx.clone(), failure_tx); let real_base_controller = CoreBaseController { - channel, informer_cache, - state_client, queue_client, shared_queue: shared_tx, shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), - failure_rx, + failure_rx: Arc::new(std::sync::Mutex::new(Some(failure_rx))), + bg_worker_handle: Arc::new(std::sync::Mutex::new(None)), }; let real_base_controller = Arc::new(real_base_controller); // Start the background worker let controller_clone = real_base_controller.clone(); - rt.spawn(async move { + let handle = rt.spawn(async move { controller_clone.bg_worker().await; }); + + // Store the handle + *real_base_controller.bg_worker_handle.lock().unwrap() = Some(handle); + Ok(real_base_controller) } @@ -314,14 +316,29 @@ impl CoreBaseController { .get(&action.get_run_identifier(), &action.parent_action_name) .await; if let Some(informer) = opt_informer { - // todo: check these two errors - // Before firing completion event, update the action in the // informer, otherwise client_err will not be set. - let _ = informer.set_action_client_err(&action).await; - let _ = informer + // todo: gain a better understanding of these two errors and handle + let res = informer.set_action_client_err(&action).await; + match res { + Ok(()) => {} + Err(e) => { + error!( + "Error setting error for failed action {}: {}", + &action.get_full_name(), + e + ) + } + } + let res = informer .fire_completion_event(&action.action_id.name) .await; + match res { + Ok(()) => {} + Err(e) => { + error!("Error firing completion event for failed action {}: {}", &action.get_full_name(), e) + } + } } else { error!( "Max retries hit for action but informer missing: {:?}", @@ -413,7 +430,8 @@ impl CoreBaseController { return Ok(()); } - debug!("Cancelling action: {}", action.action_id.name); + // debug + warn!("Cancelling action!!!: {}", action.action_id.name); action.mark_cancelled(); if let Some(informer) = self @@ -614,4 +632,62 @@ impl CoreBaseController { } } } + + pub async fn watch_for_errors(&self) -> Result<(), ControllerError> { + // Take ownership of both (can only be called once) + let handle = self.bg_worker_handle.lock().unwrap().take(); + let failure_rx = self.failure_rx.lock().unwrap().take(); + + match (handle, failure_rx) { + (Some(handle), Some(mut rx)) => { + // Race bg_worker completion vs informer errors + tokio::select! { + // bg_worker completed or panicked + result = handle => { + match result { + Ok(_) => { + error!("Background worker exited unexpectedly"); + Err(ControllerError::RuntimeError( + "Background worker exited unexpectedly".to_string(), + )) + } + Err(e) if e.is_panic() => { + error!("Background worker panicked: {:?}", e); + Err(ControllerError::RuntimeError(format!( + "Background worker panicked: {:?}", + e + ))) + } + Err(e) => { + error!("Background worker was cancelled: {:?}", e); + Err(ControllerError::RuntimeError(format!( + "Background worker cancelled: {:?}", + e + ))) + } + } + } + + // Informer error received + informer_err = rx.recv() => { + match informer_err { + Some(err) => { + error!("Informer error received: {:?}", err); + Err(ControllerError::Informer(err)) + } + None => { + error!("Informer error channel closed unexpectedly"); + Err(ControllerError::RuntimeError( + "Informer error channel closed unexpectedly".to_string(), + )) + } + } + } + } + } + _ => Err(ControllerError::RuntimeError( + "watch_for_errors already called or resources not available".to_string(), + )), + } + } } diff --git a/rs_controller/src/error.rs b/rs_controller/src/error.rs index 6bf8c73e8..90ca3b418 100644 --- a/rs_controller/src/error.rs +++ b/rs_controller/src/error.rs @@ -14,6 +14,8 @@ pub enum ControllerError { GrpcError(#[from] tonic::Status), #[error("Task error: {0}")] TaskError(String), + #[error("Informer error: {0}")] + Informer(#[from] InformerError), } impl From for ControllerError { @@ -28,7 +30,7 @@ impl From for ControllerError { } } -#[derive(Error, Debug)] +#[derive(Error, Debug, Clone)] pub enum InformerError { #[error("Informer watch failed for run {run_name}, parent action {parent_action_name}: {error_message}")] WatchFailed { @@ -36,4 +38,20 @@ pub enum InformerError { parent_action_name: String, error_message: String, }, + #[error("gRPC error in watch stream: {0}")] + GrpcError(String), + #[error("Stream error: {0}")] + StreamError(String), + #[error("Failed to send action to queue: {0}")] + QueueSendError(String), + #[error("Watch cancelled")] + Cancelled, + #[error("Bad context: {0}")] + BadContext(String), +} + +impl From for InformerError { + fn from(status: tonic::Status) -> Self { + InformerError::GrpcError(format!("{:?}", status)) + } } diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index 7195e920e..f7f915ad9 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -1,16 +1,29 @@ use crate::action::Action; use crate::core::StateClient; use crate::error::{ControllerError, InformerError}; +use tokio::time; use tokio_util::sync::CancellationToken; +/// Determine if an InformerError is retryable +fn is_retryable_error(err: &InformerError) -> bool { + match err { + // Retryable gRPC and stream errors + InformerError::GrpcError(_) => true, + InformerError::StreamError(_) => true, + + // Don't retry these + InformerError::Cancelled => false, + InformerError::BadContext(_) => false, + InformerError::QueueSendError(_) => false, + InformerError::WatchFailed { .. } => false, + } +} use flyteidl2::flyteidl::common::ActionIdentifier; use flyteidl2::flyteidl::common::RunIdentifier; -use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; use flyteidl2::flyteidl::workflow::{ watch_request, watch_response::Message, WatchRequest, WatchResponse, }; -use pyo3_async_runtimes::tokio::run; use std::collections::HashMap; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -18,12 +31,7 @@ use std::time::Duration; use tokio::select; use tokio::sync::RwLock; use tokio::sync::{mpsc, oneshot, Notify}; -use tokio::time::sleep; -use tonic::transport::channel::Channel; -use tonic::transport::Endpoint; -use tracing::log::Level::Info; use tracing::{debug, error, info, warn}; -use tracing_subscriber::fmt; #[derive(Clone, Debug)] pub struct Informer { @@ -81,7 +89,7 @@ impl Informer { async fn handle_watch_response( &self, response: WatchResponse, - ) -> Result, ControllerError> { + ) -> Result, InformerError> { debug!( "Informer for {:?}::{} processing incoming message {:?}", self.run_id.name, self.parent_action_name, &response @@ -103,7 +111,7 @@ impl Informer { .action_id .as_ref() .map(|act_id| act_id.name.clone()) - .ok_or(ControllerError::RuntimeError(format!( + .ok_or(InformerError::StreamError(format!( "Action update received without a name: {:?}", action_update )))?; @@ -132,13 +140,13 @@ impl Informer { } } } else { - Err(ControllerError::BadContext( + Err(InformerError::BadContext( "No message in response".to_string(), )) } } - async fn watch_actions(&self) -> Result<(), ControllerError> { + async fn watch_actions(&self) -> Result<(), InformerError> { let action_id = ActionIdentifier { name: self.parent_action_name.clone(), run: Some(self.run_id.clone()), @@ -153,7 +161,7 @@ impl Informer { Ok(s) => s.into_inner(), Err(e) => { error!("Failed to start watch stream: {:?}", e); - return Err(ControllerError::from(e)); + return Err(InformerError::from(e)); } }; @@ -161,7 +169,7 @@ impl Informer { select! { _ = self.cancellation_token.cancelled() => { warn!("Cancellation token got - exiting from watch_actions: {}", self.parent_action_name); - return Ok(()) + return Err(InformerError::Cancelled) } result = stream.message() => { @@ -175,7 +183,7 @@ impl Informer { } Err(e) => { error!("Informer watch failed sending action back to shared queue: {:?}", e); - return Err(ControllerError::RuntimeError(format!( + return Err(InformerError::QueueSendError(format!( "Failed to send action to shared queue: {}", e ))); @@ -187,8 +195,7 @@ impl Informer { ); } Err(err) => { - // this should cascade up to the controller to restart the informer, and if there - // are too many informer restarts, the controller should fail + // this should cascade up to retry logic error!("Error in informer watch {:?}", err); return Err(err); } @@ -199,7 +206,7 @@ impl Informer { } // Stream ended, exit loop Err(e) => { error!("Error receiving message from stream: {:?}", e); - return Err(ControllerError::from(e)); + return Err(InformerError::from(e)); } } } @@ -237,7 +244,10 @@ impl Informer { { let mut completion_events = self.completion_events.write().await; completion_events.insert(action_name.clone(), done_tx); - warn!("---------> Adding completion event in submit action {:?}", action_name); + warn!( + "---------> Adding completion event in submit action {:?}", + action_name + ); } // Add action to shared queue @@ -264,7 +274,7 @@ impl Informer { action_name, ); // Maybe the action hasn't started yet. - return Ok(()) + return Ok(()); } Ok(()) } @@ -335,13 +345,19 @@ impl InformerCache { } // Create new informer (with write lock) - debug!("Acquiring write lock to create informer for: {}", informer_name); + debug!( + "Acquiring write lock to create informer for: {}", + informer_name + ); let mut map = self.cache.write().await; info!("Write lock acquired for: {}", informer_name); // Double-check it wasn't created while we were waiting for write lock if let Some(informer) = map.get(&informer_name) { - info!("RACE: Informer was created while waiting for write lock: {}", informer_name); + info!( + "RACE: Informer was created while waiting for write lock: {}", + informer_name + ); let arc_informer = Arc::clone(informer); drop(map); debug!("Write lock released after race condition"); @@ -371,33 +387,99 @@ impl InformerCache { info!("Spawning watch task for: {}", informer_name); let _watch_handle = tokio::spawn(async move { + const MAX_RETRIES: u32 = 10; + const MIN_BACKOFF_SECS: f64 = 1.0; + const MAX_BACKOFF_SECS: f64 = 30.0; + + let mut retries = 0; + let mut last_error: Option = None; debug!("Watch task started for: {}", me.parent_action_name); - let watch_actions_result = me.watch_actions().await; - // If there are errors with the watch then notify the channel - if watch_actions_result.is_err() { - let err = watch_actions_result.err().unwrap(); + while retries < MAX_RETRIES { + if retries > 0 { + warn!( + "Informer watch retrying for {}, attempt {}/{}", + me.parent_action_name, + retries + 1, + MAX_RETRIES + ); + } + + let watch_result = me.watch_actions().await; + match watch_result { + Ok(()) => { + // Clean exit (should only happen on cancellation) + info!("Watch completed cleanly for {}", me.parent_action_name); + last_error = None; + break; + } + Err(InformerError::Cancelled) => { + // Don't retry cancellations + info!( + "Watch cancelled for {}, exiting without retry", + me.parent_action_name + ); + last_error = None; + break; + } + Err(err) if is_retryable_error(&err) => { + retries += 1; + last_error = Some(err.clone()); + + warn!( + "Watch failed for {} (retry {}/{}): {:?}", + me.parent_action_name, retries, MAX_RETRIES, err + ); + + if retries < MAX_RETRIES { + // Exponential backoff: 1s, 2s, 4s, 8s, 16s, 30s (capped) + let backoff = MIN_BACKOFF_SECS * 2_f64.powi((retries - 1) as i32); + let backoff = backoff.min(MAX_BACKOFF_SECS); + warn!("Backing off for {:.2}s before retry", backoff); + time::sleep(Duration::from_secs_f64(backoff)).await; + } + } + Err(err) => { + // Non-retryable error + error!( + "Non-retryable error for {}: {:?}", + me.parent_action_name, err + ); + last_error = Some(err); + break; + } + } + } + + // Only send error if we have one (clean exits and cancellations set last_error = None) + if let Some(err) = last_error { + // We have an error - either exhausted retries or non-retryable error!( - "Informer watch_actions failed for run {}, parent action {}: {:?}", - me.run_id.name, me.parent_action_name, err + "Informer watch failed for run {}, parent action {} (retries: {}/{}): {:?}", + me.run_id.name, me.parent_action_name, retries, MAX_RETRIES, err ); let failure = InformerError::WatchFailed { run_name: me.run_id.name.clone(), parent_action_name: me.parent_action_name.clone(), - error_message: err.to_string(), + error_message: format!( + "Retries ({}/{}) exhausted. Last error: {}", + retries, MAX_RETRIES, err + ), }; if let Err(e) = failure_tx.send(failure).await { error!("Failed to send informer failure event: {:?}", e); } - } else { - info!("Informer watch_actions completed successfully for {}", me.run_id.name); } + // If last_error is None, it's a clean exit (Ok or Cancelled) - no error to send }); // save the value and ignore the returned reference. - debug!("Acquiring write lock to save watch handle for: {}", informer_name); + debug!( + "Acquiring write lock to save watch handle for: {}", + informer_name + ); let _ = informer.watch_handle.write().await.insert(_watch_handle); info!("Watch handle saved for: {}", informer_name); @@ -405,7 +487,10 @@ impl InformerCache { debug!("Waiting for informer to be ready: {}", informer_name); Self::wait_for_ready(&informer, timeout).await; - info!("<<< Returning newly created informer for: {}", informer_name); + info!( + "<<< Returning newly created informer for: {}", + informer_name + ); informer } @@ -438,7 +523,10 @@ impl InformerCache { // Quick check - if already ready, return immediately if informer.is_ready.load(Ordering::Acquire) { - info!("Informer already ready for: {}", informer.parent_action_name); + info!( + "Informer already ready for: {}", + informer.parent_action_name + ); return; } @@ -446,7 +534,10 @@ impl InformerCache { // Otherwise wait with timeout match tokio::time::timeout(timeout, ready_fut).await { Ok(_) => { - info!("Informer ready signal received for: {}", informer.parent_action_name); + info!( + "Informer ready signal received for: {}", + informer.parent_action_name + ); } Err(_) => { warn!( @@ -477,7 +568,6 @@ impl InformerCache { } } - #[cfg(test)] mod tests { use super::*; diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index b2fdcf3ce..d60f0e02c 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -10,26 +10,19 @@ pub mod proto; // Python bindings - thin wrappers around core types use std::sync::Arc; -use std::time::Duration; use pyo3::exceptions; use pyo3::prelude::*; use pyo3::types::PyAny; use pyo3_async_runtimes::tokio::future_into_py; -use tower::ServiceBuilder; use tracing::{error, info, warn}; use tracing_subscriber::FmtSubscriber; use crate::action::{Action, ActionType}; -use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; use crate::core::CoreBaseController; use crate::error::ControllerError; -use flyteidl2::flyteidl::common::{ActionIdentifier, ProjectIdentifier, RunIdentifier}; -use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; -use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; -use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; +use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; use prost::Message; -use tonic::transport::Endpoint; // Python error conversions impl From for PyErr { @@ -93,10 +86,17 @@ impl BaseController { fn get_action<'py>( &self, py: Python<'py>, - action_id: ActionIdentifier, + // action_id: ActionIdentifier, + action_id_bytes: &[u8], parent_action_name: String, ) -> PyResult> { let real_base = self.0.clone(); + let action_id = ActionIdentifier::decode(action_id_bytes).map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "Failed to decode ActionIdentifier: {}", + e + )) + })?; let py_fut = future_into_py(py, async move { real_base .get_action(action_id.clone(), parent_action_name.as_str()) @@ -132,6 +132,20 @@ impl BaseController { }); py_fut } + + fn watch_for_errors<'py>(&self, py: Python<'py>) -> PyResult> { + let base = self.0.clone(); + let py_fut = future_into_py(py, async move { + base.watch_for_errors().await.map_err(|e| { + error!("Controller watch_for_errors detected failure: {:?}", e); + exceptions::PyRuntimeError::new_err(format!( + "Controller watch ended with failure: {}", + e + )) + }) + }); + py_fut + } } #[pymodule] diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index f6b46a979..4085d0d9d 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -373,10 +373,9 @@ def exc_handler(loop, context): fut = asyncio.run_coroutine_threadsafe(coro, self._submit_loop) return fut - # will be implemented in the future, should await for errors coming from the rust layer async def watch_for_errors(self): - e = Event() - await e.wait() + """ This pattern works better with utils.run_coros """ + await super().watch_for_errors() async def stop(self): """ @@ -385,7 +384,7 @@ async def stop(self): if self._submit_loop is not None: self._submit_loop.stop() if self._submit_thread is not None: - self._submit_thread.join() + self._submit_thread.join(0.01) self._submit_loop = None self._submit_thread = None logger.info("RemoteController stopped.") @@ -450,7 +449,7 @@ async def get_action_outputs( ), ) prev_action = await self.get_action( - sub_action_id_pb, + sub_action_id_pb.SerializeToString(), current_action_id.name, ) From 6e81b2953e019b96a55fd7ecd24c0711ea6d4737 Mon Sep 17 00:00:00 2001 From: Nary Yeh <60069744+machichima@users.noreply.github.com> Date: Tue, 16 Dec 2025 03:29:01 -0800 Subject: [PATCH 15/22] Improve Rust controller devex (#431) - Add `make dev-rs-dist` to run `make build-wheels`, `make dist`, build image, and install `flyte_controller_base` locally - Stop `_submit_loop` in rust controller correctly when running sync task - Remove color code in remote logging (see following image) image --------- Signed-off-by: machichima --- Makefile | 11 +++++++++++ rs_controller/Makefile | 2 +- rs_controller/src/lib.rs | 6 ++++++ src/flyte/_internal/controllers/__init__.py | 3 ++- .../_internal/controllers/remote/_r_controller.py | 4 ++-- 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 14e15c903..e66f17701 100644 --- a/Makefile +++ b/Makefile @@ -1,3 +1,8 @@ +# Default registry for image builds +REGISTRY ?= ghcr.io/flyteorg +# Default name for connector image +CONNECTOR_IMAGE_NAME ?= flyte-connector + # Default target: show all available targets .PHONY: help help: @@ -78,6 +83,12 @@ unit_test_plugins: fi \ done +.PHONY: dev-rs-dist +dev-rs-dist: + cd rs_controller && $(MAKE) build-wheels + $(MAKE) dist + uv run python maint_tools/build_default_image.py --registry $(REGISTRY) --name $(CONNECTOR_IMAGE_NAME) + uv pip install --find-links ./rs_controller/dist --no-index --force-reinstall --no-deps flyte_controller_base .PHONY: cli-docs-gen cli-docs-gen: ## Generate CLI documentation diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 5bdec2a93..1c8106a0d 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -48,7 +48,7 @@ build-wheels-amd64: dist-dirs clean_dist: rm -rf $(DIST_DIRS)/*whl -build-wheels: clean_dist build-wheels-arm64 build-wheels-amd64 +build-wheels: build-wheels-arm64 build-wheels-amd64 build-wheel-local # This is for Mac users, since the other targets won't build macos wheels (only local arch so probably arm64) build-wheel-local: dist-dirs diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index d60f0e02c..bf69c7590 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -152,8 +152,14 @@ impl BaseController { fn flyte_controller_base(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { static INIT: std::sync::Once = std::sync::Once::new(); INIT.call_once(|| { + // Check if running remotely by checking if FLYTE_INTERNAL_EXECUTION_PROJECT is set + let is_remote = std::env::var("FLYTE_INTERNAL_EXECUTION_PROJECT").is_ok(); + let is_rich_logging_disabled = std::env::var("DISABLE_RICH_LOGGING").is_ok(); + let disable_ansi = is_remote || is_rich_logging_disabled; + let subscriber = FmtSubscriber::builder() .with_max_level(tracing::Level::DEBUG) + .with_ansi(!disable_ansi) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("Failed to set global tracing subscriber"); diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 398a68cdd..52e9821d1 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -135,7 +135,8 @@ def create_controller( # hybrid case, despite the case statement above, meant for local runs not inside docker from flyte._internal.controllers.remote._r_controller import RemoteController - # controller = RemoteController(endpoint="http://localhost:8090", workers=10, max_system_retries=5) + # for devbox + # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) # noqa: E501 controller = RemoteController(workers=10, max_system_retries=5) case _: raise ValueError(f"{ct} is not a valid controller type.") diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 4085d0d9d..d0270195b 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -382,9 +382,9 @@ async def stop(self): Stop the controller. Incomplete, needs to gracefully shut down the rust controller as well. """ if self._submit_loop is not None: - self._submit_loop.stop() + self._submit_loop.call_soon_threadsafe(self._submit_loop.stop) if self._submit_thread is not None: - self._submit_thread.join(0.01) + self._submit_thread.join() self._submit_loop = None self._submit_thread = None logger.info("RemoteController stopped.") From 7f56f9a8e0098a9d71f0fa864952794f6a1865f4 Mon Sep 17 00:00:00 2001 From: Sergey Vilgelm <523825+SVilgelm@users.noreply.github.com> Date: Wed, 24 Dec 2025 12:28:54 -0800 Subject: [PATCH 16/22] refactor[rs_controller]: Use nightly fmt to reorganize imports (#473) Reorganize module imports across several files to use grouped and brace-formatted use statements for better readability and consistency. Add a Makefile fmt target that enforces use of the nightly toolchain and runs `cargo +nightly fmt --all` to standardize formatting in CI and local development. These changes improve code consistency, maintainability, and make it easier to apply automatic formatting across the repository. Signed-off-by: Sergey Vilgelm --- .github/workflows/lint.yml | 13 ++++++ rs_controller/Makefile | 10 ++++ rs_controller/rustfmt.toml | 6 +++ rs_controller/src/action.rs | 17 +++---- rs_controller/src/auth/client_credentials.rs | 15 ++++-- rs_controller/src/auth/config.rs | 3 +- rs_controller/src/auth/middleware.rs | 10 ++-- rs_controller/src/auth/token_client.rs | 6 ++- rs_controller/src/bin/test_auth.rs | 4 +- rs_controller/src/bin/test_controller.rs | 3 +- rs_controller/src/bin/try_list_tasks.rs | 18 +++---- rs_controller/src/bin/try_watch.rs | 17 +++---- rs_controller/src/core.rs | 49 +++++++++++--------- rs_controller/src/informer.rs | 45 ++++++++++-------- rs_controller/src/lib.rs | 16 +++---- rs_controller/src/proto/mod.rs | 9 ++-- 16 files changed, 149 insertions(+), 92 deletions(-) create mode 100644 rs_controller/rustfmt.toml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d5d358875..859866a12 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,6 +26,19 @@ jobs: - name: Lint run: | make fmt + rs-fmt: + name: rust fmt + runs-on: ubuntu-latest + steps: + - name: Fetch the code + uses: actions/checkout@v4 + - name: Install nightly toolchain + run: | + rustup toolchain install nightly + rustup component add --toolchain nightly-x86_64-unknown-linux-gnu rustfmt + - name: fmt + run: | + make -C rs_controller check-fmt mypy: name: make mypy runs-on: ubuntu-latest diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 1c8106a0d..d617254aa 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -56,3 +56,13 @@ build-wheel-local: dist-dirs sed -i.bak 's/^version = .*/version = "$(VERSION)"/' pyproject.toml && rm pyproject.toml.bak python -m build --wheel --outdir dist +check-nightly-toolchain: + @rustup toolchain list | grep -q "nightly" || (echo "Error: nightly toolchain not found. Run 'rustup toolchain install nightly'" && exit 1) + +.PHONY: fmt +fmt: check-nightly-toolchain + cargo +nightly fmt --all + +.PHONY: check-fmt +check-fmt: check-nightly-toolchain + cargo +nightly fmt --check diff --git a/rs_controller/rustfmt.toml b/rs_controller/rustfmt.toml new file mode 100644 index 000000000..93992e467 --- /dev/null +++ b/rs_controller/rustfmt.toml @@ -0,0 +1,6 @@ +edition = "2021" +indent_style = "Block" +reorder_imports = true +reorder_modules = true +imports_granularity = "Crate" +group_imports = "StdExternalCrate" diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index 4f207d4ee..abd7f102a 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -1,13 +1,14 @@ -use flyteidl2::google::protobuf::Timestamp; +use flyteidl2::{ + flyteidl::{ + common::{ActionIdentifier, ActionPhase, RunIdentifier}, + core::{ExecutionError, TypedInterface}, + task::{OutputReferences, TaskSpec, TraceSpec}, + workflow::{ActionUpdate, TraceAction}, + }, + google::protobuf::Timestamp, +}; use prost::Message; use pyo3::prelude::*; - -use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; -use flyteidl2::flyteidl::workflow::{ActionUpdate, TraceAction}; - -use flyteidl2::flyteidl::common::ActionPhase; -use flyteidl2::flyteidl::core::{ExecutionError, TypedInterface}; -use flyteidl2::flyteidl::task::{OutputReferences, TaskSpec, TraceSpec}; use tracing::debug; #[pyclass(eq, eq_int)] diff --git a/rs_controller/src/auth/client_credentials.rs b/rs_controller/src/auth/client_credentials.rs index c0795571e..3b0593166 100644 --- a/rs_controller/src/auth/client_credentials.rs +++ b/rs_controller/src/auth/client_credentials.rs @@ -1,12 +1,17 @@ -use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::{ + sync::Arc, + time::{Duration, SystemTime}, +}; + use tokio::sync::RwLock; use tonic::transport::Channel; use tracing::{debug, info}; -use super::config::{AuthConfig, ClientConfigExt}; -use super::errors::TokenError; -use super::token_client::{self, GrantType, TokenResponse}; +use super::{ + config::{AuthConfig, ClientConfigExt}, + errors::TokenError, + token_client::{self, GrantType, TokenResponse}, +}; use crate::proto::{ AuthMetadataServiceClient, OAuth2MetadataRequest, OAuth2MetadataResponse, PublicClientAuthConfigRequest, PublicClientAuthConfigResponse, diff --git a/rs_controller/src/auth/config.rs b/rs_controller/src/auth/config.rs index 3c9046355..b27ae2a17 100644 --- a/rs_controller/src/auth/config.rs +++ b/rs_controller/src/auth/config.rs @@ -1,6 +1,7 @@ -use crate::auth::errors::AuthConfigError; use base64::{engine, Engine}; +use crate::auth::errors::AuthConfigError; + /// Configuration for authentication #[derive(Debug, Clone)] pub struct AuthConfig { diff --git a/rs_controller/src/auth/middleware.rs b/rs_controller/src/auth/middleware.rs index 8aca41bb3..32b3b3bc1 100644 --- a/rs_controller/src/auth/middleware.rs +++ b/rs_controller/src/auth/middleware.rs @@ -1,7 +1,9 @@ -use std::sync::Arc; -use std::task::{Context, Poll}; -use tonic::body::BoxBody; -use tonic::transport::Channel; +use std::{ + sync::Arc, + task::{Context, Poll}, +}; + +use tonic::{body::BoxBody, transport::Channel}; use tower::{Layer, Service, ServiceExt}; use tracing::{error, warn}; diff --git a/rs_controller/src/auth/token_client.rs b/rs_controller/src/auth/token_client.rs index 83477f3e7..0c9cfb32c 100644 --- a/rs_controller/src/auth/token_client.rs +++ b/rs_controller/src/auth/token_client.rs @@ -1,10 +1,12 @@ -use crate::auth::errors::TokenError; +use std::collections::HashMap; + use base64::{engine::general_purpose, Engine as _}; use reqwest; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use tracing::debug; +use crate::auth::errors::TokenError; + #[derive(Debug, Clone, Copy)] pub enum GrantType { ClientCredentials, diff --git a/rs_controller/src/bin/test_auth.rs b/rs_controller/src/bin/test_auth.rs index 183481e6c..552515a9a 100644 --- a/rs_controller/src/bin/test_auth.rs +++ b/rs_controller/src/bin/test_auth.rs @@ -1,3 +1,5 @@ +use std::{env, sync::Arc}; + /// Standalone authentication test binary /// /// Usage: @@ -6,8 +8,6 @@ /// FLYTE_CLIENT_SECRET=your_secret \ /// cargo run --bin test_auth use flyte_controller_base::auth::{AuthConfig, ClientCredentialsAuthenticator}; -use std::env; -use std::sync::Arc; use tonic::transport::Endpoint; use tracing_subscriber; diff --git a/rs_controller/src/bin/test_controller.rs b/rs_controller/src/bin/test_controller.rs index 6672ffa00..e459ff59e 100644 --- a/rs_controller/src/bin/test_controller.rs +++ b/rs_controller/src/bin/test_controller.rs @@ -1,10 +1,11 @@ +use std::env; + /// Usage: /// _UNION_EAGER_API_KEY=your_api_key cargo run --bin test_controller /// /// Or without auth: /// cargo run --bin test_controller -- http://localhost:8089 use flyte_controller_base::core::CoreBaseController; -use std::env; use tracing_subscriber; fn main() -> Result<(), Box> { diff --git a/rs_controller/src/bin/try_list_tasks.rs b/rs_controller/src/bin/try_list_tasks.rs index b9a5483db..aaf38a70a 100644 --- a/rs_controller/src/bin/try_list_tasks.rs +++ b/rs_controller/src/bin/try_list_tasks.rs @@ -3,16 +3,18 @@ /// Usage: /// _UNION_EAGER_API_KEY=your_api_key cargo run --bin try_list_tasks use std::sync::Arc; -use tower::ServiceBuilder; -use tracing::warn; - -use flyte_controller_base::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; -use flyte_controller_base::error::ControllerError; -use flyteidl2::flyteidl::common::{ListRequest, ProjectIdentifier}; -use flyteidl2::flyteidl::task::task_service_client::TaskServiceClient; -use flyteidl2::flyteidl::task::{list_tasks_request, ListTasksRequest}; +use flyte_controller_base::{ + auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}, + error::ControllerError, +}; +use flyteidl2::flyteidl::{ + common::{ListRequest, ProjectIdentifier}, + task::{list_tasks_request, task_service_client::TaskServiceClient, ListTasksRequest}, +}; use tonic::Code; +use tower::ServiceBuilder; +use tracing::warn; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/rs_controller/src/bin/try_watch.rs b/rs_controller/src/bin/try_watch.rs index 9d6a282ce..7fed9ae88 100644 --- a/rs_controller/src/bin/try_watch.rs +++ b/rs_controller/src/bin/try_watch.rs @@ -4,18 +4,19 @@ /// _UNION_EAGER_API_KEY=your_api_key cargo run --bin try_watch use std::sync::Arc; use std::time::Duration; + +use flyte_controller_base::{ + auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}, + error::ControllerError, +}; +use flyteidl2::flyteidl::{ + common::{ActionIdentifier, RunIdentifier}, + workflow::{state_service_client::StateServiceClient, watch_request::Filter, WatchRequest}, +}; use tokio::time::sleep; use tower::ServiceBuilder; use tracing::{error, info, warn}; -use flyte_controller_base::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; -use flyte_controller_base::error::ControllerError; - -use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; -use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; -use flyteidl2::flyteidl::workflow::watch_request::Filter; -use flyteidl2::flyteidl::workflow::WatchRequest; - #[tokio::main] async fn main() -> Result<(), Box> { tracing_subscriber::fmt() diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 31b534aa3..5bc22305c 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -1,32 +1,39 @@ //! Core controller implementation - Pure Rust, no PyO3 dependencies //! This module can be used by both Python bindings and standalone Rust binaries -use std::sync::Arc; -use std::time::Duration; - +use std::{sync::Arc, time::Duration}; + +use flyteidl2::{ + flyteidl::{ + common::{ActionIdentifier, RunIdentifier}, + task::TaskIdentifier, + workflow::{ + enqueue_action_request, queue_service_client::QueueServiceClient, + state_service_client::StateServiceClient, EnqueueActionRequest, EnqueueActionResponse, + TaskAction, WatchRequest, WatchResponse, + }, + }, + google, +}; +use google::protobuf::StringValue; use pyo3_async_runtimes::tokio::get_runtime; -use tokio::sync::mpsc; -use tokio::sync::oneshot; -use tokio::time::sleep; -use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; -use tonic::Status; +use tokio::{ + sync::{mpsc, oneshot}, + time::sleep, +}; +use tonic::{ + transport::{Certificate, ClientTlsConfig, Endpoint}, + Status, +}; use tower::ServiceBuilder; use tracing::{debug, error, info, warn}; -use crate::action::Action; -use crate::auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}; -use crate::error::{ControllerError, InformerError}; -use crate::informer::{Informer, InformerCache}; -use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; -use flyteidl2::flyteidl::task::TaskIdentifier; -use flyteidl2::flyteidl::workflow::enqueue_action_request; -use flyteidl2::flyteidl::workflow::queue_service_client::QueueServiceClient; -use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; -use flyteidl2::flyteidl::workflow::{ - EnqueueActionRequest, EnqueueActionResponse, TaskAction, WatchRequest, WatchResponse, +use crate::{ + action::Action, + auth::{AuthConfig, AuthLayer, ClientCredentialsAuthenticator}, + error::{ControllerError, InformerError}, + informer::{Informer, InformerCache}, }; -use flyteidl2::google; -use google::protobuf::StringValue; // Fetches Amazon root CA certificate from Amazon Trust Services pub async fn fetch_amazon_root_ca() -> Result { diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index f7f915ad9..d6d075c61 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -1,8 +1,30 @@ -use crate::action::Action; -use crate::core::StateClient; -use crate::error::{ControllerError, InformerError}; -use tokio::time; +use std::{ + collections::HashMap, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; + +use flyteidl2::flyteidl::{ + common::{ActionIdentifier, RunIdentifier}, + workflow::{watch_request, watch_response::Message, WatchRequest, WatchResponse}, +}; +use tokio::{ + select, + sync::{mpsc, oneshot, Notify, RwLock}, + time, +}; use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info, warn}; + +use crate::{ + action::Action, + core::StateClient, + error::{ControllerError, InformerError}, +}; + /// Determine if an InformerError is retryable fn is_retryable_error(err: &InformerError) -> bool { match err { @@ -18,21 +40,6 @@ fn is_retryable_error(err: &InformerError) -> bool { } } -use flyteidl2::flyteidl::common::ActionIdentifier; -use flyteidl2::flyteidl::common::RunIdentifier; -use flyteidl2::flyteidl::workflow::{ - watch_request, watch_response::Message, WatchRequest, WatchResponse, -}; - -use std::collections::HashMap; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::time::Duration; -use tokio::select; -use tokio::sync::RwLock; -use tokio::sync::{mpsc, oneshot, Notify}; -use tracing::{debug, error, info, warn}; - #[derive(Clone, Debug)] pub struct Informer { client: StateClient, diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index bf69c7590..1aaf0cda1 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -11,18 +11,18 @@ pub mod proto; // Python bindings - thin wrappers around core types use std::sync::Arc; -use pyo3::exceptions; -use pyo3::prelude::*; -use pyo3::types::PyAny; +use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; +use prost::Message; +use pyo3::{exceptions, prelude::*, types::PyAny}; use pyo3_async_runtimes::tokio::future_into_py; use tracing::{error, info, warn}; use tracing_subscriber::FmtSubscriber; -use crate::action::{Action, ActionType}; -use crate::core::CoreBaseController; -use crate::error::ControllerError; -use flyteidl2::flyteidl::common::{ActionIdentifier, RunIdentifier}; -use prost::Message; +use crate::{ + action::{Action, ActionType}, + core::CoreBaseController, + error::ControllerError, +}; // Python error conversions impl From for PyErr { diff --git a/rs_controller/src/proto/mod.rs b/rs_controller/src/proto/mod.rs index 6419d5808..e40b8bb28 100644 --- a/rs_controller/src/proto/mod.rs +++ b/rs_controller/src/proto/mod.rs @@ -5,8 +5,7 @@ pub mod service; // Re-export the auth-related types and services for convenience -pub use service::auth_metadata_service_client::AuthMetadataServiceClient; -pub use service::OAuth2MetadataRequest; -pub use service::OAuth2MetadataResponse; -pub use service::PublicClientAuthConfigRequest; -pub use service::PublicClientAuthConfigResponse; +pub use service::{ + auth_metadata_service_client::AuthMetadataServiceClient, OAuth2MetadataRequest, + OAuth2MetadataResponse, PublicClientAuthConfigRequest, PublicClientAuthConfigResponse, +}; From e76ac3440add10c3d93acd26af5501c8e1ec6aa9 Mon Sep 17 00:00:00 2001 From: Sergey Vilgelm <523825+SVilgelm@users.noreply.github.com> Date: Fri, 26 Dec 2025 20:37:53 -0800 Subject: [PATCH 17/22] feat[rs_controller]: add Rust lint targets and GitHub Action job (#480) Add Makefile targets for linting and autofix: - lint: run cargo clippy - lint-fix: run cargo clippy --fix Update CI workflow to run the Rust lint job: - add rs-lint job that checks out code, ensures toolchain, and runs make -C rs_controller lint Fix all clippy warnings. Signed-off-by: Sergey Vilgelm --- .github/workflows/lint.yml | 12 ++++++++++++ rs_controller/Makefile | 8 ++++++++ rs_controller/src/action.rs | 2 +- rs_controller/src/bin/test_auth.rs | 1 - rs_controller/src/bin/test_controller.rs | 3 +-- rs_controller/src/bin/try_list_tasks.rs | 2 +- rs_controller/src/core.rs | 16 +++++++--------- rs_controller/src/error.rs | 2 +- rs_controller/src/informer.rs | 6 +++++- 9 files changed, 36 insertions(+), 16 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 859866a12..3ec25179e 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -39,6 +39,18 @@ jobs: - name: fmt run: | make -C rs_controller check-fmt + rs-lint: + name: rust lint + runs-on: ubuntu-latest + steps: + - name: Fetch the code + uses: actions/checkout@v4 + - name: Install toolchain + run: | + rustup toolchain install + - name: lint + run: | + make -C rs_controller lint mypy: name: make mypy runs-on: ubuntu-latest diff --git a/rs_controller/Makefile b/rs_controller/Makefile index d617254aa..6937d85a4 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -66,3 +66,11 @@ fmt: check-nightly-toolchain .PHONY: check-fmt check-fmt: check-nightly-toolchain cargo +nightly fmt --check + +.PHONY: lint +lint: + cargo clippy --all-targets -- -D warnings + +.PHONY: lint-fix +lint-fix: + cargo clippy --all-targets --fix diff --git a/rs_controller/src/action.rs b/rs_controller/src/action.rs index abd7f102a..c5060601f 100644 --- a/rs_controller/src/action.rs +++ b/rs_controller/src/action.rs @@ -199,7 +199,7 @@ impl Action { friendly_name: task_spec .task_template .as_ref() - .and_then(|tt| tt.id.as_ref().and_then(|id| Some(id.name.clone()))), + .and_then(|tt| tt.id.as_ref().map(|id| id.name.clone())), group: group_data, task: Some(task_spec), inputs_uri: Some(inputs_uri), diff --git a/rs_controller/src/bin/test_auth.rs b/rs_controller/src/bin/test_auth.rs index 552515a9a..37a213fe1 100644 --- a/rs_controller/src/bin/test_auth.rs +++ b/rs_controller/src/bin/test_auth.rs @@ -9,7 +9,6 @@ use std::{env, sync::Arc}; /// cargo run --bin test_auth use flyte_controller_base::auth::{AuthConfig, ClientCredentialsAuthenticator}; use tonic::transport::Endpoint; -use tracing_subscriber; #[tokio::main] async fn main() -> Result<(), Box> { diff --git a/rs_controller/src/bin/test_controller.rs b/rs_controller/src/bin/test_controller.rs index e459ff59e..9b4d25da9 100644 --- a/rs_controller/src/bin/test_controller.rs +++ b/rs_controller/src/bin/test_controller.rs @@ -6,7 +6,6 @@ use std::env; /// Or without auth: /// cargo run --bin test_controller -- http://localhost:8089 use flyte_controller_base::core::CoreBaseController; -use tracing_subscriber; fn main() -> Result<(), Box> { tracing_subscriber::fmt() @@ -16,7 +15,7 @@ fn main() -> Result<(), Box> { println!("=== Flyte Core Controller Test ===\n"); // Try to create a controller - let controller = if let Ok(api_key) = env::var("_UNION_EAGER_API_KEY") { + let _controller = if let Ok(api_key) = env::var("_UNION_EAGER_API_KEY") { println!("Using auth from _UNION_EAGER_API_KEY"); // Set the env var back since CoreBaseController::new_with_auth reads it env::set_var("_UNION_EAGER_API_KEY", api_key); diff --git a/rs_controller/src/bin/try_list_tasks.rs b/rs_controller/src/bin/try_list_tasks.rs index aaf38a70a..977cea97b 100644 --- a/rs_controller/src/bin/try_list_tasks.rs +++ b/rs_controller/src/bin/try_list_tasks.rs @@ -75,7 +75,7 @@ async fn main() -> Result<(), Box> { } Err(status) => { eprintln!("Error calling gRPC: {}", status); - break Err(format!("gRPC error: {}", status).into()); + break Err(format!("gRPC error: {}", status)); } } }; diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 5bc22305c..92ad2322d 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -446,7 +446,7 @@ impl CoreBaseController { .get(&action.get_run_identifier(), &action.parent_action_name) .await { - let _ = informer + informer .fire_completion_event(&action.action_id.name) .await?; } else { @@ -496,14 +496,12 @@ impl CoreBaseController { .as_ref() .and_then(|task| task.task_template.as_ref()) .and_then(|task_template| task_template.id.as_ref()) - .and_then(|core_task_id| { - Some(TaskIdentifier { - version: core_task_id.version.clone(), - org: core_task_id.org.clone(), - project: core_task_id.project.clone(), - domain: core_task_id.domain.clone(), - name: core_task_id.name.clone(), - }) + .map(|core_task_id| TaskIdentifier { + version: core_task_id.version.clone(), + org: core_task_id.org.clone(), + project: core_task_id.project.clone(), + domain: core_task_id.domain.clone(), + name: core_task_id.name.clone(), }) .ok_or(ControllerError::RuntimeError(format!( "TaskIdentifier missing from Action {:?}", diff --git a/rs_controller/src/error.rs b/rs_controller/src/error.rs index 90ca3b418..beccfddba 100644 --- a/rs_controller/src/error.rs +++ b/rs_controller/src/error.rs @@ -11,7 +11,7 @@ pub enum ControllerError { #[error("System error: {0}")] SystemError(String), #[error("gRPC error: {0}")] - GrpcError(#[from] tonic::Status), + GrpcError(#[from] Box), #[error("Task error: {0}")] TaskError(String), #[error("Informer error: {0}")] diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index d6d075c61..c1c429f1d 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -487,7 +487,7 @@ impl InformerCache { "Acquiring write lock to save watch handle for: {}", informer_name ); - let _ = informer.watch_handle.write().await.insert(_watch_handle); + *informer.watch_handle.write().await = Some(_watch_handle); info!("Watch handle saved for: {}", informer_name); // Optimistically wait for ready (sentinel) with timeout @@ -577,6 +577,10 @@ impl InformerCache { #[cfg(test)] mod tests { + use flyteidl2::flyteidl::workflow::state_service_client::StateServiceClient; + use tonic::transport::Endpoint; + use tracing_subscriber::fmt; + use super::*; async fn informer_main() { From 05f5ff4a9d6e6526ecaaaf92fdd8cbf09ff7fdee Mon Sep 17 00:00:00 2001 From: Nary Yeh <60069744+machichima@users.noreply.github.com> Date: Wed, 31 Dec 2025 13:17:36 -0800 Subject: [PATCH 18/22] [Rust Controller] Add worker pool (#477) Add worker pool for rust controller Signed-off-by: machichima Signed-off-by: Yee Hing Tong Co-authored-by: Yee Hing Tong --- rs_controller/src/bin/test_controller.rs | 9 ++-- rs_controller/src/core.rs | 67 +++++++++++++++++------- rs_controller/src/lib.rs | 13 ++--- 3 files changed, 60 insertions(+), 29 deletions(-) diff --git a/rs_controller/src/bin/test_controller.rs b/rs_controller/src/bin/test_controller.rs index 9b4d25da9..048b6d830 100644 --- a/rs_controller/src/bin/test_controller.rs +++ b/rs_controller/src/bin/test_controller.rs @@ -15,17 +15,18 @@ fn main() -> Result<(), Box> { println!("=== Flyte Core Controller Test ===\n"); // Try to create a controller + let workers = 20; // Default number of workers let _controller = if let Ok(api_key) = env::var("_UNION_EAGER_API_KEY") { - println!("Using auth from _UNION_EAGER_API_KEY"); + println!("Using auth from _UNION_EAGER_API_KEY with {} workers", workers); // Set the env var back since CoreBaseController::new_with_auth reads it env::set_var("_UNION_EAGER_API_KEY", api_key); - CoreBaseController::new_with_auth()? + CoreBaseController::new_with_auth(workers)? } else { let endpoint = env::args() .nth(1) .unwrap_or_else(|| "http://localhost:8090".to_string()); - println!("Using endpoint: {}", endpoint); - CoreBaseController::new_without_auth(endpoint)? + println!("Using endpoint: {} with {} workers", endpoint, workers); + CoreBaseController::new_without_auth(endpoint, workers)? }; println!("✓ Successfully created CoreBaseController!"); diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 92ad2322d..1691a6723 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -121,11 +121,12 @@ pub struct CoreBaseController { shared_queue_rx: Arc>>, failure_rx: Arc>>>, bg_worker_handle: Arc>>>, + workers: usize, } impl CoreBaseController { - pub fn new_with_auth() -> Result, ControllerError> { - info!("Creating CoreBaseController from _UNION_EAGER_API_KEY env var (with auth)"); + pub fn new_with_auth(workers: usize) -> Result, ControllerError> { + info!("Creating CoreBaseController from _UNION_EAGER_API_KEY env var (with auth) with {} workers", workers); // Read from env var and use auth let api_key = std::env::var("_UNION_EAGER_API_KEY").map_err(|_| { ControllerError::SystemError( @@ -188,13 +189,14 @@ impl CoreBaseController { shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), failure_rx: Arc::new(std::sync::Mutex::new(Some(failure_rx))), bg_worker_handle: Arc::new(std::sync::Mutex::new(None)), + workers, }; let real_base_controller = Arc::new(real_base_controller); - // Start the background worker + // Start the background worker pool let controller_clone = real_base_controller.clone(); let handle = rt.spawn(async move { - controller_clone.bg_worker().await; + controller_clone.bg_worker_pool().await; }); // Store the handle @@ -203,7 +205,7 @@ impl CoreBaseController { Ok(real_base_controller) } - pub fn new_without_auth(endpoint: String) -> Result, ControllerError> { + pub fn new_without_auth(endpoint: String, workers: usize) -> Result, ControllerError> { let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); // shared queue let (shared_tx, shared_queue_rx) = mpsc::channel::(64); @@ -257,13 +259,14 @@ impl CoreBaseController { shared_queue_rx: Arc::new(tokio::sync::Mutex::new(shared_queue_rx)), failure_rx: Arc::new(std::sync::Mutex::new(Some(failure_rx))), bg_worker_handle: Arc::new(std::sync::Mutex::new(None)), + workers, }; let real_base_controller = Arc::new(real_base_controller); - // Start the background worker + // Start the background worker pool let controller_clone = real_base_controller.clone(); let handle = rt.spawn(async move { - controller_clone.bg_worker().await; + controller_clone.bg_worker_pool().await; }); // Store the handle @@ -272,12 +275,38 @@ impl CoreBaseController { Ok(real_base_controller) } - async fn bg_worker(&self) { + async fn bg_worker_pool(self: Arc) { + debug!( + "Starting controller worker pool with {} workers on thread {:?}", + self.workers, + std::thread::current().name() + ); + + let mut handles = Vec::new(); + for i in 0..self.workers { + let controller = Arc::clone(&self); + let worker_id = format!("worker-{}", i); + let handle = tokio::spawn(async move { + controller.bg_worker(worker_id).await; + }); + handles.push(handle); + } + + // Wait for all workers to complete + for handle in handles { + if let Err(e) = handle.await { + error!("Worker task failed: {:?}", e); + } + } + } + + async fn bg_worker(&self, worker_id: String) { const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(100); const MAX_RETRIES: u32 = 5; - debug!( - "Launching core controller background task on thread {:?}", + info!( + "Worker {} started on thread {:?}", + worker_id, std::thread::current().name() ); loop { @@ -291,8 +320,8 @@ impl CoreBaseController { .as_ref() .map_or(String::from(""), |i| i.name.clone()); debug!( - "Controller worker processing action {}::{}", - run_name, action.action_id.name + "[{}] Controller worker processing action {}::{}", + worker_id, run_name, action.action_id.name ); // Drop the mutex guard before processing @@ -309,12 +338,12 @@ impl CoreBaseController { if action.retries > MAX_RETRIES { error!( - "Controller failed processing {}::{}, system retries {} crossed threshold {}", - run_name, action.action_id.name, action.retries, MAX_RETRIES + "[{}] Controller failed processing {}::{}, system retries {} crossed threshold {}", + worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES ); action.client_err = Some(format!( - "Controller failed {}::{}, system retries {} crossed threshold {}", - run_name, action.action_id.name, action.retries, MAX_RETRIES + "[{}] Controller failed {}::{}, system retries {} crossed threshold {}", + worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES )); // Fire completion event for failed action @@ -355,11 +384,11 @@ impl CoreBaseController { } else { // Re-queue the action for retry info!( - "Re-queuing action {}::{} for retry, attempt {}/{}", - run_name, action.action_id.name, action.retries, MAX_RETRIES + "[{}] Re-queuing action {}::{} for retry, attempt {}/{}", + worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES ); if let Err(send_err) = self.shared_queue.send(action).await { - error!("Failed to re-queue action for retry: {}", send_err); + error!("[{}] Failed to re-queue action for retry: {}", worker_id, send_err); } } } diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index 1aaf0cda1..8ea2c7edc 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -44,14 +44,15 @@ struct BaseController(Arc); #[pymethods] impl BaseController { #[new] - #[pyo3(signature = (*, endpoint=None))] - fn new(endpoint: Option) -> PyResult { + #[pyo3(signature = (*, endpoint=None, workers=None))] + fn new(endpoint: Option, workers: Option) -> PyResult { + let workers = workers.unwrap_or(20); let core_base = if let Some(ep) = endpoint { - info!("Creating controller wrapper with endpoint {:?}", ep); - CoreBaseController::new_without_auth(ep)? + info!("Creating controller wrapper with endpoint {:?} and {} workers", ep, workers); + CoreBaseController::new_without_auth(ep, workers)? } else { - info!("Creating controller wrapper from _UNION_EAGER_API_KEY env var"); - CoreBaseController::new_with_auth()? + info!("Creating controller wrapper from _UNION_EAGER_API_KEY env var with {} workers", workers); + CoreBaseController::new_with_auth(workers)? }; Ok(BaseController(core_base)) } From 475ae6e9b8b2b27c00a07cd2098bfa8a7ef552b3 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 1 Jan 2026 07:14:00 +0800 Subject: [PATCH 19/22] ctrl-rs select with env var (#429) * Add an env var `_F_USE_RUST_CONTROLLER` that selects the rust based controller library. This will get recursively added also. Will not expose through a CLI switch yet... too experimental. Signed-off-by: Yee Hing Tong --- examples/advanced/cancel_tasks.py | 6 ++---- examples/basics/devbox_one.py | 2 +- examples/basics/hello.py | 6 +++--- examples/stress/crash_recovery_trace.py | 11 +++-------- examples/stress/large_dir_io.py | 7 ++----- examples/stress/large_file_io.py | 7 ++----- examples/stress/long_recovery.py | 7 ++----- rs_controller/Makefile | 2 +- rs_controller/pyproject.toml | 2 +- rs_controller/src/bin/test_controller.rs | 5 ++++- rs_controller/src/core.rs | 16 +++++++++++++--- rs_controller/src/lib.rs | 10 ++++++++-- src/flyte/_bin/runtime.py | 7 ++++++- src/flyte/_internal/controllers/__init__.py | 12 ++++++++---- .../controllers/remote/_r_controller.py | 4 +--- src/flyte/_run.py | 5 +++++ 16 files changed, 62 insertions(+), 47 deletions(-) diff --git a/examples/advanced/cancel_tasks.py b/examples/advanced/cancel_tasks.py index 64ca870f3..14a7b7eb6 100644 --- a/examples/advanced/cancel_tasks.py +++ b/examples/advanced/cancel_tasks.py @@ -1,13 +1,11 @@ import asyncio - -import flyte.errors - from pathlib import Path import flyte +import flyte.errors from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/examples/basics/devbox_one.py b/examples/basics/devbox_one.py index 22cd573a4..27cefa552 100644 --- a/examples/basics/devbox_one.py +++ b/examples/basics/devbox_one.py @@ -6,7 +6,7 @@ import flyte from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/examples/basics/hello.py b/examples/basics/hello.py index 226590d45..1441963a1 100644 --- a/examples/basics/hello.py +++ b/examples/basics/hello.py @@ -1,9 +1,9 @@ -import flyte +from pathlib import Path +import flyte from flyte._image import PythonWheels -from pathlib import Path -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/examples/stress/crash_recovery_trace.py b/examples/stress/crash_recovery_trace.py index 6ad140177..7c9006aaa 100644 --- a/examples/stress/crash_recovery_trace.py +++ b/examples/stress/crash_recovery_trace.py @@ -1,23 +1,18 @@ import os - -import flyte -import flyte.errors - from pathlib import Path import flyte +import flyte.errors from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) env = flyte.TaskEnvironment( - name="crash_recovery_trace", - resources=flyte.Resources(memory="250Mi", cpu=1), - image=rs_controller_image + name="crash_recovery_trace", resources=flyte.Resources(memory="250Mi", cpu=1), image=rs_controller_image ) diff --git a/examples/stress/large_dir_io.py b/examples/stress/large_dir_io.py index c692f7e05..67443e004 100644 --- a/examples/stress/large_dir_io.py +++ b/examples/stress/large_dir_io.py @@ -4,18 +4,15 @@ import signal import tempfile import time +from pathlib import Path from typing import Tuple import flyte import flyte.io import flyte.storage - -from pathlib import Path - -import flyte from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/examples/stress/large_file_io.py b/examples/stress/large_file_io.py index 2a6f15537..428873eda 100644 --- a/examples/stress/large_file_io.py +++ b/examples/stress/large_file_io.py @@ -3,18 +3,15 @@ import signal import tempfile import time +from pathlib import Path from typing import Tuple import flyte import flyte.io import flyte.storage - -from pathlib import Path - -import flyte from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/examples/stress/long_recovery.py b/examples/stress/long_recovery.py index ae9f731b7..cc93feb6d 100644 --- a/examples/stress/long_recovery.py +++ b/examples/stress/long_recovery.py @@ -1,16 +1,13 @@ import asyncio import os import typing - -import flyte -import flyte.errors - from pathlib import Path import flyte +import flyte.errors from flyte._image import PythonWheels -controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/flyte-sdk/rs_controller/dist") +controller_dist_folder = Path("/Users/ytong/go/src/github.com/flyteorg/sdk-rust/rs_controller/dist") wheel_layer = PythonWheels(wheel_dir=controller_dist_folder, package_name="flyte_controller_base") base = flyte.Image.from_debian_base() rs_controller_image = base.clone(addl_layer=wheel_layer) diff --git a/rs_controller/Makefile b/rs_controller/Makefile index 6937d85a4..2ff7137d8 100644 --- a/rs_controller/Makefile +++ b/rs_controller/Makefile @@ -48,7 +48,7 @@ build-wheels-amd64: dist-dirs clean_dist: rm -rf $(DIST_DIRS)/*whl -build-wheels: build-wheels-arm64 build-wheels-amd64 build-wheel-local +build-wheels: clean_dist build-wheels-arm64 build-wheels-amd64 build-wheel-local # This is for Mac users, since the other targets won't build macos wheels (only local arch so probably arm64) build-wheel-local: dist-dirs diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 73f549785..3df91c181 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b36.dev23+gdb40f621.dirty" +version = "2.0.0b36.dev25+gcedbfba0.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] diff --git a/rs_controller/src/bin/test_controller.rs b/rs_controller/src/bin/test_controller.rs index 048b6d830..65b63e459 100644 --- a/rs_controller/src/bin/test_controller.rs +++ b/rs_controller/src/bin/test_controller.rs @@ -17,7 +17,10 @@ fn main() -> Result<(), Box> { // Try to create a controller let workers = 20; // Default number of workers let _controller = if let Ok(api_key) = env::var("_UNION_EAGER_API_KEY") { - println!("Using auth from _UNION_EAGER_API_KEY with {} workers", workers); + println!( + "Using auth from _UNION_EAGER_API_KEY with {} workers", + workers + ); // Set the env var back since CoreBaseController::new_with_auth reads it env::set_var("_UNION_EAGER_API_KEY", api_key); CoreBaseController::new_with_auth(workers)? diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 1691a6723..f20739dff 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -205,7 +205,10 @@ impl CoreBaseController { Ok(real_base_controller) } - pub fn new_without_auth(endpoint: String, workers: usize) -> Result, ControllerError> { + pub fn new_without_auth( + endpoint: String, + workers: usize, + ) -> Result, ControllerError> { let endpoint_static: &'static str = Box::leak(Box::new(endpoint.clone().into_boxed_str())); // shared queue let (shared_tx, shared_queue_rx) = mpsc::channel::(64); @@ -385,10 +388,17 @@ impl CoreBaseController { // Re-queue the action for retry info!( "[{}] Re-queuing action {}::{} for retry, attempt {}/{}", - worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES + worker_id, + run_name, + action.action_id.name, + action.retries, + MAX_RETRIES ); if let Err(send_err) = self.shared_queue.send(action).await { - error!("[{}] Failed to re-queue action for retry: {}", worker_id, send_err); + error!( + "[{}] Failed to re-queue action for retry: {}", + worker_id, send_err + ); } } } diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index 8ea2c7edc..e700bb14e 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -48,10 +48,16 @@ impl BaseController { fn new(endpoint: Option, workers: Option) -> PyResult { let workers = workers.unwrap_or(20); let core_base = if let Some(ep) = endpoint { - info!("Creating controller wrapper with endpoint {:?} and {} workers", ep, workers); + info!( + "Creating controller wrapper with endpoint {:?} and {} workers", + ep, workers + ); CoreBaseController::new_without_auth(ep, workers)? } else { - info!("Creating controller wrapper from _UNION_EAGER_API_KEY env var with {} workers", workers); + info!( + "Creating controller wrapper from _UNION_EAGER_API_KEY env var with {} workers", + workers + ); CoreBaseController::new_with_auth(workers)? }; Ok(BaseController(core_base)) diff --git a/src/flyte/_bin/runtime.py b/src/flyte/_bin/runtime.py index 68372ff0e..416e3458f 100644 --- a/src/flyte/_bin/runtime.py +++ b/src/flyte/_bin/runtime.py @@ -33,6 +33,7 @@ _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY" _F_PATH_REWRITE = "_F_PATH_REWRITE" +_F_USE_RUST_CONTROLLER = "_F_USE_RUST_CONTROLLER" @click.group() @@ -132,7 +133,11 @@ def main( if tgz or pkl: bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version) # Controller is created with the same kwargs as init, so that it can be used to run tasks - controller = create_controller(ct="rust", **controller_kwargs) + # Use Rust controller if env var is set, otherwise default to Python controller + use_rust = os.getenv(_F_USE_RUST_CONTROLLER, "").lower() in ("1", "true", "yes") + controller_type = "rust" if use_rust else "remote" + print(f"In runtime: controller kwargs are: {controller_kwargs}") + controller = create_controller(ct=controller_type, **controller_kwargs) # type: ignore[arg-type] ic = ImageCache.from_transport(image_cache) if image_cache else None diff --git a/src/flyte/_internal/controllers/__init__.py b/src/flyte/_internal/controllers/__init__.py index 52e9821d1..023af21a6 100644 --- a/src/flyte/_internal/controllers/__init__.py +++ b/src/flyte/_internal/controllers/__init__.py @@ -132,12 +132,16 @@ def create_controller( controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) # controller = RemoteController(workers=10, max_system_retries=5) case "rust": - # hybrid case, despite the case statement above, meant for local runs not inside docker + # Rust controller - works for both local (endpoint-based) and remote (API key from env) from flyte._internal.controllers.remote._r_controller import RemoteController - # for devbox - # controller = RemoteController(endpoint="http://host.docker.internal:8090", workers=10, max_system_retries=5) # noqa: E501 - controller = RemoteController(workers=10, max_system_retries=5) + # Extract endpoint if provided, otherwise Rust controller will use API key from env var + endpoint = kwargs.get("endpoint") + # Rust requires scheme prefix (http:// or https://) + if endpoint and not endpoint.startswith(("http://", "https://")): + # Default to http:// for local endpoints + endpoint = f"http://{endpoint}" + controller = RemoteController(endpoint=endpoint, workers=10, max_system_retries=5) case _: raise ValueError(f"{ct} is not a valid controller type.") diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index d0270195b..25219f50e 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -4,7 +4,6 @@ import concurrent.futures import os import threading -from asyncio import Event from collections import defaultdict from collections.abc import Callable from pathlib import Path @@ -12,7 +11,6 @@ from flyte_controller_base import Action, BaseController from flyteidl2.common import identifier_pb2, phase_pb2 -from flyteidl2.workflow import run_definition_pb2 import flyte import flyte.errors @@ -374,7 +372,7 @@ def exc_handler(loop, context): return fut async def watch_for_errors(self): - """ This pattern works better with utils.run_coros """ + """This pattern works better with utils.run_coros""" await super().watch_for_errors() async def stop(self): diff --git a/src/flyte/_run.py b/src/flyte/_run.py index cdb182d39..7933f87d6 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import os import pathlib import sys import uuid @@ -228,6 +229,10 @@ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.ar env["LOG_LEVEL"] = str(logger.getEffectiveLevel()) env["LOG_FORMAT"] = self._log_format + use_rust_controller_env_var = os.getenv("_F_USE_RUST_CONTROLLER") + if use_rust_controller_env_var: + env["_F_USE_RUST_CONTROLLER"] = use_rust_controller_env_var + # These paths will be appended to sys.path at runtime. if cfg.sync_local_sys_paths: env[FLYTE_SYS_PATH] = ":".join( From 332fe226552b876fa318c12ee24ad70447e69dd5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 1 Jan 2026 07:54:29 +0800 Subject: [PATCH 20/22] wip - pr into ctrl-rs (#487) Revert group data change Signed-off-by: Yee Hing Tong --- rs_controller/pyproject.toml | 2 +- src/flyte/_context.py | 4 ++-- src/flyte/_group.py | 3 ++- src/flyte/_internal/controllers/remote/_action.py | 8 +++++--- src/flyte/_internal/controllers/remote/_r_controller.py | 6 +++--- src/flyte/_internal/runtime/convert.py | 2 +- src/flyte/models.py | 8 +++++++- 7 files changed, 21 insertions(+), 12 deletions(-) diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 3df91c181..3d44612a1 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b36.dev25+gcedbfba0.dirty" +version = "2.0.0b42.dev21+g8b68f350.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] diff --git a/src/flyte/_context.py b/src/flyte/_context.py index a2a4116c5..20689560f 100644 --- a/src/flyte/_context.py +++ b/src/flyte/_context.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Optional, ParamSpec, TypeVar from flyte._logging import logger -from flyte.models import RawDataPath, TaskContext +from flyte.models import GroupData, RawDataPath, TaskContext if TYPE_CHECKING: from flyte.report import Report @@ -26,7 +26,7 @@ class ContextData: will be None. """ - group_data: Optional[str] = None + group_data: Optional[GroupData] = None task_context: Optional[TaskContext] = None raw_data_path: Optional[RawDataPath] = None diff --git a/src/flyte/_group.py b/src/flyte/_group.py index 4fc2bae49..8fa8d4260 100644 --- a/src/flyte/_group.py +++ b/src/flyte/_group.py @@ -1,6 +1,7 @@ from contextlib import contextmanager from ._context import internal_ctx +from .models import GroupData @contextmanager @@ -25,7 +26,7 @@ async def my_task(): yield return tctx = ctx.data.task_context - new_tctx = tctx.replace(group_data=name) + new_tctx = tctx.replace(group_data=GroupData(name)) with ctx.replace_task_context(new_tctx): yield # Exit the context and restore the previous context diff --git a/src/flyte/_internal/controllers/remote/_action.py b/src/flyte/_internal/controllers/remote/_action.py index fd55ff574..7002c92c4 100644 --- a/src/flyte/_internal/controllers/remote/_action.py +++ b/src/flyte/_internal/controllers/remote/_action.py @@ -12,6 +12,8 @@ ) from google.protobuf import timestamp_pb2 +from flyte.models import GroupData + ActionType = Literal["task", "trace"] @@ -27,7 +29,7 @@ class Action: parent_action_name: str type: ActionType = "task" # type of action, task or trace friendly_name: str | None = None - group: str | None = None + group: GroupData | None = None task: task_definition_pb2.TaskSpec | None = None trace: run_definition_pb2.TraceAction | None = None inputs_uri: str | None = None @@ -116,7 +118,7 @@ def from_task( cls, parent_action_name: str, sub_action_id: identifier_pb2.ActionIdentifier, - group_data: str | None, + group_data: GroupData | None, task_spec: task_definition_pb2.TaskSpec, inputs_uri: str, run_output_base: str, @@ -164,7 +166,7 @@ def from_trace( parent_action_name: str, action_id: identifier_pb2.ActionIdentifier, friendly_name: str, - group_data: str | None, + group_data: GroupData | None, inputs_uri: str, outputs_uri: str, start_time: float, # Unix timestamp in seconds with fractional seconds diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 25219f50e..8b3844af5 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -252,7 +252,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg action = Action.from_task( sub_action_id_bytes=sub_action_id_pb.SerializeToString(), parent_action_name=current_action_id.name, - group_data=str(tctx.group_data) if tctx.group_data else None, + group_data=tctx.group_data.name if tctx.group_data else None, task_spec_bytes=task_spec.SerializeToString(), inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, @@ -521,7 +521,7 @@ async def record_trace(self, info: TraceInfo): inputs_uri=info.inputs_path, outputs_uri=outputs_file_path, friendly_name=info.name, - group_data=tctx.group_data, + group_data=tctx.group_data.name if tctx.group_data else None, run_output_base=tctx.run_base_dir, start_time=info.start_time, end_time=info.end_time, @@ -597,7 +597,7 @@ async def _submit_task_ref(self, invoke_seq_num: int, _task: TaskDetails, *args, action = Action.from_task( sub_action_id_bytes=sub_action_id_pb.SerializeToString(), parent_action_name=current_action_id.name, - group_data=tctx.group_data, + group_data=tctx.group_data.name if tctx.group_data else None, task_spec_bytes=_task.pb2.spec.SerializeToString(), inputs_uri=inputs_uri, run_output_base=tctx.run_base_dir, diff --git a/src/flyte/_internal/runtime/convert.py b/src/flyte/_internal/runtime/convert.py index 1496565cc..4a20bf970 100644 --- a/src/flyte/_internal/runtime/convert.py +++ b/src/flyte/_internal/runtime/convert.py @@ -481,7 +481,7 @@ def generate_sub_action_id_and_output_path( sub_action_id = current_action_id.new_sub_action_from( task_hash=task_hash, input_hash=inputs_hash, - group=tctx.group_data if tctx.group_data else None, + group=tctx.group_data.name if tctx.group_data else None, task_call_seq=invoke_seq, ) sub_run_output_path = storage.join(current_output_path, sub_action_id.name) diff --git a/src/flyte/models.py b/src/flyte/models.py index a531b6d06..f4744511e 100644 --- a/src/flyte/models.py +++ b/src/flyte/models.py @@ -180,6 +180,12 @@ def get_random_remote_path(self, file_name: Optional[str] = None) -> str: return remote_path +@rich.repr.auto +@dataclass(frozen=True) +class GroupData: + name: str + + @rich.repr.auto @dataclass(frozen=True, kw_only=True) class TaskContext: @@ -201,7 +207,7 @@ class TaskContext: output_path: str run_base_dir: str report: Report - group_data: str | None = None + group_data: GroupData | None = None checkpoints: Checkpoints | None = None code_bundle: CodeBundle | None = None compiled_image_cache: ImageCache | None = None From e3e3322086b246b5fbab45644bff4bc840a369cc Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 7 Jan 2026 01:52:20 +0800 Subject: [PATCH 21/22] wip - fixes to ctrl-rs (#489) * Save action to informer action cache upon submit if not found in informer's action cache * `get_action` should call informer get or create, not just get, and return None rather than an error if informer doesn't have the action * Update flyteidl2 dep to match main project. --------- Signed-off-by: Yee Hing Tong --- examples/stress/crash_recovery_trace.py | 1 + rs_controller/Cargo.lock | 4 ++-- rs_controller/Cargo.toml | 2 +- rs_controller/pyproject.toml | 2 +- rs_controller/src/core.rs | 24 +++++++++---------- rs_controller/src/informer.rs | 2 +- rs_controller/src/lib.rs | 1 + .../controllers/remote/_r_controller.py | 2 ++ 8 files changed, 20 insertions(+), 18 deletions(-) diff --git a/examples/stress/crash_recovery_trace.py b/examples/stress/crash_recovery_trace.py index 7c9006aaa..dc8f78e61 100644 --- a/examples/stress/crash_recovery_trace.py +++ b/examples/stress/crash_recovery_trace.py @@ -45,6 +45,7 @@ async def main() -> list[int]: attempt_number = get_attempt_number() # Fail at attempts 0, 1, and 2 at for i = 100, 200, 300 respectively, then succeed if i == (attempt_number + 1) * 100 and attempt_number < 3: + print(f"Simulating crash for element {i=} and {attempt_number=}", flush=True) raise flyte.errors.RuntimeSystemError( "simulated", f"Simulated failure on attempt {get_attempt_number()} at iteration {i}" ) diff --git a/rs_controller/Cargo.lock b/rs_controller/Cargo.lock index 90239fcfe..ceba46e6b 100644 --- a/rs_controller/Cargo.lock +++ b/rs_controller/Cargo.lock @@ -318,9 +318,9 @@ dependencies = [ [[package]] name = "flyteidl2" -version = "2.0.0-alpha15" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5abd4c0481acd1132137e34f92c5e4e1aedf86aa49c8f770283349cfc6a618" +checksum = "3fa9ca7ab4b9678e656821c53656ae9aa28115ad3e4579cc160b58b8b592abc2" dependencies = [ "async-trait", "futures", diff --git a/rs_controller/Cargo.toml b/rs_controller/Cargo.toml index 00758a655..134eb1e0c 100644 --- a/rs_controller/Cargo.toml +++ b/rs_controller/Cargo.toml @@ -34,7 +34,7 @@ async-trait = "0.1" thiserror = "1.0" # Uncomment this if you need to use local flyteidl2 #flyteidl2 = { path = "/Users/ytong/go/src/github.com/flyteorg/flyte/gen/rust" } -flyteidl2 = "=2.0.0-alpha15" +flyteidl2 = "=2.0.0" reqwest = { version = "0.12", features = ["json", "rustls-tls"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/rs_controller/pyproject.toml b/rs_controller/pyproject.toml index 3d44612a1..6dd883ffc 100644 --- a/rs_controller/pyproject.toml +++ b/rs_controller/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "maturin" [project] name = "flyte_controller_base" -version = "2.0.0b42.dev21+g8b68f350.dirty" +version = "2.0.0b42.dev24+g976c9b15.dirty" description = "Rust controller for Union" requires-python = ">=3.10" classifiers = ["Programming Language :: Python", "Programming Language :: Rust"] diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index f20739dff..60ac49ac3 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -501,7 +501,7 @@ impl CoreBaseController { &self, action_id: ActionIdentifier, parent_action_name: &str, - ) -> Result { + ) -> Result, ControllerError> { let run = action_id .run .as_ref() @@ -509,19 +509,17 @@ impl CoreBaseController { "Action {:?} doesn't have a run, can't get action", action_id )))?; - if let Some(informer) = self.informer_cache.get(run, parent_action_name).await { - let action_name = action_id.name.clone(); - match informer.get_action(action_name).await { - Some(action) => Ok(action), - None => Err(ControllerError::RuntimeError(format!( - "Action not found getting from action_id: {:?}", - action_id - ))), + let informer = self + .informer_cache + .get_or_create_informer(run, parent_action_name) + .await; + let action_name = action_id.name.clone(); + match informer.get_action(action_name).await { + Some(action) => Ok(Some(action)), + None => { + debug!("Action not found getting from action_id: {:?}", action_id); + Ok(None) } - } else { - Err(ControllerError::BadContext( - "Informer not initialized".to_string(), - )) } } diff --git a/rs_controller/src/informer.rs b/rs_controller/src/informer.rs index c1c429f1d..ec86dd3b2 100644 --- a/rs_controller/src/informer.rs +++ b/rs_controller/src/informer.rs @@ -241,7 +241,7 @@ impl Informer { some_action.merge_from_submit(&action); some_action.clone() } else { - // don't need to write anything. return the original + cache.insert(action_name.clone(), action.clone()); action } }; diff --git a/rs_controller/src/lib.rs b/rs_controller/src/lib.rs index e700bb14e..70bc2a6e2 100644 --- a/rs_controller/src/lib.rs +++ b/rs_controller/src/lib.rs @@ -90,6 +90,7 @@ impl BaseController { py_fut } + // todo: what happens if we change this to async? fn get_action<'py>( &self, py: Python<'py>, diff --git a/src/flyte/_internal/controllers/remote/_r_controller.py b/src/flyte/_internal/controllers/remote/_r_controller.py index 8b3844af5..dda07563d 100644 --- a/src/flyte/_internal/controllers/remote/_r_controller.py +++ b/src/flyte/_internal/controllers/remote/_r_controller.py @@ -446,6 +446,7 @@ async def get_action_outputs( org=current_action_id.org, ), ) + prev_action = await self.get_action( sub_action_id_pb.SerializeToString(), current_action_id.name, @@ -525,6 +526,7 @@ async def record_trace(self, info: TraceInfo): run_output_base=tctx.run_base_dir, start_time=info.start_time, end_time=info.end_time, + report_uri=None, typed_interface_bytes=typed_interface.SerializeToString() if typed_interface else None, ) From 2baaef9ae3ad63177a3ddb4f61b46abbada3b253 Mon Sep 17 00:00:00 2001 From: Nary Yeh <60069744+machichima@users.noreply.github.com> Date: Sun, 8 Feb 2026 17:19:49 -0800 Subject: [PATCH 22/22] [Rust Controller] Handle SlowDownError (retry with backoff) (#578) Handle `SlowDownError` like what we had in python controller --------- Signed-off-by: machichima --- rs_controller/src/core.rs | 194 +++++++++++------- rs_controller/src/error.rs | 3 + .../_internal/controllers/remote/_core.py | 2 +- 3 files changed, 120 insertions(+), 79 deletions(-) diff --git a/rs_controller/src/core.rs b/rs_controller/src/core.rs index 60ac49ac3..bf9310e17 100644 --- a/rs_controller/src/core.rs +++ b/rs_controller/src/core.rs @@ -21,10 +21,7 @@ use tokio::{ sync::{mpsc, oneshot}, time::sleep, }; -use tonic::{ - transport::{Certificate, ClientTlsConfig, Endpoint}, - Status, -}; +use tonic::transport::{Certificate, ClientTlsConfig, Endpoint}; use tower::ServiceBuilder; use tracing::{debug, error, info, warn}; @@ -304,9 +301,6 @@ impl CoreBaseController { } async fn bg_worker(&self, worker_id: String) { - const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(100); - const MAX_RETRIES: u32 = 5; - info!( "Worker {} started on thread {:?}", worker_id, @@ -317,7 +311,7 @@ impl CoreBaseController { let mut rx = self.shared_queue_rx.lock().await; match rx.recv().await { Some(mut action) => { - let run_name = &action + let run_name = action .action_id .run .as_ref() @@ -330,76 +324,43 @@ impl CoreBaseController { // Drop the mutex guard before processing drop(rx); - match self.handle_action(&mut action).await { + match self + .process_action_with_retry(&mut action, &worker_id) + .await + { Ok(_) => {} - // Add handling here for new slow down error Err(e) => { - error!("Error in controller loop: {:?}", e); - // Handle backoff and retry logic - sleep(MIN_BACKOFF_ON_ERR).await; - action.retries += 1; - - if action.retries > MAX_RETRIES { - error!( - "[{}] Controller failed processing {}::{}, system retries {} crossed threshold {}", - worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES - ); - action.client_err = Some(format!( - "[{}] Controller failed {}::{}, system retries {} crossed threshold {}", - worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES - )); - - // Fire completion event for failed action - let opt_informer = self - .informer_cache - .get(&action.get_run_identifier(), &action.parent_action_name) - .await; - if let Some(informer) = opt_informer { - // Before firing completion event, update the action in the - // informer, otherwise client_err will not be set. - // todo: gain a better understanding of these two errors and handle - let res = informer.set_action_client_err(&action).await; - match res { - Ok(()) => {} - Err(e) => { - error!( - "Error setting error for failed action {}: {}", - &action.get_full_name(), - e - ) - } - } - let res = informer - .fire_completion_event(&action.action_id.name) - .await; - match res { - Ok(()) => {} - Err(e) => { - error!("Error firing completion event for failed action {}: {}", &action.get_full_name(), e) - } - } - } else { + // Unified error handling for all failures and exceed max retries + error!( + "[{}] Error in controller loop for {}::{}: {:?}", + worker_id, run_name, action.action_id.name, e + ); + action.client_err = Some(e.to_string()); + + let opt_informer = self + .informer_cache + .get(&action.get_run_identifier(), &action.parent_action_name) + .await; + if let Some(informer) = opt_informer { + if let Err(set_err) = informer.set_action_client_err(&action).await + { error!( - "Max retries hit for action but informer missing: {:?}", - action.action_id + "Error setting error for failed action {}: {}", + action.get_full_name(), + set_err ); } - } else { - // Re-queue the action for retry - info!( - "[{}] Re-queuing action {}::{} for retry, attempt {}/{}", - worker_id, - run_name, - action.action_id.name, - action.retries, - MAX_RETRIES - ); - if let Err(send_err) = self.shared_queue.send(action).await { + if let Err(fire_err) = + informer.fire_completion_event(&action.action_id.name).await + { error!( - "[{}] Failed to re-queue action for retry: {}", - worker_id, send_err + "Error firing completion event for failed action {}: {}", + action.get_full_name(), + fire_err ); } + } else { + error!("Informer missing for action: {:?}", action.action_id); } } } @@ -412,6 +373,74 @@ impl CoreBaseController { } } + async fn process_action_with_retry( + &self, + action: &mut Action, + worker_id: &str, + ) -> Result<(), ControllerError> { + const MIN_BACKOFF_ON_ERR: Duration = Duration::from_millis(500); + const MAX_BACKOFF_ON_ERR: Duration = Duration::from_secs(10); + const MAX_RETRIES: u32 = 5; + + let run_name = action + .action_id + .run + .as_ref() + .map_or(String::from(""), |i| i.name.clone()); + + match self.handle_action(action).await { + Ok(_) => Ok(()), + // Process action with retry logic for SlowDownError + Err(ControllerError::SlowDownError(msg)) => { + action.retries += 1; + + if action.retries > MAX_RETRIES { + // Max retries exceeded, return error to be handled by caller + Err(ControllerError::RuntimeError(format!( + "[{}] Controller failed {}::{}, system retries {} crossed threshold {}: SlowDownError: {}", + worker_id, run_name, action.action_id.name, action.retries, MAX_RETRIES, msg + ))) + } else { + // Calculate exponential backoff: min(MIN * 2^(retries-1), MAX) + let backoff_millis = + MIN_BACKOFF_ON_ERR.as_millis() as u64 * 2u64.pow(action.retries - 1); + let backoff = Duration::from_millis(backoff_millis).min(MAX_BACKOFF_ON_ERR); + + warn!( + "[{}] Backing off for {:?} [retry {}/{}] on action {}::{} due to error: {}", + worker_id, + backoff, + action.retries, + MAX_RETRIES, + run_name, + action.action_id.name, + msg + ); + sleep(backoff).await; + + warn!( + "[{}] Retrying action {}::{} after backoff", + worker_id, run_name, action.action_id.name + ); + + // Re-queue the action for retry + self.shared_queue.send(action.clone()).await.map_err(|e| { + ControllerError::RuntimeError(format!( + "[{}] Failed to re-queue action for retry: {}", + worker_id, e + )) + })?; + + Ok(()) + } + } + Err(e) => { + // All other errors are propagated up immediately + Err(e) + } + } + } + async fn handle_action(&self, action: &mut Action) -> Result<(), ControllerError> { if !action.started { // Action not started, launch it @@ -459,10 +488,8 @@ impl CoreBaseController { "Failed to launch action: {}, error: {}", action.action_id.name, e ); - Err(ControllerError::RuntimeError(format!( - "Launch failed: {}", - e - ))) + // Propagate the error as-is + Err(e) } } } @@ -583,7 +610,7 @@ impl CoreBaseController { }) } - async fn launch_task(&self, action: &Action) -> Result { + async fn launch_task(&self, action: &Action) -> Result { if !action.started && action.task.is_some() { let enqueue_request = self .create_enqueue_action_request(action) @@ -591,7 +618,7 @@ impl CoreBaseController { let mut client = self.queue_client.clone(); // todo: tonic doesn't seem to have wait_for_ready, or maybe the .ready is already doing this. let enqueue_result = client.enqueue_action(enqueue_request).await; - // Add logic from resiliency pr here, return certain errors, but change others to be a specific slowdown error. + match enqueue_result { Ok(response) => { debug!("Successfully enqueued action: {:?}", action.action_id); @@ -604,14 +631,25 @@ impl CoreBaseController { action.action_id.name ); Ok(EnqueueActionResponse {}) + } else if e.code() == tonic::Code::FailedPrecondition + || e.code() == tonic::Code::InvalidArgument + || e.code() == tonic::Code::NotFound + { + Err(ControllerError::RuntimeError(format!( + "Precondition failed: {}", + e + ))) } else { + // For all other errors, retry with backoff through raising SlowDownError error!( "Failed to launch action: {:?}, backing off...", action.action_id ); error!("Error details: {}", e); - // Handle backoff logic here - Err(e) + Err(ControllerError::SlowDownError(format!( + "Failed to launch action: {}", + e + ))) } } } diff --git a/rs_controller/src/error.rs b/rs_controller/src/error.rs index beccfddba..cd82b4b9f 100644 --- a/rs_controller/src/error.rs +++ b/rs_controller/src/error.rs @@ -16,6 +16,9 @@ pub enum ControllerError { TaskError(String), #[error("Informer error: {0}")] Informer(#[from] InformerError), + // Error type that triggers retry with backoff + #[error("Slow down error: {0}")] + SlowDownError(String), } impl From for ControllerError { diff --git a/src/flyte/_internal/controllers/remote/_core.py b/src/flyte/_internal/controllers/remote/_core.py index 595d4aa7f..5ecf9c941 100644 --- a/src/flyte/_internal/controllers/remote/_core.py +++ b/src/flyte/_internal/controllers/remote/_core.py @@ -362,7 +362,7 @@ async def _bg_launch(self, action: Action): trace=trace, input_uri=action.inputs_uri, run_output_base=action.run_output_base, - group=action.group if action.group else None, + group=action.group.name if action.group else None, # Subject is not used in the current implementation ), wait_for_ready=True,