Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 173 additions & 17 deletions bridges/ai/abort_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,187 @@ package ai
import (
"context"
"fmt"
"strings"
"unicode"
"unicode/utf8"

"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/id"
)

func formatAbortNotice(stopped int) string {
if stopped <= 0 {
return "Agent was aborted."
type stopPlanKind string

const (
stopPlanKindNoMatch stopPlanKind = "no-match"
stopPlanKindRoomWide stopPlanKind = "room-wide"
stopPlanKindActive stopPlanKind = "active-turn"
stopPlanKindQueued stopPlanKind = "queued-turn"
)

type userStopRequest struct {
Portal *bridgev2.Portal
Meta *PortalMetadata
ReplyTo id.EventID
RequestedByEventID id.EventID
RequestedVia string
}

type userStopPlan struct {
Kind stopPlanKind
Scope string
TargetKind string
TargetEventID id.EventID
}

type userStopResult struct {
Plan userStopPlan
ActiveStopped bool
QueuedStopped int
SubagentsStopped int
}

func stopLabel(count int, singular string) string {
if count == 1 {
return singular
}
return singular + "s"
}

func formatAbortNotice(result userStopResult) string {
switch result.Plan.Kind {
case stopPlanKindNoMatch:
return "No matching active or queued turn found for that reply."
case stopPlanKindActive:
if result.SubagentsStopped > 0 {
return fmt.Sprintf("Stopped that turn. Stopped %d %s.", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent"))
}
return "Stopped that turn."
case stopPlanKindQueued:
if result.QueuedStopped <= 1 {
return "Stopped that queued turn."
}
return fmt.Sprintf("Stopped %d queued %s.", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn"))
case stopPlanKindRoomWide:
parts := make([]string, 0, 3)
if result.ActiveStopped {
parts = append(parts, "stopped the active turn")
}
if result.QueuedStopped > 0 {
parts = append(parts, fmt.Sprintf("removed %d queued %s", result.QueuedStopped, stopLabel(result.QueuedStopped, "turn")))
}
if result.SubagentsStopped > 0 {
parts = append(parts, fmt.Sprintf("stopped %d %s", result.SubagentsStopped, stopLabel(result.SubagentsStopped, "sub-agent")))
}
if len(parts) == 0 {
return "No active or queued turns to stop."
}
for i := range parts {
r, size := utf8.DecodeRuneInString(parts[i])
parts[i] = string(unicode.ToUpper(r)) + parts[i][size:]
}
return strings.Join(parts, ". ") + "."
default:
return "No active or queued turns to stop."
}
}

func buildStopMetadata(plan userStopPlan, req userStopRequest) *assistantStopMetadata {
return &assistantStopMetadata{
Reason: "user_stop",
Scope: plan.Scope,
TargetKind: plan.TargetKind,
TargetEventID: plan.TargetEventID.String(),
RequestedByEventID: req.RequestedByEventID.String(),
RequestedVia: strings.TrimSpace(req.RequestedVia),
}
label := "sub-agents"
if stopped == 1 {
label = "sub-agent"
}

func (oc *AIClient) resolveUserStopPlan(req userStopRequest) userStopPlan {
if req.Portal == nil || req.Portal.MXID == "" {
return userStopPlan{Kind: stopPlanKindNoMatch}
}
if req.ReplyTo == "" {
return userStopPlan{
Kind: stopPlanKindRoomWide,
Scope: "room",
TargetKind: "all",
}
}

_, sourceEventID, initialEventID, _ := oc.roomRunTarget(req.Portal.MXID)
if initialEventID != "" && req.ReplyTo == initialEventID {
return userStopPlan{
Kind: stopPlanKindActive,
Scope: "turn",
TargetKind: "placeholder_event",
TargetEventID: req.ReplyTo,
}
}
if sourceEventID != "" && req.ReplyTo == sourceEventID {
return userStopPlan{
Kind: stopPlanKindActive,
Scope: "turn",
TargetKind: "source_event",
TargetEventID: req.ReplyTo,
}
}
return userStopPlan{
Kind: stopPlanKindQueued,
Scope: "turn",
TargetKind: "source_event",
TargetEventID: req.ReplyTo,
}
return fmt.Sprintf("Agent was aborted. Stopped %d %s.", stopped, label)
}

func (oc *AIClient) abortRoom(ctx context.Context, portal *bridgev2.Portal, meta *PortalMetadata) int {
if portal == nil {
return 0
func (oc *AIClient) finalizeStoppedQueueItems(ctx context.Context, items []pendingQueueItem) int {
for _, item := range items {
oc.removePendingAckReactions(oc.backgroundContext(ctx), item.pending.Portal, item.pending)
oc.sendQueueRejectedStatus(ctx, item.pending.Portal, item.pending.Event, item.pending.StatusEvents, "Stopped.")
}
oc.cancelRoomRun(portal.MXID)
oc.clearPendingQueue(portal.MXID)
stopped := oc.stopSubagentRuns(portal.MXID)
if meta != nil {
meta.AbortedLastRun = true
oc.savePortalQuiet(ctx, portal, "abort")
return len(items)
}

func (oc *AIClient) executeUserStopPlan(ctx context.Context, req userStopRequest, plan userStopPlan) userStopResult {
result := userStopResult{Plan: plan}
if req.Portal == nil || req.Portal.MXID == "" {
return result
}
roomID := req.Portal.MXID
switch plan.Kind {
case stopPlanKindRoomWide:
if oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req)) {
result.ActiveStopped = oc.cancelRoomRun(roomID)
}
result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.drainPendingQueue(roomID))
result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID)
case stopPlanKindActive:
markedStopped := oc.markRoomRunStopped(roomID, buildStopMetadata(plan, req))
if markedStopped {
result.ActiveStopped = oc.cancelRoomRun(roomID)
}
if result.ActiveStopped {
result.SubagentsStopped = oc.stopSubagentRuns(ctx, roomID)
} else {
result.Plan.Kind = stopPlanKindNoMatch
}
case stopPlanKindQueued:
result.QueuedStopped = oc.finalizeStoppedQueueItems(ctx, oc.removePendingQueueBySourceEvent(roomID, plan.TargetEventID))
if result.QueuedStopped == 0 {
result.Plan.Kind = stopPlanKindNoMatch
}
}

if req.Meta != nil && (result.ActiveStopped || result.QueuedStopped > 0 || result.SubagentsStopped > 0) {
req.Meta.AbortedLastRun = true
oc.savePortalQuiet(ctx, req.Portal, "stop")
}
return stopped
if req.Meta != nil && result.QueuedStopped > 0 {
oc.notifySessionMutation(ctx, req.Portal, req.Meta, false)
}
return result
}

func (oc *AIClient) handleUserStop(ctx context.Context, req userStopRequest) userStopResult {
plan := oc.resolveUserStopPlan(req)
return oc.executeUserStopPlan(ctx, req, plan)
}
187 changes: 187 additions & 0 deletions bridges/ai/abort_helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package ai

import (
"context"
"testing"

"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/id"

bridgesdk "github.com/beeper/agentremote/sdk"
)

func TestResolveUserStopPlanRoomWideWithoutReply(t *testing.T) {
oc := &AIClient{}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}
req := userStopRequest{Portal: portal, RequestedVia: "command"}

plan := oc.resolveUserStopPlan(req)
if plan.Kind != stopPlanKindRoomWide {
t.Fatalf("expected room-wide stop, got %#v", plan)
}
if plan.TargetKind != "all" || plan.Scope != "room" {
t.Fatalf("unexpected room-wide stop plan: %#v", plan)
}
}

func TestResolveUserStopPlanMatchesActiveReplyTargets(t *testing.T) {
roomID := id.RoomID("!room:test")
oc := &AIClient{
activeRoomRuns: map[id.RoomID]*roomRunState{
roomID: {
sourceEvent: id.EventID("$user"),
initialEvent: id.EventID("$assistant"),
},
},
}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}

placeholderPlan := oc.resolveUserStopPlan(userStopRequest{
Portal: portal,
ReplyTo: id.EventID("$assistant"),
})
if placeholderPlan.Kind != stopPlanKindActive || placeholderPlan.TargetKind != "placeholder_event" {
t.Fatalf("expected placeholder-targeted active stop, got %#v", placeholderPlan)
}

sourcePlan := oc.resolveUserStopPlan(userStopRequest{
Portal: portal,
ReplyTo: id.EventID("$user"),
})
if sourcePlan.Kind != stopPlanKindActive || sourcePlan.TargetKind != "source_event" {
t.Fatalf("expected source-targeted active stop, got %#v", sourcePlan)
}
}

func TestResolveUserStopPlanSpeculativelyReturnsQueued(t *testing.T) {
oc := &AIClient{}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}

plan := oc.resolveUserStopPlan(userStopRequest{
Portal: portal,
ReplyTo: id.EventID("$unknown"),
})
if plan.Kind != stopPlanKindQueued || plan.TargetKind != "source_event" {
t.Fatalf("expected speculative queued stop plan, got %#v", plan)
}
}

func TestExecuteUserStopPlanFallsBackToNoMatch(t *testing.T) {
oc := &AIClient{}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: "!room:test"}}

result := oc.executeUserStopPlan(context.Background(), userStopRequest{
Portal: portal,
}, userStopPlan{
Kind: stopPlanKindQueued,
Scope: "turn",
TargetKind: "source_event",
TargetEventID: id.EventID("$nonexistent"),
})
if result.Plan.Kind != stopPlanKindNoMatch {
t.Fatalf("expected no-match fallback, got %#v", result.Plan)
}
if result.QueuedStopped != 0 {
t.Fatalf("expected zero queued stopped, got %d", result.QueuedStopped)
}
}

func TestExecuteUserStopPlanRemovesOnlyTargetedQueuedTurn(t *testing.T) {
roomID := id.RoomID("!room:test")
oc := &AIClient{
pendingQueues: map[id.RoomID]*pendingQueue{
roomID: {
items: []pendingQueueItem{
{pending: pendingMessage{SourceEventID: id.EventID("$one")}},
{pending: pendingMessage{SourceEventID: id.EventID("$two")}},
},
},
},
}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}

result := oc.executeUserStopPlan(context.Background(), userStopRequest{
Portal: portal,
}, userStopPlan{
Kind: stopPlanKindQueued,
Scope: "turn",
TargetKind: "source_event",
TargetEventID: id.EventID("$one"),
})
if result.QueuedStopped != 1 {
t.Fatalf("expected one queued turn to stop, got %#v", result)
}
snapshot := oc.getQueueSnapshot(roomID)
if snapshot == nil || len(snapshot.items) != 1 {
t.Fatalf("expected one queued item to remain, got %#v", snapshot)
}
if got := snapshot.items[0].pending.sourceEventID(); got != id.EventID("$two") {
t.Fatalf("expected remaining queued event $two, got %q", got)
}
}

func TestExecuteUserStopPlanActiveNoOpFallsBackToNoMatch(t *testing.T) {
roomID := id.RoomID("!room:test")
oc := &AIClient{
activeRoomRuns: map[id.RoomID]*roomRunState{
roomID: {
sourceEvent: id.EventID("$user"),
},
},
}
portal := &bridgev2.Portal{Portal: &database.Portal{MXID: roomID}}

result := oc.executeUserStopPlan(context.Background(), userStopRequest{
Portal: portal,
ReplyTo: id.EventID("$user"),
}, userStopPlan{
Kind: stopPlanKindActive,
Scope: "turn",
TargetKind: "source_event",
TargetEventID: id.EventID("$user"),
})
if result.Plan.Kind != stopPlanKindNoMatch {
t.Fatalf("expected no-match fallback for no-op active stop, got %#v", result.Plan)
}
if result.ActiveStopped {
t.Fatalf("expected active stop to report false, got %#v", result)
}
}

func TestBuildStreamUIMessageIncludesStopMetadata(t *testing.T) {
oc := &AIClient{}
conv := bridgesdk.NewConversation[*AIClient, *Config](context.Background(), nil, nil, bridgev2.EventSender{}, nil, nil)
turn := conv.StartTurn(context.Background(), nil, &bridgesdk.SourceRef{EventID: "$user", SenderID: "@user:test"})
turn.SetID("turn-stop")
state := &streamingState{
turn: turn,
finishReason: "stop",
responseID: "resp_123",
completedAtMs: 1,
}
state.stop.Store(&assistantStopMetadata{
Reason: "user_stop",
Scope: "turn",
TargetKind: "source_event",
TargetEventID: "$user",
RequestedByEventID: "$stop",
RequestedVia: "command",
})

ui := oc.buildStreamUIMessage(state, nil, nil)
metadata, ok := ui["metadata"].(map[string]any)
if !ok {
t.Fatalf("expected metadata map, got %T", ui["metadata"])
}
stop, ok := metadata["stop"].(map[string]any)
if !ok {
t.Fatalf("expected nested stop metadata, got %#v", metadata["stop"])
}
if stop["reason"] != "user_stop" || stop["requested_via"] != "command" {
t.Fatalf("unexpected stop metadata: %#v", stop)
}
if metadata["response_status"] != "cancelled" {
t.Fatalf("expected cancelled response status for stopped turn, got %#v", metadata["response_status"])
}
}
Loading
Loading