diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 9273cc8b3d487..5aaf0e2ef56bf 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -545,9 +545,27 @@ def ti_skip_downstream( task_ids = [task if isinstance(task, tuple) else (task, -1) for task in tasks] log.debug("Prepared task IDs for skipping", task_ids=task_ids) + # Don't overwrite tasks that are already executing or finished. + # See: https://github.com/apache/airflow/issues/59378 + # Note: SQL NULL NOT IN (...) is falsy, so we need an explicit IS NULL check. + skippable_state_clause = or_( + TI.state.is_(None), + TI.state.not_in( + [ + TaskInstanceState.RUNNING, + TaskInstanceState.SUCCESS, + TaskInstanceState.FAILED, + ] + ), + ) query = ( update(TI) - .where(TI.dag_id == dag_id, TI.run_id == run_id, tuple_(TI.task_id, TI.map_index).in_(task_ids)) + .where( + TI.dag_id == dag_id, + TI.run_id == run_id, + tuple_(TI.task_id, TI.map_index).in_(task_ids), + skippable_state_clause, + ) .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now) .execution_options(synchronize_session=False) ) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index d9ec3916187ee..c064c399183a1 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1481,6 +1481,95 @@ def test_ti_skip_downstream(self, client, session, create_task_instance, dag_mak assert ti1.state == State.SKIPPED +class TestTISkipDownstreamRaceCondition: + """Regression tests for #59378: state guard in ti_skip_downstream().""" + + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize( + "initial_state", + [ + State.RUNNING, + State.SUCCESS, + State.FAILED, + ], + ) + def test_skip_downstream_does_not_overwrite_terminal_or_running_ti( + self, client, session, dag_maker, initial_state + ): + with dag_maker(f"skip_race_dag_{initial_state}", session=session): + branch = EmptyOperator(task_id="branch") + downstream = EmptyOperator(task_id="downstream") + branch >> downstream + dr = dag_maker.create_dagrun(run_id="run") + + ti_branch = dr.get_task_instance("branch") + ti_branch.set_state(State.SUCCESS) + + ti_downstream = dr.get_task_instance("downstream") + ti_downstream.set_state(initial_state) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti_branch.id}/skip-downstream", + json={"tasks": ["downstream"]}, + ) + assert response.status_code == 204 + + session.expire_all() + ti_downstream = dr.get_task_instance("downstream") + assert ti_downstream.state == initial_state + + def test_skip_downstream_does_skip_queued_ti(self, client, session, dag_maker): + with dag_maker("skip_race_dag_queued", session=session): + branch = EmptyOperator(task_id="branch") + downstream = EmptyOperator(task_id="downstream") + branch >> downstream + dr = dag_maker.create_dagrun(run_id="run") + + ti_branch = dr.get_task_instance("branch") + ti_branch.set_state(State.SUCCESS) + + ti_downstream = dr.get_task_instance("downstream") + ti_downstream.set_state(TaskInstanceState.QUEUED) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti_branch.id}/skip-downstream", + json={"tasks": ["downstream"]}, + ) + assert response.status_code == 204 + + session.expire_all() + ti_downstream = dr.get_task_instance("downstream") + assert ti_downstream.state == State.SKIPPED + + def test_skip_downstream_still_skips_none_state_ti(self, client, session, dag_maker): + with dag_maker("skip_race_dag_normal", session=session): + branch = EmptyOperator(task_id="branch") + downstream = EmptyOperator(task_id="downstream") + branch >> downstream + dr = dag_maker.create_dagrun(run_id="run") + + ti_branch = dr.get_task_instance("branch") + ti_branch.set_state(State.SUCCESS) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti_branch.id}/skip-downstream", + json={"tasks": ["downstream"]}, + ) + assert response.status_code == 204 + + session.expire_all() + ti_downstream = dr.get_task_instance("downstream") + assert ti_downstream.state == State.SKIPPED + + class TestTIHealthEndpoint: def setup_method(self): clear_db_runs()