Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,27 @@ def ti_skip_downstream(
task_ids = [task if isinstance(task, tuple) else (task, -1) for task in tasks]
log.debug("Prepared task IDs for skipping", task_ids=task_ids)

# Don't overwrite tasks that are already executing or finished.
# See: https://github.com/apache/airflow/issues/59378
# Note: SQL NULL NOT IN (...) is falsy, so we need an explicit IS NULL check.
skippable_state_clause = or_(
TI.state.is_(None),
TI.state.not_in(
[
TaskInstanceState.RUNNING,
TaskInstanceState.SUCCESS,
TaskInstanceState.FAILED,
]
),
)
query = (
update(TI)
.where(TI.dag_id == dag_id, TI.run_id == run_id, tuple_(TI.task_id, TI.map_index).in_(task_ids))
.where(
TI.dag_id == dag_id,
TI.run_id == run_id,
tuple_(TI.task_id, TI.map_index).in_(task_ids),
skippable_state_clause,
)
.values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now)
.execution_options(synchronize_session=False)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1481,6 +1481,95 @@ def test_ti_skip_downstream(self, client, session, create_task_instance, dag_mak
assert ti1.state == State.SKIPPED


class TestTISkipDownstreamRaceCondition:
"""Regression tests for #59378: state guard in ti_skip_downstream()."""

def setup_method(self):
clear_db_runs()

def teardown_method(self):
clear_db_runs()

@pytest.mark.parametrize(
"initial_state",
[
State.RUNNING,
State.SUCCESS,
State.FAILED,
],
)
def test_skip_downstream_does_not_overwrite_terminal_or_running_ti(
self, client, session, dag_maker, initial_state
):
with dag_maker(f"skip_race_dag_{initial_state}", session=session):
branch = EmptyOperator(task_id="branch")
downstream = EmptyOperator(task_id="downstream")
branch >> downstream
dr = dag_maker.create_dagrun(run_id="run")

ti_branch = dr.get_task_instance("branch")
ti_branch.set_state(State.SUCCESS)

ti_downstream = dr.get_task_instance("downstream")
ti_downstream.set_state(initial_state)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti_branch.id}/skip-downstream",
json={"tasks": ["downstream"]},
)
assert response.status_code == 204

session.expire_all()
ti_downstream = dr.get_task_instance("downstream")
assert ti_downstream.state == initial_state

def test_skip_downstream_does_skip_queued_ti(self, client, session, dag_maker):
with dag_maker("skip_race_dag_queued", session=session):
branch = EmptyOperator(task_id="branch")
downstream = EmptyOperator(task_id="downstream")
branch >> downstream
dr = dag_maker.create_dagrun(run_id="run")

ti_branch = dr.get_task_instance("branch")
ti_branch.set_state(State.SUCCESS)

ti_downstream = dr.get_task_instance("downstream")
ti_downstream.set_state(TaskInstanceState.QUEUED)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti_branch.id}/skip-downstream",
json={"tasks": ["downstream"]},
)
assert response.status_code == 204

session.expire_all()
ti_downstream = dr.get_task_instance("downstream")
assert ti_downstream.state == State.SKIPPED

def test_skip_downstream_still_skips_none_state_ti(self, client, session, dag_maker):
with dag_maker("skip_race_dag_normal", session=session):
branch = EmptyOperator(task_id="branch")
downstream = EmptyOperator(task_id="downstream")
branch >> downstream
dr = dag_maker.create_dagrun(run_id="run")

ti_branch = dr.get_task_instance("branch")
ti_branch.set_state(State.SUCCESS)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti_branch.id}/skip-downstream",
json={"tasks": ["downstream"]},
)
assert response.status_code == 204

session.expire_all()
ti_downstream = dr.get_task_instance("downstream")
assert ti_downstream.state == State.SKIPPED


class TestTIHealthEndpoint:
def setup_method(self):
clear_db_runs()
Expand Down
Loading