Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the query to join with DagModel and DagRun to fetch logical_date and owner, and update with_for_update to lock only TI in the join query. Construct a JSON string only with host_name in extra, which is also present in earlier version of Airflow.

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
Expand Down Expand Up @@ -64,6 +64,7 @@
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun as DR
from airflow.models.log import Log
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
Expand Down Expand Up @@ -136,10 +137,14 @@ def ti_run(
# This selects the raw JSON value, by-passing the deserialization -- we want that to happen on the
# client
column("next_kwargs", JSON),
DR.logical_date,
DagModel.owners,
)
.select_from(TI)
.join(DR, and_(TI.dag_id == DR.dag_id, TI.run_id == DR.run_id))
.join(DagModel, TI.dag_id == DagModel.dag_id)
.where(TI.id == task_instance_id)
.with_for_update()
.with_for_update(of=TI)
)
try:
ti = session.execute(old).one()
Expand Down Expand Up @@ -195,6 +200,19 @@ def ti_run(
)
else:
log.info("Task started", previous_state=previous_state, hostname=ti_run_payload.hostname)
session.add(
Log(
event=TaskInstanceState.RUNNING.value,
task_id=ti.task_id,
dag_id=ti.dag_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=ti.try_number,
logical_date=ti.logical_date,
owner=ti.owners,
extra=json.dumps({"host_name": ti_run_payload.hostname}) if ti_run_payload.hostname else None,
)
)
# Ensure there is no end date set.
query = query.values(
end_date=None,
Expand Down Expand Up @@ -297,16 +315,36 @@ def ti_update_state(
log.debug("Updating task instance state", new_state=ti_patch_payload.state)

old = (
select(TI.state, TI.try_number, TI.max_tries, TI.dag_id)
select(
TI.state,
TI.try_number,
TI.max_tries,
TI.dag_id,
TI.task_id,
TI.run_id,
TI.map_index,
TI.hostname,
DR.logical_date,
DagModel.owners,
)
.select_from(TI)
.join(DR, and_(TI.dag_id == DR.dag_id, TI.run_id == DR.run_id))
.join(DagModel, TI.dag_id == DagModel.dag_id)
.where(TI.id == task_instance_id)
.with_for_update()
.with_for_update(of=TI)
)
try:
(
previous_state,
try_number,
max_tries,
dag_id,
task_id,
run_id,
map_index,
hostname,
logical_date,
owners,
) = session.execute(old).one()
log.debug(
"Retrieved current task instance state",
Expand Down Expand Up @@ -373,6 +411,19 @@ def ti_update_state(
new_state=updated_state,
rows_affected=getattr(result, "rowcount", 0),
)
session.add(
Log(
event=updated_state.value,
task_id=task_id,
dag_id=dag_id,
run_id=run_id,
map_index=map_index,
try_number=try_number,
logical_date=logical_date,
owner=owners,
extra=json.dumps({"host_name": hostname}) if hostname else None,
)
)
except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.exceptions import AirflowSkipException
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel
from airflow.models.log import Log
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
Expand All @@ -44,6 +45,7 @@
from tests_common.test_utils.db import (
clear_db_assets,
clear_db_dags,
clear_db_logs,
clear_db_runs,
clear_db_serialized_dags,
clear_rendered_ti_fields,
Expand Down Expand Up @@ -127,11 +129,13 @@ def side_effect(cred, validators):

class TestTIRunState:
def setup_method(self):
clear_db_logs()
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()

def teardown_method(self):
clear_db_logs()
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
Expand Down Expand Up @@ -793,14 +797,57 @@ def test_ti_run_with_triggering_user_name(
assert dag_run["run_id"] == "test"
assert dag_run["state"] == "running"

def test_ti_run_creates_audit_log(self, client, session, create_task_instance, time_machine):
"""Test that transitioning to RUNNING creates an audit log record."""
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_run_creates_audit_log",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
dag_id=str(uuid4()),
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "random-hostname",
"unixname": "random-unixname",
"pid": 100,
"start_date": instant_str,
},
)

assert response.status_code == 200

logs = session.scalars(select(Log).where(Log.dag_id == ti.dag_id)).all()
assert len(logs) == 1
assert logs[0].event == TaskInstanceState.RUNNING.value
assert logs[0].task_id == ti.task_id
assert logs[0].dag_id == ti.dag_id
assert logs[0].run_id == ti.run_id
assert logs[0].map_index == ti.map_index
assert logs[0].try_number == ti.try_number
assert logs[0].logical_date == instant
assert logs[0].owner == ti.task.owner
assert logs[0].extra == '{"host_name": "random-hostname"}'


class TestTIUpdateState:
def setup_method(self):
clear_db_assets()
clear_db_logs()
clear_db_runs()

def teardown_method(self):
clear_db_assets()
clear_db_logs()
clear_db_runs()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -838,6 +885,82 @@ def test_ti_update_state_to_terminal(
assert ti.state == expected_state
assert ti.end_date == end_date

@pytest.mark.parametrize(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consolidate all the test cases into a parameterized one.

("payload", "expected_event"),
[
pytest.param(
{"state": State.SUCCESS, "end_date": DEFAULT_END_DATE.isoformat()},
State.SUCCESS,
id="success",
),
pytest.param(
{"state": State.FAILED, "end_date": DEFAULT_END_DATE.isoformat()},
State.FAILED,
id="failed",
),
pytest.param(
{"state": State.SKIPPED, "end_date": DEFAULT_END_DATE.isoformat()},
State.SKIPPED,
id="skipped",
),
pytest.param(
{"state": State.UP_FOR_RETRY, "end_date": DEFAULT_END_DATE.isoformat()},
TaskInstanceState.UP_FOR_RETRY.value,
id="up_for_retry",
),
pytest.param(
{
"state": "deferred",
"trigger_kwargs": {"key": "value", "moment": "2026-02-18T00:00:00Z"},
"trigger_timeout": "P1D",
"classpath": "my-classpath",
"next_method": "execute_callback",
},
TaskInstanceState.DEFERRED.value,
id="deferred",
),
pytest.param(
{
"state": "up_for_reschedule",
"reschedule_date": "2026-02-18T11:03:00+00:00",
"end_date": DEFAULT_END_DATE.isoformat(),
},
TaskInstanceState.UP_FOR_RESCHEDULE.value,
id="up_for_reschedule",
),
],
)
def test_ti_update_state_creates_audit_log(
self, client, session, create_task_instance, payload, expected_event
):
"""Test that state transition creates an audit log record."""
ti = create_task_instance(
task_id="test_ti_update_state_creates_audit_log",
start_date=DEFAULT_START_DATE,
state=State.RUNNING,
hostname="random-hostname",
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json=payload,
)

assert response.status_code == 204

logs = session.scalars(select(Log).where(Log.dag_id == ti.dag_id)).all()
assert len(logs) == 1
assert logs[0].event == expected_event
assert logs[0].task_id == ti.task_id
assert logs[0].dag_id == ti.dag_id
assert logs[0].run_id == ti.run_id
assert logs[0].map_index == ti.map_index
assert logs[0].try_number == ti.try_number
assert logs[0].logical_date == ti.dag_run.logical_date
assert logs[0].owner == ti.task.owner
assert logs[0].extra == '{"host_name": "random-hostname"}'

@pytest.mark.parametrize(
("state", "end_date", "expected_state", "rendered_map_index"),
[
Expand Down Expand Up @@ -1063,8 +1186,34 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # Second call returns "queued"
mock.Mock(
one=lambda: (
"running",
1,
0,
"dag",
"task",
"run",
-1,
"localhost",
timezone.utcnow(),
"test_owner",
)
), # First call returns "queued"
mock.Mock(
one=lambda: (
"running",
1,
0,
"dag",
"task",
"run",
-1,
"localhost",
timezone.utcnow(),
"test_owner",
)
), # Second call returns "queued"
Comment on lines +1189 to +1216
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the TI query return more values, update the mock.

SQLAlchemyError("Database error"), # Last call raises an error
],
),
Expand Down
Loading