diff --git a/include/eventide/async/runtime/frame.h b/include/eventide/async/runtime/frame.h index f0bcd15a..048d46d7 100644 --- a/include/eventide/async/runtime/frame.h +++ b/include/eventide/async/runtime/frame.h @@ -122,6 +122,14 @@ class async_node { /// Dump the async graph reachable from this node as a DOT (graphviz) graph. std::string dump_dot() const; + /// Returns the async_node whose coroutine body is currently executing on + /// this thread, or nullptr if no coroutine is active. + static async_node* current() noexcept; + + /// Convenience: calls dump_dot() on the current node. Returns an empty + /// string if no coroutine is active. + static std::string dump_current_dot(); + private: const static async_node* get_awaiter(const async_node* node); const static sync_primitive* get_resource_parent(const async_node* node); @@ -420,4 +428,15 @@ class system_op : public async_node { void complete() noexcept; }; +namespace detail { + +/// Thread-local pointer to the async_node whose coroutine body is currently +/// executing on this thread. Set by the tracking awaiters in task.h and +/// save/restored around every resume entry-point so that user code can always +/// call async_node::current() to obtain it. +void set_current_node(async_node* node) noexcept; +async_node* current_node() noexcept; + +} // namespace detail + } // namespace eventide diff --git a/include/eventide/async/runtime/task.h b/include/eventide/async/runtime/task.h index ee250321..4251c236 100644 --- a/include/eventide/async/runtime/task.h +++ b/include/eventide/async/runtime/task.h @@ -65,6 +65,24 @@ struct promise_result { } }; +// ============================================================================ +// initial_tracking_suspend — sets current node when coroutine body first runs +// ============================================================================ + +struct initial_tracking_suspend { + async_node* node; + + bool await_ready() const noexcept { + return false; + } + + void await_suspend(std::coroutine_handle<>) const noexcept {} + + void await_resume() const noexcept { + detail::set_current_node(node); + } +}; + // ============================================================================ // promise_exception, transition_await, cancel() // ============================================================================ @@ -245,6 +263,7 @@ std::coroutine_handle<> propagate_fail(async_node* child_node, async_node* paren // Exception: let parent resume normally; await_resume will rethrow. if(child->propagated_exception) { + detail::set_current_node(parent_node); return parent_task->handle(); } @@ -301,8 +320,8 @@ struct task_promise_object : standard_task, promise_result, promise_ return coroutine_handle::from_promise(*this); } - auto initial_suspend() const noexcept { - return std::suspend_always(); + auto initial_suspend() noexcept { + return initial_tracking_suspend{static_cast(this)}; } auto final_suspend() const noexcept { diff --git a/src/async/runtime/frame.cpp b/src/async/runtime/frame.cpp index b7db31f6..b14360b0 100644 --- a/src/async/runtime/frame.cpp +++ b/src/async/runtime/frame.cpp @@ -10,6 +10,27 @@ namespace eventide { +static thread_local async_node* current_running_node = nullptr; + +void detail::set_current_node(async_node* node) noexcept { + current_running_node = node; +} + +async_node* detail::current_node() noexcept { + return current_running_node; +} + +async_node* async_node::current() noexcept { + return current_running_node; +} + +std::string async_node::dump_current_dot() { + if(auto* node = current_running_node) { + return node->dump_dot(); + } + return {}; +} + namespace { #if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF @@ -44,7 +65,9 @@ void drain_pending_destroys() { void detail::resume_and_drain(std::coroutine_handle<> handle) { if(handle) { + auto* prev = current_running_node; handle.resume(); + current_running_node = prev; } #if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF drain_pending_destroys(); @@ -66,17 +89,22 @@ std::coroutine_handle<> aggregate_op::deliver_deferred() noexcept { awaiter->clear_awaitee(); switch(deferred) { - case Deferred::Resume: return static_cast(awaiter)->handle(); + case Deferred::Resume: + current_running_node = awaiter; + return static_cast(awaiter)->handle(); case Deferred::Cancel: if(policy & InterceptCancel) { state = Cancelled; + current_running_node = awaiter; return static_cast(awaiter)->handle(); } awaiter->state = Cancelled; return awaiter->final_transition(); - case Deferred::Error: return static_cast(awaiter)->handle(); + case Deferred::Error: + current_running_node = awaiter; + return static_cast(awaiter)->handle(); case Deferred::None: break; } @@ -168,7 +196,9 @@ void async_node::cancel() { void async_node::resume() { if(is_standard_task()) { if(!is_cancelled() && !is_failed()) { + auto* prev = current_running_node; static_cast(this)->handle().resume(); + current_running_node = prev; #if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF drain_pending_destroys(); #endif @@ -277,6 +307,7 @@ std::coroutine_handle<> async_node::handle_subtask_result(async_node* child) { if(child->state == Cancelled) { if(child->policy & InterceptCancel) { self->awaitee = nullptr; + current_running_node = self; return self->handle(); } @@ -297,6 +328,7 @@ std::coroutine_handle<> async_node::handle_subtask_result(async_node* child) { return propagate(child, self); } } + current_running_node = self; return self->handle(); } diff --git a/tests/unit/common/functional_tests.cpp b/tests/unit/common/functional_tests.cpp index ca03cba1..b8caf6e5 100644 --- a/tests/unit/common/functional_tests.cpp +++ b/tests/unit/common/functional_tests.cpp @@ -63,6 +63,72 @@ struct NonTrivialCallable { static_assert(!std::is_trivially_copyable_v); static_assert(function::sbo_eligible); +struct TrackedCallable { + int* counter; + int val; + + TrackedCallable(int* counter, int val) : counter(counter), val(val) { + ++(*counter); + } + + TrackedCallable(const TrackedCallable&) = delete; + TrackedCallable& operator=(const TrackedCallable&) = delete; + + TrackedCallable(TrackedCallable&& other) noexcept : counter(other.counter), val(other.val) { + ++(*counter); + } + + TrackedCallable& operator=(TrackedCallable&& other) noexcept { + counter = other.counter; + val = other.val; + return *this; + } + + ~TrackedCallable() { + --(*counter); + } + + int operator()(int x) const { + return val + x; + } +}; + +static_assert(function::sbo_eligible); + +struct LargeTrackedCallable { + int* counter; + int val; + [[maybe_unused]] char padding[32]{}; + + LargeTrackedCallable(int* counter, int val) : counter(counter), val(val) { + ++(*counter); + } + + LargeTrackedCallable(const LargeTrackedCallable&) = delete; + LargeTrackedCallable& operator=(const LargeTrackedCallable&) = delete; + + LargeTrackedCallable(LargeTrackedCallable&& other) noexcept : + counter(other.counter), val(other.val) { + ++(*counter); + } + + LargeTrackedCallable& operator=(LargeTrackedCallable&& other) noexcept { + counter = other.counter; + val = other.val; + return *this; + } + + ~LargeTrackedCallable() { + --(*counter); + } + + int operator()(int x) const { + return val + x; + } +}; + +static_assert(!function::sbo_eligible); + // --- Tests --- TEST_SUITE(functional) { @@ -300,6 +366,230 @@ TEST_CASE(function_move_assign_non_trivial_sbo) { EXPECT_EQ(fn2(10), 11); }; +// ===== destructor correctness tests ===== + +TEST_CASE(sbo_destructor_once) { + int counter = 0; + { + TrackedCallable tc{&counter, 10}; + EXPECT_EQ(counter, 1); + function fn(std::move(tc)); + EXPECT_EQ(fn(5), 15); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(heap_destructor_once) { + int counter = 0; + { + LargeTrackedCallable ltc{&counter, 20}; + EXPECT_EQ(counter, 1); + function fn(std::move(ltc)); + EXPECT_EQ(fn(5), 25); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_construct_sbo_destructor) { + int counter = 0; + { + TrackedCallable tc{&counter, 7}; + function fn1(std::move(tc)); + function fn2(std::move(fn1)); + EXPECT_EQ(fn2(3), 10); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_construct_heap_destructor) { + int counter = 0; + { + LargeTrackedCallable ltc{&counter, 7}; + function fn1(std::move(ltc)); + function fn2(std::move(fn1)); + EXPECT_EQ(fn2(3), 10); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_assign_sbo_destructor) { + int counter = 0; + { + TrackedCallable tc1{&counter, 1}; + TrackedCallable tc2{&counter, 2}; + function fn1(std::move(tc1)); + function fn2(std::move(tc2)); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(10), 11); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_assign_heap_destructor) { + int counter = 0; + { + LargeTrackedCallable ltc1{&counter, 1}; + LargeTrackedCallable ltc2{&counter, 2}; + function fn1(std::move(ltc1)); + function fn2(std::move(ltc2)); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(10), 11); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(heap_move_chain_destructor) { + int counter = 0; + { + LargeTrackedCallable ltc{&counter, 5}; + function fn1(std::move(ltc)); + function fn2(std::move(fn1)); + function fn3(std::move(fn2)); + function fn4(std::move(fn3)); + EXPECT_EQ(fn4(10), 15); + } + EXPECT_EQ(counter, 0); +}; + +// ===== cross-storage move assignment tests ===== + +TEST_CASE(move_assign_sbo_to_heap) { + int counter = 0; + { + TrackedCallable tc{&counter, 10}; + LargeTrackedCallable ltc{&counter, 20}; + function fn1(std::move(tc)); + function fn2(std::move(ltc)); + EXPECT_EQ(fn1(1), 11); + EXPECT_EQ(fn2(1), 21); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(1), 11); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_assign_heap_to_sbo) { + int counter = 0; + { + LargeTrackedCallable ltc{&counter, 30}; + TrackedCallable tc{&counter, 40}; + function fn1(std::move(ltc)); + function fn2(std::move(tc)); + EXPECT_EQ(fn1(1), 31); + EXPECT_EQ(fn2(1), 41); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(1), 31); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_assign_fnptr_to_sbo) { + int counter = 0; + { + TrackedCallable tc{&counter, 5}; + function fn1(free_negate); + function fn2(std::move(tc)); + EXPECT_EQ(fn1(3), -3); + EXPECT_EQ(fn2(3), 8); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(3), -3); + } + EXPECT_EQ(counter, 0); +}; + +TEST_CASE(move_assign_sbo_to_fnptr) { + int counter = 0; + { + TrackedCallable tc{&counter, 5}; + function fn1(std::move(tc)); + function fn2(free_negate); + EXPECT_EQ(fn1(3), 8); + EXPECT_EQ(fn2(3), -3); + fn2 = std::move(fn1); + EXPECT_EQ(fn2(3), 8); + } + EXPECT_EQ(counter, 0); +}; + +// ===== self-move assignment ===== + +TEST_CASE(self_move_assign_sbo) { + function fn([](int x) -> int { return x + 42; }); + auto* ptr = &fn; + fn = std::move(*ptr); + EXPECT_EQ(fn(0), 42); +}; + +TEST_CASE(self_move_assign_heap) { + [[maybe_unused]] char padding[32]{}; + int capture = 10; + auto lambda = [capture, padding](int x) -> int { + return capture + x; + }; + static_assert(sizeof(lambda) > function::sbo_size); + function fn(std::move(lambda)); + auto* ptr = &fn; + fn = std::move(*ptr); + EXPECT_EQ(fn(5), 15); +}; + +// ===== additional function_ref tests ===== + +TEST_CASE(function_ref_mutable_lambda) { + int state = 0; + auto lambda = [&state](int x) -> int { + state += x; + return state; + }; + function_ref fn(lambda); + EXPECT_EQ(fn(5), 5); + EXPECT_EQ(fn(3), 8); + EXPECT_EQ(state, 8); +}; + +TEST_CASE(function_ref_reassign) { + function_ref fn(free_add); + EXPECT_EQ(fn(1, 2), 3); + auto mul = [](int a, int b) -> int { + return a * b; + }; + function_ref fn2(mul); + fn = fn2; + EXPECT_EQ(fn(3, 4), 12); +}; + +TEST_CASE(bind_ref_reflects_mutation) { + Adder adder{0}; + auto fn = bind_ref<&Adder::add>(adder); + EXPECT_EQ(fn(5), 5); + adder.base = 100; + EXPECT_EQ(fn(5), 105); +}; + +// ===== complex type tests ===== + +TEST_CASE(function_string_return) { + function fn([](int x) -> std::string { return "val=" + std::to_string(x); }); + EXPECT_EQ(fn(42), std::string("val=42")); + + function fn2(std::move(fn)); + EXPECT_EQ(fn2(0), std::string("val=0")); +}; + +TEST_CASE(function_multiple_args) { + function fn([](int a, int b, int c) -> int { return a + b + c; }); + EXPECT_EQ(fn(1, 2, 3), 6); + + function fn2(std::move(fn)); + EXPECT_EQ(fn2(10, 20, 30), 60); +}; + +TEST_CASE(bind_const_mem_fn) { + Adder adder{10}; + auto fn = bind<&Adder::add_const>(adder); + EXPECT_EQ(fn(5), 15); +}; + // ===== mem_fn tests ===== TEST_CASE(mem_fn_non_const) {