diff --git a/examples/apps/app_with_raw_data_path.py b/examples/apps/app_with_raw_data_path.py new file mode 100644 index 000000000..74e77f21e --- /dev/null +++ b/examples/apps/app_with_raw_data_path.py @@ -0,0 +1,37 @@ +"""FastAPI app that returns the raw data path from flyte.app.ctx.""" + +from fastapi import FastAPI + +import flyte +from flyte.app import ctx +from flyte.app.extras import FastAPIAppEnvironment + +app = FastAPI( + title="Raw Data Path Demo", + description="Returns the raw data path from flyte.app.ctx", + version="1.0.0", +) + +app_env = FastAPIAppEnvironment( + name="raw-data-path-demo", + app=app, + description="App that returns the raw data path from ctx", + image=flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn"), + resources=flyte.Resources(cpu=1, memory="512Mi"), + requires_auth=False, + env_vars={"LOG_LEVEL": "10"}, +) + + +@app.get("/") +async def root() -> str: + """Return the raw data path from flyte.app.ctx.""" + return ctx().raw_data_path or "" + + +if __name__ == "__main__": + import pathlib + + flyte.init_from_config(root_dir=pathlib.Path(__file__).parent) + app_handle = flyte.serve(app_env) + print(f"Deployed app: {app_handle.url}") diff --git a/src/flyte/_bin/serve.py b/src/flyte/_bin/serve.py index 2daecf9f7..1f1769a73 100644 --- a/src/flyte/_bin/serve.py +++ b/src/flyte/_bin/serve.py @@ -206,9 +206,16 @@ def _bind_parameters( async def _serve( app_env: AppEnvironment, materialized_parameters: dict[str, str | flyte.io.File | flyte.io.Dir], + raw_data_path: str | None = None, ): import signal + if raw_data_path: + from flyte.app._context import set_raw_data_path + + set_raw_data_path(raw_data_path) + logger.info(f"Set raw_data_path in AppContext: {raw_data_path}") + logger.info("Running app via server function") assert app_env._server is not None @@ -262,6 +269,7 @@ def run_sync(): @click.option("--tgz", required=False) @click.option("--pkl", required=False) @click.option("--dest", required=False) +@click.option("--raw-data-path", "-r", required=False) @click.option("--project", envvar=PROJECT_NAME, required=False) @click.option("--domain", envvar=DOMAIN_NAME, required=False) @click.option("--org", envvar=ORG_NAME, required=False) @@ -277,6 +285,7 @@ def main( tgz: str, pkl: str, dest: str, + raw_data_path: str | None = None, command: tuple[str, ...] | None = None, project: str | None = None, domain: str | None = None, @@ -328,10 +337,18 @@ def main( os.environ[RUNTIME_PARAMETERS_FILE] = parameters_file + logger.info(f"RAW DATA PATH: {raw_data_path}") + if app_env and app_env._server is not None: - asyncio.run(_serve(app_env, materialized_parameters)) + asyncio.run(_serve(app_env, materialized_parameters, raw_data_path=raw_data_path)) exit(0) + if raw_data_path: + from flyte.app._context import set_raw_data_path + + set_raw_data_path(raw_data_path) + logger.info(f"Set raw_data_path in AppContext: {raw_data_path}") + if command is None or len(command) == 0: raise ValueError("No command provided to execute") diff --git a/src/flyte/_serve.py b/src/flyte/_serve.py index 2b850921f..1c6461291 100644 --- a/src/flyte/_serve.py +++ b/src/flyte/_serve.py @@ -345,6 +345,7 @@ def __init__( health_check_timeout: float | None = None, health_check_interval: float | None = None, health_check_path: str | None = None, + raw_data_path: str | None = None, ): """ Initialize serve context. @@ -375,6 +376,10 @@ def __init__( Defaults to `1` second. health_check_path: URL path used for the local health-check probe (e.g. `"/healthz"`). Defaults to `"/health"`. + raw_data_path: Raw data path for the app. For local serving, used when testing apps + that read ctx().raw_data_path. Defaults to ``/tmp/flyte/raw_data`` when mode is + local and not specified. For remote serving, the backend provides this via the + container command. """ from flyte._initialize import _get_init_config @@ -407,6 +412,9 @@ def __init__( health_check_interval if health_check_interval is not None else _LOCAL_IS_ACTIVE_INTERVAL ) self._health_check_path = health_check_path if health_check_path is not None else _LOCAL_HEALTH_CHECK_PATH + self._raw_data_path = ( + raw_data_path if raw_data_path is not None else ("/tmp/flyte/raw_data" if self._mode == "local" else "") + ) # ------------------------------------------------------------------ # Local serving @@ -478,6 +486,12 @@ def _serve_local_with_server_func( local_app = _LocalApp(app_env=app_env, _serve_obj=self, host=host, port=port) def _run(): + # Contextvars don't propagate to new threads. Set raw_data_path in this + # thread's context so ctx().raw_data_path works in request handlers. + from flyte.app._context import set_raw_data_path + + set_raw_data_path(self._raw_data_path or "") + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) local_app._thread_loop = loop @@ -757,6 +771,7 @@ def with_servecontext( health_check_timeout: float | None = None, health_check_interval: float | None = None, health_check_path: str | None = None, + raw_data_path: str | None = None, ) -> _Serve: """ Create a serve context with custom configuration. @@ -821,6 +836,9 @@ def with_servecontext( Defaults to 1 s. health_check_path: URL path used for the local health-check probe (e.g. ``"/healthz"``). Defaults to ``"/health"``. + raw_data_path: Raw data path for the app. For local serving, sets ctx().raw_data_path + so apps can read it. Defaults to ``/tmp/flyte/raw_data`` when mode is local. + For remote serving, the backend provides this via the container command. Returns: _Serve: Serve context manager with configured settings @@ -853,6 +871,7 @@ def with_servecontext( health_check_timeout=health_check_timeout, health_check_interval=health_check_interval, health_check_path=health_check_path, + raw_data_path=raw_data_path, ) diff --git a/src/flyte/app/_app_environment.py b/src/flyte/app/_app_environment.py index 5ea7f6115..15fc3a709 100644 --- a/src/flyte/app/_app_environment.py +++ b/src/flyte/app/_app_environment.py @@ -249,6 +249,9 @@ def container_cmd( cmd.append("--parameters") cmd.append(self._serialize_parameters(parameter_overrides)) + # Add raw-data-path with template variable for backend to substitute at runtime + cmd.extend(["--raw-data-path", "{{.rawOutputDataPrefix}}"]) + # Only add resolver args if _caller_frame is set and we can extract the module # (i.e., app was created in a module and can be found) if self._caller_frame is not None: diff --git a/src/flyte/app/_context.py b/src/flyte/app/_context.py index 696734c7f..87448ce09 100644 --- a/src/flyte/app/_context.py +++ b/src/flyte/app/_context.py @@ -1,14 +1,19 @@ import os +from contextvars import ContextVar from dataclasses import dataclass from typing import cast from flyte._serve import ServeMode +_raw_data_path_var: ContextVar[str | None] = ContextVar("raw_data_path", default=None) +_RAW_DATA_PATH_ENV = "_FLYTE_APP_RAW_DATA_PATH" + @dataclass(frozen=True) class AppContext: mode: ServeMode = "remote" project: str = "" + raw_data_path: str = "" domain: str = "" @@ -22,8 +27,17 @@ def ctx() -> AppContext: mode = os.getenv("_RUN_MODE", "remote") project = os.getenv("FLYTE_INTERNAL_EXECUTION_PROJECT", "") domain = os.getenv("FLYTE_INTERNAL_EXECUTION_DOMAIN", "") + raw_data_path = _raw_data_path_var.get() or os.getenv(_RAW_DATA_PATH_ENV, "") or "" return AppContext( mode=cast(ServeMode, mode), project=project, domain=domain, + raw_data_path=raw_data_path, ) + + +def set_raw_data_path(raw_data_path: str | None) -> None: + """Set the raw data path in the current context and as env var for thread propagation.""" + value = raw_data_path or "" + _raw_data_path_var.set(value) + os.environ[_RAW_DATA_PATH_ENV] = value