Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/eventide/async/runtime/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ class async_node {
/// Dump the async graph reachable from this node as a DOT (graphviz) graph.
std::string dump_dot() const;

/// Returns the async_node whose coroutine body is currently executing on
/// this thread, or nullptr if no coroutine is active.
static async_node* current() noexcept;

/// Convenience: calls dump_dot() on the current node. Returns an empty
/// string if no coroutine is active.
static std::string dump_current_dot();

private:
const static async_node* get_awaiter(const async_node* node);
const static sync_primitive* get_resource_parent(const async_node* node);
Expand Down Expand Up @@ -420,4 +428,15 @@ class system_op : public async_node {
void complete() noexcept;
};

namespace detail {

/// Thread-local pointer to the async_node whose coroutine body is currently
/// executing on this thread. Set by the tracking awaiters in task.h and
/// save/restored around every resume entry-point so that user code can always
/// call async_node::current() to obtain it.
void set_current_node(async_node* node) noexcept;
async_node* current_node() noexcept;

} // namespace detail

} // namespace eventide
23 changes: 21 additions & 2 deletions include/eventide/async/runtime/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ struct promise_result<void, E, C> {
}
};

// ============================================================================
// initial_tracking_suspend — sets current node when coroutine body first runs
// ============================================================================

struct initial_tracking_suspend {
async_node* node;

bool await_ready() const noexcept {
return false;
}

void await_suspend(std::coroutine_handle<>) const noexcept {}

void await_resume() const noexcept {
detail::set_current_node(node);
}
};

// ============================================================================
// promise_exception, transition_await, cancel()
// ============================================================================
Expand Down Expand Up @@ -245,6 +263,7 @@ std::coroutine_handle<> propagate_fail(async_node* child_node, async_node* paren

// Exception: let parent resume normally; await_resume will rethrow.
if(child->propagated_exception) {
detail::set_current_node(parent_node);
return parent_task->handle();
}

Expand Down Expand Up @@ -301,8 +320,8 @@ struct task_promise_object : standard_task, promise_result<T, E, void>, promise_
return coroutine_handle::from_promise(*this);
}

auto initial_suspend() const noexcept {
return std::suspend_always();
auto initial_suspend() noexcept {
return initial_tracking_suspend{static_cast<async_node*>(this)};
}

auto final_suspend() const noexcept {
Expand Down
36 changes: 34 additions & 2 deletions src/async/runtime/frame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@

namespace eventide {

static thread_local async_node* current_running_node = nullptr;

void detail::set_current_node(async_node* node) noexcept {
current_running_node = node;
}

async_node* detail::current_node() noexcept {
return current_running_node;
}

async_node* async_node::current() noexcept {
return current_running_node;
}

std::string async_node::dump_current_dot() {
if(auto* node = current_running_node) {
return node->dump_dot();
}
return {};
}

namespace {

#if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF
Expand Down Expand Up @@ -44,7 +65,9 @@ void drain_pending_destroys() {

void detail::resume_and_drain(std::coroutine_handle<> handle) {
if(handle) {
auto* prev = current_running_node;
handle.resume();
current_running_node = prev;
}
#if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF
drain_pending_destroys();
Expand All @@ -66,17 +89,22 @@ std::coroutine_handle<> aggregate_op::deliver_deferred() noexcept {
awaiter->clear_awaitee();

switch(deferred) {
case Deferred::Resume: return static_cast<standard_task*>(awaiter)->handle();
case Deferred::Resume:
current_running_node = awaiter;
return static_cast<standard_task*>(awaiter)->handle();

case Deferred::Cancel:
if(policy & InterceptCancel) {
state = Cancelled;
current_running_node = awaiter;
return static_cast<standard_task*>(awaiter)->handle();
}
awaiter->state = Cancelled;
return awaiter->final_transition();

case Deferred::Error: return static_cast<standard_task*>(awaiter)->handle();
case Deferred::Error:
current_running_node = awaiter;
return static_cast<standard_task*>(awaiter)->handle();

case Deferred::None: break;
}
Expand Down Expand Up @@ -168,7 +196,9 @@ void async_node::cancel() {
void async_node::resume() {
if(is_standard_task()) {
if(!is_cancelled() && !is_failed()) {
auto* prev = current_running_node;
static_cast<standard_task*>(this)->handle().resume();
current_running_node = prev;
#if ETD_WORKAROUND_MSVC_COROUTINE_ASAN_UAF
drain_pending_destroys();
#endif
Expand Down Expand Up @@ -277,6 +307,7 @@ std::coroutine_handle<> async_node::handle_subtask_result(async_node* child) {
if(child->state == Cancelled) {
if(child->policy & InterceptCancel) {
self->awaitee = nullptr;
current_running_node = self;
return self->handle();
}

Expand All @@ -297,6 +328,7 @@ std::coroutine_handle<> async_node::handle_subtask_result(async_node* child) {
return propagate(child, self);
}
}
current_running_node = self;
return self->handle();
}

Expand Down
Loading
Loading