diff --git a/examples/airflow-migration/basic.py b/examples/airflow-migration/basic.py new file mode 100644 index 000000000..39b4644ff --- /dev/null +++ b/examples/airflow-migration/basic.py @@ -0,0 +1,38 @@ +from pathlib import Path + +import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches +from airflow import DAG +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator + +import flyte + + +def hello_python(): + print("Hello from PythonOperator!") + + +# Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. +# Pass flyte_env so the generated workflow task uses the right container image. +with DAG( + dag_id="simple_airflow_workflow", +) as dag: + t1 = BashOperator( + task_id="say_hello", + bash_command='echo "Hello Airflow!"', + ) + t2 = BashOperator( + task_id="say_goodbye", + bash_command='echo "Goodbye Airflow!"', + ) + t3 = PythonOperator( + task_id="hello_python", + python_callable=hello_python, + ) + t1 >> t2 # t2 runs after t1 + + +if __name__ == "__main__": + flyte.init_from_config(root_dir=Path(__file__).parent.parent.parent) + run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) + print(run.url) diff --git a/examples/connectors/bigquery_example.py b/examples/connectors/bigquery_example.py index a17fc9d3d..347642e1c 100644 --- a/examples/connectors/bigquery_example.py +++ b/examples/connectors/bigquery_example.py @@ -11,7 +11,7 @@ query_template="SELECT * from dataset.flyte_table3;", ) -flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) +bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) if __name__ == "__main__": diff --git a/plugins/airflow/README.md b/plugins/airflow/README.md new file mode 100644 index 000000000..38f3d3ba3 --- /dev/null +++ b/plugins/airflow/README.md @@ -0,0 +1,41 @@ +# Flyte Airflow Plugin + +Run existing Airflow DAGs on Flyte with minimal code changes. The plugin +monkey-patches `airflow.DAG` and `BaseOperator` so that standard Airflow +definitions are transparently converted into Flyte tasks. + +## Features + +- Write a normal `with DAG(...) as dag:` block — the plugin intercepts + operator construction and wires everything into a Flyte workflow. +- Supports `BashOperator` (`AirflowShellTask`) and `PythonOperator` + (`AirflowPythonFunctionTask`). +- Dependency arrows (`>>`, `<<`) are preserved as execution order. +- Runs locally or remotely — no Airflow cluster required. + +## Installation + +```bash +pip install flyteplugins-airflow +``` + +## Quick start + +```python +import flyteplugins.airflow.task # triggers DAG + operator monkey-patches +from airflow import DAG +from airflow.operators.bash import BashOperator +import flyte + +with DAG(dag_id="my_dag") as dag: + t1 = BashOperator(task_id="step1", bash_command="echo step1") + t2 = BashOperator(task_id="step2", bash_command="echo step2") + t1 >> t2 + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.with_runcontext(mode="remote").run(dag) + print(run.url) +``` + +See `examples/airflow-migration/` for a full example including `PythonOperator`. diff --git a/plugins/airflow/pyproject.toml b/plugins/airflow/pyproject.toml new file mode 100644 index 000000000..fe1fc611e --- /dev/null +++ b/plugins/airflow/pyproject.toml @@ -0,0 +1,79 @@ +[project] +name = "flyteplugins-airflow" +dynamic = ["version"] +description = "Airflow plugin for flyte" +readme = "README.md" +authors = [{ name = "Kevin Su", email = "pingsutw@users.noreply.github.com" }] +requires-python = ">=3.10,<3.13" +dependencies = [ + "apache-airflow", + "flyte", + "jsonpickle" +] + +[dependency-groups] +dev = [ + "pytest>=8.3.5", + "pytest-asyncio>=0.26.0", +] + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] + +[tool.uv.sources] +flyte = { path = "../../", editable = true } diff --git a/plugins/airflow/src/flyteplugins/airflow/__init__.py b/plugins/airflow/src/flyteplugins/airflow/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/__init__.py @@ -0,0 +1 @@ + diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py new file mode 100644 index 000000000..cb86d92cc --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -0,0 +1,264 @@ +""" +Monkey-patches airflow.DAG so that a standard Airflow DAG definition is transparently +converted into a runnable Flyte task, with no changes to the DAG code required. + +Usage +----- + from flyteplugins.airflow.task import AirflowContainerTask # triggers patches + from airflow import DAG + from airflow.operators.bash import BashOperator + import flyte + + env = flyte.TaskEnvironment(name="hello_airflow", image=...) + + with DAG(dag_id="my_dag", flyte_env=env) as dag: + t1 = BashOperator(task_id="step1", bash_command='echo step1') + t2 = BashOperator(task_id="step2", bash_command='echo step2') + t1 >> t2 # optional: explicit dependency + + if __name__ == "__main__": + flyte.init_from_config() + run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) + print(run.url) + +Notes +----- +- ``flyte_env`` is an optional kwarg accepted by the patched DAG. If omitted a + default ``TaskEnvironment`` is created using the dag_id as the name and a + Debian-base image with ``flyteplugins-airflow`` and ``jsonpickle`` installed. +- Operator dependency arrows (``>>``, ``<<``) update the execution order. + If no explicit dependencies are declared, the operators run in definition order. +""" + +from __future__ import annotations + +import inspect +import logging +import sys as _sys +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +from flyte._task import TaskTemplate + +import airflow.models.dag as _airflow_dag_module + +if TYPE_CHECKING: + import types + + from flyteplugins.airflow.task import AirflowPythonFunctionTask, AirflowShellTask + + AirflowTask = AirflowPythonFunctionTask | AirflowShellTask + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- + +_CURRENT_FLYTE_DAG = "current_flyte_dag" + +#: Mutable container for the active FlyteDAG (set inside a ``with DAG(...)`` block). +#: Using a dict avoids ``global`` statements in the patch functions. +_state: Dict[str, Optional["FlyteDAG"]] = {_CURRENT_FLYTE_DAG: None} + + +# --------------------------------------------------------------------------- +# FlyteDAG - collects operators and builds the Flyte workflow task +# --------------------------------------------------------------------------- + + +class FlyteDAG: + """Collects Airflow operators during a DAG definition and converts them + into a single Flyte task that runs them in dependency order.""" + + def __init__(self, dag_id: str, env=None) -> None: + self.dag_id = dag_id + self.env = env + # Ordered dict preserves insertion (creation) order as the default. + self._tasks: Dict[str, "AirflowTask"] = {} + # task_id -> set of upstream task_ids + self._upstream: Dict[str, Set[str]] = defaultdict(set) + + # ------------------------------------------------------------------ + # Registration (called by _flyte_operator during DAG definition) + # ------------------------------------------------------------------ + + def add_task(self, task_id: str, task: "AirflowTask") -> None: + self._tasks[task_id] = task + # Ensure a dependency entry exists even with no upstream tasks. + _ = self._upstream[task_id] + + def set_dependency(self, upstream_id: str, downstream_id: str) -> None: + """Record that *upstream_id* must run before *downstream_id*.""" + self._upstream[downstream_id].add(upstream_id) + + # ------------------------------------------------------------------ + # Flyte task construction + # ------------------------------------------------------------------ + + def _build_downstream_map(self) -> Dict[str, List[str]]: + """Invert ``self._upstream`` to get downstream adjacency lists.""" + downstream: Dict[str, List[str]] = defaultdict(list) + for tid, upstreams in self._upstream.items(): + for up in upstreams: + downstream[up].append(tid) + return downstream + + def _find_caller_module(self) -> Tuple[str, Optional["types.ModuleType"]]: + """Walk the call stack to find the first frame outside this module. + + The Flyte task must be registered under the *user's* module so that + ``DefaultTaskResolver`` can locate it via ``getattr(module, name)`` + on the remote worker (which re-imports the module and re-runs the + DAG definition). + """ + for fi in inspect.stack(): + mod = fi.frame.f_globals.get("__name__", "") + if mod and mod != __name__: + return mod, _sys.modules.get(mod) + return __name__, None + + def _create_dag_entry( + self, + all_tasks: Dict[str, "AirflowTask"], + downstream_map: Dict[str, List[str]], + root_tasks: List["AirflowTask"], + ): + """Build the async entry function that orchestrates task execution.""" + # Snapshot to avoid capturing mutable references in the closure. + root_snapshot = list(root_tasks) + downstream_snapshot = dict(downstream_map) + + async def _dag_entry() -> None: + import asyncio + + async def _run_chain(task): + await task.aio() + ds = downstream_snapshot.get(task.name, []) + if ds: + await asyncio.gather(*[_run_chain(all_tasks[d]) for d in ds]) + + await asyncio.gather(*[_run_chain(t) for t in root_snapshot]) + + caller_module_name, caller_module = self._find_caller_module() + _dag_entry.__name__ = f"dag_{self.dag_id}" + _dag_entry.__qualname__ = f"dag_{self.dag_id}" + _dag_entry.__module__ = caller_module_name + return _dag_entry, caller_module + + def build(self) -> None: + """Create a Flyte workflow task whose entry function runs all + operator tasks in dependency order. + + The entry function captures the full dependency graph in its closure + and orchestrates execution directly, starting from root tasks and + chaining downstream tasks after each completes. This ensures + correct ordering in both local and remote execution (where sub-tasks + are resolved independently and lose their in-memory references). + """ + import flyte + + if self.env is None: + self.env = flyte.TaskEnvironment( + name=self.dag_id, + image=flyte.Image.from_debian_base() + .with_pip_packages("apache-airflow<3.0.0", "jsonpickle") + .with_local_v2(), + ) + + downstream = self._build_downstream_map() + + # Root tasks: those with no upstream dependencies. + root_tasks = [self._tasks[tid] for tid, ups in self._upstream.items() if len(ups) == 0] + + _dag_entry, caller_module = self._create_dag_entry( + all_tasks=dict(self._tasks), + downstream_map=downstream, + root_tasks=root_tasks, + ) + + # Set image and register operator tasks with the DAG's TaskEnvironment. + for _op_task in self._tasks.values(): + if _op_task.image is None: + _op_task.image = self.env.image + self.env.add_dependency(flyte.TaskEnvironment.from_task(_op_task.name, _op_task)) + + self.flyte_task = self.env.task(_dag_entry) + + # Inject the task into the caller's module so DefaultTaskResolver + # can find it via getattr(module, task_name) on both local and remote. + if caller_module is not None: + setattr(caller_module, _dag_entry.__name__, self.flyte_task) + + +# --------------------------------------------------------------------------- +# Proxy class — makes DAG instances pass isinstance(dag, TaskTemplate) +# --------------------------------------------------------------------------- + + +# All names defined on TaskTemplate (fields, methods, properties) that should +# be proxied to the underlying flyte_task rather than resolved on the DAG. +_TASK_TEMPLATE_NAMES = frozenset(name for name in dir(TaskTemplate) if not name.startswith("_")) | frozenset( + TaskTemplate.__dataclass_fields__.keys() +) + + +class _FlyteDAG(_airflow_dag_module.DAG, TaskTemplate): + """Makes an Airflow DAG pass ``isinstance(dag, TaskTemplate)`` by proxying + TaskTemplate attribute access to the attached ``flyte_task``. + """ + + def __getattribute__(self, name): + if name in _TASK_TEMPLATE_NAMES: + ft = object.__getattribute__(self, "__dict__").get("flyte_task") + if ft is not None: + return getattr(ft, name) + return super().__getattribute__(name) + + +_original_dag_init = _airflow_dag_module.DAG.__init__ +_original_dag_enter = _airflow_dag_module.DAG.__enter__ +_original_dag_exit = _airflow_dag_module.DAG.__exit__ + + +def _patched_dag_init(self, *args, **kwargs) -> None: # type: ignore[override] + # Pull out our custom kwarg before passing the rest to Airflow. + flyte_env = kwargs.pop("flyte_env", None) + _original_dag_init(self, *args, **kwargs) + self._flyte_env = flyte_env + + +def _patched_dag_enter(self): # type: ignore[override] + _state[_CURRENT_FLYTE_DAG] = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) + return _original_dag_enter(self) + + +def _patched_dag_exit(self, exc_type, exc_val, exc_tb): # type: ignore[override] + try: + if exc_type is None and _state[_CURRENT_FLYTE_DAG] is not None: + flyte_dag = _state[_CURRENT_FLYTE_DAG] + flyte_dag.build() + # Attach the Flyte task and a convenience run() to the DAG object, + # then swap __class__ so the DAG passes isinstance(dag, TaskTemplate). + self.flyte_task = flyte_dag.flyte_task + self.run = _make_run(flyte_dag.flyte_task) + self.__class__ = _FlyteDAG + finally: + _state[_CURRENT_FLYTE_DAG] = None + + return _original_dag_exit(self, exc_type, exc_val, exc_tb) + + +def _make_run(flyte_task): + """Return a ``run(**kwargs)`` helper bound to *flyte_task*.""" + import flyte + + def run(**kwargs): + return flyte.with_runcontext(**kwargs).run(flyte_task) + + return run + + +_airflow_dag_module.DAG.__init__ = _patched_dag_init +_airflow_dag_module.DAG.__enter__ = _patched_dag_enter +_airflow_dag_module.DAG.__exit__ = _patched_dag_exit diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py new file mode 100644 index 000000000..4d0f43b6d --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -0,0 +1,274 @@ +import importlib +import typing +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional + +import flyte +import jsonpickle +from flyte import get_custom_context, logger +from flyte._context import internal_ctx +from flyte._internal.resolvers.common import Resolver +from flyte._module import extract_obj_module +from flyte._task import TaskTemplate +from flyte.extend import AsyncFunctionTaskTemplate +from flyte.models import NativeInterface, SerializationContext + +import airflow.models as airflow_models +import airflow.sensors.base as airflow_sensors +import airflow.triggers.base as airflow_triggers +import airflow.utils.context as airflow_context +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator + +# Import dag module to apply DAG monkey-patches when this module is imported. +from flyteplugins.airflow import dag as _dag_module + +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + +@dataclass +class AirflowTaskMetadata: + """Stores the Airflow operator class location and constructor kwargs. + + For example, given:: + + FileSensor(task_id="id", filepath="/tmp/1234") + + the fields would be: + module: "airflow.sensors.filesystem" + name: "FileSensor" + parameters: {"task_id": "id", "filepath": "/tmp/1234"} + """ + + module: str + name: str + parameters: typing.Dict[str, Any] + + +# --------------------------------------------------------------------------- +# Resolver +# --------------------------------------------------------------------------- + + +class AirflowPythonTaskResolver(Resolver): + """Resolves an AirflowPythonFunctionTask on the remote worker. + + The resolver records the Airflow operator metadata and the wrapped Python + callable so that the task can be reconstructed from loader args alone. + """ + + @property + def import_path(self) -> str: + return "flyteplugins.airflow.task.AirflowPythonTaskResolver" + + def load_task(self, loader_args: typing.List[str]) -> AsyncFunctionTaskTemplate: + _, airflow_task_module, _, airflow_task_name, _, airflow_task_parameters, _, func_module, _, func_name = ( + loader_args + ) + func_module = importlib.import_module(name=func_module) + func_def = getattr(func_module, func_name) + return AirflowPythonFunctionTask( + name=airflow_task_name, + airflow_task_metadata=AirflowTaskMetadata( + module=airflow_task_module, + name=airflow_task_name, + parameters=jsonpickle.decode(airflow_task_parameters), + ), + func=func_def, + ) + + def loader_args(self, task: "AirflowPythonFunctionTask", root_dir: Path) -> List[str]: # type:ignore + entity_module_name, _ = extract_obj_module(task.func, root_dir) + return [ + "airflow-task-module", + task.airflow_task_metadata.module, + "airflow-task-name", + task.airflow_task_metadata.name, + "airflow-task-parameters", + jsonpickle.encode(task.airflow_task_metadata.parameters), + "airflow-func-module", + entity_module_name, + "airflow-func-name", + task.func.__name__, + ] + + +# --------------------------------------------------------------------------- +# Shared task behaviour (mixin) +# --------------------------------------------------------------------------- + + +class _AirflowTaskMixin: + """Shared behaviour for both raw-container and function Airflow tasks. + + Provides Airflow-style dependency arrows (``>>`` / ``<<``) and the + ``ExecutorSafeguard`` workaround needed when tasks run on background threads. + """ + + def _init_airflow_mixin(self) -> None: + self._call_as_synchronous = True + + # Airflow dependency-arrow support (>> / <<) + # Records the dependency in the active FlyteDAG if one is being built. + + def __rshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": + """``self >> other`` — other runs after self.""" + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].set_dependency(self.name, other.name) + return other + + def __lshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": + """``self << other`` — self runs after other.""" + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].set_dependency(other.name, self.name) + return other + + @staticmethod + def _patch_executor_safeguard() -> None: + """Ensure ExecutorSafeguard's thread-local has a ``callers`` dict. + + ExecutorSafeguard stores a sentinel in a ``threading.local()`` dict + that is initialised on the main thread at import time. Tasks may run + on a background thread where the thread-local has no ``callers`` key. + """ + from airflow.models.baseoperator import ExecutorSafeguard + + if not hasattr(ExecutorSafeguard._sentinel, "callers"): + ExecutorSafeguard._sentinel.callers = {} + + +# --------------------------------------------------------------------------- +# Task classes +# --------------------------------------------------------------------------- + + +class AirflowShellTask(_AirflowTaskMixin, TaskTemplate): + """Wraps an Airflow BashOperator as a Flyte raw-container task.""" + + def __init__( + self, + name: str, + airflow_task_metadata: AirflowTaskMetadata, + command: str, + **kwargs, + ): + super().__init__( + name=name, + interface=NativeInterface(inputs={}, outputs={}), + **kwargs, + ) + self._init_airflow_mixin() + self._airflow_task_metadata = airflow_task_metadata + self._command = command + + def container_args(self, sctx: SerializationContext) -> List[str]: + return self._command.split() + + async def execute(self, **kwargs) -> Any: + self._patch_executor_safeguard() + logger.info("Executing Airflow bash operator") + _get_airflow_instance(self._airflow_task_metadata).execute(context=airflow_context.Context()) + + +class AirflowPythonFunctionTask(_AirflowTaskMixin, AsyncFunctionTaskTemplate): + """Wraps an Airflow PythonOperator as a Flyte function task. + + The airflow task module, name, and parameters are stored in the task + config. Some Airflow operators are not deferrable (e.g. + ``BeamRunJavaPipelineOperator``). These tasks lack an async method to + poll job status so they cannot use the Flyte connector — we run them in + a container instead. + """ + + def __init__( + self, + name: str, + airflow_task_metadata: AirflowTaskMetadata, + func: Optional[callable], + **kwargs, + ): + super().__init__( + name=name, + func=func, + interface=NativeInterface(inputs={}, outputs={}), + **kwargs, + ) + self._init_airflow_mixin() + self.resolver = AirflowPythonTaskResolver() + self.airflow_task_metadata = airflow_task_metadata + + async def execute(self, **kwargs) -> Any: + logger.info("Executing Airflow python task") + self.airflow_task_metadata.parameters["python_callable"] = self.func + _get_airflow_instance(self.airflow_task_metadata).execute(context=airflow_context.Context()) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_airflow_instance( + airflow_task_metadata: AirflowTaskMetadata, +) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: + """Instantiate the original Airflow operator from its metadata.""" + # Set GET_ORIGINAL_TASK so that obj_def returns the real Airflow + # operator instead of being intercepted by _flyte_operator. + with flyte.custom_context(GET_ORIGINAL_TASK="True"): + obj_module = importlib.import_module(name=airflow_task_metadata.module) + obj_def = getattr(obj_module, airflow_task_metadata.name) + return obj_def(**airflow_task_metadata.parameters) + + +# --------------------------------------------------------------------------- +# Operator intercept (monkey-patch) +# --------------------------------------------------------------------------- + + +def _flyte_operator(*args, **kwargs): + """Intercept Airflow operator construction and return a Flyte task instead. + + Called via the monkey-patched ``BaseOperator.__new__``. Depending on + context this either registers the task with an active FlyteDAG, submits + it as a sub-task during execution, or returns the task object for later + serialization. + """ + cls = args[0] + if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": + logger.debug("Returning original Airflow task") + return object.__new__(cls) + + container_image = kwargs.pop("container_image", None) + task_id = kwargs.get("task_id", cls.__name__) + airflow_task_metadata = AirflowTaskMetadata(module=cls.__module__, name=cls.__name__, parameters=kwargs) + + if cls == BashOperator: + command = kwargs.get("bash_command", "") + task = AirflowShellTask(name=task_id, airflow_task_metadata=airflow_task_metadata, command=command) + elif cls == PythonOperator: + func = kwargs.get("python_callable", None) + kwargs.pop("python_callable", None) + task = AirflowPythonFunctionTask( + name=task_id, airflow_task_metadata=airflow_task_metadata, func=func, image=container_image + ) + else: + raise ValueError(f"Unsupported Airflow operator: {cls.__name__}") + + # Case 1: inside a ``with DAG(...) as dag:`` block — register with FlyteDAG. + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].add_task(task_id, task) + return task + + # Case 2: inside a Flyte task execution — submit the operator as a sub-task. + if internal_ctx().is_task_context(): + return task() + + # Case 3: outside any context (e.g. serialization / import scan). + return task + + +# Monkey-patch: intercept Airflow operator construction. +airflow_models.BaseOperator.__new__ = _flyte_operator diff --git a/plugins/airflow/tests/test_airflow_dag.py b/plugins/airflow/tests/test_airflow_dag.py new file mode 100644 index 000000000..0952bd3a3 --- /dev/null +++ b/plugins/airflow/tests/test_airflow_dag.py @@ -0,0 +1,86 @@ +""" +Tests for the Airflow DAG monkey-patch in flyteplugins.airflow.dag. +""" + +import flyte +import pytest +from flyte._image import Image + +import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches + + +@pytest.fixture(autouse=True) +def _reset_environment_registry(): + """Clean up TaskEnvironment instances created during each test.""" + from flyte._environment import _ENVIRONMENT_REGISTRY + + initial_len = len(_ENVIRONMENT_REGISTRY) + yield + del _ENVIRONMENT_REGISTRY[initial_len:] + + +def test_flyte_env_image_preserved_after_dag_build(): + """The image supplied via flyte_env= must survive FlyteDAG.build(). + + Previously, build() called TaskEnvironment.from_task() which derived a new + environment from the operator tasks' images (all None), silently discarding + the user-supplied image. + """ + from airflow import DAG + from airflow.operators.bash import BashOperator + + custom_image = Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0") + env = flyte.TaskEnvironment(name="test-dag-env", image=custom_image) + + with DAG(dag_id="test_image_preserved", flyte_env=env) as dag: + BashOperator(task_id="say_hello", bash_command="echo hello") + + assert dag.flyte_task is not None + parent_env = dag.flyte_task.parent_env() + assert parent_env is not None, "flyte_task must be attached to an environment" + assert parent_env.image is custom_image, ( + f"Expected the user-supplied image to be preserved, got: {parent_env.image}" + ) + + # Operator tasks must also inherit the env's image so they can be + # serialized correctly when submitted as sub-tasks during remote execution. + for task_name, op_task in parent_env.tasks.items(): + if task_name != f"{env.name}.dag_test_image_preserved": + assert op_task.image is custom_image, ( + f"Operator task {task_name!r} did not inherit env image: {op_task.image}" + ) + + +def test_operator_tasks_registered_in_env(monkeypatch): + """Operator tasks must appear in env.tasks so they are included in deployment.""" + from airflow import DAG + from airflow.operators.bash import BashOperator + + env = flyte.TaskEnvironment(name="test-dag-tasks-env") + + with DAG(dag_id="test_tasks_registered", flyte_env=env) as dag: + BashOperator(task_id="step1", bash_command="echo step1") + BashOperator(task_id="step2", bash_command="echo step2") + + parent_env = dag.flyte_task.parent_env() + # Both operator tasks and the orchestrator task must be in env.tasks. + env_task_names = list(parent_env.tasks.keys()) + assert any("step1" in name for name in env_task_names), f"step1 not found in env.tasks: {env_task_names}" + assert any("step2" in name for name in env_task_names), f"step2 not found in env.tasks: {env_task_names}" + + +def test_default_env_created_when_flyte_env_omitted(): + """When flyte_env is not supplied, a default TaskEnvironment is created using + the dag_id as the name and a Debian-base image with airflow packages.""" + from airflow import DAG + from airflow.operators.bash import BashOperator + + with DAG(dag_id="test_default_env") as dag: + BashOperator(task_id="greet", bash_command="echo hi") + + assert dag.flyte_task is not None + parent_env = dag.flyte_task.parent_env() + assert parent_env is not None + assert parent_env.name == "test_default_env" + # Default env should have a real image (not None or "auto") + assert isinstance(parent_env.image, Image) diff --git a/plugins/anthropic/src/flyteplugins/anthropic/__init__.py b/plugins/anthropic/src/flyteplugins/anthropic/__init__.py deleted file mode 100644 index 8f74819e4..000000000 --- a/plugins/anthropic/src/flyteplugins/anthropic/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Anthropic Claude plugin for Flyte. - -This plugin provides integration between Flyte tasks and Anthropic's Claude API, -enabling you to use Flyte tasks as tools for Claude agents. -""" - -from .agents import Agent, function_tool, run_agent - -__all__ = ["Agent", "function_tool", "run_agent"] diff --git a/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py b/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py index d4c4c948d..2a7152338 100644 --- a/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py +++ b/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py @@ -13,11 +13,10 @@ from dataclasses import dataclass, field from functools import partial +import anthropic from flyte._task import AsyncFunctionTaskTemplate from flyte.models import NativeInterface -import anthropic - logger = logging.getLogger(__name__) diff --git a/src/flyte/_internal/resolvers/common.py b/src/flyte/_internal/resolvers/common.py index d2b73a6c3..4dead7390 100644 --- a/src/flyte/_internal/resolvers/common.py +++ b/src/flyte/_internal/resolvers/common.py @@ -30,7 +30,7 @@ def load_app_env(self, loader_args: str) -> AppEnvironment: """ raise NotImplementedError - def loader_args(self, t: TaskTemplate, root_dir: Optional[Path]) -> List[str] | str: + def loader_args(self, task: TaskTemplate, root_dir: Optional[Path]) -> List[str] | str: """ Return a list of strings that can help identify the parameter TaskTemplate. Each string should not have spaces or special characters. This is used to identify the task in the resolver. diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 504239f46..fdc08b280 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -660,8 +660,8 @@ async def example_task(x: int, y: str) -> str: if isinstance(task, (LazyEntity, TaskDetails)) and self._mode != "remote": raise ValueError("Remote task can only be run in remote mode.") - if not isinstance(task, TaskTemplate) and not isinstance(task, (LazyEntity, TaskDetails)): - raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.") + if not isinstance(task, (TaskTemplate, LazyEntity, TaskDetails)): + raise TypeError(f"Only Flyte tasks can be run, not '{type(task)}'.") # Set the run mode in the context variable so that offloaded types (files, directories, dataframes) # can check the mode for controlling auto-uploading behavior (only enabled in remote mode). diff --git a/src/flyte/_task.py b/src/flyte/_task.py index c21477cb2..3a34dcbe3 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from flyteidl2.core.tasks_pb2 import DataLoadingConfig + from ._internal.resolvers.common import Resolver from ._task_environment import TaskEnvironment P = ParamSpec("P") # capture the function's parameters @@ -111,6 +112,7 @@ def my_task(): report: bool = False queue: Optional[str] = None debuggable: bool = False + resolver: Optional[Resolver] = None parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None parent_env_name: Optional[str] = None @@ -147,6 +149,10 @@ def __post_init__(self): # If short_name is not set, use the name of the task self.short_name = self.name + from ._internal.resolvers.default import DefaultTaskResolver + + self.resolver = self.resolver or DefaultTaskResolver() + def __getstate__(self): """ This method is called when the object is pickled. We need to remove the parent_env reference @@ -564,7 +570,7 @@ def container_args(self, serialize_context: SerializationContext) -> List[str]: "SerializationError", "Root dir is required for default task resolver when no code bundle is provided.", ) - _task_resolver = DefaultTaskResolver() + _task_resolver = self.resolver or DefaultTaskResolver() args = [ *args, *[