Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/apps/app_with_raw_data_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""FastAPI app that returns the raw data path from flyte.app.ctx."""

from fastapi import FastAPI

import flyte
from flyte.app import ctx
from flyte.app.extras import FastAPIAppEnvironment

app = FastAPI(
title="Raw Data Path Demo",
description="Returns the raw data path from flyte.app.ctx",
version="1.0.0",
)

app_env = FastAPIAppEnvironment(
name="raw-data-path-demo",
app=app,
description="App that returns the raw data path from ctx",
image=flyte.Image.from_debian_base().with_pip_packages("fastapi", "uvicorn"),
resources=flyte.Resources(cpu=1, memory="512Mi"),
requires_auth=False,
env_vars={"LOG_LEVEL": "10"},
)


@app.get("/")
async def root() -> str:
"""Return the raw data path from flyte.app.ctx."""
return ctx().raw_data_path or ""


if __name__ == "__main__":
import pathlib

flyte.init_from_config(root_dir=pathlib.Path(__file__).parent)
app_handle = flyte.serve(app_env)
print(f"Deployed app: {app_handle.url}")
19 changes: 18 additions & 1 deletion src/flyte/_bin/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,16 @@ def _bind_parameters(
async def _serve(
app_env: AppEnvironment,
materialized_parameters: dict[str, str | flyte.io.File | flyte.io.Dir],
raw_data_path: str | None = None,
):
import signal

if raw_data_path:
from flyte.app._context import set_raw_data_path

set_raw_data_path(raw_data_path)
logger.info(f"Set raw_data_path in AppContext: {raw_data_path}")

logger.info("Running app via server function")
assert app_env._server is not None

Expand Down Expand Up @@ -262,6 +269,7 @@ def run_sync():
@click.option("--tgz", required=False)
@click.option("--pkl", required=False)
@click.option("--dest", required=False)
@click.option("--raw-data-path", "-r", required=False)
@click.option("--project", envvar=PROJECT_NAME, required=False)
@click.option("--domain", envvar=DOMAIN_NAME, required=False)
@click.option("--org", envvar=ORG_NAME, required=False)
Expand All @@ -277,6 +285,7 @@ def main(
tgz: str,
pkl: str,
dest: str,
raw_data_path: str | None = None,
command: tuple[str, ...] | None = None,
project: str | None = None,
domain: str | None = None,
Expand Down Expand Up @@ -328,10 +337,18 @@ def main(

os.environ[RUNTIME_PARAMETERS_FILE] = parameters_file

logger.info(f"RAW DATA PATH: {raw_data_path}")

if app_env and app_env._server is not None:
asyncio.run(_serve(app_env, materialized_parameters))
asyncio.run(_serve(app_env, materialized_parameters, raw_data_path=raw_data_path))
exit(0)

if raw_data_path:
from flyte.app._context import set_raw_data_path

set_raw_data_path(raw_data_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of updating the env var in set_raw_data_path, would it be better to just pass raw_data_path to _serve()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, can you re-+1?

logger.info(f"Set raw_data_path in AppContext: {raw_data_path}")

if command is None or len(command) == 0:
raise ValueError("No command provided to execute")

Expand Down
19 changes: 19 additions & 0 deletions src/flyte/_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def __init__(
health_check_timeout: float | None = None,
health_check_interval: float | None = None,
health_check_path: str | None = None,
raw_data_path: str | None = None,
):
"""
Initialize serve context.
Expand Down Expand Up @@ -375,6 +376,10 @@ def __init__(
Defaults to `1` second.
health_check_path: URL path used for the local health-check probe (e.g. `"/healthz"`).
Defaults to `"/health"`.
raw_data_path: Raw data path for the app. For local serving, used when testing apps
that read ctx().raw_data_path. Defaults to ``/tmp/flyte/raw_data`` when mode is
local and not specified. For remote serving, the backend provides this via the
container command.
"""
from flyte._initialize import _get_init_config

Expand Down Expand Up @@ -407,6 +412,9 @@ def __init__(
health_check_interval if health_check_interval is not None else _LOCAL_IS_ACTIVE_INTERVAL
)
self._health_check_path = health_check_path if health_check_path is not None else _LOCAL_HEALTH_CHECK_PATH
self._raw_data_path = (
raw_data_path if raw_data_path is not None else ("/tmp/flyte/raw_data" if self._mode == "local" else "")
)

# ------------------------------------------------------------------
# Local serving
Expand Down Expand Up @@ -478,6 +486,12 @@ def _serve_local_with_server_func(
local_app = _LocalApp(app_env=app_env, _serve_obj=self, host=host, port=port)

def _run():
# Contextvars don't propagate to new threads. Set raw_data_path in this
# thread's context so ctx().raw_data_path works in request handlers.
from flyte.app._context import set_raw_data_path

set_raw_data_path(self._raw_data_path or "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work if we run two apps locally? Both of them will override the same env vars, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but it should be the same raw data path in /tmp


loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
local_app._thread_loop = loop
Expand Down Expand Up @@ -757,6 +771,7 @@ def with_servecontext(
health_check_timeout: float | None = None,
health_check_interval: float | None = None,
health_check_path: str | None = None,
raw_data_path: str | None = None,
) -> _Serve:
"""
Create a serve context with custom configuration.
Expand Down Expand Up @@ -821,6 +836,9 @@ def with_servecontext(
Defaults to 1 s.
health_check_path: URL path used for the local health-check probe (e.g. ``"/healthz"``).
Defaults to ``"/health"``.
raw_data_path: Raw data path for the app. For local serving, sets ctx().raw_data_path
so apps can read it. Defaults to ``/tmp/flyte/raw_data`` when mode is local.
For remote serving, the backend provides this via the container command.

Returns:
_Serve: Serve context manager with configured settings
Expand Down Expand Up @@ -853,6 +871,7 @@ def with_servecontext(
health_check_timeout=health_check_timeout,
health_check_interval=health_check_interval,
health_check_path=health_check_path,
raw_data_path=raw_data_path,
)


Expand Down
3 changes: 3 additions & 0 deletions src/flyte/app/_app_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def container_cmd(
cmd.append("--parameters")
cmd.append(self._serialize_parameters(parameter_overrides))

# Add raw-data-path with template variable for backend to substitute at runtime
cmd.extend(["--raw-data-path", "{{.rawOutputDataPrefix}}"])

# Only add resolver args if _caller_frame is set and we can extract the module
# (i.e., app was created in a module and can be found)
if self._caller_frame is not None:
Expand Down
14 changes: 14 additions & 0 deletions src/flyte/app/_context.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import os
from contextvars import ContextVar
from dataclasses import dataclass
from typing import cast

from flyte._serve import ServeMode

_raw_data_path_var: ContextVar[str | None] = ContextVar("raw_data_path", default=None)
_RAW_DATA_PATH_ENV = "_FLYTE_APP_RAW_DATA_PATH"


@dataclass(frozen=True)
class AppContext:
mode: ServeMode = "remote"
project: str = ""
raw_data_path: str = ""
domain: str = ""


Expand All @@ -22,8 +27,17 @@ def ctx() -> AppContext:
mode = os.getenv("_RUN_MODE", "remote")
project = os.getenv("FLYTE_INTERNAL_EXECUTION_PROJECT", "")
domain = os.getenv("FLYTE_INTERNAL_EXECUTION_DOMAIN", "")
raw_data_path = _raw_data_path_var.get() or os.getenv(_RAW_DATA_PATH_ENV, "") or ""
return AppContext(
mode=cast(ServeMode, mode),
project=project,
domain=domain,
raw_data_path=raw_data_path,
)


def set_raw_data_path(raw_data_path: str | None) -> None:
"""Set the raw data path in the current context and as env var for thread propagation."""
value = raw_data_path or ""
_raw_data_path_var.set(value)
os.environ[_RAW_DATA_PATH_ENV] = value
Loading