From c1c159585e63861c175c2c8d8c3f7ed3ef780308 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 10 Feb 2026 22:20:36 -0800 Subject: [PATCH 01/17] wip Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 15 ++ plugins/airflow/README.md | 12 ++ plugins/airflow/pyproject.toml | 79 +++++++ .../src/flyteplugins/airflow/__init__.py | 3 + .../airflow/src/flyteplugins/airflow/task.py | 200 ++++++++++++++++++ 5 files changed, 309 insertions(+) create mode 100644 examples/airflow-migration/bash_operator.py create mode 100644 plugins/airflow/README.md create mode 100644 plugins/airflow/pyproject.toml create mode 100644 plugins/airflow/src/flyteplugins/airflow/__init__.py create mode 100644 plugins/airflow/src/flyteplugins/airflow/task.py diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py new file mode 100644 index 000000000..89cd49ae8 --- /dev/null +++ b/examples/airflow-migration/bash_operator.py @@ -0,0 +1,15 @@ +from airflow import DAG +from airflow.operators.bash import BashOperator +from pendulum import datetime + +with DAG( + dag_id='simple_bash_operator_example', + start_date=datetime(2025, 1, 1), + schedule=None, + catchup=False, +) as dag: + # Define the BashOperator task + hello_task = BashOperator( + task_id='say_hello', + bash_command='echo "Hello Airflow!"', + ) diff --git a/plugins/airflow/README.md b/plugins/airflow/README.md new file mode 100644 index 000000000..95349163b --- /dev/null +++ b/plugins/airflow/README.md @@ -0,0 +1,12 @@ +# Flyte Airflow Plugin + +Airflow plugin allows you to seamlessly run Airflow tasks in the Flyte workflow without changing any code. + +- Compile Airflow tasks to Flyte tasks +- Use Airflow sensors/operators in Flyte workflows +- Add support for running Airflow tasks locally without running a cluster + + +```bash +pip install --pre flyteplugins-airflow +``` diff --git a/plugins/airflow/pyproject.toml b/plugins/airflow/pyproject.toml new file mode 100644 index 000000000..fe1fc611e --- /dev/null +++ b/plugins/airflow/pyproject.toml @@ -0,0 +1,79 @@ +[project] +name = "flyteplugins-airflow" +dynamic = ["version"] +description = "Airflow plugin for flyte" +readme = "README.md" +authors = [{ name = "Kevin Su", email = "pingsutw@users.noreply.github.com" }] +requires-python = ">=3.10,<3.13" +dependencies = [ + "apache-airflow", + "flyte", + "jsonpickle" +] + +[dependency-groups] +dev = [ + "pytest>=8.3.5", + "pytest-asyncio>=0.26.0", +] + +[build-system] +requires = ["setuptools", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = true +license-files = ["licenses/*.txt", "LICENSE"] + +[tool.setuptools.packages.find] +where = ["src"] +include = ["flyteplugins*"] + +[tool.setuptools_scm] +root = "../../" + +[tool.pytest.ini_options] +norecursedirs = [] +log_cli = true +log_cli_level = 20 +markers = [] +asyncio_default_fixture_loop_scope = "function" + +[tool.coverage.run] +branch = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "PLW", + "YTT", + "ASYNC", + "C4", + "T10", + "EXE", + "ISC", + "LOG", + "PIE", + "Q", + "RSE", + "FLY", + "PGH", + "PLC", + "PLE", + "PLW", + "FURB", + "RUF", +] +ignore = ["PGH003", "PLC0415"] + +[tool.ruff.lint.per-file-ignores] +"examples/*" = ["E402"] + +[tool.uv.sources] +flyte = { path = "../../", editable = true } diff --git a/plugins/airflow/src/flyteplugins/airflow/__init__.py b/plugins/airflow/src/flyteplugins/airflow/__init__.py new file mode 100644 index 000000000..13dac4a10 --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/__init__.py @@ -0,0 +1,3 @@ +from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig + +__all__ = ["HeadNodeConfig", "RayJobConfig", "WorkerNodeConfig"] diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py new file mode 100644 index 000000000..b166877ef --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -0,0 +1,200 @@ +import importlib +import logging +import typing +from dataclasses import dataclass +from typing import Any, Dict, Optional, Type + +import airflow + +import airflow.models as airflow_models +import airflow.sensors.base as airflow_sensors +import jsonpickle +from airflow.triggers.base as airflow_triggers +import airflow.utils.context as airflow_context + +from flyte import logger +from flyte._internal.resolvers.common import Resolver +from flyte._task import TaskTemplate +from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry + + +@dataclass +class AirflowObj(object): + """ + This class is used to store the Airflow task configuration. It is serialized and stored in the Flyte task config. + It can be trigger, hook, operator or sensor. For example: + + from airflow.sensors.filesystem import FileSensor + sensor = FileSensor(task_id="id", filepath="/tmp/1234") + + In this case, the attributes of AirflowObj will be: + module: airflow.sensors.filesystem + name: FileSensor + parameters: {"task_id": "id", "filepath": "/tmp/1234"} + """ + + module: str + name: str + parameters: typing.Dict[str, Any] + + +class AirflowTaskResolver(Resolver): + """ + This class is used to resolve an Airflow task. It will load an airflow task in the container. + """ + + @property + def import_path(self) -> str: + return "flyteplugins.airflow.task.AirflowTaskResolver" + + def load_task(self, loader_args: typing.List[str]) -> TaskTemplate: + """ + This method is used to load an Airflow task. + """ + _, task_module, _, task_name, _, task_config = loader_args + task_module = importlib.import_module(name=task_module) # type: ignore + task_def = getattr(task_module, task_name) + return task_def(name=task_name, task_config=jsonpickle.decode(task_config)) + + def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore + return [ + "task-module", + task.__module__, + "task-name", + task.__class__.__name__, + "task-config", + jsonpickle.encode(task.task_config), + ] + +airflow_task_resolver = AirflowTaskResolver() + + +class AirflowContainerTask(AsyncFunctionTaskTemplate): + """ + This python container task is used to wrap an Airflow task. It is used to run an Airflow task in a container. + The airflow task module, name and parameters are stored in the task config. + + Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator. + These tasks don't have an async method to get the job status, so cannot be used in the Flyte connector. We run these tasks in a container. + """ + + def __init__( + self, + name: str, + task_config: AirflowObj, + # inputs: Optional[Dict[str, Type]] = None, + **kwargs, + ): + super().__init__( + name=name, + plugin_config=task_config, + # interface=Interface(inputs=inputs or {}), + **kwargs, + ) + self._task_resolver = airflow_task_resolver + + def execute(self, **kwargs) -> Any: + logger.info("Executing Airflow task") + _get_airflow_instance(self.plugin_config).execute(context=airflow_context.Context()) + + +class AirflowTask(PythonTask[AirflowObj]): + """ + This python task is used to wrap an Airflow task. + It is used to run an Airflow task in Flyte connector. + The airflow task module, name and parameters are stored in the task config. + We run the Airflow task in the connector. + """ + + _TASK_TYPE = "airflow" + + def __init__( + self, + name: str, + task_config: Optional[AirflowObj], + inputs: Optional[Dict[str, Type]] = None, + **kwargs, + ): + super().__init__( + name=name, + task_config=task_config, + interface=Interface(inputs=inputs or {}), + task_type=self._TASK_TYPE, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + # Use jsonpickle to serialize the Airflow task config since the return value should be json serializable. + return {"task_config_pkl": jsonpickle.encode(self.task_config)} + + +def _get_airflow_instance( + airflow_obj: AirflowObj, +) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: + # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original + # airflow task instead of the Flyte task. + ctx = FlyteContextManager.current_context() + ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build() + + obj_module = importlib.import_module(name=airflow_obj.module) + obj_def = getattr(obj_module, airflow_obj.name) + if _is_deferrable(obj_def): + try: + return obj_def(**airflow_obj.parameters, deferrable=True) + except airflow.exceptions.AirflowException as e: + logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.") + logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.") + + return obj_def(**airflow_obj.parameters) + + +def _is_deferrable(cls: Type) -> bool: + """ + This function is used to check if the Airflow operator is deferrable. + If the operator is not deferrable, we run it in a container instead of the connector. + """ + # Only Airflow operators are deferrable. + if not issubclass(cls, airflow_models.BaseOperator): + return False + # Airflow sensors are not deferrable. The Sensor is a subclass of BaseOperator. + if issubclass(cls, airflow_sensors.BaseSensorOperator): + return False + try: + from airflow.providers.apache.beam.operators.beam import BeamBasePipelineOperator + + # Dataflow operators are not deferrable. + if issubclass(cls, BeamBasePipelineOperator): + return False + except ImportError: + logger.debug("Failed to import BeamBasePipelineOperator") + return True + + +def _flyte_operator(*args, **kwargs): + """ + This function is called by the Airflow operator to create a new task. We intercept this call and return a Flyte + task instead. + """ + cls = args[0] + try: + if FlyteContextManager.current_context().user_space_params.get_original_task: + # Return an original task when running in the connector. + return object.__new__(cls) + except AssertionError: + # This happens when the task is created in the dynamic workflow. + # We don't need to return the original task in this case. + logging.debug("failed to get the attribute GET_ORIGINAL_TASK from user space params") + + container_image = kwargs.pop("container_image", None) + task_id = kwargs.get("task_id", cls.__name__) + config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) + + if not issubclass(cls, airflow_sensors.BaseSensorOperator) and not _is_deferrable(cls): + # Dataflow operators are not deferrable, so we run them in a container. + return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)() + return AirflowTask(name=task_id, task_config=config)() + + +# Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. +airflow_models.BaseOperator.__new__ = _flyte_operator + From 6947151854abd7034253656355d9695c10314ff9 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 24 Feb 2026 09:33:46 -0800 Subject: [PATCH 02/17] wip Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 52 +++++++--- .../src/flyteplugins/airflow/__init__.py | 2 - .../airflow/src/flyteplugins/airflow/task.py | 94 +++++++------------ .../src/flyteplugins/anthropic/__init__.py | 9 -- tests/flyte/test_image_cache.py | 2 +- 5 files changed, 77 insertions(+), 82 deletions(-) delete mode 100644 plugins/anthropic/src/flyteplugins/anthropic/__init__.py diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 89cd49ae8..6576ed293 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -1,15 +1,45 @@ +# from flyteplugins.airflow.task import AirflowContainerTask from airflow import DAG +from airflow.models.baseoperator import ExecutorSafeguard from airflow.operators.bash import BashOperator from pendulum import datetime +import airflow.utils.context as airflow_context +import flyte -with DAG( - dag_id='simple_bash_operator_example', - start_date=datetime(2025, 1, 1), - schedule=None, - catchup=False, -) as dag: - # Define the BashOperator task - hello_task = BashOperator( - task_id='say_hello', - bash_command='echo "Hello Airflow!"', - ) +# with DAG( +# dag_id='simple_bash_operator_example', +# start_date=datetime(2025, 1, 1), +# schedule=None, +# catchup=False, +# ) as dag: +# # Define the BashOperator task +# hello_task = BashOperator( +# task_id='say_hello', +# bash_command='echo "Hello Airflow!"', +# ) + + +env = flyte.TaskEnvironment( + name="hello_airflow", +) + + +@env.task +async def fn(name: str) -> None: + print("starting to run airflow task") + # ExecutorSafeguard stores a sentinel in a threading.local() dict. That dict + # is initialised on the main thread at import time, but Flyte runs tasks in a + # background async thread where the thread-local has no 'callers' key yet. + if not hasattr(ExecutorSafeguard._sentinel, "callers"): + ExecutorSafeguard._sentinel.callers = {} + BashOperator( + task_id='airflow', + bash_command=f'echo "Hello {name}!"', + ).execute(context=airflow_context.Context()) + print("finished running airflow task") + + +if __name__ == '__main__': + flyte.init_from_config() + run = flyte.with_runcontext(mode="local", log_level="10").run(fn, name="Airflow") + print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/__init__.py b/plugins/airflow/src/flyteplugins/airflow/__init__.py index 13dac4a10..8b1378917 100644 --- a/plugins/airflow/src/flyteplugins/airflow/__init__.py +++ b/plugins/airflow/src/flyteplugins/airflow/__init__.py @@ -1,3 +1 @@ -from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig -__all__ = ["HeadNodeConfig", "RayJobConfig", "WorkerNodeConfig"] diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index b166877ef..ed72e5c7f 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -2,20 +2,22 @@ import logging import typing from dataclasses import dataclass -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, List import airflow - +from pathlib import Path import airflow.models as airflow_models import airflow.sensors.base as airflow_sensors import jsonpickle -from airflow.triggers.base as airflow_triggers +import airflow.triggers.base as airflow_triggers import airflow.utils.context as airflow_context -from flyte import logger +import flyte +from flyte import logger, get_custom_context from flyte._internal.resolvers.common import Resolver from flyte._task import TaskTemplate from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry +from flyte.models import SerializationContext, NativeInterface @dataclass @@ -47,7 +49,7 @@ class AirflowTaskResolver(Resolver): def import_path(self) -> str: return "flyteplugins.airflow.task.AirflowTaskResolver" - def load_task(self, loader_args: typing.List[str]) -> TaskTemplate: + def load_task(self, loader_args: typing.List[str]) -> AsyncFunctionTaskTemplate: """ This method is used to load an Airflow task. """ @@ -56,20 +58,21 @@ def load_task(self, loader_args: typing.List[str]) -> TaskTemplate: task_def = getattr(task_module, task_name) return task_def(name=task_name, task_config=jsonpickle.decode(task_config)) - def loader_args(self, task: TaskTemplate, root_dir: Path) -> List[str]: # type:ignore + def loader_args(self, task: AsyncFunctionTaskTemplate, root_dir: Path) -> List[str]: # type:ignore return [ "task-module", task.__module__, "task-name", task.__class__.__name__, "task-config", - jsonpickle.encode(task.task_config), + jsonpickle.encode(task.plugin_config), ] + airflow_task_resolver = AirflowTaskResolver() -class AirflowContainerTask(AsyncFunctionTaskTemplate): +class AirflowContainerTask(TaskTemplate): """ This python container task is used to wrap an Airflow task. It is used to run an Airflow task in a container. The airflow task module, name and parameters are stored in the task config. @@ -81,71 +84,47 @@ class AirflowContainerTask(AsyncFunctionTaskTemplate): def __init__( self, name: str, - task_config: AirflowObj, + plugin_config: AirflowObj, # inputs: Optional[Dict[str, Type]] = None, **kwargs, ): super().__init__( name=name, - plugin_config=task_config, - # interface=Interface(inputs=inputs or {}), + # plugin_config=plugin_config, + interface=NativeInterface(inputs={}, outputs={}), **kwargs, ) self._task_resolver = airflow_task_resolver + self._plugin_config = plugin_config def execute(self, **kwargs) -> Any: + # ExecutorSafeguard stores a sentinel in a threading.local() dict. That + # dict is initialised on the main thread at import time, but tasks may + # run in a background thread where the thread-local has no 'callers' key. + from airflow.models.baseoperator import ExecutorSafeguard + if not hasattr(ExecutorSafeguard._sentinel, "callers"): + ExecutorSafeguard._sentinel.callers = {} logger.info("Executing Airflow task") - _get_airflow_instance(self.plugin_config).execute(context=airflow_context.Context()) - - -class AirflowTask(PythonTask[AirflowObj]): - """ - This python task is used to wrap an Airflow task. - It is used to run an Airflow task in Flyte connector. - The airflow task module, name and parameters are stored in the task config. - We run the Airflow task in the connector. - """ - - _TASK_TYPE = "airflow" - - def __init__( - self, - name: str, - task_config: Optional[AirflowObj], - inputs: Optional[Dict[str, Type]] = None, - **kwargs, - ): - super().__init__( - name=name, - task_config=task_config, - interface=Interface(inputs=inputs or {}), - task_type=self._TASK_TYPE, - **kwargs, - ) - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - # Use jsonpickle to serialize the Airflow task config since the return value should be json serializable. - return {"task_config_pkl": jsonpickle.encode(self.task_config)} + _get_airflow_instance(self._plugin_config).execute(context=airflow_context.Context()) def _get_airflow_instance( - airflow_obj: AirflowObj, + airflow_obj: AirflowObj, ) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original # airflow task instead of the Flyte task. - ctx = FlyteContextManager.current_context() - ctx.user_space_params.builder().add_attr("GET_ORIGINAL_TASK", True).build() + with flyte.custom_context(GET_ORIGINAL_TASK="True"): - obj_module = importlib.import_module(name=airflow_obj.module) - obj_def = getattr(obj_module, airflow_obj.name) - if _is_deferrable(obj_def): - try: - return obj_def(**airflow_obj.parameters, deferrable=True) - except airflow.exceptions.AirflowException as e: - logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.") - logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.") + obj_module = importlib.import_module(name=airflow_obj.module) + obj_def = getattr(obj_module, airflow_obj.name) + if _is_deferrable(obj_def): + try: + return obj_def(**airflow_obj.parameters, deferrable=True) + except airflow.exceptions.AirflowException as e: + logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.") + logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.") - return obj_def(**airflow_obj.parameters) + return obj_def(**airflow_obj.parameters) def _is_deferrable(cls: Type) -> bool: @@ -177,7 +156,7 @@ def _flyte_operator(*args, **kwargs): """ cls = args[0] try: - if FlyteContextManager.current_context().user_space_params.get_original_task: + if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": # Return an original task when running in the connector. return object.__new__(cls) except AssertionError: @@ -189,10 +168,7 @@ def _flyte_operator(*args, **kwargs): task_id = kwargs.get("task_id", cls.__name__) config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) - if not issubclass(cls, airflow_sensors.BaseSensorOperator) and not _is_deferrable(cls): - # Dataflow operators are not deferrable, so we run them in a container. - return AirflowContainerTask(name=task_id, task_config=config, container_image=container_image)() - return AirflowTask(name=task_id, task_config=config)() + return AirflowContainerTask(name=task_id, plugin_config=config, image=container_image).execute(**kwargs) # Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. diff --git a/plugins/anthropic/src/flyteplugins/anthropic/__init__.py b/plugins/anthropic/src/flyteplugins/anthropic/__init__.py deleted file mode 100644 index 8f74819e4..000000000 --- a/plugins/anthropic/src/flyteplugins/anthropic/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Anthropic Claude plugin for Flyte. - -This plugin provides integration between Flyte tasks and Anthropic's Claude API, -enabling you to use Flyte tasks as tools for Claude agents. -""" - -from .agents import Agent, function_tool, run_agent - -__all__ = ["Agent", "function_tool", "run_agent"] diff --git a/tests/flyte/test_image_cache.py b/tests/flyte/test_image_cache.py index 44282d93b..51a0d606d 100644 --- a/tests/flyte/test_image_cache.py +++ b/tests/flyte/test_image_cache.py @@ -18,7 +18,7 @@ def test_image_cache_serialization_round_trip(): # Deserialize back into an ImageCache object # This should also save the serialized form into the object for downstream tasks to get it. - restored_cache = ImageCache.from_transport(serialized) + restored_cache = ImageCache.from_transport("H4sIAAAAAAAC/53MSw6DIBAA0LvMuqK0pB8uQ6YjMQZwCAMxqfHu7cbEdQ/w3gZzwsm7yBxaBrsBFaTgRiZxEucEFu7GPMzL6KcZrmoMRXkqqknnUWqnFSb88IKrKOLUvz0SL73UfNODjVi9VLic1ym3P9OfPMZ9/wKPjjm1ugAAAA==") # Check that the deserialized data matches the original assert restored_cache.image_lookup == original_data["image_lookup"] From 39df347025c704c2e83db8ae74dae2a726055800 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 24 Feb 2026 12:24:18 -0800 Subject: [PATCH 03/17] wip Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 14 +++++--------- plugins/airflow/src/flyteplugins/airflow/task.py | 4 +++- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 6576ed293..0cc234c25 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -1,4 +1,4 @@ -# from flyteplugins.airflow.task import AirflowContainerTask +from flyteplugins.airflow.task import AirflowContainerTask # type: ignore from airflow import DAG from airflow.models.baseoperator import ExecutorSafeguard from airflow.operators.bash import BashOperator @@ -21,25 +21,21 @@ env = flyte.TaskEnvironment( name="hello_airflow", + image=flyte.Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0", "jsonpickle").with_local_v2() ) @env.task -async def fn(name: str) -> None: +async def main(name: str) -> None: print("starting to run airflow task") - # ExecutorSafeguard stores a sentinel in a threading.local() dict. That dict - # is initialised on the main thread at import time, but Flyte runs tasks in a - # background async thread where the thread-local has no 'callers' key yet. - if not hasattr(ExecutorSafeguard._sentinel, "callers"): - ExecutorSafeguard._sentinel.callers = {} BashOperator( task_id='airflow', bash_command=f'echo "Hello {name}!"', - ).execute(context=airflow_context.Context()) + ) print("finished running airflow task") if __name__ == '__main__': flyte.init_from_config() - run = flyte.with_runcontext(mode="local", log_level="10").run(fn, name="Airflow") + run = flyte.with_runcontext(mode="remote", log_level="10").run(main, name="Airflow") print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index ed72e5c7f..dcfdc1987 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -158,6 +158,7 @@ def _flyte_operator(*args, **kwargs): try: if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": # Return an original task when running in the connector. + print("Returning original Airflow task") return object.__new__(cls) except AssertionError: # This happens when the task is created in the dynamic workflow. @@ -168,7 +169,8 @@ def _flyte_operator(*args, **kwargs): task_id = kwargs.get("task_id", cls.__name__) config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) - return AirflowContainerTask(name=task_id, plugin_config=config, image=container_image).execute(**kwargs) + print(f"Creating AirflowContainerTask with config: {config}") + return AirflowContainerTask(name=task_id, plugin_config=config, image=container_image)() # Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. From 6ebb1f2a1099514ce1758ae807594b2d65e27161 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Tue, 24 Feb 2026 16:32:56 -0800 Subject: [PATCH 04/17] wip Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 44 ++--- examples/basics/hello.py | 3 +- examples/connectors/bigquery_example.py | 15 +- .../airflow/src/flyteplugins/airflow/dag.py | 186 ++++++++++++++++++ .../airflow/src/flyteplugins/airflow/task.py | 58 +++++- src/flyte/_run.py | 7 +- 6 files changed, 278 insertions(+), 35 deletions(-) create mode 100644 plugins/airflow/src/flyteplugins/airflow/dag.py diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 0cc234c25..1d7c4ec90 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -1,41 +1,33 @@ -from flyteplugins.airflow.task import AirflowContainerTask # type: ignore +from flyteplugins.airflow.task import AirflowContainerTask # triggers DAG + operator patches # type: ignore from airflow import DAG -from airflow.models.baseoperator import ExecutorSafeguard from airflow.operators.bash import BashOperator -from pendulum import datetime -import airflow.utils.context as airflow_context import flyte -# with DAG( -# dag_id='simple_bash_operator_example', -# start_date=datetime(2025, 1, 1), -# schedule=None, -# catchup=False, -# ) as dag: -# # Define the BashOperator task -# hello_task = BashOperator( -# task_id='say_hello', -# bash_command='echo "Hello Airflow!"', -# ) - - env = flyte.TaskEnvironment( name="hello_airflow", image=flyte.Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0", "jsonpickle").with_local_v2() ) - -@env.task -async def main(name: str) -> None: - print("starting to run airflow task") - BashOperator( - task_id='airflow', - bash_command=f'echo "Hello {name}!"', +# Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. +# Pass flyte_env so the generated workflow task uses the right container image. +with DAG( + dag_id='simple_bash_operator_example', + flyte_env=env, +) as dag: + t1 = BashOperator( + task_id='say_hello', + bash_command='echo "Hello Airflow!"', + ) + t2 = BashOperator( + task_id='say_goodbye', + bash_command='echo "Goodbye Airflow!"', ) - print("finished running airflow task") + # t1 >> t2 # t2 runs after t1 if __name__ == '__main__': flyte.init_from_config() - run = flyte.with_runcontext(mode="remote", log_level="10").run(main, name="Airflow") + # dag.run() is a convenience wrapper — equivalent to: + run = flyte.with_runcontext(mode="local", log_level="10").run(dag) + # run = dag.run(mode="local", log_level="10") print(run.url) diff --git a/examples/basics/hello.py b/examples/basics/hello.py index 46095d0ad..5066ee1bb 100644 --- a/examples/basics/hello.py +++ b/examples/basics/hello.py @@ -17,6 +17,7 @@ def fn(x: int) -> int: # type annotations are recommended. # tasks can also call other tasks, which will be manifested in different containers. @env.task def main(x_list: list[int]) -> float: + fn(x=2) x_len = len(x_list) if x_len < 10: raise ValueError(f"x_list doesn't have a larger enough sample size, found: {x_len}") @@ -28,7 +29,7 @@ def main(x_list: list[int]) -> float: if __name__ == "__main__": flyte.init_from_config() # establish remote connection from within your script. - run = flyte.run(main, x_list=list(range(10))) # run remotely inline and pass data. + run = flyte.with_runcontext(mode="local").run(main, x_list=list(range(10))) # run remotely inline and pass data. # print various attributes of the run. print(run.name) diff --git a/examples/connectors/bigquery_example.py b/examples/connectors/bigquery_example.py index a17fc9d3d..f3eedb23b 100644 --- a/examples/connectors/bigquery_example.py +++ b/examples/connectors/bigquery_example.py @@ -11,10 +11,21 @@ query_template="SELECT * from dataset.flyte_table3;", ) -flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) +bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) + +env = flyte.TaskEnvironment( + name="bigquery_example_env", + image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-bigquery"), + depends_on=[bigquery_env], +) + + +@env.task() +def main(version: int): + bigquery_task(version=version) if __name__ == "__main__": flyte.init_from_config() - run = flyte.with_runcontext(mode="local").run(bigquery_task, 123) + run = flyte.with_runcontext(mode="remote").run(main, 123) print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py new file mode 100644 index 000000000..746cdf1a4 --- /dev/null +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -0,0 +1,186 @@ +""" +Monkey-patches airflow.DAG so that a standard Airflow DAG definition is transparently +converted into a runnable Flyte task, with no changes to the DAG code required. + +Usage +----- + from flyteplugins.airflow.task import AirflowContainerTask # triggers patches + from airflow import DAG + from airflow.operators.bash import BashOperator + import flyte + + env = flyte.TaskEnvironment(name="hello_airflow", image=...) + + with DAG(dag_id="my_dag", flyte_env=env) as dag: + t1 = BashOperator(task_id="step1", bash_command='echo step1') + t2 = BashOperator(task_id="step2", bash_command='echo step2') + t1 >> t2 # optional: explicit dependency + + if __name__ == "__main__": + flyte.init_from_config() + run = dag.run(mode="local") + print(run.url) + +Notes +----- +- ``flyte_env`` is an optional kwarg accepted by the patched DAG. If omitted a + default ``TaskEnvironment(name=dag_id)`` is created. +- Operator dependency arrows (``>>``, ``<<``) update the execution order. + If no explicit dependencies are declared the operators run in definition order. +- ``dag.run(**kwargs)`` is a convenience wrapper around + ``flyte.with_runcontext(**kwargs).run(dag.flyte_task)``. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional + +import airflow.models.dag as _airflow_dag_module + +if TYPE_CHECKING: + from flyteplugins.airflow.task import AirflowContainerTask + +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- + +#: Set when the code is inside a ``with DAG(...) as dag:`` block. +_current_flyte_dag: Optional["FlyteDAG"] = None + + +# --------------------------------------------------------------------------- +# FlyteDAG – collects operators and builds the Flyte workflow task +# --------------------------------------------------------------------------- + +class FlyteDAG: + """Collects Airflow operators during a DAG definition and converts them + into a single Flyte task that runs them in dependency order.""" + + def __init__(self, dag_id: str, env=None) -> None: + self.dag_id = dag_id + self.env = env + # Ordered dict preserves insertion (creation) order as the default. + self._tasks: Dict[str, "AirflowContainerTask"] = {} + # task_id -> set of upstream task_ids + self._upstream: Dict[str, Set[str]] = defaultdict(set) + + # ------------------------------------------------------------------ + # Registration (called by _flyte_operator during DAG definition) + # ------------------------------------------------------------------ + + def add_task(self, task_id: str, task: "AirflowContainerTask") -> None: + self._tasks[task_id] = task + # Ensure a dependency entry exists even with no upstream tasks. + _ = self._upstream[task_id] + + def set_dependency(self, upstream_id: str, downstream_id: str) -> None: + """Record that *upstream_id* must run before *downstream_id*.""" + self._upstream[downstream_id].add(upstream_id) + + # ------------------------------------------------------------------ + # Flyte task construction + # ------------------------------------------------------------------ + + def build(self) -> None: + """Annotate each task with its downstream tasks and create a Flyte + workflow task whose entry function calls only the root tasks. + + Each root task's execute() will trigger its downstream tasks in + parallel via asyncio.gather, propagating the chain automatically. + """ + import flyte + + env = self.env + if env is None: + env = flyte.TaskEnvironment(name=self.dag_id) + + # Build downstream map from the upstream map. + downstream: Dict[str, List[str]] = defaultdict(list) + for tid, upstreams in self._upstream.items(): + for up in upstreams: + downstream[up].append(tid) + + # Annotate each AirflowContainerTask with its downstream tasks. + for tid, task in self._tasks.items(): + task._downstream_flyte_tasks = [ + self._tasks[d] for d in downstream[tid] if d in self._tasks + ] + + # Root tasks: those with no upstream dependencies. + root_tasks = [ + self._tasks[tid] + for tid, ups in self._upstream.items() + if len(ups) == 0 + ] + + # Snapshot to avoid capturing mutable references in the closure. + root_snapshot = list(root_tasks) + + def _dag_entry() -> None: + for task in root_snapshot: + task() # _call_as_synchronous=True → submit_sync → blocks until done + + _dag_entry.__name__ = f"dag_{self.dag_id}" + _dag_entry.__qualname__ = f"dag_{self.dag_id}" + + self.flyte_task = env.task(_dag_entry) + + +# --------------------------------------------------------------------------- +# DAG monkey-patch helpers +# --------------------------------------------------------------------------- + +_original_dag_init = _airflow_dag_module.DAG.__init__ +_original_dag_enter = _airflow_dag_module.DAG.__enter__ +_original_dag_exit = _airflow_dag_module.DAG.__exit__ + + +def _patched_dag_init(self, *args, **kwargs) -> None: # type: ignore[override] + # Pull out our custom kwarg before passing the rest to Airflow. + flyte_env = kwargs.pop("flyte_env", None) + _original_dag_init(self, *args, **kwargs) + self._flyte_env = flyte_env + + +def _patched_dag_enter(self): # type: ignore[override] + global _current_flyte_dag + _current_flyte_dag = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) + return _original_dag_enter(self) + + +def _patched_dag_exit(self, exc_type, exc_val, exc_tb): # type: ignore[override] + global _current_flyte_dag + try: + if exc_type is None and _current_flyte_dag is not None: + flyte_dag = _current_flyte_dag + flyte_dag.build() + # Attach the Flyte task and a convenience run() to the DAG object. + self.flyte_task = flyte_dag.flyte_task + self.run = _make_run(flyte_dag.flyte_task) + finally: + _current_flyte_dag = None + + return _original_dag_exit(self, exc_type, exc_val, exc_tb) + + +def _make_run(flyte_task): + """Return a ``run(**kwargs)`` helper bound to *flyte_task*.""" + import flyte + + def run(**kwargs): + return flyte.with_runcontext(**kwargs).run(flyte_task) + + return run + + +# --------------------------------------------------------------------------- +# Apply patches +# --------------------------------------------------------------------------- + +_airflow_dag_module.DAG.__init__ = _patched_dag_init +_airflow_dag_module.DAG.__enter__ = _patched_dag_enter +_airflow_dag_module.DAG.__exit__ = _patched_dag_exit diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index dcfdc1987..f174b817a 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -1,5 +1,7 @@ import importlib import logging +import os +import threading import typing from dataclasses import dataclass from typing import Any, Dict, Optional, Type, List @@ -14,11 +16,20 @@ import flyte from flyte import logger, get_custom_context +from flyte._context import internal_ctx, root_context_var +from flyte._internal.controllers import get_controller +from flyte._internal.controllers._local_controller import _TaskRunner from flyte._internal.resolvers.common import Resolver from flyte._task import TaskTemplate from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry from flyte.models import SerializationContext, NativeInterface +# Per-thread _TaskRunner instances used by _flyte_operator for sync blocking submission. +_airflow_runners: Dict[str, _TaskRunner] = {} + +# Import dag module to apply DAG monkey-patches when this module is imported. +from flyteplugins.airflow import dag as _dag_module # noqa: E402 + @dataclass class AirflowObj(object): @@ -96,8 +107,27 @@ def __init__( ) self._task_resolver = airflow_task_resolver self._plugin_config = plugin_config - - def execute(self, **kwargs) -> Any: + self._call_as_synchronous = True + self._downstream_flyte_tasks: List["AirflowContainerTask"] = [] + + # ------------------------------------------------------------------ + # Airflow dependency-arrow support (>> / <<) + # Records the dependency in the active FlyteDAG if one is being built. + # ------------------------------------------------------------------ + + def __rshift__(self, other: "AirflowContainerTask") -> "AirflowContainerTask": + """``self >> other`` — other runs after self.""" + if _dag_module._current_flyte_dag is not None: + _dag_module._current_flyte_dag.set_dependency(self.name, other.name) + return other + + def __lshift__(self, other: "AirflowContainerTask") -> "AirflowContainerTask": + """``self << other`` — self runs after other.""" + if _dag_module._current_flyte_dag is not None: + _dag_module._current_flyte_dag.set_dependency(other.name, self.name) + return other + + async def execute(self, **kwargs) -> Any: # ExecutorSafeguard stores a sentinel in a threading.local() dict. That # dict is initialised on the main thread at import time, but tasks may # run in a background thread where the thread-local has no 'callers' key. @@ -106,6 +136,10 @@ def execute(self, **kwargs) -> Any: ExecutorSafeguard._sentinel.callers = {} logger.info("Executing Airflow task") _get_airflow_instance(self._plugin_config).execute(context=airflow_context.Context()) + # Trigger downstream tasks in parallel after this operator completes. + if self._downstream_flyte_tasks: + import asyncio + await asyncio.gather(*[t.aio() for t in self._downstream_flyte_tasks]) def _get_airflow_instance( @@ -170,9 +204,23 @@ def _flyte_operator(*args, **kwargs): config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) print(f"Creating AirflowContainerTask with config: {config}") - return AirflowContainerTask(name=task_id, plugin_config=config, image=container_image)() + task = AirflowContainerTask(name=task_id, plugin_config=config, image=container_image) + # ── Case 1: inside a ``with DAG(...) as dag:`` block ──────────────────── + # Register the task with the active FlyteDAG collector so it can be wired + # into the Flyte workflow when the DAG context exits. Do NOT execute yet. + if _dag_module._current_flyte_dag is not None: + _dag_module._current_flyte_dag.add_task(task_id, task) + return task -# Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. -airflow_models.BaseOperator.__new__ = _flyte_operator + # ── Case 2: inside a Flyte task execution ─────────────────────────────── + # The dag workflow function is executing; submit the operator as a sub-task. + if internal_ctx().is_task_context(): + return task() + # ── Case 3: outside any context (e.g. serialization / import scan) ────── + return task + + +# Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. +airflow_models.BaseOperator.__new__ = _flyte_operator \ No newline at end of file diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 504239f46..273f35884 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -660,8 +660,13 @@ async def example_task(x: int, y: str) -> str: if isinstance(task, (LazyEntity, TaskDetails)) and self._mode != "remote": raise ValueError("Remote task can only be run in remote mode.") + # Allow objects (e.g. Airflow DAGs wrapped by flyteplugins) that expose a + # .flyte_task attribute to be passed directly to run(). if not isinstance(task, TaskTemplate) and not isinstance(task, (LazyEntity, TaskDetails)): - raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.") + if hasattr(task, "flyte_task") and isinstance(task.flyte_task, TaskTemplate): + task = task.flyte_task + else: + raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.") # Set the run mode in the context variable so that offloaded types (files, directories, dataframes) # can check the mode for controlling auto-uploading behavior (only enabled in remote mode). From ce86c24e3b20a6f5f3656c48f997af3e06faa1d8 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Feb 2026 10:55:21 -0800 Subject: [PATCH 05/17] work version Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 10 ++++--- .../airflow/src/flyteplugins/airflow/dag.py | 30 +++++++++++++++++-- .../airflow/src/flyteplugins/airflow/task.py | 1 + 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 1d7c4ec90..91b6dc893 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -1,3 +1,5 @@ +from pathlib import Path + from flyteplugins.airflow.task import AirflowContainerTask # triggers DAG + operator patches # type: ignore from airflow import DAG from airflow.operators.bash import BashOperator @@ -16,18 +18,18 @@ ) as dag: t1 = BashOperator( task_id='say_hello', - bash_command='echo "Hello Airflow!"', + bash_command='echo "Hello Airflow1!"', ) t2 = BashOperator( task_id='say_goodbye', - bash_command='echo "Goodbye Airflow!"', + bash_command='echo "Goodbye Airflow2!"', ) # t1 >> t2 # t2 runs after t1 if __name__ == '__main__': - flyte.init_from_config() + flyte.init_from_config(root_dir=Path("/Users/kevin/git/flyte-sdk")) # dag.run() is a convenience wrapper — equivalent to: - run = flyte.with_runcontext(mode="local", log_level="10").run(dag) + run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) # run = dag.run(mode="local", log_level="10") print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 746cdf1a4..f96ec4de8 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -14,7 +14,7 @@ with DAG(dag_id="my_dag", flyte_env=env) as dag: t1 = BashOperator(task_id="step1", bash_command='echo step1') t2 = BashOperator(task_id="step2", bash_command='echo step2') - t1 >> t2 # optional: explicit dependency + t1 >> t2 # optional: explicit dependency if __name__ == "__main__": flyte.init_from_config() @@ -26,7 +26,7 @@ - ``flyte_env`` is an optional kwarg accepted by the patched DAG. If omitted a default ``TaskEnvironment(name=dag_id)`` is created. - Operator dependency arrows (``>>``, ``<<``) update the execution order. - If no explicit dependencies are declared the operators run in definition order. + If no explicit dependencies are declared, the operators run in definition order. - ``dag.run(**kwargs)`` is a convenience wrapper around ``flyte.with_runcontext(**kwargs).run(dag.flyte_task)``. """ @@ -35,7 +35,7 @@ import logging from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Set import airflow.models.dag as _airflow_dag_module @@ -124,11 +124,35 @@ def _dag_entry() -> None: for task in root_snapshot: task() # _call_as_synchronous=True → submit_sync → blocks until done + # Find the first call frame outside this module so the Flyte task is + # registered under the user's module, not dag.py. The + # DefaultTaskResolver records (module, name) at submission time and + # on the remote worker it imports that module — which re-runs the + # DAG definition and re-injects the task — before calling + # getattr(module, task_name). + import inspect + import sys as _sys + + _caller_module_name = __name__ + _caller_module = None + for _fi in inspect.stack(): + _mod = _fi.frame.f_globals.get("__name__", "") + if _mod and _mod != __name__: + _caller_module_name = _mod + _caller_module = _sys.modules.get(_caller_module_name) + break + _dag_entry.__name__ = f"dag_{self.dag_id}" _dag_entry.__qualname__ = f"dag_{self.dag_id}" + _dag_entry.__module__ = _caller_module_name self.flyte_task = env.task(_dag_entry) + # Inject the task into the caller's module so DefaultTaskResolver + # can find it via getattr(module, task_name) on both local and remote. + if _caller_module is not None: + setattr(_caller_module, _dag_entry.__name__, self.flyte_task) + # --------------------------------------------------------------------------- # DAG monkey-patch helpers diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index f174b817a..2eae9a7bd 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -205,6 +205,7 @@ def _flyte_operator(*args, **kwargs): print(f"Creating AirflowContainerTask with config: {config}") task = AirflowContainerTask(name=task_id, plugin_config=config, image=container_image) + flyte.TaskEnvironment.from_task(task_id, task) # ── Case 1: inside a ``with DAG(...) as dag:`` block ──────────────────── # Register the task with the active FlyteDAG collector so it can be wired From cc58587278efae9da832beea0d6868b02a274f7d Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Feb 2026 11:33:14 -0800 Subject: [PATCH 06/17] update env Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 1 - plugins/airflow/src/flyteplugins/airflow/dag.py | 8 ++++++++ plugins/airflow/src/flyteplugins/airflow/task.py | 2 -- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 91b6dc893..0e2819122 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -9,7 +9,6 @@ name="hello_airflow", image=flyte.Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0", "jsonpickle").with_local_v2() ) - # Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. # Pass flyte_env so the generated workflow task uses the right container image. with DAG( diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index f96ec4de8..c21b83d66 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -124,6 +124,14 @@ def _dag_entry() -> None: for task in root_snapshot: task() # _call_as_synchronous=True → submit_sync → blocks until done + # Register all operator tasks with the DAG's TaskEnvironment so that + # they get parent_env / parent_env_name (required for serialization and + # image lookup) and appear in env.tasks (required for deployment). + # from_task validates image consistency across tasks and returns a new + # env; we reassign `env` so the orchestrator task (env.task below) + # ends up in the same environment. + env = flyte.TaskEnvironment.from_task(env.name, *self._tasks.values()) + # Find the first call frame outside this module so the Flyte task is # registered under the user's module, not dag.py. The # DefaultTaskResolver records (module, name) at submission time and diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index 2eae9a7bd..b660f25d4 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -203,9 +203,7 @@ def _flyte_operator(*args, **kwargs): task_id = kwargs.get("task_id", cls.__name__) config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) - print(f"Creating AirflowContainerTask with config: {config}") task = AirflowContainerTask(name=task_id, plugin_config=config, image=container_image) - flyte.TaskEnvironment.from_task(task_id, task) # ── Case 1: inside a ``with DAG(...) as dag:`` block ──────────────────── # Register the task with the active FlyteDAG collector so it can be wired From 00b0963e8d87a56c0933f9422ce8bdb26369f063 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Feb 2026 12:34:12 -0800 Subject: [PATCH 07/17] work version Signed-off-by: Kevin Su --- examples/connectors/bigquery_example.py | 1 - .../airflow/src/flyteplugins/airflow/dag.py | 20 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/examples/connectors/bigquery_example.py b/examples/connectors/bigquery_example.py index f3eedb23b..2f9b33db6 100644 --- a/examples/connectors/bigquery_example.py +++ b/examples/connectors/bigquery_example.py @@ -12,7 +12,6 @@ ) bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) - env = flyte.TaskEnvironment( name="bigquery_example_env", image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-bigquery"), diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index c21b83d66..fc40eb06c 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -124,13 +124,25 @@ def _dag_entry() -> None: for task in root_snapshot: task() # _call_as_synchronous=True → submit_sync → blocks until done + # Operator tasks are created without an image (image=None). Resolve + # the env's image to an Image object (mirroring TaskTemplate.__post_init__) + # and assign it to each task that has no explicit image, so that + # from_task sees a consistent set of images and the tasks can be + # serialized correctly when submitted as sub-tasks during remote execution. + _env_image = env.image + if _env_image == "auto": + _env_image = flyte.Image.from_debian_base() + elif isinstance(_env_image, str): + _env_image = flyte.Image.from_base(_env_image) + + for _op_task in self._tasks.values(): + if _op_task.image is None: + _op_task.image = _env_image + # Register all operator tasks with the DAG's TaskEnvironment so that # they get parent_env / parent_env_name (required for serialization and # image lookup) and appear in env.tasks (required for deployment). - # from_task validates image consistency across tasks and returns a new - # env; we reassign `env` so the orchestrator task (env.task below) - # ends up in the same environment. - env = flyte.TaskEnvironment.from_task(env.name, *self._tasks.values()) + env = env.from_task(env.name, *self._tasks.values()) # Find the first call frame outside this module so the Flyte task is # registered under the user's module, not dag.py. The From f28193370be5c1df11e6a16ab6d81df6d773e087 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Feb 2026 12:49:55 -0800 Subject: [PATCH 08/17] parallel Signed-off-by: Kevin Su --- .../airflow/src/flyteplugins/airflow/dag.py | 18 ++-- tests/flyte/test_airflow_dag.py | 90 +++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) create mode 100644 tests/flyte/test_airflow_dag.py diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index fc40eb06c..fc49e8179 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -24,7 +24,8 @@ Notes ----- - ``flyte_env`` is an optional kwarg accepted by the patched DAG. If omitted a - default ``TaskEnvironment(name=dag_id)`` is created. + default ``TaskEnvironment`` is created using the dag_id as the name and a + Debian-base image with ``apache-airflow<3.0.0`` and ``jsonpickle`` installed. - Operator dependency arrows (``>>``, ``<<``) update the execution order. If no explicit dependencies are declared, the operators run in definition order. - ``dag.run(**kwargs)`` is a convenience wrapper around @@ -96,7 +97,12 @@ def build(self) -> None: env = self.env if env is None: - env = flyte.TaskEnvironment(name=self.dag_id) + env = flyte.TaskEnvironment( + name=self.dag_id, + image=flyte.Image.from_debian_base().with_pip_packages( + "apache-airflow<3.0.0", "jsonpickle" + ).with_local_v2(), + ) # Build downstream map from the upstream map. downstream: Dict[str, List[str]] = defaultdict(list) @@ -120,9 +126,11 @@ def build(self) -> None: # Snapshot to avoid capturing mutable references in the closure. root_snapshot = list(root_tasks) - def _dag_entry() -> None: - for task in root_snapshot: - task() # _call_as_synchronous=True → submit_sync → blocks until done + async def _dag_entry() -> None: + import asyncio + # Root tasks run in parallel; each task's execute() chains its + # downstream tasks (set via >>) after it completes. + await asyncio.gather(*[t.aio() for t in root_snapshot]) # Operator tasks are created without an image (image=None). Resolve # the env's image to an Image object (mirroring TaskTemplate.__post_init__) diff --git a/tests/flyte/test_airflow_dag.py b/tests/flyte/test_airflow_dag.py new file mode 100644 index 000000000..b32c71963 --- /dev/null +++ b/tests/flyte/test_airflow_dag.py @@ -0,0 +1,90 @@ +""" +Tests for the Airflow DAG monkey-patch in flyteplugins.airflow.dag. +""" +import pytest + +import flyte +from flyte._image import Image + + +@pytest.fixture(autouse=True) +def _reset_environment_registry(): + """Clean up TaskEnvironment instances created during each test.""" + from flyte._environment import _ENVIRONMENT_REGISTRY + initial_len = len(_ENVIRONMENT_REGISTRY) + yield + del _ENVIRONMENT_REGISTRY[initial_len:] + + +def test_flyte_env_image_preserved_after_dag_build(): + """The image supplied via flyte_env= must survive FlyteDAG.build(). + + Previously, build() called TaskEnvironment.from_task() which derived a new + environment from the operator tasks' images (all None), silently discarding + the user-supplied image. + """ + from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from airflow import DAG + from airflow.operators.bash import BashOperator + + custom_image = Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0") + env = flyte.TaskEnvironment(name="test-dag-env", image=custom_image) + + with DAG(dag_id="test_image_preserved", flyte_env=env) as dag: + BashOperator(task_id="say_hello", bash_command='echo hello') + + assert dag.flyte_task is not None + parent_env = dag.flyte_task.parent_env() + assert parent_env is not None, "flyte_task must be attached to an environment" + assert parent_env.image is custom_image, ( + f"Expected the user-supplied image to be preserved, got: {parent_env.image}" + ) + + # Operator tasks must also inherit the env's image so they can be + # serialized correctly when submitted as sub-tasks during remote execution. + for task_name, op_task in parent_env.tasks.items(): + if task_name != f"{env.name}.dag_test_image_preserved": + assert op_task.image is custom_image, ( + f"Operator task {task_name!r} did not inherit env image: {op_task.image}" + ) + + +def test_operator_tasks_registered_in_env(monkeypatch): + """Operator tasks must appear in env.tasks so they are included in deployment.""" + from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from airflow import DAG + from airflow.operators.bash import BashOperator + + env = flyte.TaskEnvironment(name="test-dag-tasks-env") + + with DAG(dag_id="test_tasks_registered", flyte_env=env) as dag: + BashOperator(task_id="step1", bash_command='echo step1') + BashOperator(task_id="step2", bash_command='echo step2') + + parent_env = dag.flyte_task.parent_env() + # Both operator tasks and the orchestrator task must be in env.tasks. + env_task_names = list(parent_env.tasks.keys()) + assert any("step1" in name for name in env_task_names), ( + f"step1 not found in env.tasks: {env_task_names}" + ) + assert any("step2" in name for name in env_task_names), ( + f"step2 not found in env.tasks: {env_task_names}" + ) + + +def test_default_env_created_when_flyte_env_omitted(): + """When flyte_env is not supplied, a default TaskEnvironment is created using + the dag_id as the name and a Debian-base image with airflow packages.""" + from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from airflow import DAG + from airflow.operators.bash import BashOperator + + with DAG(dag_id="test_default_env") as dag: + BashOperator(task_id="greet", bash_command='echo hi') + + assert dag.flyte_task is not None + parent_env = dag.flyte_task.parent_env() + assert parent_env is not None + assert parent_env.name == "test_default_env" + # Default env should have a real image (not None or "auto") + assert isinstance(parent_env.image, Image) From a4ad2081b8132c60d088614395777710eac90af1 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 25 Feb 2026 16:54:02 -0800 Subject: [PATCH 09/17] update example Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index 0e2819122..caa521cfd 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -5,15 +5,10 @@ from airflow.operators.bash import BashOperator import flyte -env = flyte.TaskEnvironment( - name="hello_airflow", - image=flyte.Image.from_debian_base().with_pip_packages("apache-airflow<3.0.0", "jsonpickle").with_local_v2() -) # Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. # Pass flyte_env so the generated workflow task uses the right container image. with DAG( dag_id='simple_bash_operator_example', - flyte_env=env, ) as dag: t1 = BashOperator( task_id='say_hello', From 5b5f74ed7e925113e9ea7858decf3873459c20e2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 26 Feb 2026 11:28:02 -0800 Subject: [PATCH 10/17] update example Signed-off-by: Kevin Su --- examples/airflow-migration/bash_operator.py | 18 ++++++++--- .../airflow/src/flyteplugins/airflow/dag.py | 31 ++++++++++++++----- .../airflow/src/flyteplugins/airflow/task.py | 4 --- tests/flyte/test_image_cache.py | 2 +- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/bash_operator.py index caa521cfd..a763497ba 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/bash_operator.py @@ -3,8 +3,14 @@ from flyteplugins.airflow.task import AirflowContainerTask # triggers DAG + operator patches # type: ignore from airflow import DAG from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator import flyte + +def hello_python(): + print("Hello from PythonOperator!") + + # Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. # Pass flyte_env so the generated workflow task uses the right container image. with DAG( @@ -12,18 +18,20 @@ ) as dag: t1 = BashOperator( task_id='say_hello', - bash_command='echo "Hello Airflow1!"', + bash_command='echo "Hello Airflow!"', ) t2 = BashOperator( task_id='say_goodbye', - bash_command='echo "Goodbye Airflow2!"', + bash_command='echo "Goodbye Airflow!"', + ) + t3 = PythonOperator( + task_id='hello_python', + python_callable=hello_python, ) - # t1 >> t2 # t2 runs after t1 + t1 >> t2 # t2 runs after t1 if __name__ == '__main__': flyte.init_from_config(root_dir=Path("/Users/kevin/git/flyte-sdk")) - # dag.run() is a convenience wrapper — equivalent to: run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) - # run = dag.run(mode="local", log_level="10") print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index fc49e8179..54f8cb583 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -87,11 +87,14 @@ def set_dependency(self, upstream_id: str, downstream_id: str) -> None: # ------------------------------------------------------------------ def build(self) -> None: - """Annotate each task with its downstream tasks and create a Flyte - workflow task whose entry function calls only the root tasks. - - Each root task's execute() will trigger its downstream tasks in - parallel via asyncio.gather, propagating the chain automatically. + """Create a Flyte workflow task whose entry function runs all + operator tasks in dependency order. + + The entry function captures the full dependency graph in its closure + and orchestrates execution directly, starting from root tasks and + chaining downstream tasks after each completes. This ensures + correct ordering in both local and remote execution (where sub-tasks + are resolved independently and lose their in-memory references). """ import flyte @@ -126,11 +129,23 @@ def build(self) -> None: # Snapshot to avoid capturing mutable references in the closure. root_snapshot = list(root_tasks) + # Capture the full dependency graph so _dag_entry can orchestrate + # all tasks itself. In remote execution each sub-task is resolved + # independently and loses its _downstream_flyte_tasks references, + # so the entry function must drive the execution order. + all_tasks = dict(self._tasks) + downstream_snapshot = dict(downstream) + async def _dag_entry() -> None: import asyncio - # Root tasks run in parallel; each task's execute() chains its - # downstream tasks (set via >>) after it completes. - await asyncio.gather(*[t.aio() for t in root_snapshot]) + + async def _run_chain(task): + await task.aio() + ds = downstream_snapshot.get(task.name, []) + if ds: + await asyncio.gather(*[_run_chain(all_tasks[d]) for d in ds]) + + await asyncio.gather(*[_run_chain(t) for t in root_snapshot]) # Operator tasks are created without an image (image=None). Resolve # the env's image to an Image object (mirroring TaskTemplate.__post_init__) diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index b660f25d4..a63e7c2c6 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -136,10 +136,6 @@ async def execute(self, **kwargs) -> Any: ExecutorSafeguard._sentinel.callers = {} logger.info("Executing Airflow task") _get_airflow_instance(self._plugin_config).execute(context=airflow_context.Context()) - # Trigger downstream tasks in parallel after this operator completes. - if self._downstream_flyte_tasks: - import asyncio - await asyncio.gather(*[t.aio() for t in self._downstream_flyte_tasks]) def _get_airflow_instance( diff --git a/tests/flyte/test_image_cache.py b/tests/flyte/test_image_cache.py index 51a0d606d..44282d93b 100644 --- a/tests/flyte/test_image_cache.py +++ b/tests/flyte/test_image_cache.py @@ -18,7 +18,7 @@ def test_image_cache_serialization_round_trip(): # Deserialize back into an ImageCache object # This should also save the serialized form into the object for downstream tasks to get it. - restored_cache = ImageCache.from_transport("H4sIAAAAAAAC/53MSw6DIBAA0LvMuqK0pB8uQ6YjMQZwCAMxqfHu7cbEdQ/w3gZzwsm7yBxaBrsBFaTgRiZxEucEFu7GPMzL6KcZrmoMRXkqqknnUWqnFSb88IKrKOLUvz0SL73UfNODjVi9VLic1ym3P9OfPMZ9/wKPjjm1ugAAAA==") + restored_cache = ImageCache.from_transport(serialized) # Check that the deserialized data matches the original assert restored_cache.image_lookup == original_data["image_lookup"] From a7e091705ef1519e137b18b842977c9e7e5b838f Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 26 Feb 2026 23:53:13 -0800 Subject: [PATCH 11/17] run the command Signed-off-by: Kevin Su --- .../{bash_operator.py => basic.py} | 6 +- .../airflow/src/flyteplugins/airflow/dag.py | 40 ++--- .../airflow/src/flyteplugins/airflow/task.py | 167 +++++++++++------- src/flyte/_internal/resolvers/common.py | 2 +- src/flyte/_task.py | 11 +- tests/flyte/test_airflow_dag.py | 6 +- 6 files changed, 137 insertions(+), 95 deletions(-) rename examples/airflow-migration/{bash_operator.py => basic.py} (79%) diff --git a/examples/airflow-migration/bash_operator.py b/examples/airflow-migration/basic.py similarity index 79% rename from examples/airflow-migration/bash_operator.py rename to examples/airflow-migration/basic.py index a763497ba..e76a511b3 100644 --- a/examples/airflow-migration/bash_operator.py +++ b/examples/airflow-migration/basic.py @@ -1,6 +1,6 @@ from pathlib import Path -from flyteplugins.airflow.task import AirflowContainerTask # triggers DAG + operator patches # type: ignore +from flyteplugins.airflow.task import AirflowFunctionTask # triggers DAG + operator patches # type: ignore from airflow import DAG from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator @@ -14,7 +14,7 @@ def hello_python(): # Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. # Pass flyte_env so the generated workflow task uses the right container image. with DAG( - dag_id='simple_bash_operator_example', + dag_id='simple_airflow_workflow', ) as dag: t1 = BashOperator( task_id='say_hello', @@ -32,6 +32,6 @@ def hello_python(): if __name__ == '__main__': - flyte.init_from_config(root_dir=Path("/Users/kevin/git/flyte-sdk")) + flyte.init_from_config(root_dir=Path(__file__).parent.parent.parent) run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 54f8cb583..faa717c82 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -18,18 +18,16 @@ if __name__ == "__main__": flyte.init_from_config() - run = dag.run(mode="local") + run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) print(run.url) Notes ----- - ``flyte_env`` is an optional kwarg accepted by the patched DAG. If omitted a default ``TaskEnvironment`` is created using the dag_id as the name and a - Debian-base image with ``apache-airflow<3.0.0`` and ``jsonpickle`` installed. + Debian-base image with ``flyteplugins-airflow`` and ``jsonpickle`` installed. - Operator dependency arrows (``>>``, ``<<``) update the execution order. If no explicit dependencies are declared, the operators run in definition order. -- ``dag.run(**kwargs)`` is a convenience wrapper around - ``flyte.with_runcontext(**kwargs).run(dag.flyte_task)``. """ from __future__ import annotations @@ -41,7 +39,9 @@ import airflow.models.dag as _airflow_dag_module if TYPE_CHECKING: - from flyteplugins.airflow.task import AirflowContainerTask + from flyteplugins.airflow.task import AirflowFunctionTask, AirflowRawContainerTask + + AirflowTask = AirflowFunctionTask | AirflowRawContainerTask log = logging.getLogger(__name__) @@ -65,7 +65,7 @@ def __init__(self, dag_id: str, env=None) -> None: self.dag_id = dag_id self.env = env # Ordered dict preserves insertion (creation) order as the default. - self._tasks: Dict[str, "AirflowContainerTask"] = {} + self._tasks: Dict[str, "AirflowTask"] = {} # task_id -> set of upstream task_ids self._upstream: Dict[str, Set[str]] = defaultdict(set) @@ -73,7 +73,7 @@ def __init__(self, dag_id: str, env=None) -> None: # Registration (called by _flyte_operator during DAG definition) # ------------------------------------------------------------------ - def add_task(self, task_id: str, task: "AirflowContainerTask") -> None: + def add_task(self, task_id: str, task: "AirflowTask") -> None: self._tasks[task_id] = task # Ensure a dependency entry exists even with no upstream tasks. _ = self._upstream[task_id] @@ -98,16 +98,16 @@ def build(self) -> None: """ import flyte - env = self.env - if env is None: - env = flyte.TaskEnvironment( + if self.env is None: + self.env = flyte.TaskEnvironment( name=self.dag_id, image=flyte.Image.from_debian_base().with_pip_packages( "apache-airflow<3.0.0", "jsonpickle" ).with_local_v2(), ) + env = self.env - # Build downstream map from the upstream map. + # Build the downstream map from the upstream map. downstream: Dict[str, List[str]] = defaultdict(list) for tid, upstreams in self._upstream.items(): for up in upstreams: @@ -147,25 +147,15 @@ async def _run_chain(task): await asyncio.gather(*[_run_chain(t) for t in root_snapshot]) - # Operator tasks are created without an image (image=None). Resolve - # the env's image to an Image object (mirroring TaskTemplate.__post_init__) - # and assign it to each task that has no explicit image, so that - # from_task sees a consistent set of images and the tasks can be - # serialized correctly when submitted as sub-tasks during remote execution. - _env_image = env.image - if _env_image == "auto": - _env_image = flyte.Image.from_debian_base() - elif isinstance(_env_image, str): - _env_image = flyte.Image.from_base(_env_image) - for _op_task in self._tasks.values(): if _op_task.image is None: - _op_task.image = _env_image + _op_task.image = self.env.image # Register all operator tasks with the DAG's TaskEnvironment so that # they get parent_env / parent_env_name (required for serialization and # image lookup) and appear in env.tasks (required for deployment). - env = env.from_task(env.name, *self._tasks.values()) + for _op_task in self._tasks.values(): + self.env.add_dependency(flyte.TaskEnvironment.from_task(_op_task.name, _op_task)) # Find the first call frame outside this module so the Flyte task is # registered under the user's module, not dag.py. The @@ -189,7 +179,7 @@ async def _run_chain(task): _dag_entry.__qualname__ = f"dag_{self.dag_id}" _dag_entry.__module__ = _caller_module_name - self.flyte_task = env.task(_dag_entry) + self.flyte_task = self.env.task(_dag_entry) # Inject the task into the caller's module so DefaultTaskResolver # can find it via getattr(module, task_name) on both local and remote. diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index a63e7c2c6..4d5fecf5a 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -13,6 +13,8 @@ import jsonpickle import airflow.triggers.base as airflow_triggers import airflow.utils.context as airflow_context +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator import flyte from flyte import logger, get_custom_context @@ -20,6 +22,7 @@ from flyte._internal.controllers import get_controller from flyte._internal.controllers._local_controller import _TaskRunner from flyte._internal.resolvers.common import Resolver +from flyte._module import extract_obj_module from flyte._task import TaskTemplate from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry from flyte.models import SerializationContext, NativeInterface @@ -32,7 +35,7 @@ @dataclass -class AirflowObj(object): +class AirflowTaskMetadata(object): """ This class is used to store the Airflow task configuration. It is serialized and stored in the Flyte task config. It can be trigger, hook, operator or sensor. For example: @@ -64,26 +67,91 @@ def load_task(self, loader_args: typing.List[str]) -> AsyncFunctionTaskTemplate: """ This method is used to load an Airflow task. """ - _, task_module, _, task_name, _, task_config = loader_args - task_module = importlib.import_module(name=task_module) # type: ignore - task_def = getattr(task_module, task_name) - return task_def(name=task_name, task_config=jsonpickle.decode(task_config)) + _, airflow_task_module, _, airflow_task_name, _, airflow_task_parameters, _, func_module, _, func_name = loader_args + func_module = importlib.import_module(name=func_module) + func_def = getattr(func_module, func_name) + return AirflowFunctionTask( + name=airflow_task_name, + airflow_task_metadata=AirflowTaskMetadata( + module=airflow_task_module, + name=airflow_task_name, + parameters=jsonpickle.decode(airflow_task_parameters) + ), + func=func_def, + ) - def loader_args(self, task: AsyncFunctionTaskTemplate, root_dir: Path) -> List[str]: # type:ignore + def loader_args(self, task: "AirflowFunctionTask", root_dir: Path) -> List[str]: # type:ignore + entity_module_name, _ = extract_obj_module(task.func, root_dir) return [ - "task-module", - task.__module__, - "task-name", - task.__class__.__name__, - "task-config", - jsonpickle.encode(task.plugin_config), + "airflow-task-module", + task.airflow_task_metadata.module, + "airflow-task-name", + task.airflow_task_metadata.name, + "airflow-task-parameters", + jsonpickle.encode(task.airflow_task_metadata.parameters), + "airflow-func-module", + entity_module_name, + "airflow-func-name", + task.func.__name__ ] -airflow_task_resolver = AirflowTaskResolver() +class AirflowRawContainerTask(TaskTemplate): + """ + Running Bash command in the container. + """ + + def __init__( + self, + name: str, + airflow_task_metadata: AirflowTaskMetadata, + command: str, + # inputs: Optional[Dict[str, Type]] = None, + **kwargs, + ): + super().__init__( + name=name, + interface=NativeInterface(inputs={}, outputs={}), + **kwargs, + ) + self.resolver = AirflowTaskResolver() + self._airflow_task_metadata = airflow_task_metadata + self._command = command + self._call_as_synchronous = True + self._downstream_flyte_tasks: List["AirflowFunctionTask"] = [] + + # ------------------------------------------------------------------ + # Airflow dependency-arrow support (>> / <<) + # Records the dependency in the active FlyteDAG if one is being built. + # ------------------------------------------------------------------ + + def __rshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": + """``self >> other`` — other runs after self.""" + if _dag_module._current_flyte_dag is not None: + _dag_module._current_flyte_dag.set_dependency(self.name, other.name) + return other + + def __lshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": + """``self << other`` — self runs after other.""" + if _dag_module._current_flyte_dag is not None: + _dag_module._current_flyte_dag.set_dependency(other.name, self.name) + return other + + def container_args(self, sctx: SerializationContext) -> List[str]: + return self._command.split() + + async def execute(self, **kwargs) -> Any: + # ExecutorSafeguard stores a sentinel in a threading.local() dict. That + # dict is initialised on the main thread at import time, but tasks may + # run in a background thread where the thread-local has no 'callers' key. + from airflow.models.baseoperator import ExecutorSafeguard + if not hasattr(ExecutorSafeguard._sentinel, "callers"): + ExecutorSafeguard._sentinel.callers = {} + logger.info("Executing Airflow task") + return _get_airflow_instance(self._airflow_task_metadata).execute(context=airflow_context.Context()) -class AirflowContainerTask(TaskTemplate): +class AirflowFunctionTask(AsyncFunctionTaskTemplate): """ This python container task is used to wrap an Airflow task. It is used to run an Airflow task in a container. The airflow task module, name and parameters are stored in the task config. @@ -95,33 +163,35 @@ class AirflowContainerTask(TaskTemplate): def __init__( self, name: str, - plugin_config: AirflowObj, + airflow_task_metadata: AirflowTaskMetadata, + func: Optional[callable], # inputs: Optional[Dict[str, Type]] = None, **kwargs, ): super().__init__( name=name, # plugin_config=plugin_config, + func=func, interface=NativeInterface(inputs={}, outputs={}), **kwargs, ) - self._task_resolver = airflow_task_resolver - self._plugin_config = plugin_config + self.resolver = AirflowTaskResolver() + self.airflow_task_metadata = airflow_task_metadata self._call_as_synchronous = True - self._downstream_flyte_tasks: List["AirflowContainerTask"] = [] + self._downstream_flyte_tasks: List["AirflowFunctionTask"] = [] # ------------------------------------------------------------------ # Airflow dependency-arrow support (>> / <<) # Records the dependency in the active FlyteDAG if one is being built. # ------------------------------------------------------------------ - def __rshift__(self, other: "AirflowContainerTask") -> "AirflowContainerTask": + def __rshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": """``self >> other`` — other runs after self.""" if _dag_module._current_flyte_dag is not None: _dag_module._current_flyte_dag.set_dependency(self.name, other.name) return other - def __lshift__(self, other: "AirflowContainerTask") -> "AirflowContainerTask": + def __lshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": """``self << other`` — self runs after other.""" if _dag_module._current_flyte_dag is not None: _dag_module._current_flyte_dag.set_dependency(other.name, self.name) @@ -135,48 +205,19 @@ async def execute(self, **kwargs) -> Any: if not hasattr(ExecutorSafeguard._sentinel, "callers"): ExecutorSafeguard._sentinel.callers = {} logger.info("Executing Airflow task") - _get_airflow_instance(self._plugin_config).execute(context=airflow_context.Context()) + self.airflow_task_metadata.parameters["python_callable"] = self.func + return _get_airflow_instance(self.airflow_task_metadata).execute(context=airflow_context.Context()) def _get_airflow_instance( - airflow_obj: AirflowObj, + airflow_task_metadata: AirflowTaskMetadata, ) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original # airflow task instead of the Flyte task. with flyte.custom_context(GET_ORIGINAL_TASK="True"): - - obj_module = importlib.import_module(name=airflow_obj.module) - obj_def = getattr(obj_module, airflow_obj.name) - if _is_deferrable(obj_def): - try: - return obj_def(**airflow_obj.parameters, deferrable=True) - except airflow.exceptions.AirflowException as e: - logger.debug(f"Failed to create operator {airflow_obj.name} with err: {e}.") - logger.debug(f"Airflow operator {airflow_obj.name} does not support deferring.") - - return obj_def(**airflow_obj.parameters) - - -def _is_deferrable(cls: Type) -> bool: - """ - This function is used to check if the Airflow operator is deferrable. - If the operator is not deferrable, we run it in a container instead of the connector. - """ - # Only Airflow operators are deferrable. - if not issubclass(cls, airflow_models.BaseOperator): - return False - # Airflow sensors are not deferrable. The Sensor is a subclass of BaseOperator. - if issubclass(cls, airflow_sensors.BaseSensorOperator): - return False - try: - from airflow.providers.apache.beam.operators.beam import BeamBasePipelineOperator - - # Dataflow operators are not deferrable. - if issubclass(cls, BeamBasePipelineOperator): - return False - except ImportError: - logger.debug("Failed to import BeamBasePipelineOperator") - return True + obj_module = importlib.import_module(name=airflow_task_metadata.module) + obj_def = getattr(obj_module, airflow_task_metadata.name) + return obj_def(**airflow_task_metadata.parameters) def _flyte_operator(*args, **kwargs): @@ -197,9 +238,17 @@ def _flyte_operator(*args, **kwargs): container_image = kwargs.pop("container_image", None) task_id = kwargs.get("task_id", cls.__name__) - config = AirflowObj(module=cls.__module__, name=cls.__name__, parameters=kwargs) - - task = AirflowContainerTask(name=task_id, plugin_config=config, image=container_image) + airflow_task_metadata = AirflowTaskMetadata(module=cls.__module__, name=cls.__name__, parameters=kwargs) + + if cls == BashOperator: + command = kwargs.get("bash_command", "") + task = AirflowRawContainerTask(name=task_id, airflow_task_metadata=airflow_task_metadata, command=command) + elif cls == PythonOperator: + func = kwargs.get("python_callable", None) + kwargs.pop("python_callable", None) + task = AirflowFunctionTask(name=task_id, airflow_task_metadata=airflow_task_metadata, func=func, image=container_image) + else: + raise ValueError(f"Unsupported Airflow operator: {cls.__name__}") # ── Case 1: inside a ``with DAG(...) as dag:`` block ──────────────────── # Register the task with the active FlyteDAG collector so it can be wired @@ -218,4 +267,4 @@ def _flyte_operator(*args, **kwargs): # Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. -airflow_models.BaseOperator.__new__ = _flyte_operator \ No newline at end of file +airflow_models.BaseOperator.__new__ = _flyte_operator diff --git a/src/flyte/_internal/resolvers/common.py b/src/flyte/_internal/resolvers/common.py index d2b73a6c3..4dead7390 100644 --- a/src/flyte/_internal/resolvers/common.py +++ b/src/flyte/_internal/resolvers/common.py @@ -30,7 +30,7 @@ def load_app_env(self, loader_args: str) -> AppEnvironment: """ raise NotImplementedError - def loader_args(self, t: TaskTemplate, root_dir: Optional[Path]) -> List[str] | str: + def loader_args(self, task: TaskTemplate, root_dir: Optional[Path]) -> List[str] | str: """ Return a list of strings that can help identify the parameter TaskTemplate. Each string should not have spaces or special characters. This is used to identify the task in the resolver. diff --git a/src/flyte/_task.py b/src/flyte/_task.py index d141d1c06..ecbb146a0 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import functools import weakref from dataclasses import dataclass, field, replace from inspect import iscoroutinefunction @@ -41,7 +42,7 @@ if TYPE_CHECKING: from flyteidl2.core.tasks_pb2 import DataLoadingConfig - + from ._internal.resolvers.common import Resolver from ._task_environment import TaskEnvironment P = ParamSpec("P") # capture the function's parameters @@ -111,6 +112,7 @@ def my_task(): report: bool = False queue: Optional[str] = None debuggable: bool = False + resolver: Optional[Resolver] = None parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None parent_env_name: Optional[str] = None @@ -147,6 +149,9 @@ def __post_init__(self): # If short_name is not set, use the name of the task self.short_name = self.name + from ._internal.resolvers.default import DefaultTaskResolver + self.resolver = self.resolver or DefaultTaskResolver() + def __getstate__(self): """ This method is called when the object is pickled. We need to remove the parent_env reference @@ -548,14 +553,12 @@ def container_args(self, serialize_context: SerializationContext) -> List[str]: if not serialize_context.code_bundle or not serialize_context.code_bundle.pkl: # If we do not have a code bundle, or if we have one, but it is not a pkl, we need to add the resolver - from flyte._internal.resolvers.default import DefaultTaskResolver - if not serialize_context.root_dir: raise RuntimeSystemError( "SerializationError", "Root dir is required for default task resolver when no code bundle is provided.", ) - _task_resolver = DefaultTaskResolver() + _task_resolver = self.resolver args = [ *args, *[ diff --git a/tests/flyte/test_airflow_dag.py b/tests/flyte/test_airflow_dag.py index b32c71963..3379f0940 100644 --- a/tests/flyte/test_airflow_dag.py +++ b/tests/flyte/test_airflow_dag.py @@ -23,7 +23,7 @@ def test_flyte_env_image_preserved_after_dag_build(): environment from the operator tasks' images (all None), silently discarding the user-supplied image. """ - from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator @@ -51,7 +51,7 @@ def test_flyte_env_image_preserved_after_dag_build(): def test_operator_tasks_registered_in_env(monkeypatch): """Operator tasks must appear in env.tasks so they are included in deployment.""" - from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator @@ -75,7 +75,7 @@ def test_operator_tasks_registered_in_env(monkeypatch): def test_default_env_created_when_flyte_env_omitted(): """When flyte_env is not supplied, a default TaskEnvironment is created using the dag_id as the name and a Debian-base image with airflow packages.""" - from flyteplugins.airflow.task import AirflowContainerTask # applies patches + from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator From 5bc58c7c493b0dd2934547df19eeb3b58a88f703 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 00:05:54 -0800 Subject: [PATCH 12/17] refactor Signed-off-by: Kevin Su --- .../airflow/src/flyteplugins/airflow/dag.py | 122 +++++----- .../airflow/src/flyteplugins/airflow/task.py | 211 +++++++++--------- 2 files changed, 171 insertions(+), 162 deletions(-) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index faa717c82..4c5991e9f 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -32,13 +32,17 @@ from __future__ import annotations +import inspect import logging +import sys as _sys from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import airflow.models.dag as _airflow_dag_module if TYPE_CHECKING: + import types + from flyteplugins.airflow.task import AirflowFunctionTask, AirflowRawContainerTask AirflowTask = AirflowFunctionTask | AirflowRawContainerTask @@ -86,6 +90,56 @@ def set_dependency(self, upstream_id: str, downstream_id: str) -> None: # Flyte task construction # ------------------------------------------------------------------ + def _build_downstream_map(self) -> Dict[str, List[str]]: + """Invert ``self._upstream`` to get downstream adjacency lists.""" + downstream: Dict[str, List[str]] = defaultdict(list) + for tid, upstreams in self._upstream.items(): + for up in upstreams: + downstream[up].append(tid) + return downstream + + def _find_caller_module(self) -> Tuple[str, Optional["types.ModuleType"]]: + """Walk the call stack to find the first frame outside this module. + + The Flyte task must be registered under the *user's* module so that + ``DefaultTaskResolver`` can locate it via ``getattr(module, name)`` + on the remote worker (which re-imports the module and re-runs the + DAG definition). + """ + for fi in inspect.stack(): + mod = fi.frame.f_globals.get("__name__", "") + if mod and mod != __name__: + return mod, _sys.modules.get(mod) + return __name__, None + + def _create_dag_entry( + self, + all_tasks: Dict[str, "AirflowTask"], + downstream_map: Dict[str, List[str]], + root_tasks: List["AirflowTask"], + ): + """Build the async entry function that orchestrates task execution.""" + # Snapshot to avoid capturing mutable references in the closure. + root_snapshot = list(root_tasks) + downstream_snapshot = dict(downstream_map) + + async def _dag_entry() -> None: + import asyncio + + async def _run_chain(task): + await task.aio() + ds = downstream_snapshot.get(task.name, []) + if ds: + await asyncio.gather(*[_run_chain(all_tasks[d]) for d in ds]) + + await asyncio.gather(*[_run_chain(t) for t in root_snapshot]) + + caller_module_name, caller_module = self._find_caller_module() + _dag_entry.__name__ = f"dag_{self.dag_id}" + _dag_entry.__qualname__ = f"dag_{self.dag_id}" + _dag_entry.__module__ = caller_module_name + return _dag_entry, caller_module + def build(self) -> None: """Create a Flyte workflow task whose entry function runs all operator tasks in dependency order. @@ -105,15 +159,10 @@ def build(self) -> None: "apache-airflow<3.0.0", "jsonpickle" ).with_local_v2(), ) - env = self.env - # Build the downstream map from the upstream map. - downstream: Dict[str, List[str]] = defaultdict(list) - for tid, upstreams in self._upstream.items(): - for up in upstreams: - downstream[up].append(tid) + downstream = self._build_downstream_map() - # Annotate each AirflowContainerTask with its downstream tasks. + # Annotate each task with its downstream tasks. for tid, task in self._tasks.items(): task._downstream_flyte_tasks = [ self._tasks[d] for d in downstream[tid] if d in self._tasks @@ -126,65 +175,24 @@ def build(self) -> None: if len(ups) == 0 ] - # Snapshot to avoid capturing mutable references in the closure. - root_snapshot = list(root_tasks) - - # Capture the full dependency graph so _dag_entry can orchestrate - # all tasks itself. In remote execution each sub-task is resolved - # independently and loses its _downstream_flyte_tasks references, - # so the entry function must drive the execution order. - all_tasks = dict(self._tasks) - downstream_snapshot = dict(downstream) - - async def _dag_entry() -> None: - import asyncio - - async def _run_chain(task): - await task.aio() - ds = downstream_snapshot.get(task.name, []) - if ds: - await asyncio.gather(*[_run_chain(all_tasks[d]) for d in ds]) - - await asyncio.gather(*[_run_chain(t) for t in root_snapshot]) + _dag_entry, caller_module = self._create_dag_entry( + all_tasks=dict(self._tasks), + downstream_map=downstream, + root_tasks=root_tasks, + ) + # Set image and register operator tasks with the DAG's TaskEnvironment. for _op_task in self._tasks.values(): if _op_task.image is None: _op_task.image = self.env.image - - # Register all operator tasks with the DAG's TaskEnvironment so that - # they get parent_env / parent_env_name (required for serialization and - # image lookup) and appear in env.tasks (required for deployment). - for _op_task in self._tasks.values(): self.env.add_dependency(flyte.TaskEnvironment.from_task(_op_task.name, _op_task)) - # Find the first call frame outside this module so the Flyte task is - # registered under the user's module, not dag.py. The - # DefaultTaskResolver records (module, name) at submission time and - # on the remote worker it imports that module — which re-runs the - # DAG definition and re-injects the task — before calling - # getattr(module, task_name). - import inspect - import sys as _sys - - _caller_module_name = __name__ - _caller_module = None - for _fi in inspect.stack(): - _mod = _fi.frame.f_globals.get("__name__", "") - if _mod and _mod != __name__: - _caller_module_name = _mod - _caller_module = _sys.modules.get(_caller_module_name) - break - - _dag_entry.__name__ = f"dag_{self.dag_id}" - _dag_entry.__qualname__ = f"dag_{self.dag_id}" - _dag_entry.__module__ = _caller_module_name - self.flyte_task = self.env.task(_dag_entry) # Inject the task into the caller's module so DefaultTaskResolver # can find it via getattr(module, task_name) on both local and remote. - if _caller_module is not None: - setattr(_caller_module, _dag_entry.__name__, self.flyte_task) + if caller_module is not None: + setattr(caller_module, _dag_entry.__name__, self.flyte_task) # --------------------------------------------------------------------------- diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index 4d5fecf5a..7ca99807e 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -1,52 +1,47 @@ import importlib import logging -import os -import threading import typing from dataclasses import dataclass -from typing import Any, Dict, Optional, Type, List - -import airflow from pathlib import Path +from typing import Any, List, Optional + import airflow.models as airflow_models import airflow.sensors.base as airflow_sensors -import jsonpickle import airflow.triggers.base as airflow_triggers import airflow.utils.context as airflow_context +import jsonpickle from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator import flyte -from flyte import logger, get_custom_context -from flyte._context import internal_ctx, root_context_var -from flyte._internal.controllers import get_controller -from flyte._internal.controllers._local_controller import _TaskRunner +from flyte import get_custom_context, logger +from flyte._context import internal_ctx from flyte._internal.resolvers.common import Resolver from flyte._module import extract_obj_module from flyte._task import TaskTemplate -from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry -from flyte.models import SerializationContext, NativeInterface - -# Per-thread _TaskRunner instances used by _flyte_operator for sync blocking submission. -_airflow_runners: Dict[str, _TaskRunner] = {} +from flyte.extend import AsyncFunctionTaskTemplate +from flyte.models import NativeInterface, SerializationContext # Import dag module to apply DAG monkey-patches when this module is imported. from flyteplugins.airflow import dag as _dag_module # noqa: E402 +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + @dataclass -class AirflowTaskMetadata(object): - """ - This class is used to store the Airflow task configuration. It is serialized and stored in the Flyte task config. - It can be trigger, hook, operator or sensor. For example: +class AirflowTaskMetadata: + """Stores the Airflow operator class location and constructor kwargs. + + For example, given:: - from airflow.sensors.filesystem import FileSensor - sensor = FileSensor(task_id="id", filepath="/tmp/1234") + FileSensor(task_id="id", filepath="/tmp/1234") - In this case, the attributes of AirflowObj will be: - module: airflow.sensors.filesystem - name: FileSensor - parameters: {"task_id": "id", "filepath": "/tmp/1234"} + the fields would be: + module: "airflow.sensors.filesystem" + name: "FileSensor" + parameters: {"task_id": "id", "filepath": "/tmp/1234"} """ module: str @@ -54,9 +49,15 @@ class AirflowTaskMetadata(object): parameters: typing.Dict[str, Any] +# --------------------------------------------------------------------------- +# Resolver +# --------------------------------------------------------------------------- + class AirflowTaskResolver(Resolver): - """ - This class is used to resolve an Airflow task. It will load an airflow task in the container. + """Resolves an AirflowFunctionTask on the remote worker. + + The resolver records the Airflow operator metadata and the wrapped Python + callable so that the task can be reconstructed from loader args alone. """ @property @@ -64,9 +65,6 @@ def import_path(self) -> str: return "flyteplugins.airflow.task.AirflowTaskResolver" def load_task(self, loader_args: typing.List[str]) -> AsyncFunctionTaskTemplate: - """ - This method is used to load an Airflow task. - """ _, airflow_task_module, _, airflow_task_name, _, airflow_task_parameters, _, func_module, _, func_name = loader_args func_module = importlib.import_module(name=func_module) func_def = getattr(func_module, func_name) @@ -96,34 +94,24 @@ def loader_args(self, task: "AirflowFunctionTask", root_dir: Path) -> List[str]: ] -class AirflowRawContainerTask(TaskTemplate): - """ - Running Bash command in the container. +# --------------------------------------------------------------------------- +# Shared task behaviour (mixin) +# --------------------------------------------------------------------------- + +class _AirflowTaskMixin: + """Shared behaviour for both raw-container and function Airflow tasks. + + Provides Airflow-style dependency arrows (``>>`` / ``<<``) and the + ``ExecutorSafeguard`` workaround needed when tasks run on background threads. """ - def __init__( - self, - name: str, - airflow_task_metadata: AirflowTaskMetadata, - command: str, - # inputs: Optional[Dict[str, Type]] = None, - **kwargs, - ): - super().__init__( - name=name, - interface=NativeInterface(inputs={}, outputs={}), - **kwargs, - ) + def _init_airflow_mixin(self) -> None: self.resolver = AirflowTaskResolver() - self._airflow_task_metadata = airflow_task_metadata - self._command = command self._call_as_synchronous = True self._downstream_flyte_tasks: List["AirflowFunctionTask"] = [] - # ------------------------------------------------------------------ # Airflow dependency-arrow support (>> / <<) # Records the dependency in the active FlyteDAG if one is being built. - # ------------------------------------------------------------------ def __rshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": """``self >> other`` — other runs after self.""" @@ -137,27 +125,59 @@ def __lshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": _dag_module._current_flyte_dag.set_dependency(other.name, self.name) return other - def container_args(self, sctx: SerializationContext) -> List[str]: - return self._command.split() + @staticmethod + def _patch_executor_safeguard() -> None: + """Ensure ExecutorSafeguard's thread-local has a ``callers`` dict. - async def execute(self, **kwargs) -> Any: - # ExecutorSafeguard stores a sentinel in a threading.local() dict. That - # dict is initialised on the main thread at import time, but tasks may - # run in a background thread where the thread-local has no 'callers' key. + ExecutorSafeguard stores a sentinel in a ``threading.local()`` dict + that is initialised on the main thread at import time. Tasks may run + on a background thread where the thread-local has no ``callers`` key. + """ from airflow.models.baseoperator import ExecutorSafeguard if not hasattr(ExecutorSafeguard._sentinel, "callers"): ExecutorSafeguard._sentinel.callers = {} + + +# --------------------------------------------------------------------------- +# Task classes +# --------------------------------------------------------------------------- + +class AirflowRawContainerTask(_AirflowTaskMixin, TaskTemplate): + """Wraps an Airflow BashOperator as a Flyte raw-container task.""" + + def __init__( + self, + name: str, + airflow_task_metadata: AirflowTaskMetadata, + command: str, + **kwargs, + ): + super().__init__( + name=name, + interface=NativeInterface(inputs={}, outputs={}), + **kwargs, + ) + self._init_airflow_mixin() + self._airflow_task_metadata = airflow_task_metadata + self._command = command + + def container_args(self, sctx: SerializationContext) -> List[str]: + return self._command.split() + + async def execute(self, **kwargs) -> Any: + self._patch_executor_safeguard() logger.info("Executing Airflow task") return _get_airflow_instance(self._airflow_task_metadata).execute(context=airflow_context.Context()) -class AirflowFunctionTask(AsyncFunctionTaskTemplate): - """ - This python container task is used to wrap an Airflow task. It is used to run an Airflow task in a container. - The airflow task module, name and parameters are stored in the task config. +class AirflowFunctionTask(_AirflowTaskMixin, AsyncFunctionTaskTemplate): + """Wraps an Airflow PythonOperator as a Flyte function task. - Some of the Airflow operators are not deferrable, For example, BeamRunJavaPipelineOperator, BeamRunPythonPipelineOperator. - These tasks don't have an async method to get the job status, so cannot be used in the Flyte connector. We run these tasks in a container. + The airflow task module, name, and parameters are stored in the task + config. Some Airflow operators are not deferrable (e.g. + ``BeamRunJavaPipelineOperator``). These tasks lack an async method to + poll job status so they cannot use the Flyte connector — we run them in + a container instead. """ def __init__( @@ -165,75 +185,59 @@ def __init__( name: str, airflow_task_metadata: AirflowTaskMetadata, func: Optional[callable], - # inputs: Optional[Dict[str, Type]] = None, **kwargs, ): super().__init__( name=name, - # plugin_config=plugin_config, func=func, interface=NativeInterface(inputs={}, outputs={}), **kwargs, ) - self.resolver = AirflowTaskResolver() + self._init_airflow_mixin() self.airflow_task_metadata = airflow_task_metadata - self._call_as_synchronous = True - self._downstream_flyte_tasks: List["AirflowFunctionTask"] = [] - - # ------------------------------------------------------------------ - # Airflow dependency-arrow support (>> / <<) - # Records the dependency in the active FlyteDAG if one is being built. - # ------------------------------------------------------------------ - - def __rshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": - """``self >> other`` — other runs after self.""" - if _dag_module._current_flyte_dag is not None: - _dag_module._current_flyte_dag.set_dependency(self.name, other.name) - return other - - def __lshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": - """``self << other`` — self runs after other.""" - if _dag_module._current_flyte_dag is not None: - _dag_module._current_flyte_dag.set_dependency(other.name, self.name) - return other async def execute(self, **kwargs) -> Any: - # ExecutorSafeguard stores a sentinel in a threading.local() dict. That - # dict is initialised on the main thread at import time, but tasks may - # run in a background thread where the thread-local has no 'callers' key. - from airflow.models.baseoperator import ExecutorSafeguard - if not hasattr(ExecutorSafeguard._sentinel, "callers"): - ExecutorSafeguard._sentinel.callers = {} + self._patch_executor_safeguard() logger.info("Executing Airflow task") self.airflow_task_metadata.parameters["python_callable"] = self.func return _get_airflow_instance(self.airflow_task_metadata).execute(context=airflow_context.Context()) +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + def _get_airflow_instance( airflow_task_metadata: AirflowTaskMetadata, ) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: - # Set the GET_ORIGINAL_TASK attribute to True so that obj_def will return the original - # airflow task instead of the Flyte task. + """Instantiate the original Airflow operator from its metadata.""" + # Set GET_ORIGINAL_TASK so that obj_def returns the real Airflow + # operator instead of being intercepted by _flyte_operator. with flyte.custom_context(GET_ORIGINAL_TASK="True"): obj_module = importlib.import_module(name=airflow_task_metadata.module) obj_def = getattr(obj_module, airflow_task_metadata.name) return obj_def(**airflow_task_metadata.parameters) +# --------------------------------------------------------------------------- +# Operator intercept (monkey-patch) +# --------------------------------------------------------------------------- + def _flyte_operator(*args, **kwargs): - """ - This function is called by the Airflow operator to create a new task. We intercept this call and return a Flyte - task instead. + """Intercept Airflow operator construction and return a Flyte task instead. + + Called via the monkey-patched ``BaseOperator.__new__``. Depending on + context this either registers the task with an active FlyteDAG, submits + it as a sub-task during execution, or returns the task object for later + serialization. """ cls = args[0] try: if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": - # Return an original task when running in the connector. - print("Returning original Airflow task") + logger.debug("Returning original Airflow task") return object.__new__(cls) except AssertionError: # This happens when the task is created in the dynamic workflow. - # We don't need to return the original task in this case. logging.debug("failed to get the attribute GET_ORIGINAL_TASK from user space params") container_image = kwargs.pop("container_image", None) @@ -250,21 +254,18 @@ def _flyte_operator(*args, **kwargs): else: raise ValueError(f"Unsupported Airflow operator: {cls.__name__}") - # ── Case 1: inside a ``with DAG(...) as dag:`` block ──────────────────── - # Register the task with the active FlyteDAG collector so it can be wired - # into the Flyte workflow when the DAG context exits. Do NOT execute yet. + # Case 1: inside a ``with DAG(...) as dag:`` block — register with FlyteDAG. if _dag_module._current_flyte_dag is not None: _dag_module._current_flyte_dag.add_task(task_id, task) return task - # ── Case 2: inside a Flyte task execution ─────────────────────────────── - # The dag workflow function is executing; submit the operator as a sub-task. + # Case 2: inside a Flyte task execution — submit the operator as a sub-task. if internal_ctx().is_task_context(): return task() - # ── Case 3: outside any context (e.g. serialization / import scan) ────── + # Case 3: outside any context (e.g. serialization / import scan). return task -# Monkey patches the Airflow operator. Instead of creating an airflow task, it returns a Flyte task. +# Monkey-patch: intercept Airflow operator construction. airflow_models.BaseOperator.__new__ = _flyte_operator From 53d410804e08bb09b0865ac74e00bacb3f7002b2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 00:26:14 -0800 Subject: [PATCH 13/17] Rewrite airflow plugin for clarity and maintainability - Extract _AirflowTaskMixin to DRY shared behavior (dependency arrows, ExecutorSafeguard workaround, common init) between task classes - Rename classes for clarity: AirflowRawContainerTask -> AirflowShellTask, AirflowFunctionTask -> AirflowPythonFunctionTask, AirflowTaskResolver -> AirflowPythonTaskResolver - Break FlyteDAG.build() into smaller methods: _build_downstream_map(), _find_caller_module(), _create_dag_entry() - Replace global statement with module-level _state dict - Remove unused imports, dead code (_downstream_flyte_tasks, commented-out params), and replace print() with logger.debug() - Consolidate two iteration loops in build() into one - Update README with current API and quick-start example Co-Authored-By: Claude Opus 4.6 Signed-off-by: Kevin Su --- examples/airflow-migration/basic.py | 13 +-- plugins/airflow/README.md | 39 +++++++- .../airflow/src/flyteplugins/airflow/dag.py | 40 +++----- .../airflow/src/flyteplugins/airflow/task.py | 93 ++++++++++--------- tests/flyte/test_airflow_dag.py | 22 ++--- 5 files changed, 113 insertions(+), 94 deletions(-) diff --git a/examples/airflow-migration/basic.py b/examples/airflow-migration/basic.py index e76a511b3..39b4644ff 100644 --- a/examples/airflow-migration/basic.py +++ b/examples/airflow-migration/basic.py @@ -1,9 +1,10 @@ from pathlib import Path -from flyteplugins.airflow.task import AirflowFunctionTask # triggers DAG + operator patches # type: ignore +import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches from airflow import DAG from airflow.operators.bash import BashOperator from airflow.operators.python import PythonOperator + import flyte @@ -14,24 +15,24 @@ def hello_python(): # Standard Airflow DAG definition — no Flyte-specific changes needed inside the block. # Pass flyte_env so the generated workflow task uses the right container image. with DAG( - dag_id='simple_airflow_workflow', + dag_id="simple_airflow_workflow", ) as dag: t1 = BashOperator( - task_id='say_hello', + task_id="say_hello", bash_command='echo "Hello Airflow!"', ) t2 = BashOperator( - task_id='say_goodbye', + task_id="say_goodbye", bash_command='echo "Goodbye Airflow!"', ) t3 = PythonOperator( - task_id='hello_python', + task_id="hello_python", python_callable=hello_python, ) t1 >> t2 # t2 runs after t1 -if __name__ == '__main__': +if __name__ == "__main__": flyte.init_from_config(root_dir=Path(__file__).parent.parent.parent) run = flyte.with_runcontext(mode="remote", log_level="10").run(dag) print(run.url) diff --git a/plugins/airflow/README.md b/plugins/airflow/README.md index 95349163b..38f3d3ba3 100644 --- a/plugins/airflow/README.md +++ b/plugins/airflow/README.md @@ -1,12 +1,41 @@ # Flyte Airflow Plugin -Airflow plugin allows you to seamlessly run Airflow tasks in the Flyte workflow without changing any code. +Run existing Airflow DAGs on Flyte with minimal code changes. The plugin +monkey-patches `airflow.DAG` and `BaseOperator` so that standard Airflow +definitions are transparently converted into Flyte tasks. -- Compile Airflow tasks to Flyte tasks -- Use Airflow sensors/operators in Flyte workflows -- Add support for running Airflow tasks locally without running a cluster +## Features +- Write a normal `with DAG(...) as dag:` block — the plugin intercepts + operator construction and wires everything into a Flyte workflow. +- Supports `BashOperator` (`AirflowShellTask`) and `PythonOperator` + (`AirflowPythonFunctionTask`). +- Dependency arrows (`>>`, `<<`) are preserved as execution order. +- Runs locally or remotely — no Airflow cluster required. + +## Installation ```bash -pip install --pre flyteplugins-airflow +pip install flyteplugins-airflow ``` + +## Quick start + +```python +import flyteplugins.airflow.task # triggers DAG + operator monkey-patches +from airflow import DAG +from airflow.operators.bash import BashOperator +import flyte + +with DAG(dag_id="my_dag") as dag: + t1 = BashOperator(task_id="step1", bash_command="echo step1") + t2 = BashOperator(task_id="step2", bash_command="echo step2") + t1 >> t2 + +if __name__ == "__main__": + flyte.init_from_config() + run = flyte.with_runcontext(mode="remote").run(dag) + print(run.url) +``` + +See `examples/airflow-migration/` for a full example including `PythonOperator`. diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 4c5991e9f..83094d452 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -43,9 +43,9 @@ if TYPE_CHECKING: import types - from flyteplugins.airflow.task import AirflowFunctionTask, AirflowRawContainerTask + from flyteplugins.airflow.task import AirflowPythonFunctionTask, AirflowShellTask - AirflowTask = AirflowFunctionTask | AirflowRawContainerTask + AirflowTask = AirflowPythonFunctionTask | AirflowShellTask log = logging.getLogger(__name__) @@ -53,14 +53,16 @@ # Module-level state # --------------------------------------------------------------------------- -#: Set when the code is inside a ``with DAG(...) as dag:`` block. -_current_flyte_dag: Optional["FlyteDAG"] = None +#: Mutable container for the active FlyteDAG (set inside a ``with DAG(...)`` block). +#: Using a dict avoids ``global`` statements in the patch functions. +_state: Dict[str, Optional["FlyteDAG"]] = {"current_flyte_dag": None} # --------------------------------------------------------------------------- -# FlyteDAG – collects operators and builds the Flyte workflow task +# FlyteDAG - collects operators and builds the Flyte workflow task # --------------------------------------------------------------------------- + class FlyteDAG: """Collects Airflow operators during a DAG definition and converts them into a single Flyte task that runs them in dependency order.""" @@ -155,25 +157,15 @@ def build(self) -> None: if self.env is None: self.env = flyte.TaskEnvironment( name=self.dag_id, - image=flyte.Image.from_debian_base().with_pip_packages( - "apache-airflow<3.0.0", "jsonpickle" - ).with_local_v2(), + image=flyte.Image.from_debian_base() + .with_pip_packages("apache-airflow<3.0.0", "jsonpickle") + .with_local_v2(), ) downstream = self._build_downstream_map() - # Annotate each task with its downstream tasks. - for tid, task in self._tasks.items(): - task._downstream_flyte_tasks = [ - self._tasks[d] for d in downstream[tid] if d in self._tasks - ] - # Root tasks: those with no upstream dependencies. - root_tasks = [ - self._tasks[tid] - for tid, ups in self._upstream.items() - if len(ups) == 0 - ] + root_tasks = [self._tasks[tid] for tid, ups in self._upstream.items() if len(ups) == 0] _dag_entry, caller_module = self._create_dag_entry( all_tasks=dict(self._tasks), @@ -212,22 +204,20 @@ def _patched_dag_init(self, *args, **kwargs) -> None: # type: ignore[override] def _patched_dag_enter(self): # type: ignore[override] - global _current_flyte_dag - _current_flyte_dag = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) + _state["current_flyte_dag"] = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) return _original_dag_enter(self) def _patched_dag_exit(self, exc_type, exc_val, exc_tb): # type: ignore[override] - global _current_flyte_dag try: - if exc_type is None and _current_flyte_dag is not None: - flyte_dag = _current_flyte_dag + if exc_type is None and _state["current_flyte_dag"] is not None: + flyte_dag = _state["current_flyte_dag"] flyte_dag.build() # Attach the Flyte task and a convenience run() to the DAG object. self.flyte_task = flyte_dag.flyte_task self.run = _make_run(flyte_dag.flyte_task) finally: - _current_flyte_dag = None + _state["current_flyte_dag"] = None return _original_dag_exit(self, exc_type, exc_val, exc_tb) diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index 7ca99807e..b39527bae 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -1,19 +1,11 @@ import importlib -import logging import typing from dataclasses import dataclass from pathlib import Path from typing import Any, List, Optional -import airflow.models as airflow_models -import airflow.sensors.base as airflow_sensors -import airflow.triggers.base as airflow_triggers -import airflow.utils.context as airflow_context -import jsonpickle -from airflow.operators.bash import BashOperator -from airflow.operators.python import PythonOperator - import flyte +import jsonpickle from flyte import get_custom_context, logger from flyte._context import internal_ctx from flyte._internal.resolvers.common import Resolver @@ -22,14 +14,21 @@ from flyte.extend import AsyncFunctionTaskTemplate from flyte.models import NativeInterface, SerializationContext -# Import dag module to apply DAG monkey-patches when this module is imported. -from flyteplugins.airflow import dag as _dag_module # noqa: E402 +import airflow.models as airflow_models +import airflow.sensors.base as airflow_sensors +import airflow.triggers.base as airflow_triggers +import airflow.utils.context as airflow_context +from airflow.operators.bash import BashOperator +from airflow.operators.python import PythonOperator +# Import dag module to apply DAG monkey-patches when this module is imported. +from flyteplugins.airflow import dag as _dag_module # --------------------------------------------------------------------------- # Data models # --------------------------------------------------------------------------- + @dataclass class AirflowTaskMetadata: """Stores the Airflow operator class location and constructor kwargs. @@ -53,8 +52,9 @@ class AirflowTaskMetadata: # Resolver # --------------------------------------------------------------------------- -class AirflowTaskResolver(Resolver): - """Resolves an AirflowFunctionTask on the remote worker. + +class AirflowPythonTaskResolver(Resolver): + """Resolves an AirflowPythonFunctionTask on the remote worker. The resolver records the Airflow operator metadata and the wrapped Python callable so that the task can be reconstructed from loader args alone. @@ -62,23 +62,25 @@ class AirflowTaskResolver(Resolver): @property def import_path(self) -> str: - return "flyteplugins.airflow.task.AirflowTaskResolver" + return "flyteplugins.airflow.task.AirflowPythonTaskResolver" def load_task(self, loader_args: typing.List[str]) -> AsyncFunctionTaskTemplate: - _, airflow_task_module, _, airflow_task_name, _, airflow_task_parameters, _, func_module, _, func_name = loader_args + _, airflow_task_module, _, airflow_task_name, _, airflow_task_parameters, _, func_module, _, func_name = ( + loader_args + ) func_module = importlib.import_module(name=func_module) func_def = getattr(func_module, func_name) - return AirflowFunctionTask( + return AirflowPythonFunctionTask( name=airflow_task_name, airflow_task_metadata=AirflowTaskMetadata( module=airflow_task_module, name=airflow_task_name, - parameters=jsonpickle.decode(airflow_task_parameters) + parameters=jsonpickle.decode(airflow_task_parameters), ), func=func_def, ) - def loader_args(self, task: "AirflowFunctionTask", root_dir: Path) -> List[str]: # type:ignore + def loader_args(self, task: "AirflowPythonFunctionTask", root_dir: Path) -> List[str]: # type:ignore entity_module_name, _ = extract_obj_module(task.func, root_dir) return [ "airflow-task-module", @@ -90,7 +92,7 @@ def loader_args(self, task: "AirflowFunctionTask", root_dir: Path) -> List[str]: "airflow-func-module", entity_module_name, "airflow-func-name", - task.func.__name__ + task.func.__name__, ] @@ -98,6 +100,7 @@ def loader_args(self, task: "AirflowFunctionTask", root_dir: Path) -> List[str]: # Shared task behaviour (mixin) # --------------------------------------------------------------------------- + class _AirflowTaskMixin: """Shared behaviour for both raw-container and function Airflow tasks. @@ -106,23 +109,21 @@ class _AirflowTaskMixin: """ def _init_airflow_mixin(self) -> None: - self.resolver = AirflowTaskResolver() self._call_as_synchronous = True - self._downstream_flyte_tasks: List["AirflowFunctionTask"] = [] # Airflow dependency-arrow support (>> / <<) # Records the dependency in the active FlyteDAG if one is being built. - def __rshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": + def __rshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": """``self >> other`` — other runs after self.""" - if _dag_module._current_flyte_dag is not None: - _dag_module._current_flyte_dag.set_dependency(self.name, other.name) + if _dag_module._state["current_flyte_dag"] is not None: + _dag_module._state["current_flyte_dag"].set_dependency(self.name, other.name) return other - def __lshift__(self, other: "AirflowFunctionTask") -> "AirflowFunctionTask": + def __lshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": """``self << other`` — self runs after other.""" - if _dag_module._current_flyte_dag is not None: - _dag_module._current_flyte_dag.set_dependency(other.name, self.name) + if _dag_module._state["current_flyte_dag"] is not None: + _dag_module._state["current_flyte_dag"].set_dependency(other.name, self.name) return other @staticmethod @@ -134,6 +135,7 @@ def _patch_executor_safeguard() -> None: on a background thread where the thread-local has no ``callers`` key. """ from airflow.models.baseoperator import ExecutorSafeguard + if not hasattr(ExecutorSafeguard._sentinel, "callers"): ExecutorSafeguard._sentinel.callers = {} @@ -142,7 +144,8 @@ def _patch_executor_safeguard() -> None: # Task classes # --------------------------------------------------------------------------- -class AirflowRawContainerTask(_AirflowTaskMixin, TaskTemplate): + +class AirflowShellTask(_AirflowTaskMixin, TaskTemplate): """Wraps an Airflow BashOperator as a Flyte raw-container task.""" def __init__( @@ -166,11 +169,11 @@ def container_args(self, sctx: SerializationContext) -> List[str]: async def execute(self, **kwargs) -> Any: self._patch_executor_safeguard() - logger.info("Executing Airflow task") - return _get_airflow_instance(self._airflow_task_metadata).execute(context=airflow_context.Context()) + logger.info("Executing Airflow bash operator") + _get_airflow_instance(self._airflow_task_metadata).execute(context=airflow_context.Context()) -class AirflowFunctionTask(_AirflowTaskMixin, AsyncFunctionTaskTemplate): +class AirflowPythonFunctionTask(_AirflowTaskMixin, AsyncFunctionTaskTemplate): """Wraps an Airflow PythonOperator as a Flyte function task. The airflow task module, name, and parameters are stored in the task @@ -194,19 +197,20 @@ def __init__( **kwargs, ) self._init_airflow_mixin() + self.resolver = AirflowPythonTaskResolver() self.airflow_task_metadata = airflow_task_metadata async def execute(self, **kwargs) -> Any: - self._patch_executor_safeguard() - logger.info("Executing Airflow task") + logger.info("Executing Airflow python task") self.airflow_task_metadata.parameters["python_callable"] = self.func - return _get_airflow_instance(self.airflow_task_metadata).execute(context=airflow_context.Context()) + _get_airflow_instance(self.airflow_task_metadata).execute(context=airflow_context.Context()) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- + def _get_airflow_instance( airflow_task_metadata: AirflowTaskMetadata, ) -> typing.Union[airflow_models.BaseOperator, airflow_sensors.BaseSensorOperator, airflow_triggers.BaseTrigger]: @@ -223,6 +227,7 @@ def _get_airflow_instance( # Operator intercept (monkey-patch) # --------------------------------------------------------------------------- + def _flyte_operator(*args, **kwargs): """Intercept Airflow operator construction and return a Flyte task instead. @@ -232,13 +237,9 @@ def _flyte_operator(*args, **kwargs): serialization. """ cls = args[0] - try: - if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": - logger.debug("Returning original Airflow task") - return object.__new__(cls) - except AssertionError: - # This happens when the task is created in the dynamic workflow. - logging.debug("failed to get the attribute GET_ORIGINAL_TASK from user space params") + if get_custom_context().get("GET_ORIGINAL_TASK", "False") == "True": + logger.debug("Returning original Airflow task") + return object.__new__(cls) container_image = kwargs.pop("container_image", None) task_id = kwargs.get("task_id", cls.__name__) @@ -246,17 +247,19 @@ def _flyte_operator(*args, **kwargs): if cls == BashOperator: command = kwargs.get("bash_command", "") - task = AirflowRawContainerTask(name=task_id, airflow_task_metadata=airflow_task_metadata, command=command) + task = AirflowShellTask(name=task_id, airflow_task_metadata=airflow_task_metadata, command=command) elif cls == PythonOperator: func = kwargs.get("python_callable", None) kwargs.pop("python_callable", None) - task = AirflowFunctionTask(name=task_id, airflow_task_metadata=airflow_task_metadata, func=func, image=container_image) + task = AirflowPythonFunctionTask( + name=task_id, airflow_task_metadata=airflow_task_metadata, func=func, image=container_image + ) else: raise ValueError(f"Unsupported Airflow operator: {cls.__name__}") # Case 1: inside a ``with DAG(...) as dag:`` block — register with FlyteDAG. - if _dag_module._current_flyte_dag is not None: - _dag_module._current_flyte_dag.add_task(task_id, task) + if _dag_module._state["current_flyte_dag"] is not None: + _dag_module._state["current_flyte_dag"].add_task(task_id, task) return task # Case 2: inside a Flyte task execution — submit the operator as a sub-task. diff --git a/tests/flyte/test_airflow_dag.py b/tests/flyte/test_airflow_dag.py index 3379f0940..46eabe5b3 100644 --- a/tests/flyte/test_airflow_dag.py +++ b/tests/flyte/test_airflow_dag.py @@ -1,6 +1,8 @@ """ Tests for the Airflow DAG monkey-patch in flyteplugins.airflow.dag. """ + +import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches import pytest import flyte @@ -11,6 +13,7 @@ def _reset_environment_registry(): """Clean up TaskEnvironment instances created during each test.""" from flyte._environment import _ENVIRONMENT_REGISTRY + initial_len = len(_ENVIRONMENT_REGISTRY) yield del _ENVIRONMENT_REGISTRY[initial_len:] @@ -23,7 +26,6 @@ def test_flyte_env_image_preserved_after_dag_build(): environment from the operator tasks' images (all None), silently discarding the user-supplied image. """ - from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator @@ -31,7 +33,7 @@ def test_flyte_env_image_preserved_after_dag_build(): env = flyte.TaskEnvironment(name="test-dag-env", image=custom_image) with DAG(dag_id="test_image_preserved", flyte_env=env) as dag: - BashOperator(task_id="say_hello", bash_command='echo hello') + BashOperator(task_id="say_hello", bash_command="echo hello") assert dag.flyte_task is not None parent_env = dag.flyte_task.parent_env() @@ -51,36 +53,30 @@ def test_flyte_env_image_preserved_after_dag_build(): def test_operator_tasks_registered_in_env(monkeypatch): """Operator tasks must appear in env.tasks so they are included in deployment.""" - from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator env = flyte.TaskEnvironment(name="test-dag-tasks-env") with DAG(dag_id="test_tasks_registered", flyte_env=env) as dag: - BashOperator(task_id="step1", bash_command='echo step1') - BashOperator(task_id="step2", bash_command='echo step2') + BashOperator(task_id="step1", bash_command="echo step1") + BashOperator(task_id="step2", bash_command="echo step2") parent_env = dag.flyte_task.parent_env() # Both operator tasks and the orchestrator task must be in env.tasks. env_task_names = list(parent_env.tasks.keys()) - assert any("step1" in name for name in env_task_names), ( - f"step1 not found in env.tasks: {env_task_names}" - ) - assert any("step2" in name for name in env_task_names), ( - f"step2 not found in env.tasks: {env_task_names}" - ) + assert any("step1" in name for name in env_task_names), f"step1 not found in env.tasks: {env_task_names}" + assert any("step2" in name for name in env_task_names), f"step2 not found in env.tasks: {env_task_names}" def test_default_env_created_when_flyte_env_omitted(): """When flyte_env is not supplied, a default TaskEnvironment is created using the dag_id as the name and a Debian-base image with airflow packages.""" - from flyteplugins.airflow.task import AirflowFunctionTask # applies patches from airflow import DAG from airflow.operators.bash import BashOperator with DAG(dag_id="test_default_env") as dag: - BashOperator(task_id="greet", bash_command='echo hi') + BashOperator(task_id="greet", bash_command="echo hi") assert dag.flyte_task is not None parent_env = dag.flyte_task.parent_env() From 781a596bd65bbd7abb1b34421cfc0fb42d31574b Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 00:30:43 -0800 Subject: [PATCH 14/17] refactor Signed-off-by: Kevin Su --- plugins/airflow/src/flyteplugins/airflow/dag.py | 12 +++++++----- plugins/airflow/src/flyteplugins/airflow/task.py | 12 ++++++------ .../airflow/tests}/test_airflow_dag.py | 0 .../flyteplugins/anthropic/agents/_function_tools.py | 3 +-- src/flyte/_task.py | 3 ++- 5 files changed, 16 insertions(+), 14 deletions(-) rename {tests/flyte => plugins/airflow/tests}/test_airflow_dag.py (100%) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 83094d452..8e953eb7b 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -53,9 +53,11 @@ # Module-level state # --------------------------------------------------------------------------- +_CURRENT_FLYTE_DAG = "current_flyte_dag" + #: Mutable container for the active FlyteDAG (set inside a ``with DAG(...)`` block). #: Using a dict avoids ``global`` statements in the patch functions. -_state: Dict[str, Optional["FlyteDAG"]] = {"current_flyte_dag": None} +_state: Dict[str, Optional["FlyteDAG"]] = {_CURRENT_FLYTE_DAG: None} # --------------------------------------------------------------------------- @@ -204,20 +206,20 @@ def _patched_dag_init(self, *args, **kwargs) -> None: # type: ignore[override] def _patched_dag_enter(self): # type: ignore[override] - _state["current_flyte_dag"] = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) + _state[_CURRENT_FLYTE_DAG] = FlyteDAG(dag_id=self.dag_id, env=getattr(self, "_flyte_env", None)) return _original_dag_enter(self) def _patched_dag_exit(self, exc_type, exc_val, exc_tb): # type: ignore[override] try: - if exc_type is None and _state["current_flyte_dag"] is not None: - flyte_dag = _state["current_flyte_dag"] + if exc_type is None and _state[_CURRENT_FLYTE_DAG] is not None: + flyte_dag = _state[_CURRENT_FLYTE_DAG] flyte_dag.build() # Attach the Flyte task and a convenience run() to the DAG object. self.flyte_task = flyte_dag.flyte_task self.run = _make_run(flyte_dag.flyte_task) finally: - _state["current_flyte_dag"] = None + _state[_CURRENT_FLYTE_DAG] = None return _original_dag_exit(self, exc_type, exc_val, exc_tb) diff --git a/plugins/airflow/src/flyteplugins/airflow/task.py b/plugins/airflow/src/flyteplugins/airflow/task.py index b39527bae..4d0f43b6d 100644 --- a/plugins/airflow/src/flyteplugins/airflow/task.py +++ b/plugins/airflow/src/flyteplugins/airflow/task.py @@ -116,14 +116,14 @@ def _init_airflow_mixin(self) -> None: def __rshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": """``self >> other`` — other runs after self.""" - if _dag_module._state["current_flyte_dag"] is not None: - _dag_module._state["current_flyte_dag"].set_dependency(self.name, other.name) + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].set_dependency(self.name, other.name) return other def __lshift__(self, other: "AirflowPythonFunctionTask") -> "AirflowPythonFunctionTask": """``self << other`` — self runs after other.""" - if _dag_module._state["current_flyte_dag"] is not None: - _dag_module._state["current_flyte_dag"].set_dependency(other.name, self.name) + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].set_dependency(other.name, self.name) return other @staticmethod @@ -258,8 +258,8 @@ def _flyte_operator(*args, **kwargs): raise ValueError(f"Unsupported Airflow operator: {cls.__name__}") # Case 1: inside a ``with DAG(...) as dag:`` block — register with FlyteDAG. - if _dag_module._state["current_flyte_dag"] is not None: - _dag_module._state["current_flyte_dag"].add_task(task_id, task) + if _dag_module._state[_dag_module._CURRENT_FLYTE_DAG] is not None: + _dag_module._state[_dag_module._CURRENT_FLYTE_DAG].add_task(task_id, task) return task # Case 2: inside a Flyte task execution — submit the operator as a sub-task. diff --git a/tests/flyte/test_airflow_dag.py b/plugins/airflow/tests/test_airflow_dag.py similarity index 100% rename from tests/flyte/test_airflow_dag.py rename to plugins/airflow/tests/test_airflow_dag.py diff --git a/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py b/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py index d4c4c948d..2a7152338 100644 --- a/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py +++ b/plugins/anthropic/src/flyteplugins/anthropic/agents/_function_tools.py @@ -13,11 +13,10 @@ from dataclasses import dataclass, field from functools import partial +import anthropic from flyte._task import AsyncFunctionTaskTemplate from flyte.models import NativeInterface -import anthropic - logger = logging.getLogger(__name__) diff --git a/src/flyte/_task.py b/src/flyte/_task.py index 0e671d90c..8f12e1680 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -1,7 +1,6 @@ from __future__ import annotations import asyncio -import functools import weakref from dataclasses import dataclass, field, replace from inspect import iscoroutinefunction @@ -42,6 +41,7 @@ if TYPE_CHECKING: from flyteidl2.core.tasks_pb2 import DataLoadingConfig + from ._internal.resolvers.common import Resolver from ._task_environment import TaskEnvironment @@ -150,6 +150,7 @@ def __post_init__(self): self.short_name = self.name from ._internal.resolvers.default import DefaultTaskResolver + self.resolver = self.resolver or DefaultTaskResolver() def __getstate__(self): From 91cd7f242289a76c4d8c95db7d8990ec61fdde3a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 00:31:31 -0800 Subject: [PATCH 15/17] lint Signed-off-by: Kevin Su --- plugins/airflow/tests/test_airflow_dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/airflow/tests/test_airflow_dag.py b/plugins/airflow/tests/test_airflow_dag.py index 46eabe5b3..0952bd3a3 100644 --- a/plugins/airflow/tests/test_airflow_dag.py +++ b/plugins/airflow/tests/test_airflow_dag.py @@ -2,12 +2,12 @@ Tests for the Airflow DAG monkey-patch in flyteplugins.airflow.dag. """ -import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches -import pytest - import flyte +import pytest from flyte._image import Image +import flyteplugins.airflow.task # noqa: F401 — triggers DAG + operator monkey-patches + @pytest.fixture(autouse=True) def _reset_environment_registry(): From 56a78e9f5feecc3e9926dbc355a7276bc55b1700 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 01:00:06 -0800 Subject: [PATCH 16/17] lint Signed-off-by: Kevin Su --- examples/basics/hello.py | 2 +- examples/connectors/bigquery_example.py | 12 +------ .../airflow/src/flyteplugins/airflow/dag.py | 32 +++++++++++++++---- src/flyte/_run.py | 9 ++---- src/flyte/_task.py | 4 ++- 5 files changed, 33 insertions(+), 26 deletions(-) diff --git a/examples/basics/hello.py b/examples/basics/hello.py index ad43e507a..1135ae329 100644 --- a/examples/basics/hello.py +++ b/examples/basics/hello.py @@ -28,7 +28,7 @@ def main(x_list: list[int] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) -> float: if __name__ == "__main__": flyte.init_from_config() # establish remote connection from within your script. - run = flyte.with_runcontext(mode="local").run(main, x_list=list(range(10))) # run remotely inline and pass data. + run = flyte.run(main, x_list=list(range(10))) # run remotely inline and pass data. # print various attributes of the run. print(run.name) diff --git a/examples/connectors/bigquery_example.py b/examples/connectors/bigquery_example.py index 2f9b33db6..347642e1c 100644 --- a/examples/connectors/bigquery_example.py +++ b/examples/connectors/bigquery_example.py @@ -12,19 +12,9 @@ ) bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", bigquery_task) -env = flyte.TaskEnvironment( - name="bigquery_example_env", - image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-bigquery"), - depends_on=[bigquery_env], -) - - -@env.task() -def main(version: int): - bigquery_task(version=version) if __name__ == "__main__": flyte.init_from_config() - run = flyte.with_runcontext(mode="remote").run(main, 123) + run = flyte.with_runcontext(mode="local").run(bigquery_task, 123) print(run.url) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 8e953eb7b..22109aee0 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -38,6 +38,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from flyte._task import TaskTemplate + import airflow.models.dag as _airflow_dag_module if TYPE_CHECKING: @@ -190,9 +192,29 @@ def build(self) -> None: # --------------------------------------------------------------------------- -# DAG monkey-patch helpers +# Proxy class — makes DAG instances pass isinstance(dag, TaskTemplate) # --------------------------------------------------------------------------- + +# All names defined on TaskTemplate (fields, methods, properties) that should +# be proxied to the underlying flyte_task rather than resolved on the DAG. +_TASK_TEMPLATE_NAMES = frozenset(name for name in dir(TaskTemplate) if not name.startswith("_")) | frozenset( + TaskTemplate.__dataclass_fields__.keys() +) + + +class _FlyteDAG(_airflow_dag_module.DAG, TaskTemplate): + """Makes an Airflow DAG pass ``isinstance(dag, TaskTemplate)`` by proxying + TaskTemplate attribute access to the attached ``flyte_task``. + """ + + def __getattribute__(self, name): + if name in _TASK_TEMPLATE_NAMES: + ft = object.__getattribute__(self, "__dict__").get("flyte_task") + if ft is not None: + return getattr(ft, name) + return super().__getattribute__(name) + _original_dag_init = _airflow_dag_module.DAG.__init__ _original_dag_enter = _airflow_dag_module.DAG.__enter__ _original_dag_exit = _airflow_dag_module.DAG.__exit__ @@ -215,9 +237,11 @@ def _patched_dag_exit(self, exc_type, exc_val, exc_tb): # type: ignore[override if exc_type is None and _state[_CURRENT_FLYTE_DAG] is not None: flyte_dag = _state[_CURRENT_FLYTE_DAG] flyte_dag.build() - # Attach the Flyte task and a convenience run() to the DAG object. + # Attach the Flyte task and a convenience run() to the DAG object, + # then swap __class__ so the DAG passes isinstance(dag, TaskTemplate). self.flyte_task = flyte_dag.flyte_task self.run = _make_run(flyte_dag.flyte_task) + self.__class__ = _FlyteDAG finally: _state[_CURRENT_FLYTE_DAG] = None @@ -234,10 +258,6 @@ def run(**kwargs): return run -# --------------------------------------------------------------------------- -# Apply patches -# --------------------------------------------------------------------------- - _airflow_dag_module.DAG.__init__ = _patched_dag_init _airflow_dag_module.DAG.__enter__ = _patched_dag_enter _airflow_dag_module.DAG.__exit__ = _patched_dag_exit diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 273f35884..fdc08b280 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -660,13 +660,8 @@ async def example_task(x: int, y: str) -> str: if isinstance(task, (LazyEntity, TaskDetails)) and self._mode != "remote": raise ValueError("Remote task can only be run in remote mode.") - # Allow objects (e.g. Airflow DAGs wrapped by flyteplugins) that expose a - # .flyte_task attribute to be passed directly to run(). - if not isinstance(task, TaskTemplate) and not isinstance(task, (LazyEntity, TaskDetails)): - if hasattr(task, "flyte_task") and isinstance(task.flyte_task, TaskTemplate): - task = task.flyte_task - else: - raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.") + if not isinstance(task, (TaskTemplate, LazyEntity, TaskDetails)): + raise TypeError(f"Only Flyte tasks can be run, not '{type(task)}'.") # Set the run mode in the context variable so that offloaded types (files, directories, dataframes) # can check the mode for controlling auto-uploading behavior (only enabled in remote mode). diff --git a/src/flyte/_task.py b/src/flyte/_task.py index 8f12e1680..3a34dcbe3 100644 --- a/src/flyte/_task.py +++ b/src/flyte/_task.py @@ -563,12 +563,14 @@ def container_args(self, serialize_context: SerializationContext) -> List[str]: if not serialize_context.code_bundle or not serialize_context.code_bundle.pkl: # If we do not have a code bundle, or if we have one, but it is not a pkl, we need to add the resolver + from flyte._internal.resolvers.default import DefaultTaskResolver + if not serialize_context.root_dir: raise RuntimeSystemError( "SerializationError", "Root dir is required for default task resolver when no code bundle is provided.", ) - _task_resolver = self.resolver + _task_resolver = self.resolver or DefaultTaskResolver() args = [ *args, *[ From bf57e26e1fcbc76c254cead44393a61e0d5df62e Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 27 Feb 2026 01:00:53 -0800 Subject: [PATCH 17/17] lint Signed-off-by: Kevin Su --- plugins/airflow/src/flyteplugins/airflow/dag.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/airflow/src/flyteplugins/airflow/dag.py b/plugins/airflow/src/flyteplugins/airflow/dag.py index 22109aee0..cb86d92cc 100644 --- a/plugins/airflow/src/flyteplugins/airflow/dag.py +++ b/plugins/airflow/src/flyteplugins/airflow/dag.py @@ -215,6 +215,7 @@ def __getattribute__(self, name): return getattr(ft, name) return super().__getattribute__(name) + _original_dag_init = _airflow_dag_module.DAG.__init__ _original_dag_enter = _airflow_dag_module.DAG.__enter__ _original_dag_exit = _airflow_dag_module.DAG.__exit__