diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index f22d7c125853d..e963770b240e0 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -30,7 +30,7 @@ from cadwyn import VersionedAPIRouter from fastapi import Body, HTTPException, Query, status from pydantic import JsonValue -from sqlalchemy import func, or_, tuple_, update +from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.orm import joinedload @@ -64,6 +64,7 @@ from airflow.models.asset import AssetActive from airflow.models.dag import DagModel from airflow.models.dagrun import DagRun as DR +from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger @@ -136,10 +137,14 @@ def ti_run( # This selects the raw JSON value, by-passing the deserialization -- we want that to happen on the # client column("next_kwargs", JSON), + DR.logical_date, + DagModel.owners, ) .select_from(TI) + .join(DR, and_(TI.dag_id == DR.dag_id, TI.run_id == DR.run_id)) + .join(DagModel, TI.dag_id == DagModel.dag_id) .where(TI.id == task_instance_id) - .with_for_update() + .with_for_update(of=TI) ) try: ti = session.execute(old).one() @@ -195,6 +200,19 @@ def ti_run( ) else: log.info("Task started", previous_state=previous_state, hostname=ti_run_payload.hostname) + session.add( + Log( + event=TaskInstanceState.RUNNING.value, + task_id=ti.task_id, + dag_id=ti.dag_id, + run_id=ti.run_id, + map_index=ti.map_index, + try_number=ti.try_number, + logical_date=ti.logical_date, + owner=ti.owners, + extra=json.dumps({"host_name": ti_run_payload.hostname}) if ti_run_payload.hostname else None, + ) + ) # Ensure there is no end date set. query = query.values( end_date=None, @@ -297,9 +315,23 @@ def ti_update_state( log.debug("Updating task instance state", new_state=ti_patch_payload.state) old = ( - select(TI.state, TI.try_number, TI.max_tries, TI.dag_id) + select( + TI.state, + TI.try_number, + TI.max_tries, + TI.dag_id, + TI.task_id, + TI.run_id, + TI.map_index, + TI.hostname, + DR.logical_date, + DagModel.owners, + ) + .select_from(TI) + .join(DR, and_(TI.dag_id == DR.dag_id, TI.run_id == DR.run_id)) + .join(DagModel, TI.dag_id == DagModel.dag_id) .where(TI.id == task_instance_id) - .with_for_update() + .with_for_update(of=TI) ) try: ( @@ -307,6 +339,12 @@ def ti_update_state( try_number, max_tries, dag_id, + task_id, + run_id, + map_index, + hostname, + logical_date, + owners, ) = session.execute(old).one() log.debug( "Retrieved current task instance state", @@ -373,6 +411,19 @@ def ti_update_state( new_state=updated_state, rows_affected=getattr(result, "rowcount", 0), ) + session.add( + Log( + event=updated_state.value, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + try_number=try_number, + logical_date=logical_date, + owner=owners, + extra=json.dumps({"host_name": hostname}) if hostname else None, + ) + ) except SQLAlchemyError as e: log.error("Error updating Task Instance state", error=str(e)) raise HTTPException( diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index ea1153f01cba5..0fe582cb921b3 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -34,6 +34,7 @@ from airflow.exceptions import AirflowSkipException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel +from airflow.models.log import Log from airflow.models.taskinstance import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.providers.standard.operators.empty import EmptyOperator @@ -44,6 +45,7 @@ from tests_common.test_utils.db import ( clear_db_assets, clear_db_dags, + clear_db_logs, clear_db_runs, clear_db_serialized_dags, clear_rendered_ti_fields, @@ -127,11 +129,13 @@ def side_effect(cred, validators): class TestTIRunState: def setup_method(self): + clear_db_logs() clear_db_runs() clear_db_serialized_dags() clear_db_dags() def teardown_method(self): + clear_db_logs() clear_db_runs() clear_db_serialized_dags() clear_db_dags() @@ -793,14 +797,57 @@ def test_ti_run_with_triggering_user_name( assert dag_run["run_id"] == "test" assert dag_run["state"] == "running" + def test_ti_run_creates_audit_log(self, client, session, create_task_instance, time_machine): + """Test that transitioning to RUNNING creates an audit log record.""" + instant_str = "2024-09-30T12:00:00Z" + instant = timezone.parse(instant_str) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_ti_run_creates_audit_log", + state=State.QUEUED, + dagrun_state=DagRunState.RUNNING, + session=session, + start_date=instant, + dag_id=str(uuid4()), + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": instant_str, + }, + ) + + assert response.status_code == 200 + + logs = session.scalars(select(Log).where(Log.dag_id == ti.dag_id)).all() + assert len(logs) == 1 + assert logs[0].event == TaskInstanceState.RUNNING.value + assert logs[0].task_id == ti.task_id + assert logs[0].dag_id == ti.dag_id + assert logs[0].run_id == ti.run_id + assert logs[0].map_index == ti.map_index + assert logs[0].try_number == ti.try_number + assert logs[0].logical_date == instant + assert logs[0].owner == ti.task.owner + assert logs[0].extra == '{"host_name": "random-hostname"}' + class TestTIUpdateState: def setup_method(self): clear_db_assets() + clear_db_logs() clear_db_runs() def teardown_method(self): clear_db_assets() + clear_db_logs() clear_db_runs() @pytest.mark.parametrize( @@ -838,6 +885,82 @@ def test_ti_update_state_to_terminal( assert ti.state == expected_state assert ti.end_date == end_date + @pytest.mark.parametrize( + ("payload", "expected_event"), + [ + pytest.param( + {"state": State.SUCCESS, "end_date": DEFAULT_END_DATE.isoformat()}, + State.SUCCESS, + id="success", + ), + pytest.param( + {"state": State.FAILED, "end_date": DEFAULT_END_DATE.isoformat()}, + State.FAILED, + id="failed", + ), + pytest.param( + {"state": State.SKIPPED, "end_date": DEFAULT_END_DATE.isoformat()}, + State.SKIPPED, + id="skipped", + ), + pytest.param( + {"state": State.UP_FOR_RETRY, "end_date": DEFAULT_END_DATE.isoformat()}, + TaskInstanceState.UP_FOR_RETRY.value, + id="up_for_retry", + ), + pytest.param( + { + "state": "deferred", + "trigger_kwargs": {"key": "value", "moment": "2026-02-18T00:00:00Z"}, + "trigger_timeout": "P1D", + "classpath": "my-classpath", + "next_method": "execute_callback", + }, + TaskInstanceState.DEFERRED.value, + id="deferred", + ), + pytest.param( + { + "state": "up_for_reschedule", + "reschedule_date": "2026-02-18T11:03:00+00:00", + "end_date": DEFAULT_END_DATE.isoformat(), + }, + TaskInstanceState.UP_FOR_RESCHEDULE.value, + id="up_for_reschedule", + ), + ], + ) + def test_ti_update_state_creates_audit_log( + self, client, session, create_task_instance, payload, expected_event + ): + """Test that state transition creates an audit log record.""" + ti = create_task_instance( + task_id="test_ti_update_state_creates_audit_log", + start_date=DEFAULT_START_DATE, + state=State.RUNNING, + hostname="random-hostname", + ) + session.commit() + + response = client.patch( + f"/execution/task-instances/{ti.id}/state", + json=payload, + ) + + assert response.status_code == 204 + + logs = session.scalars(select(Log).where(Log.dag_id == ti.dag_id)).all() + assert len(logs) == 1 + assert logs[0].event == expected_event + assert logs[0].task_id == ti.task_id + assert logs[0].dag_id == ti.dag_id + assert logs[0].run_id == ti.run_id + assert logs[0].map_index == ti.map_index + assert logs[0].try_number == ti.try_number + assert logs[0].logical_date == ti.dag_run.logical_date + assert logs[0].owner == ti.task.owner + assert logs[0].extra == '{"host_name": "random-hostname"}' + @pytest.mark.parametrize( ("state", "end_date", "expected_state", "rendered_map_index"), [ @@ -1063,8 +1186,34 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta mock.patch( "airflow.api_fastapi.common.db.common.Session.execute", side_effect=[ - mock.Mock(one=lambda: ("running", 1, 0, "dag")), # First call returns "queued" - mock.Mock(one=lambda: ("running", 1, 0, "dag")), # Second call returns "queued" + mock.Mock( + one=lambda: ( + "running", + 1, + 0, + "dag", + "task", + "run", + -1, + "localhost", + timezone.utcnow(), + "test_owner", + ) + ), # First call returns "queued" + mock.Mock( + one=lambda: ( + "running", + 1, + 0, + "dag", + "task", + "run", + -1, + "localhost", + timezone.utcnow(), + "test_owner", + ) + ), # Second call returns "queued" SQLAlchemyError("Database error"), # Last call raises an error ], ),