From 33b8d5fcc6038c3fc3cf2260f815ce57d017dc12 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 2 Mar 2026 14:06:11 -0800 Subject: [PATCH] Support recording inplace op intermediate output (#17796) Summary: This diff updates codegen to move `event_tracer_log_evalue` after reassign the return value to the tensor to support logging in-place updated operators like index-put. Differential Revision: D93646471 --- codegen/gen.py | 2 +- codegen/test/test_executorch_gen.py | 8 +- devtools/etdump/etdump_flatcc.cpp | 8 + devtools/inspector/tests/inspector_test.py | 253 +++++++++++++++++++++ 4 files changed, 266 insertions(+), 5 deletions(-) diff --git a/codegen/gen.py b/codegen/gen.py index 643b1c07608..9b102aa1bf6 100644 --- a/codegen/gen.py +++ b/codegen/gen.py @@ -309,8 +309,8 @@ def __call__( internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_{f.func.name}"); EXECUTORCH_SCOPE_PROF("native_call_{f.func.name}"); {ret_prefix}{kernel_call}(context, {args_str}); - {event_tracer_output_logging} {return_assignment} + {event_tracer_output_logging} {exception_boundary_end} }} ), diff --git a/codegen/test/test_executorch_gen.py b/codegen/test/test_executorch_gen.py index d9c575c1398..cd445acee62 100644 --- a/codegen/test/test_executorch_gen.py +++ b/codegen/test/test_executorch_gen.py @@ -516,9 +516,9 @@ def test_codegen_unboxed_specialized(self) -> None: internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); EXECUTORCH_SCOPE_PROF("native_call_op_1"); bool result_ = at::native::default_kernel(context, ); + *stack[0] = EValue(result_); internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - *stack[0] = EValue(result_); } ), @@ -615,9 +615,9 @@ def test_codegen_unboxed_default(self) -> None: internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); EXECUTORCH_SCOPE_PROF("native_call_op_1"); bool result_ = at::native::default_kernel(context, ); + *stack[0] = EValue(result_); internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - *stack[0] = EValue(result_); } ), @@ -642,9 +642,9 @@ def test_codegen_unboxed_default(self) -> None: internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); EXECUTORCH_SCOPE_PROF("native_call_op_1"); bool result_ = at::native::default_kernel(context, ); + *stack[0] = EValue(result_); internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - *stack[0] = EValue(result_); } catch (const std::exception& ex) { ET_LOG(Error, "Kernel threw an exception: %s", ex.what()); context.fail(torch::executor::Error::Internal); @@ -686,9 +686,9 @@ def test_codegen_unboxed_default_kernel_key_selected(self) -> None: internal::EventTracerProfileOpScope event_tracer_op_scope(context.internal_event_tracer(), "native_call_op_1"); EXECUTORCH_SCOPE_PROF("native_call_op_1"); bool result_ = at::native::default_kernel(context, ); + *stack[0] = EValue(result_); internal::event_tracer_log_evalue(context.internal_event_tracer(), *stack[0]); - *stack[0] = EValue(result_); } ), diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index a6e0a105069..d841c45afc5 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -714,6 +714,14 @@ Result ETDumpGen::write_tensor_or_return_error(Tensor tensor) { return static_cast(-1); } + // A tensor with nbytes > 0 but null data pointer indicates a corrupt PTE + // or a bug in the system. This should not happen in normal operation. + ET_CHECK_OR_RETURN_ERROR( + tensor.const_data_ptr() != nullptr, + InvalidState, + "Tensor has nbytes=%zu but null data pointer. This indicates a corrupt program or internal error.", + tensor.nbytes()); + if (!data_sink_) { return Error::InvalidArgument; } diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index f7050593cc2..4adbe29b3fb 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -58,12 +58,30 @@ to_edge, to_edge_transform_and_lower, ) +from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.pybindings.portable_lib import ( _load_for_executorch_from_buffer, ) from torch.export import export, ExportedProgram +# Models for testing inplace ops intermediate output logging +class IndexPutModel(torch.nn.Module): + """ + A model that uses index_put to update a tensor at specific indices. + When the reinplace_pass is enabled, this will be converted to index_put_ + (the inplace variant), which was causing issues with event tracer logging. + """ + + def __init__(self): + super().__init__() + self.register_buffer("data", torch.zeros(5, 3)) + + def forward(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + result = self.data.index_put((indices,), values) + return result.sum() + + OP_TYPE = "aten::add" EVENT_BLOCK_NAME = "block_0" EVENTS_SIZE = 10 @@ -1788,3 +1806,238 @@ def _gen_random_events(self) -> List[Event]: ) ) return events + + +class TestInplaceOpsIntermediateOutput(unittest.TestCase): + """ + Test suite for verifying that inplace operators correctly log intermediate + outputs when the event tracer is enabled. + + This validates the fix for an issue where inplace ops converted by the + reinplace_pass could cause logging errors because the output tensor's data + pointer was null at the time of logging. + + Note: The reinplace_pass currently only supports converting index_put to + index_put_ (see executorch/exir/passes/reinplace.py). + """ + + def _run_model_and_get_inspector( + self, + model: torch.nn.Module, + example_inputs: tuple, + run_reinplace_pass: bool = True, + ) -> Inspector: + """ + Helper method to export a model, run it with event tracing, and return + an Inspector instance for verifying intermediate outputs. + """ + model.eval() + + with tempfile.TemporaryDirectory() as tmp_dir: + model_path = os.path.join(tmp_dir, "model.pte") + etrecord_path = os.path.join(tmp_dir, "etrecord.bin") + etdump_path = os.path.join(tmp_dir, "etdump.etdp") + debug_buffer_path = os.path.join(tmp_dir, "debug_buffer.bin") + + # Step 1: Export the model + exported_program = export(model, example_inputs) + self.assertIsNotNone(exported_program) + + # Step 2: Convert to edge dialect + edge_compile_config = EdgeCompileConfig(_check_ir_validity=False) + edge_program = to_edge(exported_program, compile_config=edge_compile_config) + self.assertIsNotNone(edge_program) + + # Keep a copy for etrecord + edge_program_copy = to_edge( + export(model, example_inputs), compile_config=edge_compile_config + ) + + # Step 3: Convert to executorch with reinplace_pass enabled + executorch_config = ExecutorchBackendConfig( + run_reinplace_pass=run_reinplace_pass + ) + executorch_program = edge_program.to_executorch(config=executorch_config) + self.assertIsNotNone(executorch_program) + + # Step 4: Generate ETRecord + generate_etrecord( + etrecord_path, + edge_program_copy, + executorch_program, + ) + + # Step 5: Save the PTE file + with open(model_path, "wb") as f: + executorch_program.write_to_file(f) + + # Step 6: Load and run with event tracing enabled + with open(model_path, "rb") as f: + pte_buffer = f.read() + + executorch_module = _load_for_executorch_from_buffer( + pte_buffer, + enable_etdump=True, + debug_buffer_size=1024 * 1024, # 1MB for testing + ) + self.assertIsNotNone(executorch_module) + + # Run the model + import torch.utils._pytree as pytree + + flattened_inputs = pytree.tree_flatten(example_inputs)[0] + executorch_module.run_method("forward", tuple(flattened_inputs)) + + # Write ETDump results + executorch_module.write_etdump_result_to_file( + etdump_path, debug_buffer_path + ) + + # Check if event tracer captured data + if not os.path.exists(etdump_path): + self.skipTest( + "Event tracer not enabled. Run with --config executorch.event_tracer_enabled=true" + ) + + # Step 7: Create Inspector and return + inspector = Inspector( + etdump_path=etdump_path, + etrecord=etrecord_path, + debug_buffer_path=debug_buffer_path, + ) + return inspector + + def test_index_put_without_reinplace_pass(self): + """ + Test that the model works correctly without the reinplace pass as a + baseline comparison, and verify intermediate output correctness. + """ + model = IndexPutModel() + indices = torch.tensor([0, 2, 4]) + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + example_inputs = (indices, values) + + # Compute expected intermediate output of index_put + # index_put on zeros(5,3) with indices [0,2,4] and values [[1,2,3],[4,5,6],[7,8,9]] + # Result should be: + # [[1, 2, 3], + # [0, 0, 0], + # [4, 5, 6], + # [0, 0, 0], + # [7, 8, 9]] + expected_index_put_output = torch.zeros(5, 3) + expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0]) + expected_index_put_output[2] = torch.tensor([4.0, 5.0, 6.0]) + expected_index_put_output[4] = torch.tensor([7.0, 8.0, 9.0]) + + inspector = self._run_model_and_get_inspector( + model, example_inputs, run_reinplace_pass=False + ) + + self.assertIsNotNone(inspector) + self.assertGreater(len(inspector.event_blocks), 0) + + # Verify intermediate output correctness (same validation as with reinplace) + found_index_put_output = False + for event_block in inspector.event_blocks: + for event in event_block.events: + if hasattr(event, "debug_data") and event.debug_data is not None: + for debug_entry in event.debug_data: + if isinstance(debug_entry, torch.Tensor): + # Verify tensor has valid data pointer + self.assertIsNotNone( + debug_entry.data_ptr(), + "Intermediate output tensor should have valid data pointer", + ) + self.assertNotEqual( + debug_entry.data_ptr(), + 0, + "Intermediate output tensor data pointer should not be null", + ) + + # Check if this matches our expected index_put output shape + if debug_entry.shape == expected_index_put_output.shape: + if torch.allclose( + debug_entry, expected_index_put_output, atol=1e-5 + ): + found_index_put_output = True + + self.assertTrue( + found_index_put_output, + "Expected to find index_put intermediate output with correct tensor data (without reinplace pass).", + ) + + def test_index_put_intermediate_output_data_correctness(self): + """ + Test that the intermediate output values captured by the event tracer + are valid tensors with correct data. + + This specifically validates that: + 1. The output tensor has a valid (non-null) data pointer + 2. The output tensor contains the correct values after index_put_ + """ + model = IndexPutModel() + # Use simple values to verify correctness + indices = torch.tensor([0, 1]) + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + example_inputs = (indices, values) + + # Compute expected intermediate output of index_put + # index_put on zeros(5,3) with indices [0,1] and values [[1,2,3],[4,5,6]] + # Result should be: + # [[1, 2, 3], + # [4, 5, 6], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0]] + expected_index_put_output = torch.zeros(5, 3) + expected_index_put_output[0] = torch.tensor([1.0, 2.0, 3.0]) + expected_index_put_output[1] = torch.tensor([4.0, 5.0, 6.0]) + + inspector = self._run_model_and_get_inspector( + model, example_inputs, run_reinplace_pass=True + ) + + self.assertIsNotNone(inspector) + self.assertGreater(len(inspector.event_blocks), 0) + + total_events = sum(len(eb.events) for eb in inspector.event_blocks) + self.assertGreater( + total_events, 0, "Expected at least one event to be captured" + ) + + # Find and verify the index_put_ output + found_index_put_output = False + for event_block in inspector.event_blocks: + for event in event_block.events: + # Check if this event has debug_data (intermediate outputs) + if hasattr(event, "debug_data") and event.debug_data is not None: + for debug_entry in event.debug_data: + if isinstance(debug_entry, torch.Tensor): + # Verify tensor has valid data pointer + self.assertIsNotNone( + debug_entry.data_ptr(), + "Intermediate output tensor should have valid data pointer", + ) + self.assertNotEqual( + debug_entry.data_ptr(), + 0, + "Intermediate output tensor data pointer should not be null", + ) + + # Check if this matches our expected index_put output shape + if debug_entry.shape == expected_index_put_output.shape: + # Verify the data is correct + if torch.allclose( + debug_entry, expected_index_put_output, atol=1e-5 + ): + found_index_put_output = True + + # Assert that we found the expected index_put output with correct data + # This validates that the intermediate output was properly logged + # and contains the correct tensor values + self.assertTrue( + found_index_put_output, + "Expected to find index_put intermediate output with correct tensor data. " + "The output tensor should match the expected result of index_put operation.", + )