From 71a41026a94f645214b828838834b051a7d0f00a Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 13:54:19 +1300 Subject: [PATCH 01/22] Add initial design spec for type-safe DAG execution library --- .../specs/2026-03-20-taskdag-design.md | 314 ++++++++++++++++++ 1 file changed, 314 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-20-taskdag-design.md diff --git a/docs/superpowers/specs/2026-03-20-taskdag-design.md b/docs/superpowers/specs/2026-03-20-taskdag-design.md new file mode 100644 index 00000000000..dc5101f9ce3 --- /dev/null +++ b/docs/superpowers/specs/2026-03-20-taskdag-design.md @@ -0,0 +1,314 @@ +# taskflow — Type-Safe DAG Execution Library + +## Problem + +`go-workflow` wires task dependencies through untyped closures. Dependencies are declared separately from the data they carry (`DependsOn` + `Input` callbacks), making it easy to forget one or the other. The result: nil-pointer panics at runtime, logic that's hard to follow, and no compile-time safety on data flow between tasks. + +## Goal + +A general-purpose Go library for defining and executing tasks as a DAG with: + +1. **Type-safe dependencies** — each task declares its upstream tasks as typed struct fields. The compiler enforces field types; the framework enforces that they're wired. +2. **Concurrent execution** — independent tasks run in parallel automatically. +3. **Simplicity** — no wrapper types, no magic registration functions, no hidden state. Tasks are plain Go structs. + +## Core Design + +### Task Contract + +A task is any struct that implements: + +```go +type Task interface { + Do(ctx context.Context) error +} +``` + +Dependencies are declared as a struct field named `Deps` containing pointers to upstream tasks. Outputs are written to a field named `Output`. Both are optional — a leaf task has no Deps, a sink task has no Output. + +```go +type BuildImage struct { + Output BuildOutput +} + +func (b *BuildImage) Do(ctx context.Context) error { + b.Output = BuildOutput{ImagePath: "/img"} + return nil +} + +type Deploy struct { + Deps struct { + Build *BuildImage + Config *LoadConfig + } + Output DeployOutput +} + +func (d *Deploy) Do(ctx context.Context) error { + d.Output = DeployOutput{ + URL: fmt.Sprintf("%s:%d", d.Deps.Build.Output.ImagePath, d.Deps.Config.Output.Port), + } + return nil +} +``` + +### Wiring + +The DAG is expressed through plain Go struct initialization: + +```go +build := &BuildImage{} +config := &LoadConfig{} +deploy := &Deploy{ + Deps: struct{ Build *BuildImage; Config *LoadConfig }{ + Build: build, Config: config, + }, +} +``` + +No `Add()`, no `Connect()`, no `DependsOn()`. The struct field assignments *are* the dependency declarations. + +### Execution + +```go +err := taskflow.Execute(ctx, deploy) +``` + +`Execute` takes a root task and: + +1. **Walks the graph** — reflects over each task's `Deps` field, follows pointers recursively to discover the full DAG. +2. **Validates** — checks for cycles and nil Deps pointers. Returns an error before running anything if the graph is invalid. +3. **Deduplicates** — the same task pointer reached via multiple paths (diamond dependency) is executed exactly once. +4. **Schedules** — runs tasks concurrently. A task starts only after all its Deps have completed successfully. +5. **Populates outputs** — since Deps hold pointers to upstream tasks, `task.Deps.Upstream.Output` is directly readable after the upstream completes. No copying needed — it's just Go pointer dereferencing. + +### Multiple Roots + +```go +err := taskflow.Execute(ctx, teardown1, teardown2) +``` + +All tasks across both graphs are deduplicated and run as a single DAG. + +## Configuration + +```go +err := taskflow.Execute(ctx, root, taskflow.Config{ + OnError: taskflow.CancelAll, + MaxConcurrency: 4, +}) +``` + +### Config Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `OnError` | `ErrorStrategy` | `CancelDependents` | What happens when a task fails | +| `MaxConcurrency` | `int` | `0` (unlimited) | Max number of tasks running in parallel | + +### Error Strategies + +**`CancelDependents` (default):** When a task fails, all tasks that transitively depend on it are skipped. Independent branches continue running. + +**`CancelAll`:** When any task fails, the context passed to all running tasks is canceled. Fail-fast. + +## Graph Discovery via Reflection + +At `Execute` time, the framework: + +1. For each task, checks if it has a `Deps` field of struct type. +2. Iterates over all fields in `Deps`. Each field must be a pointer to a struct that implements `Task`. +3. Follows those pointers recursively to discover the full graph. +4. Non-pointer fields in Deps, or pointers to non-Task types, are a validation error. +5. Nil pointer fields in Deps are a validation error. + +The framework never touches `Output` — that's purely a user convention. Tasks write to `self.Output`, downstream tasks read `dep.Output`. The framework only cares about `Deps` pointers and the `Task` interface. + +## Error Reporting + +`Execute` returns a `DAGError` containing the result of every task: + +```go +type DAGError struct { + Results map[Task]TaskResult +} + +type TaskResult struct { + Status TaskStatus // Succeeded, Failed, Skipped, Canceled + Err error // nil if Succeeded +} + +type TaskStatus int + +const ( + Succeeded TaskStatus = iota + Failed + Skipped // dependency failed, this task was not run + Canceled // context was canceled while running +) +``` + +`DAGError` implements `error`. `Execute` returns `nil` if all tasks succeeded. + +## Task Reuse + +Each `Execute` call re-runs all tasks in the graph from scratch. Previous `Output` values are overwritten. + +## Accessing Transitive Dependencies + +A task can read through its deps to access transitive outputs: + +```go +func (c *CreateCluster) Do(ctx context.Context) error { + rgName := c.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName + return nil +} +``` + +Safe — DAG ordering guarantees all transitive deps have completed. + +## Complete Example + +```go +package main + +import ( + "context" + "fmt" + + "github.com/example/taskflow" +) + +// --- Task definitions --- + +type CreateRG struct { + Output struct{ RGName string } +} + +func (t *CreateRG) Do(ctx context.Context) error { + t.Output.RGName = "my-rg" + return nil +} + +type CreateVNet struct { + Deps struct { + RG *CreateRG + } + Output struct{ VNetID string } +} + +func (t *CreateVNet) Do(ctx context.Context) error { + t.Output.VNetID = fmt.Sprintf("%s-vnet", t.Deps.RG.Output.RGName) + return nil +} + +type CreateSubnet struct { + Deps struct { + VNet *CreateVNet + } + Output struct{ SubnetID string } +} + +func (t *CreateSubnet) Do(ctx context.Context) error { + t.Output.SubnetID = fmt.Sprintf("%s-subnet", t.Deps.VNet.Output.VNetID) + return nil +} + +type CreateCluster struct { + Deps struct { + RG *CreateRG + Subnet *CreateSubnet + } + Output struct{ ClusterID string } +} + +func (t *CreateCluster) Do(ctx context.Context) error { + t.Output.ClusterID = fmt.Sprintf("cluster-in-%s-%s", + t.Deps.RG.Output.RGName, + t.Deps.Subnet.Output.SubnetID) + return nil +} + +type RunTests struct { + Deps struct { + Cluster *CreateCluster + } +} + +func (t *RunTests) Do(ctx context.Context) error { + fmt.Println("Running tests on", t.Deps.Cluster.Output.ClusterID) + return nil +} + +type Teardown struct { + Deps struct { + RG *CreateRG + Tests *RunTests + } +} + +func (t *Teardown) Do(ctx context.Context) error { + fmt.Println("Tearing down", t.Deps.RG.Output.RGName) + return nil +} + +// --- Wiring and execution --- + +func main() { + rg := &CreateRG{} + vnet := &CreateVNet{Deps: struct{ RG *CreateRG }{RG: rg}} + subnet := &CreateSubnet{Deps: struct{ VNet *CreateVNet }{VNet: vnet}} + cluster := &CreateCluster{Deps: struct { + RG *CreateRG + Subnet *CreateSubnet + }{RG: rg, Subnet: subnet}} + tests := &RunTests{Deps: struct{ Cluster *CreateCluster }{Cluster: cluster}} + teardown := &Teardown{Deps: struct { + RG *CreateRG + Tests *RunTests + }{RG: rg, Tests: tests}} + + // DAG (concurrent where possible): + // + // CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ + // │ │ + // └──────────────── CreateCluster ──┘ + // │ + // RunTests + // │ + // Teardown + + err := taskflow.Execute(context.Background(), teardown) + if err != nil { + panic(err) + } +} +``` + +## Validation Rules (enforced at Execute time) + +| Rule | Detection | Error | +|------|-----------|-------| +| Nil pointer in Deps | Reflection | `"task %T has nil dependency field %s"` | +| Deps field is not a pointer to Task | Reflection | `"task %T.Deps.%s: %T does not implement Task"` | +| Cycle in dependency graph | Topological sort | `"cycle detected: A -> B -> A"` | +| Deps field is not a struct | Reflection | `"task %T.Deps must be a struct"` | + +## What's NOT in Scope (V1) + +Intentionally deferred to keep V1 minimal: + +- **Retry / timeout** — implement inside `Do()`. Framework support later. +- **Conditional execution** — adds complexity. Deferred. +- **Observability hooks** — deferred. Users can wrap tasks. +- **Step naming / logging** — can use `fmt.Stringer`. Deferred. +- **WorkflowMutator pattern** — not needed. The graph is just Go structs; mutation is just Go code. + +## Package Name Candidates + +- `taskflow` — descriptive, flows well +- `tasks` — minimal, Go-idiomatic +- `rundag` — action-oriented +- `orchid` — short for orchestration, catchy + +Open to your preference. From 9efee3e9959362980672272a400aefab6f19a422 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 13:59:18 +1300 Subject: [PATCH 02/22] Rename taskflow to tasks, use named Deps structs throughout --- .../specs/2026-03-20-taskdag-design.md | 232 +++++++++++------- 1 file changed, 143 insertions(+), 89 deletions(-) diff --git a/docs/superpowers/specs/2026-03-20-taskdag-design.md b/docs/superpowers/specs/2026-03-20-taskdag-design.md index dc5101f9ce3..0eff67c429a 100644 --- a/docs/superpowers/specs/2026-03-20-taskdag-design.md +++ b/docs/superpowers/specs/2026-03-20-taskdag-design.md @@ -1,4 +1,4 @@ -# taskflow — Type-Safe DAG Execution Library +# tasks — Type-Safe DAG Execution Library ## Problem @@ -24,9 +24,17 @@ type Task interface { } ``` +**Tasks must use pointer receivers for `Do`.** Since tasks write to `self.Output` during execution, a value receiver would discard the result. The framework validates this at graph construction time. + Dependencies are declared as a struct field named `Deps` containing pointers to upstream tasks. Outputs are written to a field named `Output`. Both are optional — a leaf task has no Deps, a sink task has no Output. +**The `Deps` struct must contain only pointers to types implementing `Task`.** Any other field type (e.g., `*string`, `int`, config structs) is a validation error. Use separate struct fields outside `Deps` for non-task data. + ```go +type BuildOutput struct { + ImagePath string +} + type BuildImage struct { Output BuildOutput } @@ -36,11 +44,17 @@ func (b *BuildImage) Do(ctx context.Context) error { return nil } +type DeployDeps struct { + Build *BuildImage + Config *LoadConfig +} + +type DeployOutput struct { + URL string +} + type Deploy struct { - Deps struct { - Build *BuildImage - Config *LoadConfig - } + Deps DeployDeps Output DeployOutput } @@ -60,9 +74,7 @@ The DAG is expressed through plain Go struct initialization: build := &BuildImage{} config := &LoadConfig{} deploy := &Deploy{ - Deps: struct{ Build *BuildImage; Config *LoadConfig }{ - Build: build, Config: config, - }, + Deps: DeployDeps{Build: build, Config: config}, } ``` @@ -71,46 +83,55 @@ No `Add()`, no `Connect()`, no `DependsOn()`. The struct field assignments *are* ### Execution ```go -err := taskflow.Execute(ctx, deploy) +func Execute(ctx context.Context, cfg Config, roots ...Task) error ``` -`Execute` takes a root task and: +`Execute` takes a context, a config, and one or more root tasks: -1. **Walks the graph** — reflects over each task's `Deps` field, follows pointers recursively to discover the full DAG. -2. **Validates** — checks for cycles and nil Deps pointers. Returns an error before running anything if the graph is invalid. -3. **Deduplicates** — the same task pointer reached via multiple paths (diamond dependency) is executed exactly once. -4. **Schedules** — runs tasks concurrently. A task starts only after all its Deps have completed successfully. -5. **Populates outputs** — since Deps hold pointers to upstream tasks, `task.Deps.Upstream.Output` is directly readable after the upstream completes. No copying needed — it's just Go pointer dereferencing. +```go +// Single root with default config +err := tasks.Execute(ctx, tasks.Config{}, deploy) -### Multiple Roots +// Multiple roots +err := tasks.Execute(ctx, tasks.Config{}, teardown1, teardown2) -```go -err := taskflow.Execute(ctx, teardown1, teardown2) +// With options +err := tasks.Execute(ctx, tasks.Config{ + OnError: tasks.CancelAll, + MaxConcurrency: 4, +}, deploy) ``` -All tasks across both graphs are deduplicated and run as a single DAG. +When multiple roots are provided, all tasks across their graphs are deduplicated by pointer identity and run as a single DAG. If a root also appears as an interior node of another root's graph, it is deduplicated — not an error. + +`Execute` proceeds as follows: + +1. **Walks the graph** — reflects over each task's `Deps` field, follows pointers recursively to discover the full DAG. Nodes are identified by pointer identity. +2. **Validates** — checks for cycles (via topological sort on pointer-identity nodes), nil Deps pointers, and invalid Deps field types. Returns an error before running anything if the graph is invalid. +3. **Deduplicates** — the same task pointer reached via multiple paths (diamond dependency) is executed exactly once. +4. **Schedules** — runs tasks concurrently. A task starts only after all its Deps have completed successfully (or with the appropriate status per the error strategy). +5. **Outputs are available** — since Deps hold pointers to upstream tasks, `task.Deps.Upstream.Output` is directly readable inside `Do()`. The framework guarantees happens-before ordering: a task's goroutine is only launched after all upstream goroutines have completed and their results are visible (synchronized via `sync.WaitGroup` or channel). ## Configuration ```go -err := taskflow.Execute(ctx, root, taskflow.Config{ - OnError: taskflow.CancelAll, - MaxConcurrency: 4, -}) +type Config struct { + // OnError controls what happens when a task fails. + // Default (zero value): CancelDependents. + OnError ErrorStrategy + + // MaxConcurrency limits how many tasks run in parallel. + // 0 (default): unlimited. 1: serial execution (useful for debugging). + // Negative values are treated as 0 (unlimited). + MaxConcurrency int +} ``` -### Config Fields - -| Field | Type | Default | Description | -|-------|------|---------|-------------| -| `OnError` | `ErrorStrategy` | `CancelDependents` | What happens when a task fails | -| `MaxConcurrency` | `int` | `0` (unlimited) | Max number of tasks running in parallel | - ### Error Strategies -**`CancelDependents` (default):** When a task fails, all tasks that transitively depend on it are skipped. Independent branches continue running. +**`CancelDependents` (default / zero value):** When a task fails, all tasks that transitively depend on it are skipped (status `Skipped`). Independent branches continue running. Already-running tasks are not interrupted. -**`CancelAll`:** When any task fails, the context passed to all running tasks is canceled. Fail-fast. +**`CancelAll`:** When any task fails, the context passed to all running and future tasks is canceled. Tasks currently in `Do()` receive cancellation via `ctx.Done()` and should return promptly (status `Canceled`). Tasks that haven't started yet also get status `Canceled`. ## Graph Discovery via Reflection @@ -124,35 +145,55 @@ At `Execute` time, the framework: The framework never touches `Output` — that's purely a user convention. Tasks write to `self.Output`, downstream tasks read `dep.Output`. The framework only cares about `Deps` pointers and the `Task` interface. +### Concurrency Safety + +The framework guarantees that when `Do(ctx)` is called on a task, all upstream tasks have fully completed and their writes (including to `Output` fields) are visible. This happens-before relationship is established through Go synchronization primitives (e.g., `sync.WaitGroup.Done()` in the upstream goroutine, `sync.WaitGroup.Wait()` before launching the downstream goroutine). + +Tasks must not mutate their `Deps` fields during `Do()`. Doing so is undefined behavior. + ## Error Reporting -`Execute` returns a `DAGError` containing the result of every task: +`Execute` returns a `*DAGError` containing the result of every task: ```go type DAGError struct { + // Results is keyed by task pointer. Since tasks must be pointers, + // they are comparable and safe to use as map keys. Results map[Task]TaskResult } type TaskResult struct { - Status TaskStatus // Succeeded, Failed, Skipped, Canceled - Err error // nil if Succeeded + Status TaskStatus + Err error // nil if Succeeded } type TaskStatus int const ( Succeeded TaskStatus = iota - Failed - Skipped // dependency failed, this task was not run - Canceled // context was canceled while running + Failed // Do() returned a non-nil error + Skipped // a dependency failed (CancelDependents mode); task was never started + Canceled // context was canceled (CancelAll mode); task may or may not have started ) ``` `DAGError` implements `error`. `Execute` returns `nil` if all tasks succeeded. +### Inspecting Results + +```go +err := tasks.Execute(ctx, tasks.Config{}, root) +var dagErr *tasks.DAGError +if errors.As(err, &dagErr) { + for task, result := range dagErr.Results { + fmt.Printf("%T: %s %v\n", task, result.Status, result.Err) + } +} +``` + ## Task Reuse -Each `Execute` call re-runs all tasks in the graph from scratch. Previous `Output` values are overwritten. +Each `Execute` call re-runs all tasks in the graph from scratch. The framework does **not** reset `Output` fields — it is the user's responsibility to ensure `Do()` overwrites `Output` fully. If a task can fail partway through writing `Output`, the user should write to a local variable first and assign to `Output` only on success. ## Accessing Transitive Dependencies @@ -165,7 +206,7 @@ func (c *CreateCluster) Do(ctx context.Context) error { } ``` -Safe — DAG ordering guarantees all transitive deps have completed. +This is safe — DAG ordering guarantees all transitive deps have completed. However, it creates coupling to the internal structure of transitive dependencies. Prefer declaring direct deps when practical. ## Complete Example @@ -176,13 +217,17 @@ import ( "context" "fmt" - "github.com/example/taskflow" + "github.com/example/tasks" ) // --- Task definitions --- +type CreateRGOutput struct { + RGName string +} + type CreateRG struct { - Output struct{ RGName string } + Output CreateRGOutput } func (t *CreateRG) Do(ctx context.Context) error { @@ -190,11 +235,17 @@ func (t *CreateRG) Do(ctx context.Context) error { return nil } +type CreateVNetDeps struct { + RG *CreateRG +} + +type CreateVNetOutput struct { + VNetID string +} + type CreateVNet struct { - Deps struct { - RG *CreateRG - } - Output struct{ VNetID string } + Deps CreateVNetDeps + Output CreateVNetOutput } func (t *CreateVNet) Do(ctx context.Context) error { @@ -202,11 +253,17 @@ func (t *CreateVNet) Do(ctx context.Context) error { return nil } +type CreateSubnetDeps struct { + VNet *CreateVNet +} + +type CreateSubnetOutput struct { + SubnetID string +} + type CreateSubnet struct { - Deps struct { - VNet *CreateVNet - } - Output struct{ SubnetID string } + Deps CreateSubnetDeps + Output CreateSubnetOutput } func (t *CreateSubnet) Do(ctx context.Context) error { @@ -214,12 +271,18 @@ func (t *CreateSubnet) Do(ctx context.Context) error { return nil } +type CreateClusterDeps struct { + RG *CreateRG + Subnet *CreateSubnet +} + +type CreateClusterOutput struct { + ClusterID string +} + type CreateCluster struct { - Deps struct { - RG *CreateRG - Subnet *CreateSubnet - } - Output struct{ ClusterID string } + Deps CreateClusterDeps + Output CreateClusterOutput } func (t *CreateCluster) Do(ctx context.Context) error { @@ -229,10 +292,12 @@ func (t *CreateCluster) Do(ctx context.Context) error { return nil } +type RunTestsDeps struct { + Cluster *CreateCluster +} + type RunTests struct { - Deps struct { - Cluster *CreateCluster - } + Deps RunTestsDeps } func (t *RunTests) Do(ctx context.Context) error { @@ -240,11 +305,13 @@ func (t *RunTests) Do(ctx context.Context) error { return nil } +type TeardownDeps struct { + RG *CreateRG + Tests *RunTests +} + type Teardown struct { - Deps struct { - RG *CreateRG - Tests *RunTests - } + Deps TeardownDeps } func (t *Teardown) Do(ctx context.Context) error { @@ -256,29 +323,23 @@ func (t *Teardown) Do(ctx context.Context) error { func main() { rg := &CreateRG{} - vnet := &CreateVNet{Deps: struct{ RG *CreateRG }{RG: rg}} - subnet := &CreateSubnet{Deps: struct{ VNet *CreateVNet }{VNet: vnet}} - cluster := &CreateCluster{Deps: struct { - RG *CreateRG - Subnet *CreateSubnet - }{RG: rg, Subnet: subnet}} - tests := &RunTests{Deps: struct{ Cluster *CreateCluster }{Cluster: cluster}} - teardown := &Teardown{Deps: struct { - RG *CreateRG - Tests *RunTests - }{RG: rg, Tests: tests}} + vnet := &CreateVNet{Deps: CreateVNetDeps{RG: rg}} + subnet := &CreateSubnet{Deps: CreateSubnetDeps{VNet: vnet}} + cluster := &CreateCluster{Deps: CreateClusterDeps{RG: rg, Subnet: subnet}} + tests := &RunTests{Deps: RunTestsDeps{Cluster: cluster}} + teardown := &Teardown{Deps: TeardownDeps{RG: rg, Tests: tests}} // DAG (concurrent where possible): // // CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ // │ │ - // └──────────────── CreateCluster ──┘ - // │ - // RunTests - // │ - // Teardown + // ├──────────────── CreateCluster ──┘ + // │ │ + // │ RunTests + // │ │ + // └──────────────── Teardown - err := taskflow.Execute(context.Background(), teardown) + err := tasks.Execute(context.Background(), tasks.Config{}, teardown) if err != nil { panic(err) } @@ -291,7 +352,7 @@ func main() { |------|-----------|-------| | Nil pointer in Deps | Reflection | `"task %T has nil dependency field %s"` | | Deps field is not a pointer to Task | Reflection | `"task %T.Deps.%s: %T does not implement Task"` | -| Cycle in dependency graph | Topological sort | `"cycle detected: A -> B -> A"` | +| Cycle in dependency graph | Topological sort (pointer identity) | `"cycle detected: %T(%p) -> %T(%p) -> ..."` | | Deps field is not a struct | Reflection | `"task %T.Deps must be a struct"` | ## What's NOT in Scope (V1) @@ -303,12 +364,5 @@ Intentionally deferred to keep V1 minimal: - **Observability hooks** — deferred. Users can wrap tasks. - **Step naming / logging** — can use `fmt.Stringer`. Deferred. - **WorkflowMutator pattern** — not needed. The graph is just Go structs; mutation is just Go code. - -## Package Name Candidates - -- `taskflow` — descriptive, flows well -- `tasks` — minimal, Go-idiomatic -- `rundag` — action-oriented -- `orchid` — short for orchestration, catchy - -Open to your preference. +- **Output reset between re-runs** — user responsibility. Framework doesn't touch Output. +- **Linter for nil deps** — out of scope for the library, but a natural companion tool. From 1e5f8bf8579feecf387a2c47e1c5a2c01ac8bc42 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:14:35 +1300 Subject: [PATCH 03/22] =?UTF-8?q?tasks:=20add=20core=20types=20=E2=80=94?= =?UTF-8?q?=20Task=20interface,=20Config,=20DAGError,=20TaskStatus?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- e2e/tasks/task.go | 96 ++++++++++++++++++++++++++++++++++++++++++ e2e/tasks/task_test.go | 63 +++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 e2e/tasks/task.go create mode 100644 e2e/tasks/task_test.go diff --git a/e2e/tasks/task.go b/e2e/tasks/task.go new file mode 100644 index 00000000000..48c688c70bc --- /dev/null +++ b/e2e/tasks/task.go @@ -0,0 +1,96 @@ +package tasks + +import ( + "context" + "fmt" + "strings" +) + +// Task is the interface that all tasks must implement. +// Implement Do with a pointer receiver. +type Task interface { + Do(ctx context.Context) error +} + +// ErrorStrategy controls behavior when a task fails. +type ErrorStrategy int + +const ( + // CancelDependents skips tasks that transitively depend on the failed task. + // Independent branches continue running. + CancelDependents ErrorStrategy = iota + + // CancelAll cancels the context for all running and pending tasks. + CancelAll +) + +// Config controls execution behavior. +type Config struct { + // OnError controls what happens when a task fails. + // Default (zero value): CancelDependents. + OnError ErrorStrategy + + // MaxConcurrency limits how many tasks run in parallel. + // 0 (default): unlimited. 1: serial execution. + // Negative values are treated as 0 (unlimited). + MaxConcurrency int +} + +// TaskStatus represents the final status of a task after execution. +type TaskStatus int + +const ( + Succeeded TaskStatus = iota + Failed + Skipped + Canceled +) + +func (s TaskStatus) String() string { + switch s { + case Succeeded: + return "Succeeded" + case Failed: + return "Failed" + case Skipped: + return "Skipped" + case Canceled: + return "Canceled" + default: + return fmt.Sprintf("TaskStatus(%d)", int(s)) + } +} + +// TaskResult holds the outcome of a single task. +type TaskResult struct { + Status TaskStatus + Err error +} + +// DAGError is returned by Execute when one or more tasks did not succeed. +type DAGError struct { + Results map[Task]TaskResult +} + +func (e *DAGError) Error() string { + var failed []string + for task, result := range e.Results { + if result.Status != Succeeded { + failed = append(failed, fmt.Sprintf("%T: %s: %v", task, result.Status, result.Err)) + } + } + return fmt.Sprintf("dag execution failed: %s", strings.Join(failed, "; ")) +} + +// ValidationError is returned when the task graph fails validation. +type ValidationError struct { + Task Task + Message string +} + +func (e *ValidationError) Error() string { + if e.Task != nil { + return fmt.Sprintf("validation error on %T: %s", e.Task, e.Message) + } + return fmt.Sprintf("validation error: %s", e.Message) +} diff --git a/e2e/tasks/task_test.go b/e2e/tasks/task_test.go new file mode 100644 index 00000000000..c362378105f --- /dev/null +++ b/e2e/tasks/task_test.go @@ -0,0 +1,63 @@ +package tasks + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testTask struct { + Output string +} + +func (t *testTask) Do(ctx context.Context) error { + t.Output = "done" + return nil +} + +func TestTaskInterface(t *testing.T) { + var _ Task = (*testTask)(nil) +} + +func TestTaskStatusString(t *testing.T) { + tests := []struct { + status TaskStatus + want string + }{ + {Succeeded, "Succeeded"}, + {Failed, "Failed"}, + {Skipped, "Skipped"}, + {Canceled, "Canceled"}, + {TaskStatus(99), "TaskStatus(99)"}, + } + for _, tt := range tests { + assert.Equal(t, tt.want, tt.status.String()) + } +} + +func TestDAGErrorMessage(t *testing.T) { + task := &testTask{} + err := &DAGError{ + Results: map[Task]TaskResult{ + task: {Status: Failed, Err: fmt.Errorf("boom")}, + }, + } + msg := err.Error() + require.NotEmpty(t, msg) + assert.Contains(t, msg, "boom") + assert.Contains(t, msg, "Failed") +} + +func TestValidationErrorMessage(t *testing.T) { + task := &testTask{} + err := &ValidationError{Task: task, Message: "Deps.A is nil"} + assert.Contains(t, err.Error(), "testTask") + assert.Contains(t, err.Error(), "Deps.A is nil") + + // ValidationError without task + err2 := &ValidationError{Message: "cycle detected"} + assert.Contains(t, err2.Error(), "cycle detected") +} From 39b21c3e44d313c5760fc2e7f232650ddbd090a8 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:16:09 +1300 Subject: [PATCH 04/22] tasks: add graph discovery via reflection --- e2e/tasks/graph.go | 115 ++++++++++++++++++++++++++++++ e2e/tasks/graph_test.go | 152 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+) create mode 100644 e2e/tasks/graph.go create mode 100644 e2e/tasks/graph_test.go diff --git a/e2e/tasks/graph.go b/e2e/tasks/graph.go new file mode 100644 index 00000000000..ace8c533a0a --- /dev/null +++ b/e2e/tasks/graph.go @@ -0,0 +1,115 @@ +package tasks + +import ( + "fmt" + "reflect" +) + +// graph represents the discovered DAG. +type graph struct { + // nodes is all tasks in the graph, deduplicated by pointer identity. + nodes []Task + // deps maps each task to its direct dependencies. + deps map[Task][]Task + // dependents maps each task to the tasks that depend on it. + dependents map[Task][]Task + // order is the topological sort order, populated by validateNoCycles. + order []Task +} + +// discoverGraph walks the Deps fields of the given root tasks recursively +// to build the full DAG. Tasks are deduplicated by pointer identity. +func discoverGraph(roots []Task) (*graph, error) { + g := &graph{ + deps: make(map[Task][]Task), + dependents: make(map[Task][]Task), + } + visited := make(map[Task]bool) + for _, root := range roots { + if err := g.walk(root, visited); err != nil { + return nil, err + } + } + return g, nil +} + +func (g *graph) walk(task Task, visited map[Task]bool) error { + if visited[task] { + return nil + } + visited[task] = true + g.nodes = append(g.nodes, task) + + deps, err := extractDeps(task) + if err != nil { + return err + } + + g.deps[task] = deps + for _, dep := range deps { + g.dependents[dep] = append(g.dependents[dep], task) + if err := g.walk(dep, visited); err != nil { + return err + } + } + return nil +} + +var taskType = reflect.TypeOf((*Task)(nil)).Elem() + +// extractDeps reads the Deps field of a task via reflection and returns +// all dependency tasks found as pointer fields. +func extractDeps(task Task) ([]Task, error) { + v := reflect.ValueOf(task) + if v.Kind() != reflect.Ptr { + return nil, &ValidationError{Task: task, Message: "task must be a pointer"} + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return nil, &ValidationError{Task: task, Message: "task must be a pointer to a struct"} + } + + depsField := v.FieldByName("Deps") + if !depsField.IsValid() { + return nil, nil + } + + if depsField.Kind() != reflect.Struct { + return nil, &ValidationError{ + Task: task, + Message: "Deps field must be a struct", + } + } + + depsType := depsField.Type() + var deps []Task + for i := range depsField.NumField() { + field := depsField.Field(i) + fieldInfo := depsType.Field(i) + + if field.Kind() != reflect.Ptr { + return nil, &ValidationError{ + Task: task, + Message: fmt.Sprintf("Deps.%s must be a pointer, got %s", fieldInfo.Name, field.Type()), + } + } + + if field.IsNil() { + return nil, &ValidationError{ + Task: task, + Message: fmt.Sprintf("Deps.%s is nil", fieldInfo.Name), + } + } + + if !field.Type().Implements(taskType) { + return nil, &ValidationError{ + Task: task, + Message: fmt.Sprintf("Deps.%s: %s does not implement Task", fieldInfo.Name, field.Type()), + } + } + + dep := field.Interface().(Task) + deps = append(deps, dep) + } + return deps, nil +} diff --git a/e2e/tasks/graph_test.go b/e2e/tasks/graph_test.go new file mode 100644 index 00000000000..1aca3415d27 --- /dev/null +++ b/e2e/tasks/graph_test.go @@ -0,0 +1,152 @@ +package tasks + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- test task types for graph discovery --- + +type leafTask struct { + Output string +} + +func (t *leafTask) Do(ctx context.Context) error { return nil } + +type chainA struct{ Output string } +type chainB struct { + Deps struct{ A *chainA } +} +type chainC struct { + Deps struct{ B *chainB } +} + +func (t *chainA) Do(ctx context.Context) error { return nil } +func (t *chainB) Do(ctx context.Context) error { return nil } +func (t *chainC) Do(ctx context.Context) error { return nil } + +type diamondTop struct{ Output string } +type diamondLeft struct { + Deps struct{ Top *diamondTop } +} +type diamondRight struct { + Deps struct{ Top *diamondTop } +} +type diamondBottom struct { + Deps struct { + Left *diamondLeft + Right *diamondRight + } +} + +func (t *diamondTop) Do(ctx context.Context) error { return nil } +func (t *diamondLeft) Do(ctx context.Context) error { return nil } +func (t *diamondRight) Do(ctx context.Context) error { return nil } +func (t *diamondBottom) Do(ctx context.Context) error { return nil } + +type badDepsNotStruct struct { + Deps int +} + +func (t *badDepsNotStruct) Do(ctx context.Context) error { return nil } + +type badDepsNonPointer struct { + Deps struct { + A chainA + } +} + +func (t *badDepsNonPointer) Do(ctx context.Context) error { return nil } + +type badDepsNonTask struct { + Deps struct { + S *string + } +} + +func (t *badDepsNonTask) Do(ctx context.Context) error { return nil } + +func TestDiscoverGraph_Leaf(t *testing.T) { + task := &leafTask{} + g, err := discoverGraph([]Task{task}) + require.NoError(t, err) + assert.Len(t, g.nodes, 1) + assert.Empty(t, g.deps[task]) +} + +func TestDiscoverGraph_Chain(t *testing.T) { + a := &chainA{} + b := &chainB{} + b.Deps.A = a + c := &chainC{} + c.Deps.B = b + + g, err := discoverGraph([]Task{c}) + require.NoError(t, err) + assert.Len(t, g.nodes, 3) + + assert.Equal(t, []Task{Task(b)}, g.deps[c]) + assert.Equal(t, []Task{Task(a)}, g.deps[b]) + assert.Empty(t, g.deps[a]) +} + +func TestDiscoverGraph_Diamond(t *testing.T) { + top := &diamondTop{} + left := &diamondLeft{} + left.Deps.Top = top + right := &diamondRight{} + right.Deps.Top = top + bottom := &diamondBottom{} + bottom.Deps.Left = left + bottom.Deps.Right = right + + g, err := discoverGraph([]Task{bottom}) + require.NoError(t, err) + assert.Len(t, g.nodes, 4, "top should be deduplicated") +} + +func TestDiscoverGraph_NilDep(t *testing.T) { + b := &chainB{} // Deps.A is nil + _, err := discoverGraph([]Task{b}) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) + assert.Contains(t, ve.Message, "nil") +} + +func TestDiscoverGraph_DepsNotStruct(t *testing.T) { + task := &badDepsNotStruct{Deps: 42} + _, err := discoverGraph([]Task{task}) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) + assert.Contains(t, ve.Message, "struct") +} + +func TestDiscoverGraph_NonPointerInDeps(t *testing.T) { + task := &badDepsNonPointer{} + _, err := discoverGraph([]Task{task}) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) + assert.Contains(t, ve.Message, "pointer") +} + +func TestDiscoverGraph_NonTaskPointerInDeps(t *testing.T) { + s := "hello" + task := &badDepsNonTask{} + task.Deps.S = &s + _, err := discoverGraph([]Task{task}) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) + assert.Contains(t, ve.Message, "Task") +} From b493f46ac4fab0aee9fa63d081721eed43beb713 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:17:10 +1300 Subject: [PATCH 05/22] tasks: add cycle detection via topological sort --- e2e/tasks/validate.go | 48 +++++++++++++++++++++++++ e2e/tasks/validate_test.go | 74 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 e2e/tasks/validate.go create mode 100644 e2e/tasks/validate_test.go diff --git a/e2e/tasks/validate.go b/e2e/tasks/validate.go new file mode 100644 index 00000000000..984bbe09d5b --- /dev/null +++ b/e2e/tasks/validate.go @@ -0,0 +1,48 @@ +package tasks + +import "fmt" + +// validateNoCycles checks the graph for cycles using Kahn's algorithm. +// On success, populates g.order with a valid topological sort. +func validateNoCycles(g *graph) error { + inDegree := make(map[Task]int, len(g.nodes)) + for _, node := range g.nodes { + inDegree[node] = len(g.deps[node]) + } + + var queue []Task + for _, node := range g.nodes { + if inDegree[node] == 0 { + queue = append(queue, node) + } + } + + var sorted []Task + for len(queue) > 0 { + node := queue[0] + queue = queue[1:] + sorted = append(sorted, node) + + for _, dependent := range g.dependents[node] { + inDegree[dependent]-- + if inDegree[dependent] == 0 { + queue = append(queue, dependent) + } + } + } + + if len(sorted) != len(g.nodes) { + var cycleNodes []string + for _, node := range g.nodes { + if inDegree[node] > 0 { + cycleNodes = append(cycleNodes, fmt.Sprintf("%T(%p)", node, node)) + } + } + return &ValidationError{ + Message: fmt.Sprintf("cycle detected among tasks: %v", cycleNodes), + } + } + + g.order = sorted + return nil +} diff --git a/e2e/tasks/validate_test.go b/e2e/tasks/validate_test.go new file mode 100644 index 00000000000..3daa1aaa242 --- /dev/null +++ b/e2e/tasks/validate_test.go @@ -0,0 +1,74 @@ +package tasks + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type cycleA struct { + Deps struct{ B *cycleB } +} +type cycleB struct { + Deps struct{ A *cycleA } +} + +func (t *cycleA) Do(ctx context.Context) error { return nil } +func (t *cycleB) Do(ctx context.Context) error { return nil } + +func TestValidateNoCycles_ValidDAG(t *testing.T) { + a := &chainA{} + b := &chainB{} + b.Deps.A = a + + g, err := discoverGraph([]Task{b}) + require.NoError(t, err) + + err = validateNoCycles(g) + require.NoError(t, err) + assert.Len(t, g.order, 2) + // a should come before b in topological order + assert.Equal(t, Task(a), g.order[0]) + assert.Equal(t, Task(b), g.order[1]) +} + +func TestValidateNoCycles_Cycle(t *testing.T) { + a := &cycleA{} + b := &cycleB{} + a.Deps.B = b + b.Deps.A = a + + g, err := discoverGraph([]Task{a}) + require.NoError(t, err) + + err = validateNoCycles(g) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) + assert.Contains(t, ve.Message, "cycle") +} + +func TestValidateNoCycles_Diamond(t *testing.T) { + top := &diamondTop{} + left := &diamondLeft{} + left.Deps.Top = top + right := &diamondRight{} + right.Deps.Top = top + bottom := &diamondBottom{} + bottom.Deps.Left = left + bottom.Deps.Right = right + + g, err := discoverGraph([]Task{bottom}) + require.NoError(t, err) + + err = validateNoCycles(g) + require.NoError(t, err) + assert.Len(t, g.order, 4) + // top must come before left and right, which must come before bottom + assert.Equal(t, Task(top), g.order[0]) + assert.Equal(t, Task(bottom), g.order[3]) +} From b3a28573b810bf0237d48d2ffc46f6c13989bb9b Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:20:51 +1300 Subject: [PATCH 06/22] tasks: add concurrent scheduler with error strategies Implements Execute() which discovers the DAG, validates for cycles, and runs tasks concurrently respecting dependency order. Supports CancelDependents (skip downstream on failure) and CancelAll (cancel context) error strategies, plus MaxConcurrency semaphore limiting. --- e2e/tasks/execute.go | 164 +++++++++++++++++++ e2e/tasks/execute_test.go | 337 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 e2e/tasks/execute.go create mode 100644 e2e/tasks/execute_test.go diff --git a/e2e/tasks/execute.go b/e2e/tasks/execute.go new file mode 100644 index 00000000000..dd962adb144 --- /dev/null +++ b/e2e/tasks/execute.go @@ -0,0 +1,164 @@ +package tasks + +import ( + "context" + "sync" +) + +// Execute runs the DAG rooted at the given tasks. +// It discovers the graph via reflection, validates it, and executes +// tasks concurrently respecting dependency order. +func Execute(ctx context.Context, cfg Config, roots ...Task) error { + g, err := discoverGraph(roots) + if err != nil { + return err + } + if err := validateNoCycles(g); err != nil { + return err + } + return runGraph(ctx, cfg, g) +} + +func runGraph(ctx context.Context, cfg Config, g *graph) error { + if len(g.nodes) == 0 { + return nil + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Semaphore for MaxConcurrency + var sem chan struct{} + if cfg.MaxConcurrency > 0 { + sem = make(chan struct{}, cfg.MaxConcurrency) + } + + var mu sync.Mutex + results := make(map[Task]TaskResult, len(g.nodes)) + failed := make(map[Task]bool) + + // remaining tracks how many deps each task is still waiting on. + remaining := make(map[Task]int, len(g.nodes)) + for _, node := range g.nodes { + remaining[node] = len(g.deps[node]) + } + + var wg sync.WaitGroup + + var launch func(task Task) + launch = func(task Task) { + wg.Add(1) + go func() { + defer wg.Done() + + // Acquire semaphore slot + if sem != nil { + select { + case sem <- struct{}{}: + defer func() { <-sem }() + case <-ctx.Done(): + mu.Lock() + results[task] = TaskResult{Status: Canceled, Err: ctx.Err()} + failed[task] = true + mu.Unlock() + notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) + return + } + } + + // Check if we should skip (dependency failed) or cancel + mu.Lock() + skip := false + for _, dep := range g.deps[task] { + if failed[dep] { + skip = true + break + } + } + + if skip { + status := Skipped + if cfg.OnError == CancelAll && ctx.Err() != nil { + status = Canceled + } + results[task] = TaskResult{Status: status} + failed[task] = true + mu.Unlock() + notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) + return + } + + if ctx.Err() != nil { + results[task] = TaskResult{Status: Canceled, Err: ctx.Err()} + failed[task] = true + mu.Unlock() + notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) + return + } + mu.Unlock() + + // Run the task + taskErr := task.Do(ctx) + + mu.Lock() + if taskErr != nil { + results[task] = TaskResult{Status: Failed, Err: taskErr} + failed[task] = true + if cfg.OnError == CancelAll { + cancel() + } + } else { + results[task] = TaskResult{Status: Succeeded} + } + mu.Unlock() + + notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) + }() + } + + // Start all leaf tasks (no dependencies) + for _, node := range g.nodes { + if len(g.deps[node]) == 0 { + launch(node) + } + } + + wg.Wait() + + // Mark any tasks that were never reached + mu.Lock() + for _, node := range g.nodes { + if _, ok := results[node]; !ok { + results[node] = TaskResult{Status: Canceled, Err: ctx.Err()} + } + } + mu.Unlock() + + for _, result := range results { + if result.Status != Succeeded { + return &DAGError{Results: results} + } + } + return nil +} + +func notifyDependents( + task Task, + g *graph, + mu *sync.Mutex, + remaining map[Task]int, + failed map[Task]bool, + results map[Task]TaskResult, + launch func(Task), + cfg Config, + ctx context.Context, +) { + mu.Lock() + defer mu.Unlock() + for _, dependent := range g.dependents[task] { + remaining[dependent]-- + if remaining[dependent] == 0 { + launch(dependent) + } + } +} diff --git a/e2e/tasks/execute_test.go b/e2e/tasks/execute_test.go new file mode 100644 index 00000000000..4e55a386087 --- /dev/null +++ b/e2e/tasks/execute_test.go @@ -0,0 +1,337 @@ +package tasks + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- test task types for execution --- + +type valueTask struct { + Value int + Output int +} + +func (t *valueTask) Do(ctx context.Context) error { + t.Output = t.Value + return nil +} + +type addTask struct { + Deps struct { + A *valueTask + B *valueTask + } + Output int +} + +func (t *addTask) Do(ctx context.Context) error { + t.Output = t.Deps.A.Output + t.Deps.B.Output + return nil +} + +type failTask struct{} + +func (t *failTask) Do(ctx context.Context) error { + return fmt.Errorf("intentional failure") +} + +type afterFailTask struct { + Deps struct{ F *failTask } + ran bool +} + +func (t *afterFailTask) Do(ctx context.Context) error { + t.ran = true + return nil +} + +// --- basic execution tests --- + +func TestExecute_LeafTask(t *testing.T) { + v := &valueTask{Value: 42} + err := Execute(context.Background(), Config{}, v) + require.NoError(t, err) + assert.Equal(t, 42, v.Output) +} + +func TestExecute_OutputFlowsBetweenDeps(t *testing.T) { + a := &valueTask{Value: 3} + b := &valueTask{Value: 5} + add := &addTask{} + add.Deps.A = a + add.Deps.B = b + + err := Execute(context.Background(), Config{}, add) + require.NoError(t, err) + assert.Equal(t, 8, add.Output) +} + +func TestExecute_FailReturnsDAGError(t *testing.T) { + f := &failTask{} + err := Execute(context.Background(), Config{}, f) + require.Error(t, err) + + var dagErr *DAGError + require.True(t, errors.As(err, &dagErr)) + + result, ok := dagErr.Results[f] + require.True(t, ok) + assert.Equal(t, Failed, result.Status) + assert.Contains(t, result.Err.Error(), "intentional failure") +} + +func TestExecute_NilDep_ReturnsValidationError(t *testing.T) { + add := &addTask{} // Deps.A and Deps.B are nil + err := Execute(context.Background(), Config{}, add) + require.Error(t, err) + + var ve *ValidationError + require.True(t, errors.As(err, &ve)) +} + +// --- error strategy tests --- + +func TestExecute_CancelDependents_SkipsDownstream(t *testing.T) { + f := &failTask{} + after := &afterFailTask{} + after.Deps.F = f + + err := Execute(context.Background(), Config{OnError: CancelDependents}, after) + require.Error(t, err) + + var dagErr *DAGError + require.True(t, errors.As(err, &dagErr)) + assert.Equal(t, Failed, dagErr.Results[f].Status) + assert.Equal(t, Skipped, dagErr.Results[after].Status) + assert.False(t, after.ran, "skipped task should not have run") +} + +func TestExecute_CancelDependents_IndependentBranchContinues(t *testing.T) { + // fail and independent are both leaves; root depends on both + f := &failTask{} + independent := &valueTask{Value: 99} + + type twoDepTask struct { + Deps struct { + F *failTask + V *valueTask + } + Output int + } + // Can't define Do on local type — use a package-level type instead. + // For this test, just verify independent runs by checking its output. + // We'll verify via a different approach: run them as separate roots. + err := Execute(context.Background(), Config{OnError: CancelDependents}, f, independent) + require.Error(t, err) + // independent should have run successfully + assert.Equal(t, 99, independent.Output) +} + +func TestExecute_CancelAll_CancelsContext(t *testing.T) { + f := &failTask{} + after := &afterFailTask{} + after.Deps.F = f + + err := Execute(context.Background(), Config{OnError: CancelAll}, after) + require.Error(t, err) + + var dagErr *DAGError + require.True(t, errors.As(err, &dagErr)) + assert.Equal(t, Failed, dagErr.Results[f].Status) + assert.Equal(t, Canceled, dagErr.Results[after].Status) +} + +// --- concurrency tests --- + +func TestExecute_MaxConcurrency_Serial(t *testing.T) { + a := &valueTask{Value: 3} + b := &valueTask{Value: 5} + add := &addTask{} + add.Deps.A = a + add.Deps.B = b + + err := Execute(context.Background(), Config{MaxConcurrency: 1}, add) + require.NoError(t, err) + assert.Equal(t, 8, add.Output) +} + +func TestExecute_MaxConcurrency_Respected(t *testing.T) { + // Track max concurrent tasks + var mu sync.Mutex + var current, maxConcurrent int32 + + type trackingTask struct { + current *int32 + maxConc *int32 + mu *sync.Mutex + Output int + } + + // Can't define Do on local type. Use atomic counters and a known task type. + // Instead, test with a simpler approach using the race detector + timing. + // Just verify MaxConcurrency=1 produces correct results (tested above) + // and unlimited concurrency also works. + a := &valueTask{Value: 1} + b := &valueTask{Value: 2} + add := &addTask{} + add.Deps.A = a + add.Deps.B = b + + _ = mu + _ = current + _ = maxConcurrent + + err := Execute(context.Background(), Config{MaxConcurrency: 0}, add) + require.NoError(t, err) + assert.Equal(t, 3, add.Output) +} + +// --- diamond and dedup tests --- + +func TestExecute_Diamond(t *testing.T) { + top := &diamondTop{} + left := &diamondLeft{} + left.Deps.Top = top + right := &diamondRight{} + right.Deps.Top = top + bottom := &diamondBottom{} + bottom.Deps.Left = left + bottom.Deps.Right = right + + err := Execute(context.Background(), Config{}, bottom) + require.NoError(t, err) +} + +func TestExecute_MultipleRoots(t *testing.T) { + shared := &valueTask{Value: 10} + + a := &chainB{} + a.Deps.A = &chainA{} + b := &chainB{} + b.Deps.A = &chainA{} + + err := Execute(context.Background(), Config{}, a, b) + require.NoError(t, err) + _ = shared +} + +func TestExecute_MultipleRoots_SharedTask(t *testing.T) { + // Two roots share the same leaf — it should run only once + shared := &valueTask{Value: 7} + + add1 := &addTask{} + add1.Deps.A = shared + add1.Deps.B = &valueTask{Value: 3} + + add2 := &addTask{} + add2.Deps.A = shared + add2.Deps.B = &valueTask{Value: 5} + + err := Execute(context.Background(), Config{}, add1, add2) + require.NoError(t, err) + assert.Equal(t, 10, add1.Output) + assert.Equal(t, 12, add2.Output) + assert.Equal(t, 7, shared.Output) +} + +// --- context cancellation --- + +func TestExecute_PreCanceledContext(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + v := &valueTask{Value: 1} + // Should not hang — either succeeds or returns error + done := make(chan error, 1) + go func() { + done <- Execute(ctx, Config{}, v) + }() + + select { + case <-done: + // good — didn't hang + case <-time.After(2 * time.Second): + t.Fatal("Execute hung on pre-canceled context") + } +} + +// concurrencyTracker is a package-level task that tracks max concurrency. +type concurrencyTracker struct { + current *atomic.Int32 + peak *atomic.Int32 + Output int +} + +func (t *concurrencyTracker) Do(ctx context.Context) error { + cur := t.current.Add(1) + // Update peak + for { + p := t.peak.Load() + if cur <= p || t.peak.CompareAndSwap(p, cur) { + break + } + } + time.Sleep(10 * time.Millisecond) + t.current.Add(-1) + t.Output = 1 + return nil +} + +type concurrencyRoot struct { + Deps struct { + A *concurrencyTracker + B *concurrencyTracker + C *concurrencyTracker + D *concurrencyTracker + } +} + +func (t *concurrencyRoot) Do(ctx context.Context) error { return nil } + +func TestExecute_ConcurrentIndependentTasks(t *testing.T) { + var current, peak atomic.Int32 + + a := &concurrencyTracker{current: ¤t, peak: &peak} + b := &concurrencyTracker{current: ¤t, peak: &peak} + c := &concurrencyTracker{current: ¤t, peak: &peak} + d := &concurrencyTracker{current: ¤t, peak: &peak} + + root := &concurrencyRoot{} + root.Deps.A = a + root.Deps.B = b + root.Deps.C = c + root.Deps.D = d + + err := Execute(context.Background(), Config{}, root) + require.NoError(t, err) + // With unlimited concurrency, all 4 should run in parallel + assert.Greater(t, peak.Load(), int32(1), "independent tasks should run concurrently") +} + +func TestExecute_MaxConcurrency_LimitsParallelism(t *testing.T) { + var current, peak atomic.Int32 + + a := &concurrencyTracker{current: ¤t, peak: &peak} + b := &concurrencyTracker{current: ¤t, peak: &peak} + c := &concurrencyTracker{current: ¤t, peak: &peak} + d := &concurrencyTracker{current: ¤t, peak: &peak} + + root := &concurrencyRoot{} + root.Deps.A = a + root.Deps.B = b + root.Deps.C = c + root.Deps.D = d + + err := Execute(context.Background(), Config{MaxConcurrency: 2}, root) + require.NoError(t, err) + assert.LessOrEqual(t, peak.Load(), int32(2), "max concurrency should be respected") +} From 2e5d06e23155365ccbaea8d65681a8bc06961b9f Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:23:22 +1300 Subject: [PATCH 07/22] tasks: add integration tests and clean up vet warnings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds integration tests mirroring the complete spec example (CreateRG → CreateVNet → CreateSubnet → CreateCluster → RunTests → Teardown), plus tests for transitive dependency access, mid-pipeline failure with CancelDependents, shared tasks across independent subgraphs, and empty graph. Removes unused dead code that triggered go vet copylock warning. --- e2e/tasks/execute_test.go | 22 +-- e2e/tasks/graph_test.go | 6 +- e2e/tasks/integration_test.go | 277 ++++++++++++++++++++++++++++++++++ 3 files changed, 282 insertions(+), 23 deletions(-) create mode 100644 e2e/tasks/integration_test.go diff --git a/e2e/tasks/execute_test.go b/e2e/tasks/execute_test.go index 4e55a386087..09878c79fdc 100644 --- a/e2e/tasks/execute_test.go +++ b/e2e/tasks/execute_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "sync" "sync/atomic" "testing" "time" @@ -165,31 +164,14 @@ func TestExecute_MaxConcurrency_Serial(t *testing.T) { } func TestExecute_MaxConcurrency_Respected(t *testing.T) { - // Track max concurrent tasks - var mu sync.Mutex - var current, maxConcurrent int32 - - type trackingTask struct { - current *int32 - maxConc *int32 - mu *sync.Mutex - Output int - } - - // Can't define Do on local type. Use atomic counters and a known task type. - // Instead, test with a simpler approach using the race detector + timing. - // Just verify MaxConcurrency=1 produces correct results (tested above) - // and unlimited concurrency also works. + // Verify unlimited concurrency (MaxConcurrency=0) produces correct results. + // Actual parallelism verification is in TestExecute_MaxConcurrency_LimitsParallelism. a := &valueTask{Value: 1} b := &valueTask{Value: 2} add := &addTask{} add.Deps.A = a add.Deps.B = b - _ = mu - _ = current - _ = maxConcurrent - err := Execute(context.Background(), Config{MaxConcurrency: 0}, add) require.NoError(t, err) assert.Equal(t, 3, add.Output) diff --git a/e2e/tasks/graph_test.go b/e2e/tasks/graph_test.go index 1aca3415d27..ea08afdf24a 100644 --- a/e2e/tasks/graph_test.go +++ b/e2e/tasks/graph_test.go @@ -44,9 +44,9 @@ type diamondBottom struct { } func (t *diamondTop) Do(ctx context.Context) error { return nil } -func (t *diamondLeft) Do(ctx context.Context) error { return nil } -func (t *diamondRight) Do(ctx context.Context) error { return nil } -func (t *diamondBottom) Do(ctx context.Context) error { return nil } +func (t *diamondLeft) Do(ctx context.Context) error { return nil } +func (t *diamondRight) Do(ctx context.Context) error { return nil } +func (t *diamondBottom) Do(ctx context.Context) error { return nil } type badDepsNotStruct struct { Deps int diff --git a/e2e/tasks/integration_test.go b/e2e/tasks/integration_test.go new file mode 100644 index 00000000000..23c9476fb05 --- /dev/null +++ b/e2e/tasks/integration_test.go @@ -0,0 +1,277 @@ +package tasks + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Spec example task definitions --- +// Mirrors the complete example from the design spec: +// CreateRG → CreateVNet → CreateSubnet → CreateCluster → RunTests → Teardown + +type createRGOutput struct { + RGName string +} + +type createRG struct { + Output createRGOutput +} + +func (t *createRG) Do(ctx context.Context) error { + t.Output.RGName = "my-rg" + return nil +} + +type createVNet struct { + Deps struct { + RG *createRG + } + Output struct { + VNetID string + } +} + +func (t *createVNet) Do(ctx context.Context) error { + t.Output.VNetID = fmt.Sprintf("%s-vnet", t.Deps.RG.Output.RGName) + return nil +} + +type createSubnet struct { + Deps struct { + VNet *createVNet + } + Output struct { + SubnetID string + } +} + +func (t *createSubnet) Do(ctx context.Context) error { + t.Output.SubnetID = fmt.Sprintf("%s-subnet", t.Deps.VNet.Output.VNetID) + return nil +} + +type createCluster struct { + Deps struct { + RG *createRG + Subnet *createSubnet + } + Output struct { + ClusterID string + } +} + +func (t *createCluster) Do(ctx context.Context) error { + t.Output.ClusterID = fmt.Sprintf("cluster-in-%s-%s", + t.Deps.RG.Output.RGName, + t.Deps.Subnet.Output.SubnetID) + return nil +} + +type runTests struct { + Deps struct { + Cluster *createCluster + } + Output struct { + Passed bool + } +} + +func (t *runTests) Do(ctx context.Context) error { + t.Output.Passed = true + return nil +} + +type teardown struct { + Deps struct { + RG *createRG + Tests *runTests + } + Output struct { + TornDown bool + } +} + +func (t *teardown) Do(ctx context.Context) error { + t.Output.TornDown = true + return nil +} + +// --- Integration tests --- + +func TestIntegration_SpecExample(t *testing.T) { + // Wire up the full DAG from the spec: + // + // CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ + // │ │ + // ├──────────────── CreateCluster ──┘ + // │ │ + // │ RunTests + // │ │ + // └──────────────── Teardown + rg := &createRG{} + vnet := &createVNet{} + vnet.Deps.RG = rg + subnet := &createSubnet{} + subnet.Deps.VNet = vnet + cluster := &createCluster{} + cluster.Deps.RG = rg + cluster.Deps.Subnet = subnet + tests := &runTests{} + tests.Deps.Cluster = cluster + td := &teardown{} + td.Deps.RG = rg + td.Deps.Tests = tests + + err := Execute(context.Background(), Config{}, td) + require.NoError(t, err) + + // Verify all outputs propagated correctly + assert.Equal(t, "my-rg", rg.Output.RGName) + assert.Equal(t, "my-rg-vnet", vnet.Output.VNetID) + assert.Equal(t, "my-rg-vnet-subnet", subnet.Output.SubnetID) + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) + assert.True(t, tests.Output.Passed) + assert.True(t, td.Output.TornDown) +} + +func TestIntegration_SpecExample_WithMaxConcurrency(t *testing.T) { + rg := &createRG{} + vnet := &createVNet{} + vnet.Deps.RG = rg + subnet := &createSubnet{} + subnet.Deps.VNet = vnet + cluster := &createCluster{} + cluster.Deps.RG = rg + cluster.Deps.Subnet = subnet + tests := &runTests{} + tests.Deps.Cluster = cluster + td := &teardown{} + td.Deps.RG = rg + td.Deps.Tests = tests + + err := Execute(context.Background(), Config{MaxConcurrency: 1}, td) + require.NoError(t, err) + + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) + assert.True(t, td.Output.TornDown) +} + +func TestIntegration_TransitiveDependencyAccess(t *testing.T) { + // Verify that a task can read transitive dependencies through Deps chains + // as described in the spec's "Accessing Transitive Dependencies" section. + rg := &createRG{} + vnet := &createVNet{} + vnet.Deps.RG = rg + subnet := &createSubnet{} + subnet.Deps.VNet = vnet + cluster := &createCluster{} + cluster.Deps.RG = rg + cluster.Deps.Subnet = subnet + + err := Execute(context.Background(), Config{}, cluster) + require.NoError(t, err) + + // Access transitive dep: cluster -> subnet -> vnet -> rg + rgName := cluster.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName + assert.Equal(t, "my-rg", rgName) +} + +// failingRunTests simulates a test failure mid-pipeline +type failingRunTests struct { + Deps struct { + Cluster *createCluster + } +} + +func (t *failingRunTests) Do(ctx context.Context) error { + return fmt.Errorf("tests failed: 2 of 10 scenarios failed") +} + +type teardownAfterFail struct { + Deps struct { + RG *createRG + Tests *failingRunTests + } + Output struct{ TornDown bool } +} + +func (t *teardownAfterFail) Do(ctx context.Context) error { + t.Output.TornDown = true + return nil +} + +func TestIntegration_MidPipelineFailure_CancelDependents(t *testing.T) { + rg := &createRG{} + vnet := &createVNet{} + vnet.Deps.RG = rg + subnet := &createSubnet{} + subnet.Deps.VNet = vnet + cluster := &createCluster{} + cluster.Deps.RG = rg + cluster.Deps.Subnet = subnet + failTests := &failingRunTests{} + failTests.Deps.Cluster = cluster + + td := &teardownAfterFail{} + td.Deps.RG = rg + td.Deps.Tests = failTests + + err := Execute(context.Background(), Config{OnError: CancelDependents}, td) + require.Error(t, err) + + var dagErr *DAGError + require.True(t, errors.As(err, &dagErr)) + + // Upstream tasks should have succeeded + assert.Equal(t, Succeeded, dagErr.Results[rg].Status) + assert.Equal(t, Succeeded, dagErr.Results[vnet].Status) + assert.Equal(t, Succeeded, dagErr.Results[subnet].Status) + assert.Equal(t, Succeeded, dagErr.Results[cluster].Status) + + // failTests should have failed + assert.Equal(t, Failed, dagErr.Results[failTests].Status) + assert.Contains(t, dagErr.Results[failTests].Err.Error(), "tests failed") + + // teardown should be skipped since it depends on failTests + assert.Equal(t, Skipped, dagErr.Results[td].Status) + + // Outputs of successful tasks should still be populated + assert.Equal(t, "my-rg", rg.Output.RGName) + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) +} + +func TestIntegration_TwoIndependentSubgraphs_SharedTask(t *testing.T) { + // Two independent pipelines share CreateRG. + // Both should complete, CreateRG should execute only once. + rg := &createRG{} + + vnet1 := &createVNet{} + vnet1.Deps.RG = rg + vnet2 := &createVNet{} + vnet2.Deps.RG = rg + + subnet1 := &createSubnet{} + subnet1.Deps.VNet = vnet1 + subnet2 := &createSubnet{} + subnet2.Deps.VNet = vnet2 + + // Both subnets are roots; they share rg + err := Execute(context.Background(), Config{}, subnet1, subnet2) + require.NoError(t, err) + + assert.Equal(t, "my-rg", rg.Output.RGName) + assert.Equal(t, "my-rg-vnet", vnet1.Output.VNetID) + assert.Equal(t, "my-rg-vnet", vnet2.Output.VNetID) + assert.Equal(t, "my-rg-vnet-subnet", subnet1.Output.SubnetID) + assert.Equal(t, "my-rg-vnet-subnet", subnet2.Output.SubnetID) +} + +func TestIntegration_EmptyGraph(t *testing.T) { + err := Execute(context.Background(), Config{}) + require.NoError(t, err) +} From 0c3f1d477e32db5d67df7c4227bf7e945c06ecd6 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Fri, 20 Mar 2026 15:56:16 +1300 Subject: [PATCH 08/22] tasks: simplify scheduler and address review findings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace 9-param notifyDependents with runState struct methods - Eliminate redundant `failed` map — derive from results directly - Collapse 3 near-identical early-exit blocks into abort() helper - Launch dependents outside the mutex (collect-then-launch) - Remove spurious mu.Lock after wg.Wait (all goroutines done) - Sort DAGError.Error() output for deterministic error messages - Add ErrorStrategy.String() for consistency with TaskStatus - Remove dead test code (unused shared var, dead twoDepTask type) - Extract buildSpecDAG() helper to reduce copy-paste in tests - Reduce pre-canceled context test timeout from 2s to 100ms --- e2e/tasks/execute.go | 219 +++++++++++++++++----------------- e2e/tasks/execute_test.go | 18 +-- e2e/tasks/integration_test.go | 107 ++++++++--------- e2e/tasks/task.go | 13 ++ 4 files changed, 175 insertions(+), 182 deletions(-) diff --git a/e2e/tasks/execute.go b/e2e/tasks/execute.go index dd962adb144..cf0c4f66fa4 100644 --- a/e2e/tasks/execute.go +++ b/e2e/tasks/execute.go @@ -19,6 +19,21 @@ func Execute(ctx context.Context, cfg Config, roots ...Task) error { return runGraph(ctx, cfg, g) } +// runState holds all shared mutable state for a single DAG execution. +type runState struct { + g *graph + cfg Config + ctx context.Context + cancel context.CancelFunc + sem chan struct{} + mu sync.Mutex + wg sync.WaitGroup + results map[Task]TaskResult + + // remaining tracks how many deps each task is still waiting on. + remaining map[Task]int +} + func runGraph(ctx context.Context, cfg Config, g *graph) error { if len(g.nodes) == 0 { return nil @@ -27,138 +42,126 @@ func runGraph(ctx context.Context, cfg Config, g *graph) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - // Semaphore for MaxConcurrency - var sem chan struct{} + s := &runState{ + g: g, + cfg: cfg, + ctx: ctx, + cancel: cancel, + results: make(map[Task]TaskResult, len(g.nodes)), + remaining: make(map[Task]int, len(g.nodes)), + } + if cfg.MaxConcurrency > 0 { - sem = make(chan struct{}, cfg.MaxConcurrency) + s.sem = make(chan struct{}, cfg.MaxConcurrency) } - var mu sync.Mutex - results := make(map[Task]TaskResult, len(g.nodes)) - failed := make(map[Task]bool) + for _, node := range g.nodes { + s.remaining[node] = len(g.deps[node]) + } - // remaining tracks how many deps each task is still waiting on. - remaining := make(map[Task]int, len(g.nodes)) + // Start all leaf tasks (no dependencies) for _, node := range g.nodes { - remaining[node] = len(g.deps[node]) + if len(g.deps[node]) == 0 { + s.launch(node) + } } - var wg sync.WaitGroup - - var launch func(task Task) - launch = func(task Task) { - wg.Add(1) - go func() { - defer wg.Done() - - // Acquire semaphore slot - if sem != nil { - select { - case sem <- struct{}{}: - defer func() { <-sem }() - case <-ctx.Done(): - mu.Lock() - results[task] = TaskResult{Status: Canceled, Err: ctx.Err()} - failed[task] = true - mu.Unlock() - notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) - return - } - } + s.wg.Wait() - // Check if we should skip (dependency failed) or cancel - mu.Lock() - skip := false - for _, dep := range g.deps[task] { - if failed[dep] { - skip = true - break - } + // Mark any tasks that were never reached. + // Safe without mutex: all goroutines have completed after wg.Wait(). + for _, node := range g.nodes { + if _, ok := s.results[node]; !ok { + s.results[node] = TaskResult{Status: Canceled, Err: ctx.Err()} + } + } + + for _, result := range s.results { + if result.Status != Succeeded { + return &DAGError{Results: s.results} + } + } + return nil +} + +// abort records a non-success result for a task and notifies dependents. +// Must be called with s.mu held; unlocks s.mu before returning. +func (s *runState) abort(task Task, status TaskStatus, err error) { + s.results[task] = TaskResult{Status: status, Err: err} + s.mu.Unlock() + s.notifyDependents(task) +} + +func (s *runState) launch(task Task) { + // wg.Add must be called in the caller's goroutine to ensure + // wg.Wait cannot return before the new goroutine starts. + s.wg.Add(1) + go func() { + defer s.wg.Done() + + // Acquire semaphore slot + if s.sem != nil { + select { + case s.sem <- struct{}{}: + defer func() { <-s.sem }() + case <-s.ctx.Done(): + s.mu.Lock() + s.abort(task, Canceled, s.ctx.Err()) + return } + } - if skip { + // Check if we should skip (dependency failed) or cancel + s.mu.Lock() + for _, dep := range s.g.deps[task] { + if s.results[dep].Status != Succeeded { status := Skipped - if cfg.OnError == CancelAll && ctx.Err() != nil { + if s.cfg.OnError == CancelAll && s.ctx.Err() != nil { status = Canceled } - results[task] = TaskResult{Status: status} - failed[task] = true - mu.Unlock() - notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) + s.abort(task, status, nil) return } + } - if ctx.Err() != nil { - results[task] = TaskResult{Status: Canceled, Err: ctx.Err()} - failed[task] = true - mu.Unlock() - notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) - return - } - mu.Unlock() + if s.ctx.Err() != nil { + s.abort(task, Canceled, s.ctx.Err()) + return + } + s.mu.Unlock() - // Run the task - taskErr := task.Do(ctx) + // Run the task + taskErr := task.Do(s.ctx) - mu.Lock() - if taskErr != nil { - results[task] = TaskResult{Status: Failed, Err: taskErr} - failed[task] = true - if cfg.OnError == CancelAll { - cancel() - } - } else { - results[task] = TaskResult{Status: Succeeded} + s.mu.Lock() + if taskErr != nil { + s.results[task] = TaskResult{Status: Failed, Err: taskErr} + if s.cfg.OnError == CancelAll { + s.cancel() } - mu.Unlock() - - notifyDependents(task, g, &mu, remaining, failed, results, launch, cfg, ctx) - }() - } - - // Start all leaf tasks (no dependencies) - for _, node := range g.nodes { - if len(g.deps[node]) == 0 { - launch(node) + } else { + s.results[task] = TaskResult{Status: Succeeded} } - } + s.mu.Unlock() - wg.Wait() - - // Mark any tasks that were never reached - mu.Lock() - for _, node := range g.nodes { - if _, ok := results[node]; !ok { - results[node] = TaskResult{Status: Canceled, Err: ctx.Err()} - } - } - mu.Unlock() + s.notifyDependents(task) + }() +} - for _, result := range results { - if result.Status != Succeeded { - return &DAGError{Results: results} +// notifyDependents decrements remaining counts for dependents and launches +// any that become ready. Launches happen outside the lock. +func (s *runState) notifyDependents(task Task) { + s.mu.Lock() + var ready []Task + for _, dependent := range s.g.dependents[task] { + s.remaining[dependent]-- + if s.remaining[dependent] == 0 { + ready = append(ready, dependent) } } - return nil -} + s.mu.Unlock() -func notifyDependents( - task Task, - g *graph, - mu *sync.Mutex, - remaining map[Task]int, - failed map[Task]bool, - results map[Task]TaskResult, - launch func(Task), - cfg Config, - ctx context.Context, -) { - mu.Lock() - defer mu.Unlock() - for _, dependent := range g.dependents[task] { - remaining[dependent]-- - if remaining[dependent] == 0 { - launch(dependent) - } + for _, t := range ready { + s.launch(t) } } diff --git a/e2e/tasks/execute_test.go b/e2e/tasks/execute_test.go index 09878c79fdc..de54467d0f8 100644 --- a/e2e/tasks/execute_test.go +++ b/e2e/tasks/execute_test.go @@ -115,20 +115,11 @@ func TestExecute_CancelDependents_SkipsDownstream(t *testing.T) { } func TestExecute_CancelDependents_IndependentBranchContinues(t *testing.T) { - // fail and independent are both leaves; root depends on both + // fail and independent are both leaves; run as separate roots. + // CancelDependents should not affect independent branches. f := &failTask{} independent := &valueTask{Value: 99} - type twoDepTask struct { - Deps struct { - F *failTask - V *valueTask - } - Output int - } - // Can't define Do on local type — use a package-level type instead. - // For this test, just verify independent runs by checking its output. - // We'll verify via a different approach: run them as separate roots. err := Execute(context.Background(), Config{OnError: CancelDependents}, f, independent) require.Error(t, err) // independent should have run successfully @@ -194,8 +185,6 @@ func TestExecute_Diamond(t *testing.T) { } func TestExecute_MultipleRoots(t *testing.T) { - shared := &valueTask{Value: 10} - a := &chainB{} a.Deps.A = &chainA{} b := &chainB{} @@ -203,7 +192,6 @@ func TestExecute_MultipleRoots(t *testing.T) { err := Execute(context.Background(), Config{}, a, b) require.NoError(t, err) - _ = shared } func TestExecute_MultipleRoots_SharedTask(t *testing.T) { @@ -241,7 +229,7 @@ func TestExecute_PreCanceledContext(t *testing.T) { select { case <-done: // good — didn't hang - case <-time.After(2 * time.Second): + case <-time.After(100 * time.Millisecond): t.Fatal("Execute hung on pre-canceled context") } } diff --git a/e2e/tasks/integration_test.go b/e2e/tasks/integration_test.go index 23c9476fb05..f30c1b4c879 100644 --- a/e2e/tasks/integration_test.go +++ b/e2e/tasks/integration_test.go @@ -103,16 +103,26 @@ func (t *teardown) Do(ctx context.Context) error { // --- Integration tests --- -func TestIntegration_SpecExample(t *testing.T) { - // Wire up the full DAG from the spec: - // - // CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ - // │ │ - // ├──────────────── CreateCluster ──┘ - // │ │ - // │ RunTests - // │ │ - // └──────────────── Teardown +// specDAG holds all wired nodes from the spec example for reuse across tests. +type specDAG struct { + RG *createRG + VNet *createVNet + Subnet *createSubnet + Cluster *createCluster + Tests *runTests + TD *teardown +} + +// buildSpecDAG wires the full spec example DAG: +// +// CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ +// │ │ +// ├──────────────── CreateCluster ──┘ +// │ │ +// │ RunTests +// │ │ +// └──────────────── Teardown +func buildSpecDAG() specDAG { rg := &createRG{} vnet := &createVNet{} vnet.Deps.RG = rg @@ -126,58 +136,44 @@ func TestIntegration_SpecExample(t *testing.T) { td := &teardown{} td.Deps.RG = rg td.Deps.Tests = tests + return specDAG{RG: rg, VNet: vnet, Subnet: subnet, Cluster: cluster, Tests: tests, TD: td} +} + +func TestIntegration_SpecExample(t *testing.T) { + d := buildSpecDAG() - err := Execute(context.Background(), Config{}, td) + err := Execute(context.Background(), Config{}, d.TD) require.NoError(t, err) // Verify all outputs propagated correctly - assert.Equal(t, "my-rg", rg.Output.RGName) - assert.Equal(t, "my-rg-vnet", vnet.Output.VNetID) - assert.Equal(t, "my-rg-vnet-subnet", subnet.Output.SubnetID) - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) - assert.True(t, tests.Output.Passed) - assert.True(t, td.Output.TornDown) + assert.Equal(t, "my-rg", d.RG.Output.RGName) + assert.Equal(t, "my-rg-vnet", d.VNet.Output.VNetID) + assert.Equal(t, "my-rg-vnet-subnet", d.Subnet.Output.SubnetID) + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) + assert.True(t, d.Tests.Output.Passed) + assert.True(t, d.TD.Output.TornDown) } func TestIntegration_SpecExample_WithMaxConcurrency(t *testing.T) { - rg := &createRG{} - vnet := &createVNet{} - vnet.Deps.RG = rg - subnet := &createSubnet{} - subnet.Deps.VNet = vnet - cluster := &createCluster{} - cluster.Deps.RG = rg - cluster.Deps.Subnet = subnet - tests := &runTests{} - tests.Deps.Cluster = cluster - td := &teardown{} - td.Deps.RG = rg - td.Deps.Tests = tests + d := buildSpecDAG() - err := Execute(context.Background(), Config{MaxConcurrency: 1}, td) + err := Execute(context.Background(), Config{MaxConcurrency: 1}, d.TD) require.NoError(t, err) - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) - assert.True(t, td.Output.TornDown) + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) + assert.True(t, d.TD.Output.TornDown) } func TestIntegration_TransitiveDependencyAccess(t *testing.T) { // Verify that a task can read transitive dependencies through Deps chains // as described in the spec's "Accessing Transitive Dependencies" section. - rg := &createRG{} - vnet := &createVNet{} - vnet.Deps.RG = rg - subnet := &createSubnet{} - subnet.Deps.VNet = vnet - cluster := &createCluster{} - cluster.Deps.RG = rg - cluster.Deps.Subnet = subnet + d := buildSpecDAG() - err := Execute(context.Background(), Config{}, cluster) + err := Execute(context.Background(), Config{}, d.Cluster) require.NoError(t, err) // Access transitive dep: cluster -> subnet -> vnet -> rg - rgName := cluster.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName + rgName := d.Cluster.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName assert.Equal(t, "my-rg", rgName) } @@ -206,19 +202,12 @@ func (t *teardownAfterFail) Do(ctx context.Context) error { } func TestIntegration_MidPipelineFailure_CancelDependents(t *testing.T) { - rg := &createRG{} - vnet := &createVNet{} - vnet.Deps.RG = rg - subnet := &createSubnet{} - subnet.Deps.VNet = vnet - cluster := &createCluster{} - cluster.Deps.RG = rg - cluster.Deps.Subnet = subnet + d := buildSpecDAG() failTests := &failingRunTests{} - failTests.Deps.Cluster = cluster + failTests.Deps.Cluster = d.Cluster td := &teardownAfterFail{} - td.Deps.RG = rg + td.Deps.RG = d.RG td.Deps.Tests = failTests err := Execute(context.Background(), Config{OnError: CancelDependents}, td) @@ -228,10 +217,10 @@ func TestIntegration_MidPipelineFailure_CancelDependents(t *testing.T) { require.True(t, errors.As(err, &dagErr)) // Upstream tasks should have succeeded - assert.Equal(t, Succeeded, dagErr.Results[rg].Status) - assert.Equal(t, Succeeded, dagErr.Results[vnet].Status) - assert.Equal(t, Succeeded, dagErr.Results[subnet].Status) - assert.Equal(t, Succeeded, dagErr.Results[cluster].Status) + assert.Equal(t, Succeeded, dagErr.Results[d.RG].Status) + assert.Equal(t, Succeeded, dagErr.Results[d.VNet].Status) + assert.Equal(t, Succeeded, dagErr.Results[d.Subnet].Status) + assert.Equal(t, Succeeded, dagErr.Results[d.Cluster].Status) // failTests should have failed assert.Equal(t, Failed, dagErr.Results[failTests].Status) @@ -241,8 +230,8 @@ func TestIntegration_MidPipelineFailure_CancelDependents(t *testing.T) { assert.Equal(t, Skipped, dagErr.Results[td].Status) // Outputs of successful tasks should still be populated - assert.Equal(t, "my-rg", rg.Output.RGName) - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", cluster.Output.ClusterID) + assert.Equal(t, "my-rg", d.RG.Output.RGName) + assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) } func TestIntegration_TwoIndependentSubgraphs_SharedTask(t *testing.T) { diff --git a/e2e/tasks/task.go b/e2e/tasks/task.go index 48c688c70bc..adc6b4c1405 100644 --- a/e2e/tasks/task.go +++ b/e2e/tasks/task.go @@ -3,6 +3,7 @@ package tasks import ( "context" "fmt" + "sort" "strings" ) @@ -24,6 +25,17 @@ const ( CancelAll ) +func (s ErrorStrategy) String() string { + switch s { + case CancelDependents: + return "CancelDependents" + case CancelAll: + return "CancelAll" + default: + return fmt.Sprintf("ErrorStrategy(%d)", int(s)) + } +} + // Config controls execution behavior. type Config struct { // OnError controls what happens when a task fails. @@ -79,6 +91,7 @@ func (e *DAGError) Error() string { failed = append(failed, fmt.Sprintf("%T: %s: %v", task, result.Status, result.Err)) } } + sort.Strings(failed) return fmt.Sprintf("dag execution failed: %s", strings.Join(failed, "; ")) } From 6ce7e7a9c4753bc07cf48d55b4aa0e5a187b2c8a Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 09:19:35 +1300 Subject: [PATCH 09/22] refactor(e2e): replace sequential prepareCluster with concurrent DAG MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add e2e/dag package — a lightweight, type-safe DAG executor using Go generics. Two verbs: Go (returns value) and Run (side-effect), with numbered variants (Go1-Go3, Run1-Run3) for typed dependency injection. Replace the sequential prepareCluster implementation with a DAG version that runs independent tasks (bastion, subnet, kube, identity, firewall, garbage collection, etc.) concurrently after cluster creation completes. Also fix pre-existing fmt.Sprintf %%w usage in config/config.go. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/cluster.go | 129 +++++++------- e2e/config/config.go | 2 +- e2e/dag/dag.go | 367 +++++++++++++++++++++++++++++++++++++++ e2e/dag/dag_test.go | 400 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 837 insertions(+), 61 deletions(-) create mode 100644 e2e/dag/dag.go create mode 100644 e2e/dag/dag_test.go diff --git a/e2e/cluster.go b/e2e/cluster.go index 589371e2d2b..b6eee0a9e79 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Azure/agentbaker/e2e/config" + "github.com/Azure/agentbaker/e2e/dag" "github.com/Azure/agentbaker/e2e/toolkit" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" @@ -59,85 +60,93 @@ func (c *Cluster) MaxPodsPerNode() (int, error) { return 0, fmt.Errorf("cluster agentpool profiles were nil or empty: %+v", c.Model) } -func prepareCluster(ctx context.Context, cluster *armcontainerservice.ManagedCluster, isNetworkIsolated, attachPrivateAcr bool) (*Cluster, error) { +// prepareCluster runs all cluster preparation steps as a concurrent DAG. +func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, isNetworkIsolated, attachPrivateAcr bool) (*Cluster, error) { defer toolkit.LogStepCtx(ctx, "preparing cluster")() ctx, cancel := context.WithTimeout(ctx, config.Config.TestTimeoutCluster) defer cancel() - cluster.Name = to.Ptr(fmt.Sprintf("%s-%s", *cluster.Name, hash(cluster))) - cluster, err := getOrCreateCluster(ctx, cluster) - if err != nil { - return nil, fmt.Errorf("get or create cluster: %w", err) - } - bastion, err := getOrCreateBastion(ctx, cluster) - if err != nil { - return nil, fmt.Errorf("get or create bastion: %w", err) - } + clusterModel.Name = to.Ptr(fmt.Sprintf("%s-%s", *clusterModel.Name, hash(clusterModel))) + needACR := isNetworkIsolated || attachPrivateAcr - _, err = getOrCreateMaintenanceConfiguration(ctx, cluster) + cluster, err := getOrCreateCluster(ctx, clusterModel) if err != nil { - return nil, fmt.Errorf("get or create maintenance configuration: %w", err) + return nil, err } - subnetID, err := getClusterSubnetID(ctx, *cluster.Properties.NodeResourceGroup) - if err != nil { - return nil, fmt.Errorf("get cluster subnet: %w", err) - } - - resourceGroupName := config.ResourceGroupName(*cluster.Location) + g := dag.NewGroup(ctx) - kube, err := getClusterKubeClient(ctx, resourceGroupName, *cluster.Name) - if err != nil { - return nil, fmt.Errorf("get kube client using cluster %q: %w", *cluster.Name, err) - } + // Fan-out: all of these run concurrently. + bastion := dag.Go(g, func(ctx context.Context) (*Bastion, error) { + return getOrCreateBastion(ctx, cluster) + }) + dag.Run(g, func(ctx context.Context) error { + _, err := getOrCreateMaintenanceConfiguration(ctx, cluster) + return err + }) + subnet := dag.Go(g, func(ctx context.Context) (string, error) { + return getClusterSubnetID(ctx, *cluster.Properties.NodeResourceGroup) + }) + kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { + return getClusterKubeClient(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name) + }) + identity := dag.Go(g, func(ctx context.Context) (*armcontainerservice.UserAssignedIdentity, error) { + return getClusterKubeletIdentity(cluster) + }) + dag.Run(g, func(ctx context.Context) error { + return collectGarbageVMSS(ctx, cluster) + }) - kubeletIdentity, err := getClusterKubeletIdentity(cluster) - if err != nil { - return nil, fmt.Errorf("getting cluster kubelet identity: %w", err) - } + // ACR tasks: depend on kube + identity. + acrNonAnon := dag.Run2(g, kube, identity, + func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, cluster, k, config.ResourceGroupName(*cluster.Location), id, true) + }) + acrAnon := dag.Run2(g, kube, identity, + func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, cluster, k, config.ResourceGroupName(*cluster.Location), id, false) + }) - if isNetworkIsolated || attachPrivateAcr { - // private acr must be created before we add the debug daemonsets - if err := addPrivateAzureContainerRegistry(ctx, cluster, kube, resourceGroupName, kubeletIdentity, true); err != nil { - return nil, fmt.Errorf("add private azure container registry (true): %w", err) - } - if err := addPrivateAzureContainerRegistry(ctx, cluster, kube, resourceGroupName, kubeletIdentity, false); err != nil { - return nil, fmt.Errorf("add private azure container registry (false): %w", err) - } - } - if isNetworkIsolated { - if err := addNetworkIsolatedSettings(ctx, cluster, *cluster.Location); err != nil { - return nil, fmt.Errorf("add network isolated settings: %w", err) + // Firewall / network isolation: no deps within the DAG. + dag.Run(g, func(ctx context.Context) error { + if isNetworkIsolated { + return nil } - } - if !isNetworkIsolated { // network isolated cluster blocks all egress via NSG - if err := addFirewallRules(ctx, cluster, *cluster.Location); err != nil { - return nil, fmt.Errorf("add firewall rules: %w", err) + return addFirewallRules(ctx, cluster, *cluster.Location) + }) + dag.Run(g, func(ctx context.Context) error { + if !isNetworkIsolated { + return nil } - } + return addNetworkIsolatedSettings(ctx, cluster, *cluster.Location) + }) - if err := kube.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)); err != nil { - return nil, fmt.Errorf("ensure debug daemonsets for %q: %w", *cluster.Name, err) - } + // Debug daemonsets: depend on kube + both ACR tasks. + dag.Run(g, func(ctx context.Context) error { + return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)) + }, kube, acrNonAnon, acrAnon) - // sometimes tests can be interrupted and vmss are left behind - // don't waste resource and delete them - if err := collectGarbageVMSS(ctx, cluster); err != nil { - return nil, fmt.Errorf("collect garbage vmss: %w", err) - } + // Extract cluster params: depend on kube. + extract := dag.Go1(g, kube, func(ctx context.Context, k *Kubeclient) (*ClusterParams, error) { + return extractClusterParameters(ctx, k, cluster) + }) - clusterParams, err := extractClusterParameters(ctx, kube, cluster) - if err != nil { - return nil, fmt.Errorf("extracting cluster parameters: %w", err) + if err := g.Wait(); err != nil { + return nil, err } - return &Cluster{ Model: cluster, - Kube: kube, - KubeletIdentity: kubeletIdentity, - SubnetID: subnetID, - ClusterParams: clusterParams, - Bastion: bastion, + Kube: kube.MustGet(), + KubeletIdentity: identity.MustGet(), + SubnetID: subnet.MustGet(), + ClusterParams: extract.MustGet(), + Bastion: bastion.MustGet(), }, nil } diff --git a/e2e/config/config.go b/e2e/config/config.go index cebf0bde7f7..56ecdb11b8a 100644 --- a/e2e/config/config.go +++ b/e2e/config/config.go @@ -180,7 +180,7 @@ func mustGetNewRSAKeyPair() ([]byte, []byte, string) { privateKeyFileName, err := writePrivateKeyToTempFile(privatePEMBytes) if err != nil { - panic(fmt.Sprintf("failed to write private key to temp file: %w", err)) + panic(fmt.Sprintf("failed to write private key to temp file: %v", err)) } return privatePEMBytes, publicKeyBytes, privateKeyFileName diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go new file mode 100644 index 00000000000..d74dd87eb18 --- /dev/null +++ b/e2e/dag/dag.go @@ -0,0 +1,367 @@ +// Package dag provides a lightweight, type-safe DAG executor for running +// concurrent tasks with dependency tracking. +// +// Tasks are registered against a [Group] and form a directed acyclic graph +// through their declared dependencies. The Group launches each task in its +// own goroutine as soon as all dependencies complete successfully. +// +// There are two kinds of tasks: +// +// - Value-producing tasks return (T, error) and are represented by [Result][T]. +// Register with [Spawn] (no typed deps) or [Then] / [Then2] / [Then3] (typed deps). +// +// - Side-effect tasks return only error and are represented by [Effect]. +// Register with [Do] (no typed deps) or [ThenDo] / [ThenDo2] / [ThenDo3] (typed deps). +// +// Both [Result] and [Effect] implement [Dep], so they can be listed as +// dependencies of downstream tasks. +// +// When a typed dependency is used (Then/ThenDo variants), the dependency's +// value is passed as a function parameter — the compiler enforces correct +// wiring. When untyped dependencies are used (Spawn/Do with variadic deps), +// values are accessed via [Result.MustGet] inside the closure. +// +// On the first task error, the Group cancels its context, causing all pending +// and in-flight tasks to observe cancellation and exit. [Group.Wait] blocks +// until every goroutine returns and reports a [DAGError] containing all +// collected errors. +// +// Example: +// +// g := dag.NewGroup(ctx) +// +// kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { +// return getKubeClient(ctx) +// }) +// params := dag.Go1(g, kube, func(ctx context.Context, k *Kubeclient) (*Params, error) { +// return extractParams(ctx, k) +// }) +// dag.Run(g, func(ctx context.Context) error { +// return ensureMaintenance(ctx) +// }) +// +// if err := g.Wait(); err != nil { ... } +// fmt.Println(kube.MustGet(), params.MustGet()) +package dag + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + "sync" +) + +// --------------------------------------------------------------------------- +// Dep — the dependency interface +// --------------------------------------------------------------------------- + +// Dep is implemented by [Result] and [Effect]. It represents a dependency +// that must complete before a downstream task starts. +type Dep interface { + wait() + failed() bool +} + +// --------------------------------------------------------------------------- +// Group — the DAG executor +// --------------------------------------------------------------------------- + +// Group manages a set of concurrent tasks with dependency tracking. +// Create one with [NewGroup], register tasks, then call [Group.Wait]. +type Group struct { + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + errs []error + wg sync.WaitGroup +} + +// NewGroup creates a Group whose tasks run under ctx. +// On the first task error the Group cancels ctx, signalling all other tasks. +func NewGroup(ctx context.Context) *Group { + ctx, cancel := context.WithCancel(ctx) + return &Group{ctx: ctx, cancel: cancel} +} + +// Wait blocks until every task in the group has finished. +// It returns a *[DAGError] if any task failed, or nil on success. +func (g *Group) Wait() error { + g.wg.Wait() + g.cancel() + g.mu.Lock() + defer g.mu.Unlock() + if len(g.errs) > 0 { + return &DAGError{Errors: g.errs} + } + return nil +} + +func (g *Group) recordError(err error) { + g.mu.Lock() + g.errs = append(g.errs, err) + g.mu.Unlock() + g.cancel() +} + +// errSkipped is a sentinel set on tasks that were skipped because a +// dependency failed. It propagates through the graph so that transitive +// dependents are also skipped without running. +var errSkipped = errors.New("skipped: dependency failed") + +// launch runs fn in a new goroutine after all deps complete. +// If any dep failed or ctx is cancelled, onSkip is called instead of fn. +func (g *Group) launch(deps []Dep, fn func(), onSkip func()) { + g.wg.Add(1) + go func() { + defer g.wg.Done() + + for _, d := range deps { + d.wait() + } + + for _, d := range deps { + if d.failed() { + onSkip() + return + } + } + + if g.ctx.Err() != nil { + onSkip() + return + } + + fn() + }() +} + +// --------------------------------------------------------------------------- +// Result[T] — a typed task output +// --------------------------------------------------------------------------- + +// Result holds the outcome of a task that produces a value of type T. +// It implements [Dep] so it can be used as a dependency for downstream tasks. +type Result[T any] struct { + done chan struct{} + val T + err error +} + +func newResult[T any]() *Result[T] { + return &Result[T]{done: make(chan struct{})} +} + +func (r *Result[T]) wait() { <-r.done } +func (r *Result[T]) failed() bool { r.wait(); return r.err != nil } + +// Get returns the value and true if the task succeeded, or the zero value +// and false if it failed or was skipped. Blocks until the task completes. +func (r *Result[T]) Get() (T, bool) { + <-r.done + if r.err != nil { + var zero T + return zero, false + } + return r.val, true +} + +// MustGet returns the value, panicking if the task failed. Safe to call: +// - Inside Then/ThenDo callbacks (the scheduler guarantees deps succeeded) +// - Inside Spawn/Do callbacks when the Result is listed as a dep +// - After [Group.Wait] returned nil +func (r *Result[T]) MustGet() T { + <-r.done + if r.err != nil { + panic("dag: MustGet() called on failed Result") + } + return r.val +} + +func (r *Result[T]) finish(val T, err error) { + r.val = val + r.err = err + close(r.done) +} + +// --------------------------------------------------------------------------- +// Effect — a side-effect-only task +// --------------------------------------------------------------------------- + +// Effect represents a completed side-effect task. It implements [Dep] so +// downstream tasks can depend on it, but it carries no value. +type Effect struct { + done chan struct{} + err error +} + +func newEffect() *Effect { + return &Effect{done: make(chan struct{})} +} + +func (e *Effect) wait() { <-e.done } +func (e *Effect) failed() bool { e.wait(); return e.err != nil } + +func (e *Effect) finish(err error) { + e.err = err + close(e.done) +} + +// --------------------------------------------------------------------------- +// Go / Go1 / Go2 / Go3 — value-producing tasks. +// +// Go = no typed deps (optional untyped deps via variadic Dep args) +// GoN = N typed deps, passed as function parameters +// --------------------------------------------------------------------------- + +// Go launches fn with optional untyped deps. Values from deps are accessed +// via [Result.MustGet] inside fn. +func Go[T any](g *Group, fn func(ctx context.Context) (T, error), deps ...Dep) *Result[T] { + r := newResult[T]() + g.launch(deps, func() { + val, err := fn(g.ctx) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func() { + var zero T + r.finish(zero, errSkipped) + }) + return r +} + +// Go1 launches fn after dep completes, passing its value. +func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) (T, error)) *Result[T] { + r := newResult[T]() + g.launch([]Dep{dep}, func() { + val, err := fn(g.ctx, dep.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func() { + var zero T + r.finish(zero, errSkipped) + }) + return r +} + +// Go2 launches fn after dep1 and dep2 complete, passing both values. +func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) (T, error)) *Result[T] { + r := newResult[T]() + g.launch([]Dep{dep1, dep2}, func() { + val, err := fn(g.ctx, dep1.val, dep2.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func() { + var zero T + r.finish(zero, errSkipped) + }) + return r +} + +// Go3 launches fn after dep1, dep2, and dep3 complete, passing all values. +func Go3[T, D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) (T, error)) *Result[T] { + r := newResult[T]() + g.launch([]Dep{dep1, dep2, dep3}, func() { + val, err := fn(g.ctx, dep1.val, dep2.val, dep3.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func() { + var zero T + r.finish(zero, errSkipped) + }) + return r +} + +// --------------------------------------------------------------------------- +// Run / Run1 / Run2 / Run3 — side-effect tasks (no return value). +// +// Run = no typed deps (optional untyped deps via variadic Dep args) +// RunN = N typed deps, passed as function parameters +// --------------------------------------------------------------------------- + +// Run launches fn with optional untyped deps. +func Run(g *Group, fn func(ctx context.Context) error, deps ...Dep) *Effect { + e := newEffect() + g.launch(deps, func() { + err := fn(g.ctx) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, func() { + e.finish(errSkipped) + }) + return e +} + +// Run1 launches fn after dep completes, passing its value. +func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) error) *Effect { + e := newEffect() + g.launch([]Dep{dep}, func() { + err := fn(g.ctx, dep.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, func() { + e.finish(errSkipped) + }) + return e +} + +// Run2 launches fn after dep1 and dep2 complete, passing both values. +func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) error) *Effect { + e := newEffect() + g.launch([]Dep{dep1, dep2}, func() { + err := fn(g.ctx, dep1.val, dep2.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, func() { + e.finish(errSkipped) + }) + return e +} + +// Run3 launches fn after dep1, dep2, and dep3 complete, passing all values. +func Run3[D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) error) *Effect { + e := newEffect() + g.launch([]Dep{dep1, dep2, dep3}, func() { + err := fn(g.ctx, dep1.val, dep2.val, dep3.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, func() { + e.finish(errSkipped) + }) + return e +} + +// --------------------------------------------------------------------------- +// DAGError +// --------------------------------------------------------------------------- + +// DAGError is returned by [Group.Wait] when one or more tasks failed. +type DAGError struct { + Errors []error +} + +func (e *DAGError) Error() string { + var msgs []string + for _, err := range e.Errors { + msgs = append(msgs, err.Error()) + } + sort.Strings(msgs) + return fmt.Sprintf("dag execution failed: %s", strings.Join(msgs, "; ")) +} diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go new file mode 100644 index 00000000000..5f3e3e83713 --- /dev/null +++ b/e2e/dag/dag_test.go @@ -0,0 +1,400 @@ +package dag + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// Spawn +// --------------------------------------------------------------------------- + +func TestGo(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 42, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := r.MustGet(); v != 42 { + t.Fatalf("got %d, want 42", v) + } +} + +func TestGo_Error(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + err := g.Wait() + if err == nil { + t.Fatal("expected error") + } + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T", err) + } + if len(dagErr.Errors) != 1 { + t.Fatalf("expected 1 error, got %d", len(dagErr.Errors)) + } + if _, ok := r.Get(); ok { + t.Fatal("Get() should return false on failed Result") + } +} + +func TestGo_WithDeps(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 10, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "hello", nil }) + + c := Go(g, func(ctx context.Context) (string, error) { + return b.MustGet() + ":" + string(rune('0'+a.MustGet())), nil + }, a, b) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + _ = c.MustGet() +} + +// --------------------------------------------------------------------------- +// Do +// --------------------------------------------------------------------------- + +func TestRun(t *testing.T) { + g := NewGroup(context.Background()) + var called atomic.Bool + Run(g, func(ctx context.Context) error { + called.Store(true) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if !called.Load() { + t.Fatal("Do function was not called") + } +} + +func TestRun_WithDeps(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 10, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "hi", nil }) + + var got atomic.Value + Run(g, func(ctx context.Context) error { + got.Store(b.MustGet() + ":" + string(rune('0'+a.MustGet()))) + return nil + }, a, b) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() == nil { + t.Fatal("Do function was not called") + } +} + +// --------------------------------------------------------------------------- +// Then chain +// --------------------------------------------------------------------------- + +func TestGo1_Chain(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { + return 10, nil + }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { + return v * 2, nil + }) + c := Go1(g, b, func(ctx context.Context, v int) (string, error) { + if v != 20 { + return "", errors.New("bad value") + } + return "ok", nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := c.MustGet(); v != "ok" { + t.Fatalf("got %q, want %q", v, "ok") + } +} + +// --------------------------------------------------------------------------- +// Then2 / Then3 +// --------------------------------------------------------------------------- + +func TestGo2(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 4, nil }) + c := Go2(g, a, b, func(ctx context.Context, x, y int) (int, error) { + return x + y, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := c.MustGet(); v != 7 { + t.Fatalf("got %d, want 7", v) + } +} + +func TestGo3(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + d := Go3(g, a, b, c, func(ctx context.Context, x, y, z int) (int, error) { + return x + y + z, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := d.MustGet(); v != 6 { + t.Fatalf("got %d, want 6", v) + } +} + +// --------------------------------------------------------------------------- +// ThenDo / ThenDo2 / ThenDo3 +// --------------------------------------------------------------------------- + +func TestRun1(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 5, nil }) + var got atomic.Int32 + Run1(g, a, func(ctx context.Context, v int) error { + got.Store(int32(v)) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != 5 { + t.Fatalf("got %d, want 5", got.Load()) + } +} + +func TestRun2(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "x", nil }) + var got atomic.Value + Run2(g, a, b, func(ctx context.Context, n int, s string) error { + got.Store(s + string(rune('0'+n))) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != "x3" { + t.Fatalf("got %v, want x3", got.Load()) + } +} + +func TestRun3(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + var got atomic.Int32 + Run3(g, a, b, c, func(ctx context.Context, x, y, z int) error { + got.Store(int32(x + y + z)) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != 6 { + t.Fatalf("got %d, want 6", got.Load()) + } +} + +// --------------------------------------------------------------------------- +// Error propagation — cancel-all behavior +// --------------------------------------------------------------------------- + +func TestCancelAll_CancelsRunningTasks(t *testing.T) { + g := NewGroup(context.Background()) + + started := make(chan struct{}) + var cancelled atomic.Bool + Run(g, func(ctx context.Context) error { + close(started) + <-ctx.Done() + cancelled.Store(true) + return ctx.Err() + }) + + Go(g, func(ctx context.Context) (int, error) { + <-started + return 0, errors.New("fail") + }) + + g.Wait() + time.Sleep(10 * time.Millisecond) + if !cancelled.Load() { + t.Fatal("expected context to be cancelled for running task") + } +} + +func TestSkipsDownstream(t *testing.T) { + g := NewGroup(context.Background()) + + a := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("a failed") + }) + + var bRan atomic.Bool + Run1(g, a, func(ctx context.Context, v int) error { + bRan.Store(true) + return nil + }) + + g.Wait() + if bRan.Load() { + t.Fatal("dependent task b should have been skipped") + } +} + +func TestTransitiveSkip(t *testing.T) { + g := NewGroup(context.Background()) + + a := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("a failed") + }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { + return v + 1, nil + }) + + var cRan atomic.Bool + Run1(g, b, func(ctx context.Context, v int) error { + cRan.Store(true) + return nil + }) + + g.Wait() + if cRan.Load() { + t.Fatal("transitive dependent c should have been skipped") + } +} + +// --------------------------------------------------------------------------- +// DAG topologies +// --------------------------------------------------------------------------- + +func TestDiamond(t *testing.T) { + // a + // / \ + // b c + // \ / + // d + g := NewGroup(context.Background()) + + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 10, nil }) + c := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 100, nil }) + d := Go2(g, b, c, func(ctx context.Context, x, y int) (int, error) { return x + y, nil }) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if d.MustGet() != 112 { + t.Fatalf("got %d, want 112", d.MustGet()) + } +} + +// --------------------------------------------------------------------------- +// Result.Get / Result.MustGet safety +// --------------------------------------------------------------------------- + +func TestGet_SafeOnError(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + g.Wait() + + v, ok := r.Get() + if ok { + t.Fatal("Get() should return false on error") + } + if v != 0 { + t.Fatalf("Get() should return zero value, got %d", v) + } +} + +func TestMustGet_PanicsOnError(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + g.Wait() + + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic from MustGet() on failed Result") + } + }() + r.MustGet() +} + +// --------------------------------------------------------------------------- +// Edge cases +// --------------------------------------------------------------------------- + +func TestMultipleErrors(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("err1") }) + Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("err2") }) + + err := g.Wait() + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T", err) + } + if len(dagErr.Errors) < 1 { + t.Fatal("expected at least 1 error") + } +} + +func TestParentContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + g := NewGroup(ctx) + Run(g, func(ctx context.Context) error { + return nil + }) + + // Key invariant: Wait() returns without hanging. + g.Wait() +} + +func TestEffect_AsDep(t *testing.T) { + g := NewGroup(context.Background()) + + var order []int + var mu atomic.Value + mu.Store([]int{}) + + e := Run(g, func(ctx context.Context) error { + order = append(order, 1) + return nil + }) + Run(g, func(ctx context.Context) error { + if len(order) == 0 { + return errors.New("expected effect to run first") + } + return nil + }, e) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} From da1d4b26ce314a093dafe9fb9e89ffc8642fe714 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 09:34:09 +1300 Subject: [PATCH 10/22] refactor(e2e): simplify prepareCluster, remove clusterSetup struct Replace the clusterSetup struct and its 12 one-liner methods with inline closures. Each closure is 1-2 lines binding the cluster local to the real function call. Everything reads top-to-bottom in one place. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/cluster.go | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index b6eee0a9e79..4d1e8010ad3 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -74,9 +74,9 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag return nil, err } + rg := config.ResourceGroupName(*cluster.Location) g := dag.NewGroup(ctx) - // Fan-out: all of these run concurrently. bastion := dag.Go(g, func(ctx context.Context) (*Bastion, error) { return getOrCreateBastion(ctx, cluster) }) @@ -88,7 +88,7 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag return getClusterSubnetID(ctx, *cluster.Properties.NodeResourceGroup) }) kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { - return getClusterKubeClient(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name) + return getClusterKubeClient(ctx, rg, *cluster.Name) }) identity := dag.Go(g, func(ctx context.Context) (*armcontainerservice.UserAssignedIdentity, error) { return getClusterKubeletIdentity(cluster) @@ -96,24 +96,6 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag dag.Run(g, func(ctx context.Context) error { return collectGarbageVMSS(ctx, cluster) }) - - // ACR tasks: depend on kube + identity. - acrNonAnon := dag.Run2(g, kube, identity, - func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { - if !needACR { - return nil - } - return addPrivateAzureContainerRegistry(ctx, cluster, k, config.ResourceGroupName(*cluster.Location), id, true) - }) - acrAnon := dag.Run2(g, kube, identity, - func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { - if !needACR { - return nil - } - return addPrivateAzureContainerRegistry(ctx, cluster, k, config.ResourceGroupName(*cluster.Location), id, false) - }) - - // Firewall / network isolation: no deps within the DAG. dag.Run(g, func(ctx context.Context) error { if isNetworkIsolated { return nil @@ -126,13 +108,23 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag } return addNetworkIsolatedSettings(ctx, cluster, *cluster.Location) }) - - // Debug daemonsets: depend on kube + both ACR tasks. + acrNonAnon := dag.Run2(g, kube, identity, + func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, cluster, k, rg, id, true) + }) + acrAnon := dag.Run2(g, kube, identity, + func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, cluster, k, rg, id, false) + }) dag.Run(g, func(ctx context.Context) error { return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)) }, kube, acrNonAnon, acrAnon) - - // Extract cluster params: depend on kube. extract := dag.Go1(g, kube, func(ctx context.Context, k *Kubeclient) (*ClusterParams, error) { return extractClusterParameters(ctx, k, cluster) }) From 4f0e867b112041876bb87312e27a3572a09e5a4a Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 09:42:43 +1300 Subject: [PATCH 11/22] refactor(e2e): wire cluster as DAG task, pass functions directly Move cluster creation back into the DAG so all tasks use typed dependency injection. Where function signatures match exactly (getOrCreateBastion, getClusterKubeletIdentity, collectGarbageVMSS), pass them directly without closures. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/cluster.go | 63 ++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index 4d1e8010ad3..a03cbea4ff0 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -69,71 +69,64 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag clusterModel.Name = to.Ptr(fmt.Sprintf("%s-%s", *clusterModel.Name, hash(clusterModel))) needACR := isNetworkIsolated || attachPrivateAcr - cluster, err := getOrCreateCluster(ctx, clusterModel) - if err != nil { - return nil, err - } - - rg := config.ResourceGroupName(*cluster.Location) + rg := config.ResourceGroupName(*clusterModel.Location) g := dag.NewGroup(ctx) - bastion := dag.Go(g, func(ctx context.Context) (*Bastion, error) { - return getOrCreateBastion(ctx, cluster) + cluster := dag.Go(g, func(ctx context.Context) (*armcontainerservice.ManagedCluster, error) { + return getOrCreateCluster(ctx, clusterModel) }) - dag.Run(g, func(ctx context.Context) error { - _, err := getOrCreateMaintenanceConfiguration(ctx, cluster) + + bastion := dag.Go1(g, cluster, getOrCreateBastion) + dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { + _, err := getOrCreateMaintenanceConfiguration(ctx, c) return err }) - subnet := dag.Go(g, func(ctx context.Context) (string, error) { - return getClusterSubnetID(ctx, *cluster.Properties.NodeResourceGroup) + subnet := dag.Go1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) (string, error) { + return getClusterSubnetID(ctx, *c.Properties.NodeResourceGroup) }) - kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { - return getClusterKubeClient(ctx, rg, *cluster.Name) + kube := dag.Go1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) (*Kubeclient, error) { + return getClusterKubeClient(ctx, rg, *c.Name) }) - identity := dag.Go(g, func(ctx context.Context) (*armcontainerservice.UserAssignedIdentity, error) { - return getClusterKubeletIdentity(cluster) - }) - dag.Run(g, func(ctx context.Context) error { - return collectGarbageVMSS(ctx, cluster) - }) - dag.Run(g, func(ctx context.Context) error { + identity := dag.Go1(g, cluster, getClusterKubeletIdentity) + dag.Run1(g, cluster, collectGarbageVMSS) + dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { if isNetworkIsolated { return nil } - return addFirewallRules(ctx, cluster, *cluster.Location) + return addFirewallRules(ctx, c, *c.Location) }) - dag.Run(g, func(ctx context.Context) error { + dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { if !isNetworkIsolated { return nil } - return addNetworkIsolatedSettings(ctx, cluster, *cluster.Location) + return addNetworkIsolatedSettings(ctx, c, *c.Location) }) - acrNonAnon := dag.Run2(g, kube, identity, - func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + acrNonAnon := dag.Run3(g, cluster, kube, identity, + func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { if !needACR { return nil } - return addPrivateAzureContainerRegistry(ctx, cluster, k, rg, id, true) + return addPrivateAzureContainerRegistry(ctx, c, k, rg, id, true) }) - acrAnon := dag.Run2(g, kube, identity, - func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + acrAnon := dag.Run3(g, cluster, kube, identity, + func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { if !needACR { return nil } - return addPrivateAzureContainerRegistry(ctx, cluster, k, rg, id, false) + return addPrivateAzureContainerRegistry(ctx, c, k, rg, id, false) }) dag.Run(g, func(ctx context.Context) error { - return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)) + return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.MustGet().Location)) }, kube, acrNonAnon, acrAnon) - extract := dag.Go1(g, kube, func(ctx context.Context, k *Kubeclient) (*ClusterParams, error) { - return extractClusterParameters(ctx, k, cluster) + extract := dag.Go2(g, cluster, kube, func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient) (*ClusterParams, error) { + return extractClusterParameters(ctx, k, c) }) if err := g.Wait(); err != nil { return nil, err } return &Cluster{ - Model: cluster, + Model: cluster.MustGet(), Kube: kube.MustGet(), KubeletIdentity: identity.MustGet(), SubnetID: subnet.MustGet(), @@ -142,7 +135,7 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag }, nil } -func getClusterKubeletIdentity(cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { +func getClusterKubeletIdentity(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { if cluster == nil || cluster.Properties == nil || cluster.Properties.IdentityProfile == nil { return nil, fmt.Errorf("cannot dereference cluster identity profile to extract kubelet identity ID") } From c58c7279fb45a2df09b34443d9aa7307a59095db Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 09:52:33 +1300 Subject: [PATCH 12/22] refactor(e2e): eliminate all anonymous functions from prepareCluster Extract named helpers for conditional tasks (configureFirewall, configureNetworkIsolation, setupACR, ensureDebugDaemonsets) and update function signatures (getClusterSubnetID, getClusterKubeClient, extractClusterParameters) to accept *ManagedCluster directly so they can be passed to the DAG without closures. prepareCluster now reads as a pure declarative DAG with no inline anonymous functions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/cluster.go | 105 ++++++++++++++++++++++++++----------------------- e2e/kube.go | 8 +++- 2 files changed, 62 insertions(+), 51 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index a03cbea4ff0..0cee2fdbdc2 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -67,60 +67,23 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag defer cancel() clusterModel.Name = to.Ptr(fmt.Sprintf("%s-%s", *clusterModel.Name, hash(clusterModel))) - needACR := isNetworkIsolated || attachPrivateAcr - rg := config.ResourceGroupName(*clusterModel.Location) g := dag.NewGroup(ctx) - cluster := dag.Go(g, func(ctx context.Context) (*armcontainerservice.ManagedCluster, error) { - return getOrCreateCluster(ctx, clusterModel) - }) - + cluster := dag.Go(g, newClusterTask(clusterModel)) bastion := dag.Go1(g, cluster, getOrCreateBastion) - dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { - _, err := getOrCreateMaintenanceConfiguration(ctx, c) - return err - }) - subnet := dag.Go1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) (string, error) { - return getClusterSubnetID(ctx, *c.Properties.NodeResourceGroup) - }) - kube := dag.Go1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) (*Kubeclient, error) { - return getClusterKubeClient(ctx, rg, *c.Name) - }) + dag.Run1(g, cluster, ensureMaintenanceConfiguration) + subnet := dag.Go1(g, cluster, getClusterSubnetID) + kube := dag.Go1(g, cluster, getClusterKubeClient) identity := dag.Go1(g, cluster, getClusterKubeletIdentity) dag.Run1(g, cluster, collectGarbageVMSS) - dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { - if isNetworkIsolated { - return nil - } - return addFirewallRules(ctx, c, *c.Location) - }) - dag.Run1(g, cluster, func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { - if !isNetworkIsolated { - return nil - } - return addNetworkIsolatedSettings(ctx, c, *c.Location) - }) - acrNonAnon := dag.Run3(g, cluster, kube, identity, - func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { - if !needACR { - return nil - } - return addPrivateAzureContainerRegistry(ctx, c, k, rg, id, true) - }) - acrAnon := dag.Run3(g, cluster, kube, identity, - func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { - if !needACR { - return nil - } - return addPrivateAzureContainerRegistry(ctx, c, k, rg, id, false) - }) - dag.Run(g, func(ctx context.Context) error { - return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.MustGet().Location)) - }, kube, acrNonAnon, acrAnon) - extract := dag.Go2(g, cluster, kube, func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient) (*ClusterParams, error) { - return extractClusterParameters(ctx, k, c) - }) + dag.Run1(g, cluster, configureFirewall(isNetworkIsolated)) + dag.Run1(g, cluster, configureNetworkIsolation(isNetworkIsolated)) + needACR := isNetworkIsolated || attachPrivateAcr + acrNonAnon := dag.Run3(g, cluster, kube, identity, setupACR(needACR, true)) + acrAnon := dag.Run3(g, cluster, kube, identity, setupACR(needACR, false)) + dag.Run(g, ensureDebugDaemonsets(isNetworkIsolated, cluster, kube), kube, acrNonAnon, acrAnon) + extract := dag.Go2(g, cluster, kube, extractClusterParameters) if err := g.Wait(); err != nil { return nil, err @@ -135,6 +98,50 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag }, nil } +func newClusterTask(model *armcontainerservice.ManagedCluster) func(context.Context) (*armcontainerservice.ManagedCluster, error) { + return func(ctx context.Context) (*armcontainerservice.ManagedCluster, error) { + return getOrCreateCluster(ctx, model) + } +} + +func configureFirewall(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster) error { + return func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { + if isNetworkIsolated { + return nil + } + return addFirewallRules(ctx, c, *c.Location) + } +} + +func configureNetworkIsolation(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster) error { + return func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { + if !isNetworkIsolated { + return nil + } + return addNetworkIsolatedSettings(ctx, c, *c.Location) + } +} + +func setupACR(needACR, isNonAnonymousPull bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { + return func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, c, k, config.ResourceGroupName(*c.Location), id, isNonAnonymousPull) + } +} + +func ensureDebugDaemonsets(isNetworkIsolated bool, cluster *dag.Result[*armcontainerservice.ManagedCluster], kube *dag.Result[*Kubeclient]) func(context.Context) error { + return func(ctx context.Context) error { + return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.MustGet().Location)) + } +} + +func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) error { + _, err := getOrCreateMaintenanceConfiguration(ctx, cluster) + return err +} + func getClusterKubeletIdentity(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { if cluster == nil || cluster.Properties == nil || cluster.Properties.IdentityProfile == nil { return nil, fmt.Errorf("cannot dereference cluster identity profile to extract kubelet identity ID") @@ -146,7 +153,7 @@ func getClusterKubeletIdentity(ctx context.Context, cluster *armcontainerservice return kubeletIdentity, nil } -func extractClusterParameters(ctx context.Context, kube *Kubeclient, cluster *armcontainerservice.ManagedCluster) (*ClusterParams, error) { +func extractClusterParameters(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient) (*ClusterParams, error) { kubeconfig, err := clientcmd.Load(kube.KubeConfig) if err != nil { return nil, fmt.Errorf("loading cluster kubeconfig: %w", err) diff --git a/e2e/kube.go b/e2e/kube.go index 3b551f6f370..150d58fc31f 100644 --- a/e2e/kube.go +++ b/e2e/kube.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/agentbaker/e2e/config" "github.com/Azure/agentbaker/e2e/toolkit" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v8" "github.com/stretchr/testify/require" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -40,7 +41,9 @@ const ( podNetworkDebugAppLabel = "debugnonhost-mariner-tolerated" ) -func getClusterKubeClient(ctx context.Context, resourceGroupName, clusterName string) (*Kubeclient, error) { +func getClusterKubeClient(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*Kubeclient, error) { + resourceGroupName := config.ResourceGroupName(*cluster.Location) + clusterName := *cluster.Name data, err := getClusterKubeconfigBytes(ctx, resourceGroupName, clusterName) if err != nil { return nil, fmt.Errorf("get cluster kubeconfig bytes: %w", err) @@ -448,7 +451,8 @@ func daemonsetDebug(ctx context.Context, deploymentName, targetNodeLabel, privat } } -func getClusterSubnetID(ctx context.Context, mcResourceGroupName string) (string, error) { +func getClusterSubnetID(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (string, error) { + mcResourceGroupName := *cluster.Properties.NodeResourceGroup pager := config.Azure.VNet.NewListPager(mcResourceGroupName, nil) for pager.More() { nextResult, err := pager.NextPage(ctx) From c9d07921b130c1839283894563e3fa966b83eeb6 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 09:58:28 +1300 Subject: [PATCH 13/22] refactor(e2e): eliminate wrapper helpers, absorb args into functions - addFirewallRules/addNetworkIsolatedSettings: derive location from cluster instead of taking it as a param. Use conditional DAG registration instead of runtime checks. - addPrivateAzureContainerRegistry: derive resourceGroupName from cluster.Location internally. - ensureMaintenanceConfiguration: replaces getOrCreate wrapper, returns error only (value was never used). - getClusterSubnetID/getClusterKubeClient: take *ManagedCluster instead of extracted strings. - GoN/RunN: accept optional extra ...Dep barrier deps so typed dep tasks can also wait on untyped barriers (e.g. ACR effects). Remaining helpers (newClusterTask, addACRTask, addDebugDaemonsets) exist because they genuinely need parameter binding that can't be absorbed into the underlying functions. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/aks_model.go | 8 ++++--- e2e/cluster.go | 57 +++++++++++++++++------------------------------- e2e/dag/dag.go | 30 +++++++++++++++---------- 3 files changed, 43 insertions(+), 52 deletions(-) diff --git a/e2e/aks_model.go b/e2e/aks_model.go index 7498d92c0d1..648cd732954 100644 --- a/e2e/aks_model.go +++ b/e2e/aks_model.go @@ -299,8 +299,8 @@ func getFirewall(ctx context.Context, location, firewallSubnetID, publicIPID str func addFirewallRules( ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, - location string, ) error { + location := *clusterModel.Location defer toolkit.LogStepCtx(ctx, "adding firewall rules")() routeTableName := "abe2e-fw-rt" rtGetResp, err := config.Azure.RouteTables.Get( @@ -486,10 +486,11 @@ func addFirewallRules( return nil } -func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient, resourceGroupName string, kubeletIdentity *armcontainerservice.UserAssignedIdentity, isNonAnonymousPull bool) error { +func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient, kubeletIdentity *armcontainerservice.UserAssignedIdentity, isNonAnonymousPull bool) error { if cluster == nil || kube == nil || kubeletIdentity == nil { return errors.New("cluster, kubeclient, and kubeletIdentity cannot be nil when adding Private Azure Container Registry") } + resourceGroupName := config.ResourceGroupName(*cluster.Location) if err := createPrivateAzureContainerRegistry(ctx, cluster, resourceGroupName, isNonAnonymousPull); err != nil { return fmt.Errorf("failed to create private acr: %w", err) } @@ -514,7 +515,8 @@ func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainer return nil } -func addNetworkIsolatedSettings(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, location string) error { +func addNetworkIsolatedSettings(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster) error { + location := *clusterModel.Location defer toolkit.LogStepCtx(ctx, fmt.Sprintf("Adding network settings for network isolated cluster %s in rg %s", *clusterModel.Name, *clusterModel.Properties.NodeResourceGroup)) vnet, err := getClusterVNet(ctx, *clusterModel.Properties.NodeResourceGroup) diff --git a/e2e/cluster.go b/e2e/cluster.go index 0cee2fdbdc2..1eb9c68c096 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -77,12 +77,16 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag kube := dag.Go1(g, cluster, getClusterKubeClient) identity := dag.Go1(g, cluster, getClusterKubeletIdentity) dag.Run1(g, cluster, collectGarbageVMSS) - dag.Run1(g, cluster, configureFirewall(isNetworkIsolated)) - dag.Run1(g, cluster, configureNetworkIsolation(isNetworkIsolated)) + if !isNetworkIsolated { + dag.Run1(g, cluster, addFirewallRules) + } + if isNetworkIsolated { + dag.Run1(g, cluster, addNetworkIsolatedSettings) + } needACR := isNetworkIsolated || attachPrivateAcr - acrNonAnon := dag.Run3(g, cluster, kube, identity, setupACR(needACR, true)) - acrAnon := dag.Run3(g, cluster, kube, identity, setupACR(needACR, false)) - dag.Run(g, ensureDebugDaemonsets(isNetworkIsolated, cluster, kube), kube, acrNonAnon, acrAnon) + acrNonAnon := dag.Run3(g, cluster, kube, identity, addACRTask(needACR, true)) + acrAnon := dag.Run3(g, cluster, kube, identity, addACRTask(needACR, false)) + dag.Run2(g, cluster, kube, addDebugDaemonsets(isNetworkIsolated), kube, acrNonAnon, acrAnon) extract := dag.Go2(g, cluster, kube, extractClusterParameters) if err := g.Wait(); err != nil { @@ -104,44 +108,21 @@ func newClusterTask(model *armcontainerservice.ManagedCluster) func(context.Cont } } -func configureFirewall(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster) error { - return func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { - if isNetworkIsolated { - return nil - } - return addFirewallRules(ctx, c, *c.Location) - } -} - -func configureNetworkIsolation(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster) error { - return func(ctx context.Context, c *armcontainerservice.ManagedCluster) error { - if !isNetworkIsolated { - return nil - } - return addNetworkIsolatedSettings(ctx, c, *c.Location) - } -} - -func setupACR(needACR, isNonAnonymousPull bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { +func addACRTask(needACR, isNonAnonymousPull bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { return func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { if !needACR { return nil } - return addPrivateAzureContainerRegistry(ctx, c, k, config.ResourceGroupName(*c.Location), id, isNonAnonymousPull) + return addPrivateAzureContainerRegistry(ctx, c, k, id, isNonAnonymousPull) } } -func ensureDebugDaemonsets(isNetworkIsolated bool, cluster *dag.Result[*armcontainerservice.ManagedCluster], kube *dag.Result[*Kubeclient]) func(context.Context) error { - return func(ctx context.Context) error { - return kube.MustGet().EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.MustGet().Location)) +func addDebugDaemonsets(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient) error { + return func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient) error { + return k.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *c.Location)) } } -func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) error { - _, err := getOrCreateMaintenanceConfiguration(ctx, cluster) - return err -} - func getClusterKubeletIdentity(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { if cluster == nil || cluster.Properties == nil || cluster.Properties.IdentityProfile == nil { return nil, fmt.Errorf("cannot dereference cluster identity profile to extract kubelet identity ID") @@ -424,16 +405,18 @@ func createNewAKSClusterWithRetry(ctx context.Context, cluster *armcontainerserv return nil, fmt.Errorf("failed to create cluster after %d attempts due to persistent 409 Conflict: %w", maxRetries, lastErr) } -func getOrCreateMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.MaintenanceConfiguration, error) { +func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) error { existingMaintenance, err := config.Azure.Maintenance.Get(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name, "default", nil) var azErr *azcore.ResponseError if errors.As(err, &azErr) && azErr.StatusCode == 404 { - return createNewMaintenanceConfiguration(ctx, cluster) + _, err = createNewMaintenanceConfiguration(ctx, cluster) + return err } if err != nil { - return nil, fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) + return fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) } - return &existingMaintenance.MaintenanceConfiguration, nil + _ = existingMaintenance + return nil } func createNewMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.MaintenanceConfiguration, error) { diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go index d74dd87eb18..3d234085718 100644 --- a/e2e/dag/dag.go +++ b/e2e/dag/dag.go @@ -234,9 +234,10 @@ func Go[T any](g *Group, fn func(ctx context.Context) (T, error), deps ...Dep) * } // Go1 launches fn after dep completes, passing its value. -func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) (T, error)) *Result[T] { +// Extra deps are waited on but their values are not passed to fn. +func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) (T, error), extra ...Dep) *Result[T] { r := newResult[T]() - g.launch([]Dep{dep}, func() { + g.launch(append([]Dep{dep}, extra...), func() { val, err := fn(g.ctx, dep.val) if err != nil { g.recordError(err) @@ -250,9 +251,10 @@ func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D } // Go2 launches fn after dep1 and dep2 complete, passing both values. -func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) (T, error)) *Result[T] { +// Extra deps are waited on but their values are not passed to fn. +func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) (T, error), extra ...Dep) *Result[T] { r := newResult[T]() - g.launch([]Dep{dep1, dep2}, func() { + g.launch(append([]Dep{dep1, dep2}, extra...), func() { val, err := fn(g.ctx, dep1.val, dep2.val) if err != nil { g.recordError(err) @@ -266,9 +268,10 @@ func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ct } // Go3 launches fn after dep1, dep2, and dep3 complete, passing all values. -func Go3[T, D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) (T, error)) *Result[T] { +// Extra deps are waited on but their values are not passed to fn. +func Go3[T, D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) (T, error), extra ...Dep) *Result[T] { r := newResult[T]() - g.launch([]Dep{dep1, dep2, dep3}, func() { + g.launch(append([]Dep{dep1, dep2, dep3}, extra...), func() { val, err := fn(g.ctx, dep1.val, dep2.val, dep3.val) if err != nil { g.recordError(err) @@ -304,9 +307,10 @@ func Run(g *Group, fn func(ctx context.Context) error, deps ...Dep) *Effect { } // Run1 launches fn after dep completes, passing its value. -func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) error) *Effect { +// Extra deps are waited on but their values are not passed to fn. +func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) error, extra ...Dep) *Effect { e := newEffect() - g.launch([]Dep{dep}, func() { + g.launch(append([]Dep{dep}, extra...), func() { err := fn(g.ctx, dep.val) if err != nil { g.recordError(err) @@ -319,9 +323,10 @@ func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) } // Run2 launches fn after dep1 and dep2 complete, passing both values. -func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) error) *Effect { +// Extra deps are waited on but their values are not passed to fn. +func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) error, extra ...Dep) *Effect { e := newEffect() - g.launch([]Dep{dep1, dep2}, func() { + g.launch(append([]Dep{dep1, dep2}, extra...), func() { err := fn(g.ctx, dep1.val, dep2.val) if err != nil { g.recordError(err) @@ -334,9 +339,10 @@ func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx } // Run3 launches fn after dep1, dep2, and dep3 complete, passing all values. -func Run3[D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) error) *Effect { +// Extra deps are waited on but their values are not passed to fn. +func Run3[D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) error, extra ...Dep) *Effect { e := newEffect() - g.launch([]Dep{dep1, dep2, dep3}, func() { + g.launch(append([]Dep{dep1, dep2, dep3}, extra...), func() { err := fn(g.ctx, dep1.val, dep2.val, dep3.val) if err != nil { g.recordError(err) From 6efb867495a80b2348ab11239df7393e3f631b3c Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sat, 21 Mar 2026 10:05:02 +1300 Subject: [PATCH 14/22] refactor(e2e): use bind helpers, eliminate newClusterTask MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull cluster creation out of the DAG. Use generic bind/bindRun helpers to pass functions directly for tasks that only need cluster. Remaining factory helpers (addACR, ensureDebugDaemonsets, extractClusterParams) bind cluster to functions that also receive DAG-provided values (kube, identity) — these can't be eliminated without putting cluster back in the DAG. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- e2e/cluster.go | 62 +++++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index 1eb9c68c096..bfe39c06385 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -68,32 +68,36 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag clusterModel.Name = to.Ptr(fmt.Sprintf("%s-%s", *clusterModel.Name, hash(clusterModel))) + cluster, err := getOrCreateCluster(ctx, clusterModel) + if err != nil { + return nil, err + } + g := dag.NewGroup(ctx) - cluster := dag.Go(g, newClusterTask(clusterModel)) - bastion := dag.Go1(g, cluster, getOrCreateBastion) - dag.Run1(g, cluster, ensureMaintenanceConfiguration) - subnet := dag.Go1(g, cluster, getClusterSubnetID) - kube := dag.Go1(g, cluster, getClusterKubeClient) - identity := dag.Go1(g, cluster, getClusterKubeletIdentity) - dag.Run1(g, cluster, collectGarbageVMSS) + bastion := dag.Go(g, bind(getOrCreateBastion, cluster)) + dag.Run(g, bindRun(ensureMaintenanceConfiguration, cluster)) + subnet := dag.Go(g, bind(getClusterSubnetID, cluster)) + kube := dag.Go(g, bind(getClusterKubeClient, cluster)) + identity := dag.Go(g, bind(getClusterKubeletIdentity, cluster)) + dag.Run(g, bindRun(collectGarbageVMSS, cluster)) if !isNetworkIsolated { - dag.Run1(g, cluster, addFirewallRules) + dag.Run(g, bindRun(addFirewallRules, cluster)) } if isNetworkIsolated { - dag.Run1(g, cluster, addNetworkIsolatedSettings) + dag.Run(g, bindRun(addNetworkIsolatedSettings, cluster)) } needACR := isNetworkIsolated || attachPrivateAcr - acrNonAnon := dag.Run3(g, cluster, kube, identity, addACRTask(needACR, true)) - acrAnon := dag.Run3(g, cluster, kube, identity, addACRTask(needACR, false)) - dag.Run2(g, cluster, kube, addDebugDaemonsets(isNetworkIsolated), kube, acrNonAnon, acrAnon) - extract := dag.Go2(g, cluster, kube, extractClusterParameters) + acrNonAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, true)) + acrAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, false)) + dag.Run1(g, kube, ensureDebugDaemonsets(cluster, isNetworkIsolated), acrNonAnon, acrAnon) + extract := dag.Go1(g, kube, extractClusterParams(cluster)) if err := g.Wait(); err != nil { return nil, err } return &Cluster{ - Model: cluster.MustGet(), + Model: cluster, Kube: kube.MustGet(), KubeletIdentity: identity.MustGet(), SubnetID: subnet.MustGet(), @@ -102,24 +106,34 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag }, nil } -func newClusterTask(model *armcontainerservice.ManagedCluster) func(context.Context) (*armcontainerservice.ManagedCluster, error) { - return func(ctx context.Context) (*armcontainerservice.ManagedCluster, error) { - return getOrCreateCluster(ctx, model) - } +// bind returns func(ctx) → (T, error) by binding arg to fn. +func bind[A, T any](fn func(context.Context, A) (T, error), arg A) func(context.Context) (T, error) { + return func(ctx context.Context) (T, error) { return fn(ctx, arg) } } -func addACRTask(needACR, isNonAnonymousPull bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { - return func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { +// bindRun returns func(ctx) → error by binding arg to fn. +func bindRun[A any](fn func(context.Context, A) error, arg A) func(context.Context) error { + return func(ctx context.Context) error { return fn(ctx, arg) } +} + +func addACR(cluster *armcontainerservice.ManagedCluster, needACR, isNonAnonymousPull bool) func(context.Context, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { + return func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { if !needACR { return nil } - return addPrivateAzureContainerRegistry(ctx, c, k, id, isNonAnonymousPull) + return addPrivateAzureContainerRegistry(ctx, cluster, k, id, isNonAnonymousPull) + } +} + +func ensureDebugDaemonsets(cluster *armcontainerservice.ManagedCluster, isNetworkIsolated bool) func(context.Context, *Kubeclient) error { + return func(ctx context.Context, k *Kubeclient) error { + return k.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)) } } -func addDebugDaemonsets(isNetworkIsolated bool) func(context.Context, *armcontainerservice.ManagedCluster, *Kubeclient) error { - return func(ctx context.Context, c *armcontainerservice.ManagedCluster, k *Kubeclient) error { - return k.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *c.Location)) +func extractClusterParams(cluster *armcontainerservice.ManagedCluster) func(context.Context, *Kubeclient) (*ClusterParams, error) { + return func(ctx context.Context, k *Kubeclient) (*ClusterParams, error) { + return extractClusterParameters(ctx, cluster, k) } } From d75c1646b5594401692b0230fbd07a195b1e60c3 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 09:24:55 +1300 Subject: [PATCH 15/22] polish: improve naming, comments, and error wrapping in dag/tasks packages --- e2e/cluster.go | 7 +++---- e2e/dag/dag.go | 4 ++-- e2e/dag/dag_test.go | 9 +++++---- e2e/tasks/execute.go | 16 +++++++++------- e2e/tasks/graph.go | 5 ++++- e2e/tasks/task.go | 8 ++++++++ 6 files changed, 31 insertions(+), 18 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index bfe39c06385..5776ab59492 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -70,7 +70,7 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag cluster, err := getOrCreateCluster(ctx, clusterModel) if err != nil { - return nil, err + return nil, fmt.Errorf("get or create cluster: %w", err) } g := dag.NewGroup(ctx) @@ -94,7 +94,7 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag extract := dag.Go1(g, kube, extractClusterParams(cluster)) if err := g.Wait(); err != nil { - return nil, err + return nil, fmt.Errorf("prepare cluster tasks: %w", err) } return &Cluster{ Model: cluster, @@ -420,7 +420,7 @@ func createNewAKSClusterWithRetry(ctx context.Context, cluster *armcontainerserv } func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) error { - existingMaintenance, err := config.Azure.Maintenance.Get(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name, "default", nil) + _, err := config.Azure.Maintenance.Get(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name, "default", nil) var azErr *azcore.ResponseError if errors.As(err, &azErr) && azErr.StatusCode == 404 { _, err = createNewMaintenanceConfiguration(ctx, cluster) @@ -429,7 +429,6 @@ func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerse if err != nil { return fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) } - _ = existingMaintenance return nil } diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go index 3d234085718..3e538725292 100644 --- a/e2e/dag/dag.go +++ b/e2e/dag/dag.go @@ -154,7 +154,7 @@ func newResult[T any]() *Result[T] { return &Result[T]{done: make(chan struct{})} } -func (r *Result[T]) wait() { <-r.done } +func (r *Result[T]) wait() { <-r.done } func (r *Result[T]) failed() bool { r.wait(); return r.err != nil } // Get returns the value and true if the task succeeded, or the zero value @@ -201,7 +201,7 @@ func newEffect() *Effect { return &Effect{done: make(chan struct{})} } -func (e *Effect) wait() { <-e.done } +func (e *Effect) wait() { <-e.done } func (e *Effect) failed() bool { e.wait(); return e.err != nil } func (e *Effect) finish(err error) { diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go index 5f3e3e83713..1a7cc85ad99 100644 --- a/e2e/dag/dag_test.go +++ b/e2e/dag/dag_test.go @@ -5,7 +5,6 @@ import ( "errors" "sync/atomic" "testing" - "time" ) // --------------------------------------------------------------------------- @@ -236,8 +235,9 @@ func TestCancelAll_CancelsRunningTasks(t *testing.T) { return 0, errors.New("fail") }) + // g.Wait() guarantees all goroutines have returned (via WaitGroup), + // so cancelled is guaranteed to be true here — no sleep needed. g.Wait() - time.Sleep(10 * time.Millisecond) if !cancelled.Load() { t.Fatal("expected context to be cancelled for running task") } @@ -379,9 +379,10 @@ func TestParentContextCancelled(t *testing.T) { func TestEffect_AsDep(t *testing.T) { g := NewGroup(context.Background()) + // order is shared between goroutines, but the dependency edge (e) provides + // a happens-before guarantee: close(e.done) in the first goroutine + // happens-before the second goroutine's read via e.wait(). var order []int - var mu atomic.Value - mu.Store([]int{}) e := Run(g, func(ctx context.Context) error { order = append(order, 1) diff --git a/e2e/tasks/execute.go b/e2e/tasks/execute.go index cf0c4f66fa4..49770d4642c 100644 --- a/e2e/tasks/execute.go +++ b/e2e/tasks/execute.go @@ -84,9 +84,9 @@ func runGraph(ctx context.Context, cfg Config, g *graph) error { return nil } -// abort records a non-success result for a task and notifies dependents. -// Must be called with s.mu held; unlocks s.mu before returning. -func (s *runState) abort(task Task, status TaskStatus, err error) { +// skipTask records a non-success result for a task, releases the lock, and +// notifies dependents. Caller must hold s.mu on entry; it is released before return. +func (s *runState) skipTask(task Task, status TaskStatus, err error) { s.results[task] = TaskResult{Status: status, Err: err} s.mu.Unlock() s.notifyDependents(task) @@ -106,7 +106,7 @@ func (s *runState) launch(task Task) { defer func() { <-s.sem }() case <-s.ctx.Done(): s.mu.Lock() - s.abort(task, Canceled, s.ctx.Err()) + s.skipTask(task, Canceled, s.ctx.Err()) return } } @@ -115,17 +115,19 @@ func (s *runState) launch(task Task) { s.mu.Lock() for _, dep := range s.g.deps[task] { if s.results[dep].Status != Succeeded { + // Use Canceled when we're in CancelAll mode and the dep actually + // failed (not just skipped). Otherwise use Skipped. status := Skipped - if s.cfg.OnError == CancelAll && s.ctx.Err() != nil { + if s.cfg.OnError == CancelAll && s.results[dep].Status == Failed { status = Canceled } - s.abort(task, status, nil) + s.skipTask(task, status, nil) return } } if s.ctx.Err() != nil { - s.abort(task, Canceled, s.ctx.Err()) + s.skipTask(task, Canceled, s.ctx.Err()) return } s.mu.Unlock() diff --git a/e2e/tasks/graph.go b/e2e/tasks/graph.go index ace8c533a0a..d5a701661c3 100644 --- a/e2e/tasks/graph.go +++ b/e2e/tasks/graph.go @@ -13,7 +13,10 @@ type graph struct { deps map[Task][]Task // dependents maps each task to the tasks that depend on it. dependents map[Task][]Task - // order is the topological sort order, populated by validateNoCycles. + // order is a valid topological sort, populated by validateNoCycles. + // The scheduler does not use it — execution order is driven by the + // dep-counting algorithm in runGraph. This field exists for test + // introspection (asserting correct ordering). order []Task } diff --git a/e2e/tasks/task.go b/e2e/tasks/task.go index adc6b4c1405..aef842f335b 100644 --- a/e2e/tasks/task.go +++ b/e2e/tasks/task.go @@ -1,3 +1,11 @@ +// Package tasks provides a reflection-based DAG executor where tasks are +// structs implementing the [Task] interface and dependencies are discovered +// automatically from a Deps struct field. +// +// For most e2e use cases, prefer the sibling package [dag] which uses +// closures and generics for a lighter-weight, no-reflection API. +// This package is useful when tasks are reusable objects with complex +// dependency wiring that benefits from struct-based composition. package tasks import ( From 5fea3e461818449a8ee8d610b94cd114af9aab96 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 09:27:47 +1300 Subject: [PATCH 16/22] refactor(e2e): remove unused tasks package in favor of dag package --- e2e/tasks/execute.go | 169 ------------------- e2e/tasks/execute_test.go | 307 ---------------------------------- e2e/tasks/graph.go | 118 ------------- e2e/tasks/graph_test.go | 152 ----------------- e2e/tasks/integration_test.go | 266 ----------------------------- e2e/tasks/task.go | 117 ------------- e2e/tasks/task_test.go | 63 ------- e2e/tasks/validate.go | 48 ------ e2e/tasks/validate_test.go | 74 -------- 9 files changed, 1314 deletions(-) delete mode 100644 e2e/tasks/execute.go delete mode 100644 e2e/tasks/execute_test.go delete mode 100644 e2e/tasks/graph.go delete mode 100644 e2e/tasks/graph_test.go delete mode 100644 e2e/tasks/integration_test.go delete mode 100644 e2e/tasks/task.go delete mode 100644 e2e/tasks/task_test.go delete mode 100644 e2e/tasks/validate.go delete mode 100644 e2e/tasks/validate_test.go diff --git a/e2e/tasks/execute.go b/e2e/tasks/execute.go deleted file mode 100644 index 49770d4642c..00000000000 --- a/e2e/tasks/execute.go +++ /dev/null @@ -1,169 +0,0 @@ -package tasks - -import ( - "context" - "sync" -) - -// Execute runs the DAG rooted at the given tasks. -// It discovers the graph via reflection, validates it, and executes -// tasks concurrently respecting dependency order. -func Execute(ctx context.Context, cfg Config, roots ...Task) error { - g, err := discoverGraph(roots) - if err != nil { - return err - } - if err := validateNoCycles(g); err != nil { - return err - } - return runGraph(ctx, cfg, g) -} - -// runState holds all shared mutable state for a single DAG execution. -type runState struct { - g *graph - cfg Config - ctx context.Context - cancel context.CancelFunc - sem chan struct{} - mu sync.Mutex - wg sync.WaitGroup - results map[Task]TaskResult - - // remaining tracks how many deps each task is still waiting on. - remaining map[Task]int -} - -func runGraph(ctx context.Context, cfg Config, g *graph) error { - if len(g.nodes) == 0 { - return nil - } - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - s := &runState{ - g: g, - cfg: cfg, - ctx: ctx, - cancel: cancel, - results: make(map[Task]TaskResult, len(g.nodes)), - remaining: make(map[Task]int, len(g.nodes)), - } - - if cfg.MaxConcurrency > 0 { - s.sem = make(chan struct{}, cfg.MaxConcurrency) - } - - for _, node := range g.nodes { - s.remaining[node] = len(g.deps[node]) - } - - // Start all leaf tasks (no dependencies) - for _, node := range g.nodes { - if len(g.deps[node]) == 0 { - s.launch(node) - } - } - - s.wg.Wait() - - // Mark any tasks that were never reached. - // Safe without mutex: all goroutines have completed after wg.Wait(). - for _, node := range g.nodes { - if _, ok := s.results[node]; !ok { - s.results[node] = TaskResult{Status: Canceled, Err: ctx.Err()} - } - } - - for _, result := range s.results { - if result.Status != Succeeded { - return &DAGError{Results: s.results} - } - } - return nil -} - -// skipTask records a non-success result for a task, releases the lock, and -// notifies dependents. Caller must hold s.mu on entry; it is released before return. -func (s *runState) skipTask(task Task, status TaskStatus, err error) { - s.results[task] = TaskResult{Status: status, Err: err} - s.mu.Unlock() - s.notifyDependents(task) -} - -func (s *runState) launch(task Task) { - // wg.Add must be called in the caller's goroutine to ensure - // wg.Wait cannot return before the new goroutine starts. - s.wg.Add(1) - go func() { - defer s.wg.Done() - - // Acquire semaphore slot - if s.sem != nil { - select { - case s.sem <- struct{}{}: - defer func() { <-s.sem }() - case <-s.ctx.Done(): - s.mu.Lock() - s.skipTask(task, Canceled, s.ctx.Err()) - return - } - } - - // Check if we should skip (dependency failed) or cancel - s.mu.Lock() - for _, dep := range s.g.deps[task] { - if s.results[dep].Status != Succeeded { - // Use Canceled when we're in CancelAll mode and the dep actually - // failed (not just skipped). Otherwise use Skipped. - status := Skipped - if s.cfg.OnError == CancelAll && s.results[dep].Status == Failed { - status = Canceled - } - s.skipTask(task, status, nil) - return - } - } - - if s.ctx.Err() != nil { - s.skipTask(task, Canceled, s.ctx.Err()) - return - } - s.mu.Unlock() - - // Run the task - taskErr := task.Do(s.ctx) - - s.mu.Lock() - if taskErr != nil { - s.results[task] = TaskResult{Status: Failed, Err: taskErr} - if s.cfg.OnError == CancelAll { - s.cancel() - } - } else { - s.results[task] = TaskResult{Status: Succeeded} - } - s.mu.Unlock() - - s.notifyDependents(task) - }() -} - -// notifyDependents decrements remaining counts for dependents and launches -// any that become ready. Launches happen outside the lock. -func (s *runState) notifyDependents(task Task) { - s.mu.Lock() - var ready []Task - for _, dependent := range s.g.dependents[task] { - s.remaining[dependent]-- - if s.remaining[dependent] == 0 { - ready = append(ready, dependent) - } - } - s.mu.Unlock() - - for _, t := range ready { - s.launch(t) - } -} diff --git a/e2e/tasks/execute_test.go b/e2e/tasks/execute_test.go deleted file mode 100644 index de54467d0f8..00000000000 --- a/e2e/tasks/execute_test.go +++ /dev/null @@ -1,307 +0,0 @@ -package tasks - -import ( - "context" - "errors" - "fmt" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- test task types for execution --- - -type valueTask struct { - Value int - Output int -} - -func (t *valueTask) Do(ctx context.Context) error { - t.Output = t.Value - return nil -} - -type addTask struct { - Deps struct { - A *valueTask - B *valueTask - } - Output int -} - -func (t *addTask) Do(ctx context.Context) error { - t.Output = t.Deps.A.Output + t.Deps.B.Output - return nil -} - -type failTask struct{} - -func (t *failTask) Do(ctx context.Context) error { - return fmt.Errorf("intentional failure") -} - -type afterFailTask struct { - Deps struct{ F *failTask } - ran bool -} - -func (t *afterFailTask) Do(ctx context.Context) error { - t.ran = true - return nil -} - -// --- basic execution tests --- - -func TestExecute_LeafTask(t *testing.T) { - v := &valueTask{Value: 42} - err := Execute(context.Background(), Config{}, v) - require.NoError(t, err) - assert.Equal(t, 42, v.Output) -} - -func TestExecute_OutputFlowsBetweenDeps(t *testing.T) { - a := &valueTask{Value: 3} - b := &valueTask{Value: 5} - add := &addTask{} - add.Deps.A = a - add.Deps.B = b - - err := Execute(context.Background(), Config{}, add) - require.NoError(t, err) - assert.Equal(t, 8, add.Output) -} - -func TestExecute_FailReturnsDAGError(t *testing.T) { - f := &failTask{} - err := Execute(context.Background(), Config{}, f) - require.Error(t, err) - - var dagErr *DAGError - require.True(t, errors.As(err, &dagErr)) - - result, ok := dagErr.Results[f] - require.True(t, ok) - assert.Equal(t, Failed, result.Status) - assert.Contains(t, result.Err.Error(), "intentional failure") -} - -func TestExecute_NilDep_ReturnsValidationError(t *testing.T) { - add := &addTask{} // Deps.A and Deps.B are nil - err := Execute(context.Background(), Config{}, add) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) -} - -// --- error strategy tests --- - -func TestExecute_CancelDependents_SkipsDownstream(t *testing.T) { - f := &failTask{} - after := &afterFailTask{} - after.Deps.F = f - - err := Execute(context.Background(), Config{OnError: CancelDependents}, after) - require.Error(t, err) - - var dagErr *DAGError - require.True(t, errors.As(err, &dagErr)) - assert.Equal(t, Failed, dagErr.Results[f].Status) - assert.Equal(t, Skipped, dagErr.Results[after].Status) - assert.False(t, after.ran, "skipped task should not have run") -} - -func TestExecute_CancelDependents_IndependentBranchContinues(t *testing.T) { - // fail and independent are both leaves; run as separate roots. - // CancelDependents should not affect independent branches. - f := &failTask{} - independent := &valueTask{Value: 99} - - err := Execute(context.Background(), Config{OnError: CancelDependents}, f, independent) - require.Error(t, err) - // independent should have run successfully - assert.Equal(t, 99, independent.Output) -} - -func TestExecute_CancelAll_CancelsContext(t *testing.T) { - f := &failTask{} - after := &afterFailTask{} - after.Deps.F = f - - err := Execute(context.Background(), Config{OnError: CancelAll}, after) - require.Error(t, err) - - var dagErr *DAGError - require.True(t, errors.As(err, &dagErr)) - assert.Equal(t, Failed, dagErr.Results[f].Status) - assert.Equal(t, Canceled, dagErr.Results[after].Status) -} - -// --- concurrency tests --- - -func TestExecute_MaxConcurrency_Serial(t *testing.T) { - a := &valueTask{Value: 3} - b := &valueTask{Value: 5} - add := &addTask{} - add.Deps.A = a - add.Deps.B = b - - err := Execute(context.Background(), Config{MaxConcurrency: 1}, add) - require.NoError(t, err) - assert.Equal(t, 8, add.Output) -} - -func TestExecute_MaxConcurrency_Respected(t *testing.T) { - // Verify unlimited concurrency (MaxConcurrency=0) produces correct results. - // Actual parallelism verification is in TestExecute_MaxConcurrency_LimitsParallelism. - a := &valueTask{Value: 1} - b := &valueTask{Value: 2} - add := &addTask{} - add.Deps.A = a - add.Deps.B = b - - err := Execute(context.Background(), Config{MaxConcurrency: 0}, add) - require.NoError(t, err) - assert.Equal(t, 3, add.Output) -} - -// --- diamond and dedup tests --- - -func TestExecute_Diamond(t *testing.T) { - top := &diamondTop{} - left := &diamondLeft{} - left.Deps.Top = top - right := &diamondRight{} - right.Deps.Top = top - bottom := &diamondBottom{} - bottom.Deps.Left = left - bottom.Deps.Right = right - - err := Execute(context.Background(), Config{}, bottom) - require.NoError(t, err) -} - -func TestExecute_MultipleRoots(t *testing.T) { - a := &chainB{} - a.Deps.A = &chainA{} - b := &chainB{} - b.Deps.A = &chainA{} - - err := Execute(context.Background(), Config{}, a, b) - require.NoError(t, err) -} - -func TestExecute_MultipleRoots_SharedTask(t *testing.T) { - // Two roots share the same leaf — it should run only once - shared := &valueTask{Value: 7} - - add1 := &addTask{} - add1.Deps.A = shared - add1.Deps.B = &valueTask{Value: 3} - - add2 := &addTask{} - add2.Deps.A = shared - add2.Deps.B = &valueTask{Value: 5} - - err := Execute(context.Background(), Config{}, add1, add2) - require.NoError(t, err) - assert.Equal(t, 10, add1.Output) - assert.Equal(t, 12, add2.Output) - assert.Equal(t, 7, shared.Output) -} - -// --- context cancellation --- - -func TestExecute_PreCanceledContext(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - v := &valueTask{Value: 1} - // Should not hang — either succeeds or returns error - done := make(chan error, 1) - go func() { - done <- Execute(ctx, Config{}, v) - }() - - select { - case <-done: - // good — didn't hang - case <-time.After(100 * time.Millisecond): - t.Fatal("Execute hung on pre-canceled context") - } -} - -// concurrencyTracker is a package-level task that tracks max concurrency. -type concurrencyTracker struct { - current *atomic.Int32 - peak *atomic.Int32 - Output int -} - -func (t *concurrencyTracker) Do(ctx context.Context) error { - cur := t.current.Add(1) - // Update peak - for { - p := t.peak.Load() - if cur <= p || t.peak.CompareAndSwap(p, cur) { - break - } - } - time.Sleep(10 * time.Millisecond) - t.current.Add(-1) - t.Output = 1 - return nil -} - -type concurrencyRoot struct { - Deps struct { - A *concurrencyTracker - B *concurrencyTracker - C *concurrencyTracker - D *concurrencyTracker - } -} - -func (t *concurrencyRoot) Do(ctx context.Context) error { return nil } - -func TestExecute_ConcurrentIndependentTasks(t *testing.T) { - var current, peak atomic.Int32 - - a := &concurrencyTracker{current: ¤t, peak: &peak} - b := &concurrencyTracker{current: ¤t, peak: &peak} - c := &concurrencyTracker{current: ¤t, peak: &peak} - d := &concurrencyTracker{current: ¤t, peak: &peak} - - root := &concurrencyRoot{} - root.Deps.A = a - root.Deps.B = b - root.Deps.C = c - root.Deps.D = d - - err := Execute(context.Background(), Config{}, root) - require.NoError(t, err) - // With unlimited concurrency, all 4 should run in parallel - assert.Greater(t, peak.Load(), int32(1), "independent tasks should run concurrently") -} - -func TestExecute_MaxConcurrency_LimitsParallelism(t *testing.T) { - var current, peak atomic.Int32 - - a := &concurrencyTracker{current: ¤t, peak: &peak} - b := &concurrencyTracker{current: ¤t, peak: &peak} - c := &concurrencyTracker{current: ¤t, peak: &peak} - d := &concurrencyTracker{current: ¤t, peak: &peak} - - root := &concurrencyRoot{} - root.Deps.A = a - root.Deps.B = b - root.Deps.C = c - root.Deps.D = d - - err := Execute(context.Background(), Config{MaxConcurrency: 2}, root) - require.NoError(t, err) - assert.LessOrEqual(t, peak.Load(), int32(2), "max concurrency should be respected") -} diff --git a/e2e/tasks/graph.go b/e2e/tasks/graph.go deleted file mode 100644 index d5a701661c3..00000000000 --- a/e2e/tasks/graph.go +++ /dev/null @@ -1,118 +0,0 @@ -package tasks - -import ( - "fmt" - "reflect" -) - -// graph represents the discovered DAG. -type graph struct { - // nodes is all tasks in the graph, deduplicated by pointer identity. - nodes []Task - // deps maps each task to its direct dependencies. - deps map[Task][]Task - // dependents maps each task to the tasks that depend on it. - dependents map[Task][]Task - // order is a valid topological sort, populated by validateNoCycles. - // The scheduler does not use it — execution order is driven by the - // dep-counting algorithm in runGraph. This field exists for test - // introspection (asserting correct ordering). - order []Task -} - -// discoverGraph walks the Deps fields of the given root tasks recursively -// to build the full DAG. Tasks are deduplicated by pointer identity. -func discoverGraph(roots []Task) (*graph, error) { - g := &graph{ - deps: make(map[Task][]Task), - dependents: make(map[Task][]Task), - } - visited := make(map[Task]bool) - for _, root := range roots { - if err := g.walk(root, visited); err != nil { - return nil, err - } - } - return g, nil -} - -func (g *graph) walk(task Task, visited map[Task]bool) error { - if visited[task] { - return nil - } - visited[task] = true - g.nodes = append(g.nodes, task) - - deps, err := extractDeps(task) - if err != nil { - return err - } - - g.deps[task] = deps - for _, dep := range deps { - g.dependents[dep] = append(g.dependents[dep], task) - if err := g.walk(dep, visited); err != nil { - return err - } - } - return nil -} - -var taskType = reflect.TypeOf((*Task)(nil)).Elem() - -// extractDeps reads the Deps field of a task via reflection and returns -// all dependency tasks found as pointer fields. -func extractDeps(task Task) ([]Task, error) { - v := reflect.ValueOf(task) - if v.Kind() != reflect.Ptr { - return nil, &ValidationError{Task: task, Message: "task must be a pointer"} - } - v = v.Elem() - if v.Kind() != reflect.Struct { - return nil, &ValidationError{Task: task, Message: "task must be a pointer to a struct"} - } - - depsField := v.FieldByName("Deps") - if !depsField.IsValid() { - return nil, nil - } - - if depsField.Kind() != reflect.Struct { - return nil, &ValidationError{ - Task: task, - Message: "Deps field must be a struct", - } - } - - depsType := depsField.Type() - var deps []Task - for i := range depsField.NumField() { - field := depsField.Field(i) - fieldInfo := depsType.Field(i) - - if field.Kind() != reflect.Ptr { - return nil, &ValidationError{ - Task: task, - Message: fmt.Sprintf("Deps.%s must be a pointer, got %s", fieldInfo.Name, field.Type()), - } - } - - if field.IsNil() { - return nil, &ValidationError{ - Task: task, - Message: fmt.Sprintf("Deps.%s is nil", fieldInfo.Name), - } - } - - if !field.Type().Implements(taskType) { - return nil, &ValidationError{ - Task: task, - Message: fmt.Sprintf("Deps.%s: %s does not implement Task", fieldInfo.Name, field.Type()), - } - } - - dep := field.Interface().(Task) - deps = append(deps, dep) - } - return deps, nil -} diff --git a/e2e/tasks/graph_test.go b/e2e/tasks/graph_test.go deleted file mode 100644 index ea08afdf24a..00000000000 --- a/e2e/tasks/graph_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package tasks - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- test task types for graph discovery --- - -type leafTask struct { - Output string -} - -func (t *leafTask) Do(ctx context.Context) error { return nil } - -type chainA struct{ Output string } -type chainB struct { - Deps struct{ A *chainA } -} -type chainC struct { - Deps struct{ B *chainB } -} - -func (t *chainA) Do(ctx context.Context) error { return nil } -func (t *chainB) Do(ctx context.Context) error { return nil } -func (t *chainC) Do(ctx context.Context) error { return nil } - -type diamondTop struct{ Output string } -type diamondLeft struct { - Deps struct{ Top *diamondTop } -} -type diamondRight struct { - Deps struct{ Top *diamondTop } -} -type diamondBottom struct { - Deps struct { - Left *diamondLeft - Right *diamondRight - } -} - -func (t *diamondTop) Do(ctx context.Context) error { return nil } -func (t *diamondLeft) Do(ctx context.Context) error { return nil } -func (t *diamondRight) Do(ctx context.Context) error { return nil } -func (t *diamondBottom) Do(ctx context.Context) error { return nil } - -type badDepsNotStruct struct { - Deps int -} - -func (t *badDepsNotStruct) Do(ctx context.Context) error { return nil } - -type badDepsNonPointer struct { - Deps struct { - A chainA - } -} - -func (t *badDepsNonPointer) Do(ctx context.Context) error { return nil } - -type badDepsNonTask struct { - Deps struct { - S *string - } -} - -func (t *badDepsNonTask) Do(ctx context.Context) error { return nil } - -func TestDiscoverGraph_Leaf(t *testing.T) { - task := &leafTask{} - g, err := discoverGraph([]Task{task}) - require.NoError(t, err) - assert.Len(t, g.nodes, 1) - assert.Empty(t, g.deps[task]) -} - -func TestDiscoverGraph_Chain(t *testing.T) { - a := &chainA{} - b := &chainB{} - b.Deps.A = a - c := &chainC{} - c.Deps.B = b - - g, err := discoverGraph([]Task{c}) - require.NoError(t, err) - assert.Len(t, g.nodes, 3) - - assert.Equal(t, []Task{Task(b)}, g.deps[c]) - assert.Equal(t, []Task{Task(a)}, g.deps[b]) - assert.Empty(t, g.deps[a]) -} - -func TestDiscoverGraph_Diamond(t *testing.T) { - top := &diamondTop{} - left := &diamondLeft{} - left.Deps.Top = top - right := &diamondRight{} - right.Deps.Top = top - bottom := &diamondBottom{} - bottom.Deps.Left = left - bottom.Deps.Right = right - - g, err := discoverGraph([]Task{bottom}) - require.NoError(t, err) - assert.Len(t, g.nodes, 4, "top should be deduplicated") -} - -func TestDiscoverGraph_NilDep(t *testing.T) { - b := &chainB{} // Deps.A is nil - _, err := discoverGraph([]Task{b}) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) - assert.Contains(t, ve.Message, "nil") -} - -func TestDiscoverGraph_DepsNotStruct(t *testing.T) { - task := &badDepsNotStruct{Deps: 42} - _, err := discoverGraph([]Task{task}) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) - assert.Contains(t, ve.Message, "struct") -} - -func TestDiscoverGraph_NonPointerInDeps(t *testing.T) { - task := &badDepsNonPointer{} - _, err := discoverGraph([]Task{task}) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) - assert.Contains(t, ve.Message, "pointer") -} - -func TestDiscoverGraph_NonTaskPointerInDeps(t *testing.T) { - s := "hello" - task := &badDepsNonTask{} - task.Deps.S = &s - _, err := discoverGraph([]Task{task}) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) - assert.Contains(t, ve.Message, "Task") -} diff --git a/e2e/tasks/integration_test.go b/e2e/tasks/integration_test.go deleted file mode 100644 index f30c1b4c879..00000000000 --- a/e2e/tasks/integration_test.go +++ /dev/null @@ -1,266 +0,0 @@ -package tasks - -import ( - "context" - "errors" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// --- Spec example task definitions --- -// Mirrors the complete example from the design spec: -// CreateRG → CreateVNet → CreateSubnet → CreateCluster → RunTests → Teardown - -type createRGOutput struct { - RGName string -} - -type createRG struct { - Output createRGOutput -} - -func (t *createRG) Do(ctx context.Context) error { - t.Output.RGName = "my-rg" - return nil -} - -type createVNet struct { - Deps struct { - RG *createRG - } - Output struct { - VNetID string - } -} - -func (t *createVNet) Do(ctx context.Context) error { - t.Output.VNetID = fmt.Sprintf("%s-vnet", t.Deps.RG.Output.RGName) - return nil -} - -type createSubnet struct { - Deps struct { - VNet *createVNet - } - Output struct { - SubnetID string - } -} - -func (t *createSubnet) Do(ctx context.Context) error { - t.Output.SubnetID = fmt.Sprintf("%s-subnet", t.Deps.VNet.Output.VNetID) - return nil -} - -type createCluster struct { - Deps struct { - RG *createRG - Subnet *createSubnet - } - Output struct { - ClusterID string - } -} - -func (t *createCluster) Do(ctx context.Context) error { - t.Output.ClusterID = fmt.Sprintf("cluster-in-%s-%s", - t.Deps.RG.Output.RGName, - t.Deps.Subnet.Output.SubnetID) - return nil -} - -type runTests struct { - Deps struct { - Cluster *createCluster - } - Output struct { - Passed bool - } -} - -func (t *runTests) Do(ctx context.Context) error { - t.Output.Passed = true - return nil -} - -type teardown struct { - Deps struct { - RG *createRG - Tests *runTests - } - Output struct { - TornDown bool - } -} - -func (t *teardown) Do(ctx context.Context) error { - t.Output.TornDown = true - return nil -} - -// --- Integration tests --- - -// specDAG holds all wired nodes from the spec example for reuse across tests. -type specDAG struct { - RG *createRG - VNet *createVNet - Subnet *createSubnet - Cluster *createCluster - Tests *runTests - TD *teardown -} - -// buildSpecDAG wires the full spec example DAG: -// -// CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ -// │ │ -// ├──────────────── CreateCluster ──┘ -// │ │ -// │ RunTests -// │ │ -// └──────────────── Teardown -func buildSpecDAG() specDAG { - rg := &createRG{} - vnet := &createVNet{} - vnet.Deps.RG = rg - subnet := &createSubnet{} - subnet.Deps.VNet = vnet - cluster := &createCluster{} - cluster.Deps.RG = rg - cluster.Deps.Subnet = subnet - tests := &runTests{} - tests.Deps.Cluster = cluster - td := &teardown{} - td.Deps.RG = rg - td.Deps.Tests = tests - return specDAG{RG: rg, VNet: vnet, Subnet: subnet, Cluster: cluster, Tests: tests, TD: td} -} - -func TestIntegration_SpecExample(t *testing.T) { - d := buildSpecDAG() - - err := Execute(context.Background(), Config{}, d.TD) - require.NoError(t, err) - - // Verify all outputs propagated correctly - assert.Equal(t, "my-rg", d.RG.Output.RGName) - assert.Equal(t, "my-rg-vnet", d.VNet.Output.VNetID) - assert.Equal(t, "my-rg-vnet-subnet", d.Subnet.Output.SubnetID) - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) - assert.True(t, d.Tests.Output.Passed) - assert.True(t, d.TD.Output.TornDown) -} - -func TestIntegration_SpecExample_WithMaxConcurrency(t *testing.T) { - d := buildSpecDAG() - - err := Execute(context.Background(), Config{MaxConcurrency: 1}, d.TD) - require.NoError(t, err) - - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) - assert.True(t, d.TD.Output.TornDown) -} - -func TestIntegration_TransitiveDependencyAccess(t *testing.T) { - // Verify that a task can read transitive dependencies through Deps chains - // as described in the spec's "Accessing Transitive Dependencies" section. - d := buildSpecDAG() - - err := Execute(context.Background(), Config{}, d.Cluster) - require.NoError(t, err) - - // Access transitive dep: cluster -> subnet -> vnet -> rg - rgName := d.Cluster.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName - assert.Equal(t, "my-rg", rgName) -} - -// failingRunTests simulates a test failure mid-pipeline -type failingRunTests struct { - Deps struct { - Cluster *createCluster - } -} - -func (t *failingRunTests) Do(ctx context.Context) error { - return fmt.Errorf("tests failed: 2 of 10 scenarios failed") -} - -type teardownAfterFail struct { - Deps struct { - RG *createRG - Tests *failingRunTests - } - Output struct{ TornDown bool } -} - -func (t *teardownAfterFail) Do(ctx context.Context) error { - t.Output.TornDown = true - return nil -} - -func TestIntegration_MidPipelineFailure_CancelDependents(t *testing.T) { - d := buildSpecDAG() - failTests := &failingRunTests{} - failTests.Deps.Cluster = d.Cluster - - td := &teardownAfterFail{} - td.Deps.RG = d.RG - td.Deps.Tests = failTests - - err := Execute(context.Background(), Config{OnError: CancelDependents}, td) - require.Error(t, err) - - var dagErr *DAGError - require.True(t, errors.As(err, &dagErr)) - - // Upstream tasks should have succeeded - assert.Equal(t, Succeeded, dagErr.Results[d.RG].Status) - assert.Equal(t, Succeeded, dagErr.Results[d.VNet].Status) - assert.Equal(t, Succeeded, dagErr.Results[d.Subnet].Status) - assert.Equal(t, Succeeded, dagErr.Results[d.Cluster].Status) - - // failTests should have failed - assert.Equal(t, Failed, dagErr.Results[failTests].Status) - assert.Contains(t, dagErr.Results[failTests].Err.Error(), "tests failed") - - // teardown should be skipped since it depends on failTests - assert.Equal(t, Skipped, dagErr.Results[td].Status) - - // Outputs of successful tasks should still be populated - assert.Equal(t, "my-rg", d.RG.Output.RGName) - assert.Equal(t, "cluster-in-my-rg-my-rg-vnet-subnet", d.Cluster.Output.ClusterID) -} - -func TestIntegration_TwoIndependentSubgraphs_SharedTask(t *testing.T) { - // Two independent pipelines share CreateRG. - // Both should complete, CreateRG should execute only once. - rg := &createRG{} - - vnet1 := &createVNet{} - vnet1.Deps.RG = rg - vnet2 := &createVNet{} - vnet2.Deps.RG = rg - - subnet1 := &createSubnet{} - subnet1.Deps.VNet = vnet1 - subnet2 := &createSubnet{} - subnet2.Deps.VNet = vnet2 - - // Both subnets are roots; they share rg - err := Execute(context.Background(), Config{}, subnet1, subnet2) - require.NoError(t, err) - - assert.Equal(t, "my-rg", rg.Output.RGName) - assert.Equal(t, "my-rg-vnet", vnet1.Output.VNetID) - assert.Equal(t, "my-rg-vnet", vnet2.Output.VNetID) - assert.Equal(t, "my-rg-vnet-subnet", subnet1.Output.SubnetID) - assert.Equal(t, "my-rg-vnet-subnet", subnet2.Output.SubnetID) -} - -func TestIntegration_EmptyGraph(t *testing.T) { - err := Execute(context.Background(), Config{}) - require.NoError(t, err) -} diff --git a/e2e/tasks/task.go b/e2e/tasks/task.go deleted file mode 100644 index aef842f335b..00000000000 --- a/e2e/tasks/task.go +++ /dev/null @@ -1,117 +0,0 @@ -// Package tasks provides a reflection-based DAG executor where tasks are -// structs implementing the [Task] interface and dependencies are discovered -// automatically from a Deps struct field. -// -// For most e2e use cases, prefer the sibling package [dag] which uses -// closures and generics for a lighter-weight, no-reflection API. -// This package is useful when tasks are reusable objects with complex -// dependency wiring that benefits from struct-based composition. -package tasks - -import ( - "context" - "fmt" - "sort" - "strings" -) - -// Task is the interface that all tasks must implement. -// Implement Do with a pointer receiver. -type Task interface { - Do(ctx context.Context) error -} - -// ErrorStrategy controls behavior when a task fails. -type ErrorStrategy int - -const ( - // CancelDependents skips tasks that transitively depend on the failed task. - // Independent branches continue running. - CancelDependents ErrorStrategy = iota - - // CancelAll cancels the context for all running and pending tasks. - CancelAll -) - -func (s ErrorStrategy) String() string { - switch s { - case CancelDependents: - return "CancelDependents" - case CancelAll: - return "CancelAll" - default: - return fmt.Sprintf("ErrorStrategy(%d)", int(s)) - } -} - -// Config controls execution behavior. -type Config struct { - // OnError controls what happens when a task fails. - // Default (zero value): CancelDependents. - OnError ErrorStrategy - - // MaxConcurrency limits how many tasks run in parallel. - // 0 (default): unlimited. 1: serial execution. - // Negative values are treated as 0 (unlimited). - MaxConcurrency int -} - -// TaskStatus represents the final status of a task after execution. -type TaskStatus int - -const ( - Succeeded TaskStatus = iota - Failed - Skipped - Canceled -) - -func (s TaskStatus) String() string { - switch s { - case Succeeded: - return "Succeeded" - case Failed: - return "Failed" - case Skipped: - return "Skipped" - case Canceled: - return "Canceled" - default: - return fmt.Sprintf("TaskStatus(%d)", int(s)) - } -} - -// TaskResult holds the outcome of a single task. -type TaskResult struct { - Status TaskStatus - Err error -} - -// DAGError is returned by Execute when one or more tasks did not succeed. -type DAGError struct { - Results map[Task]TaskResult -} - -func (e *DAGError) Error() string { - var failed []string - for task, result := range e.Results { - if result.Status != Succeeded { - failed = append(failed, fmt.Sprintf("%T: %s: %v", task, result.Status, result.Err)) - } - } - sort.Strings(failed) - return fmt.Sprintf("dag execution failed: %s", strings.Join(failed, "; ")) -} - -// ValidationError is returned when the task graph fails validation. -type ValidationError struct { - Task Task - Message string -} - -func (e *ValidationError) Error() string { - if e.Task != nil { - return fmt.Sprintf("validation error on %T: %s", e.Task, e.Message) - } - return fmt.Sprintf("validation error: %s", e.Message) -} diff --git a/e2e/tasks/task_test.go b/e2e/tasks/task_test.go deleted file mode 100644 index c362378105f..00000000000 --- a/e2e/tasks/task_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package tasks - -import ( - "context" - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type testTask struct { - Output string -} - -func (t *testTask) Do(ctx context.Context) error { - t.Output = "done" - return nil -} - -func TestTaskInterface(t *testing.T) { - var _ Task = (*testTask)(nil) -} - -func TestTaskStatusString(t *testing.T) { - tests := []struct { - status TaskStatus - want string - }{ - {Succeeded, "Succeeded"}, - {Failed, "Failed"}, - {Skipped, "Skipped"}, - {Canceled, "Canceled"}, - {TaskStatus(99), "TaskStatus(99)"}, - } - for _, tt := range tests { - assert.Equal(t, tt.want, tt.status.String()) - } -} - -func TestDAGErrorMessage(t *testing.T) { - task := &testTask{} - err := &DAGError{ - Results: map[Task]TaskResult{ - task: {Status: Failed, Err: fmt.Errorf("boom")}, - }, - } - msg := err.Error() - require.NotEmpty(t, msg) - assert.Contains(t, msg, "boom") - assert.Contains(t, msg, "Failed") -} - -func TestValidationErrorMessage(t *testing.T) { - task := &testTask{} - err := &ValidationError{Task: task, Message: "Deps.A is nil"} - assert.Contains(t, err.Error(), "testTask") - assert.Contains(t, err.Error(), "Deps.A is nil") - - // ValidationError without task - err2 := &ValidationError{Message: "cycle detected"} - assert.Contains(t, err2.Error(), "cycle detected") -} diff --git a/e2e/tasks/validate.go b/e2e/tasks/validate.go deleted file mode 100644 index 984bbe09d5b..00000000000 --- a/e2e/tasks/validate.go +++ /dev/null @@ -1,48 +0,0 @@ -package tasks - -import "fmt" - -// validateNoCycles checks the graph for cycles using Kahn's algorithm. -// On success, populates g.order with a valid topological sort. -func validateNoCycles(g *graph) error { - inDegree := make(map[Task]int, len(g.nodes)) - for _, node := range g.nodes { - inDegree[node] = len(g.deps[node]) - } - - var queue []Task - for _, node := range g.nodes { - if inDegree[node] == 0 { - queue = append(queue, node) - } - } - - var sorted []Task - for len(queue) > 0 { - node := queue[0] - queue = queue[1:] - sorted = append(sorted, node) - - for _, dependent := range g.dependents[node] { - inDegree[dependent]-- - if inDegree[dependent] == 0 { - queue = append(queue, dependent) - } - } - } - - if len(sorted) != len(g.nodes) { - var cycleNodes []string - for _, node := range g.nodes { - if inDegree[node] > 0 { - cycleNodes = append(cycleNodes, fmt.Sprintf("%T(%p)", node, node)) - } - } - return &ValidationError{ - Message: fmt.Sprintf("cycle detected among tasks: %v", cycleNodes), - } - } - - g.order = sorted - return nil -} diff --git a/e2e/tasks/validate_test.go b/e2e/tasks/validate_test.go deleted file mode 100644 index 3daa1aaa242..00000000000 --- a/e2e/tasks/validate_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package tasks - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -type cycleA struct { - Deps struct{ B *cycleB } -} -type cycleB struct { - Deps struct{ A *cycleA } -} - -func (t *cycleA) Do(ctx context.Context) error { return nil } -func (t *cycleB) Do(ctx context.Context) error { return nil } - -func TestValidateNoCycles_ValidDAG(t *testing.T) { - a := &chainA{} - b := &chainB{} - b.Deps.A = a - - g, err := discoverGraph([]Task{b}) - require.NoError(t, err) - - err = validateNoCycles(g) - require.NoError(t, err) - assert.Len(t, g.order, 2) - // a should come before b in topological order - assert.Equal(t, Task(a), g.order[0]) - assert.Equal(t, Task(b), g.order[1]) -} - -func TestValidateNoCycles_Cycle(t *testing.T) { - a := &cycleA{} - b := &cycleB{} - a.Deps.B = b - b.Deps.A = a - - g, err := discoverGraph([]Task{a}) - require.NoError(t, err) - - err = validateNoCycles(g) - require.Error(t, err) - - var ve *ValidationError - require.True(t, errors.As(err, &ve)) - assert.Contains(t, ve.Message, "cycle") -} - -func TestValidateNoCycles_Diamond(t *testing.T) { - top := &diamondTop{} - left := &diamondLeft{} - left.Deps.Top = top - right := &diamondRight{} - right.Deps.Top = top - bottom := &diamondBottom{} - bottom.Deps.Left = left - bottom.Deps.Right = right - - g, err := discoverGraph([]Task{bottom}) - require.NoError(t, err) - - err = validateNoCycles(g) - require.NoError(t, err) - assert.Len(t, g.order, 4) - // top must come before left and right, which must come before bottom - assert.Equal(t, Task(top), g.order[0]) - assert.Equal(t, Task(bottom), g.order[3]) -} From 4ad379a20a47b4e1960ad59eaf506438faff281c Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 09:35:12 +1300 Subject: [PATCH 17/22] chore: remove design spec docs, revert config.go changes --- .../specs/2026-03-20-taskdag-design.md | 368 ------------------ e2e/config/config.go | 2 +- 2 files changed, 1 insertion(+), 369 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-20-taskdag-design.md diff --git a/docs/superpowers/specs/2026-03-20-taskdag-design.md b/docs/superpowers/specs/2026-03-20-taskdag-design.md deleted file mode 100644 index 0eff67c429a..00000000000 --- a/docs/superpowers/specs/2026-03-20-taskdag-design.md +++ /dev/null @@ -1,368 +0,0 @@ -# tasks — Type-Safe DAG Execution Library - -## Problem - -`go-workflow` wires task dependencies through untyped closures. Dependencies are declared separately from the data they carry (`DependsOn` + `Input` callbacks), making it easy to forget one or the other. The result: nil-pointer panics at runtime, logic that's hard to follow, and no compile-time safety on data flow between tasks. - -## Goal - -A general-purpose Go library for defining and executing tasks as a DAG with: - -1. **Type-safe dependencies** — each task declares its upstream tasks as typed struct fields. The compiler enforces field types; the framework enforces that they're wired. -2. **Concurrent execution** — independent tasks run in parallel automatically. -3. **Simplicity** — no wrapper types, no magic registration functions, no hidden state. Tasks are plain Go structs. - -## Core Design - -### Task Contract - -A task is any struct that implements: - -```go -type Task interface { - Do(ctx context.Context) error -} -``` - -**Tasks must use pointer receivers for `Do`.** Since tasks write to `self.Output` during execution, a value receiver would discard the result. The framework validates this at graph construction time. - -Dependencies are declared as a struct field named `Deps` containing pointers to upstream tasks. Outputs are written to a field named `Output`. Both are optional — a leaf task has no Deps, a sink task has no Output. - -**The `Deps` struct must contain only pointers to types implementing `Task`.** Any other field type (e.g., `*string`, `int`, config structs) is a validation error. Use separate struct fields outside `Deps` for non-task data. - -```go -type BuildOutput struct { - ImagePath string -} - -type BuildImage struct { - Output BuildOutput -} - -func (b *BuildImage) Do(ctx context.Context) error { - b.Output = BuildOutput{ImagePath: "/img"} - return nil -} - -type DeployDeps struct { - Build *BuildImage - Config *LoadConfig -} - -type DeployOutput struct { - URL string -} - -type Deploy struct { - Deps DeployDeps - Output DeployOutput -} - -func (d *Deploy) Do(ctx context.Context) error { - d.Output = DeployOutput{ - URL: fmt.Sprintf("%s:%d", d.Deps.Build.Output.ImagePath, d.Deps.Config.Output.Port), - } - return nil -} -``` - -### Wiring - -The DAG is expressed through plain Go struct initialization: - -```go -build := &BuildImage{} -config := &LoadConfig{} -deploy := &Deploy{ - Deps: DeployDeps{Build: build, Config: config}, -} -``` - -No `Add()`, no `Connect()`, no `DependsOn()`. The struct field assignments *are* the dependency declarations. - -### Execution - -```go -func Execute(ctx context.Context, cfg Config, roots ...Task) error -``` - -`Execute` takes a context, a config, and one or more root tasks: - -```go -// Single root with default config -err := tasks.Execute(ctx, tasks.Config{}, deploy) - -// Multiple roots -err := tasks.Execute(ctx, tasks.Config{}, teardown1, teardown2) - -// With options -err := tasks.Execute(ctx, tasks.Config{ - OnError: tasks.CancelAll, - MaxConcurrency: 4, -}, deploy) -``` - -When multiple roots are provided, all tasks across their graphs are deduplicated by pointer identity and run as a single DAG. If a root also appears as an interior node of another root's graph, it is deduplicated — not an error. - -`Execute` proceeds as follows: - -1. **Walks the graph** — reflects over each task's `Deps` field, follows pointers recursively to discover the full DAG. Nodes are identified by pointer identity. -2. **Validates** — checks for cycles (via topological sort on pointer-identity nodes), nil Deps pointers, and invalid Deps field types. Returns an error before running anything if the graph is invalid. -3. **Deduplicates** — the same task pointer reached via multiple paths (diamond dependency) is executed exactly once. -4. **Schedules** — runs tasks concurrently. A task starts only after all its Deps have completed successfully (or with the appropriate status per the error strategy). -5. **Outputs are available** — since Deps hold pointers to upstream tasks, `task.Deps.Upstream.Output` is directly readable inside `Do()`. The framework guarantees happens-before ordering: a task's goroutine is only launched after all upstream goroutines have completed and their results are visible (synchronized via `sync.WaitGroup` or channel). - -## Configuration - -```go -type Config struct { - // OnError controls what happens when a task fails. - // Default (zero value): CancelDependents. - OnError ErrorStrategy - - // MaxConcurrency limits how many tasks run in parallel. - // 0 (default): unlimited. 1: serial execution (useful for debugging). - // Negative values are treated as 0 (unlimited). - MaxConcurrency int -} -``` - -### Error Strategies - -**`CancelDependents` (default / zero value):** When a task fails, all tasks that transitively depend on it are skipped (status `Skipped`). Independent branches continue running. Already-running tasks are not interrupted. - -**`CancelAll`:** When any task fails, the context passed to all running and future tasks is canceled. Tasks currently in `Do()` receive cancellation via `ctx.Done()` and should return promptly (status `Canceled`). Tasks that haven't started yet also get status `Canceled`. - -## Graph Discovery via Reflection - -At `Execute` time, the framework: - -1. For each task, checks if it has a `Deps` field of struct type. -2. Iterates over all fields in `Deps`. Each field must be a pointer to a struct that implements `Task`. -3. Follows those pointers recursively to discover the full graph. -4. Non-pointer fields in Deps, or pointers to non-Task types, are a validation error. -5. Nil pointer fields in Deps are a validation error. - -The framework never touches `Output` — that's purely a user convention. Tasks write to `self.Output`, downstream tasks read `dep.Output`. The framework only cares about `Deps` pointers and the `Task` interface. - -### Concurrency Safety - -The framework guarantees that when `Do(ctx)` is called on a task, all upstream tasks have fully completed and their writes (including to `Output` fields) are visible. This happens-before relationship is established through Go synchronization primitives (e.g., `sync.WaitGroup.Done()` in the upstream goroutine, `sync.WaitGroup.Wait()` before launching the downstream goroutine). - -Tasks must not mutate their `Deps` fields during `Do()`. Doing so is undefined behavior. - -## Error Reporting - -`Execute` returns a `*DAGError` containing the result of every task: - -```go -type DAGError struct { - // Results is keyed by task pointer. Since tasks must be pointers, - // they are comparable and safe to use as map keys. - Results map[Task]TaskResult -} - -type TaskResult struct { - Status TaskStatus - Err error // nil if Succeeded -} - -type TaskStatus int - -const ( - Succeeded TaskStatus = iota - Failed // Do() returned a non-nil error - Skipped // a dependency failed (CancelDependents mode); task was never started - Canceled // context was canceled (CancelAll mode); task may or may not have started -) -``` - -`DAGError` implements `error`. `Execute` returns `nil` if all tasks succeeded. - -### Inspecting Results - -```go -err := tasks.Execute(ctx, tasks.Config{}, root) -var dagErr *tasks.DAGError -if errors.As(err, &dagErr) { - for task, result := range dagErr.Results { - fmt.Printf("%T: %s %v\n", task, result.Status, result.Err) - } -} -``` - -## Task Reuse - -Each `Execute` call re-runs all tasks in the graph from scratch. The framework does **not** reset `Output` fields — it is the user's responsibility to ensure `Do()` overwrites `Output` fully. If a task can fail partway through writing `Output`, the user should write to a local variable first and assign to `Output` only on success. - -## Accessing Transitive Dependencies - -A task can read through its deps to access transitive outputs: - -```go -func (c *CreateCluster) Do(ctx context.Context) error { - rgName := c.Deps.Subnet.Deps.VNet.Deps.RG.Output.RGName - return nil -} -``` - -This is safe — DAG ordering guarantees all transitive deps have completed. However, it creates coupling to the internal structure of transitive dependencies. Prefer declaring direct deps when practical. - -## Complete Example - -```go -package main - -import ( - "context" - "fmt" - - "github.com/example/tasks" -) - -// --- Task definitions --- - -type CreateRGOutput struct { - RGName string -} - -type CreateRG struct { - Output CreateRGOutput -} - -func (t *CreateRG) Do(ctx context.Context) error { - t.Output.RGName = "my-rg" - return nil -} - -type CreateVNetDeps struct { - RG *CreateRG -} - -type CreateVNetOutput struct { - VNetID string -} - -type CreateVNet struct { - Deps CreateVNetDeps - Output CreateVNetOutput -} - -func (t *CreateVNet) Do(ctx context.Context) error { - t.Output.VNetID = fmt.Sprintf("%s-vnet", t.Deps.RG.Output.RGName) - return nil -} - -type CreateSubnetDeps struct { - VNet *CreateVNet -} - -type CreateSubnetOutput struct { - SubnetID string -} - -type CreateSubnet struct { - Deps CreateSubnetDeps - Output CreateSubnetOutput -} - -func (t *CreateSubnet) Do(ctx context.Context) error { - t.Output.SubnetID = fmt.Sprintf("%s-subnet", t.Deps.VNet.Output.VNetID) - return nil -} - -type CreateClusterDeps struct { - RG *CreateRG - Subnet *CreateSubnet -} - -type CreateClusterOutput struct { - ClusterID string -} - -type CreateCluster struct { - Deps CreateClusterDeps - Output CreateClusterOutput -} - -func (t *CreateCluster) Do(ctx context.Context) error { - t.Output.ClusterID = fmt.Sprintf("cluster-in-%s-%s", - t.Deps.RG.Output.RGName, - t.Deps.Subnet.Output.SubnetID) - return nil -} - -type RunTestsDeps struct { - Cluster *CreateCluster -} - -type RunTests struct { - Deps RunTestsDeps -} - -func (t *RunTests) Do(ctx context.Context) error { - fmt.Println("Running tests on", t.Deps.Cluster.Output.ClusterID) - return nil -} - -type TeardownDeps struct { - RG *CreateRG - Tests *RunTests -} - -type Teardown struct { - Deps TeardownDeps -} - -func (t *Teardown) Do(ctx context.Context) error { - fmt.Println("Tearing down", t.Deps.RG.Output.RGName) - return nil -} - -// --- Wiring and execution --- - -func main() { - rg := &CreateRG{} - vnet := &CreateVNet{Deps: CreateVNetDeps{RG: rg}} - subnet := &CreateSubnet{Deps: CreateSubnetDeps{VNet: vnet}} - cluster := &CreateCluster{Deps: CreateClusterDeps{RG: rg, Subnet: subnet}} - tests := &RunTests{Deps: RunTestsDeps{Cluster: cluster}} - teardown := &Teardown{Deps: TeardownDeps{RG: rg, Tests: tests}} - - // DAG (concurrent where possible): - // - // CreateRG ──┬── CreateVNet ── CreateSubnet ──┐ - // │ │ - // ├──────────────── CreateCluster ──┘ - // │ │ - // │ RunTests - // │ │ - // └──────────────── Teardown - - err := tasks.Execute(context.Background(), tasks.Config{}, teardown) - if err != nil { - panic(err) - } -} -``` - -## Validation Rules (enforced at Execute time) - -| Rule | Detection | Error | -|------|-----------|-------| -| Nil pointer in Deps | Reflection | `"task %T has nil dependency field %s"` | -| Deps field is not a pointer to Task | Reflection | `"task %T.Deps.%s: %T does not implement Task"` | -| Cycle in dependency graph | Topological sort (pointer identity) | `"cycle detected: %T(%p) -> %T(%p) -> ..."` | -| Deps field is not a struct | Reflection | `"task %T.Deps must be a struct"` | - -## What's NOT in Scope (V1) - -Intentionally deferred to keep V1 minimal: - -- **Retry / timeout** — implement inside `Do()`. Framework support later. -- **Conditional execution** — adds complexity. Deferred. -- **Observability hooks** — deferred. Users can wrap tasks. -- **Step naming / logging** — can use `fmt.Stringer`. Deferred. -- **WorkflowMutator pattern** — not needed. The graph is just Go structs; mutation is just Go code. -- **Output reset between re-runs** — user responsibility. Framework doesn't touch Output. -- **Linter for nil deps** — out of scope for the library, but a natural companion tool. diff --git a/e2e/config/config.go b/e2e/config/config.go index 56ecdb11b8a..cebf0bde7f7 100644 --- a/e2e/config/config.go +++ b/e2e/config/config.go @@ -180,7 +180,7 @@ func mustGetNewRSAKeyPair() ([]byte, []byte, string) { privateKeyFileName, err := writePrivateKeyToTempFile(privatePEMBytes) if err != nil { - panic(fmt.Sprintf("failed to write private key to temp file: %v", err)) + panic(fmt.Sprintf("failed to write private key to temp file: %w", err)) } return privatePEMBytes, publicKeyBytes, privateKeyFileName From 388a96a4bf0f060beb0e45a541790e1848c2c066 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 09:41:34 +1300 Subject: [PATCH 18/22] refactor(e2e): remove bind/bindRun helpers, use inline closures --- e2e/cluster.go | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index 5776ab59492..00838feed34 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -75,17 +75,21 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag g := dag.NewGroup(ctx) - bastion := dag.Go(g, bind(getOrCreateBastion, cluster)) - dag.Run(g, bindRun(ensureMaintenanceConfiguration, cluster)) - subnet := dag.Go(g, bind(getClusterSubnetID, cluster)) - kube := dag.Go(g, bind(getClusterKubeClient, cluster)) - identity := dag.Go(g, bind(getClusterKubeletIdentity, cluster)) - dag.Run(g, bindRun(collectGarbageVMSS, cluster)) + bastion := dag.Go(g, func(ctx context.Context) (*Bastion, error) { + return getOrCreateBastion(ctx, cluster) + }) + dag.Run(g, func(ctx context.Context) error { return ensureMaintenanceConfiguration(ctx, cluster) }) + subnet := dag.Go(g, func(ctx context.Context) (string, error) { return getClusterSubnetID(ctx, cluster) }) + kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { return getClusterKubeClient(ctx, cluster) }) + identity := dag.Go(g, func(ctx context.Context) (*armcontainerservice.UserAssignedIdentity, error) { + return getClusterKubeletIdentity(ctx, cluster) + }) + dag.Run(g, func(ctx context.Context) error { return collectGarbageVMSS(ctx, cluster) }) if !isNetworkIsolated { - dag.Run(g, bindRun(addFirewallRules, cluster)) + dag.Run(g, func(ctx context.Context) error { return addFirewallRules(ctx, cluster) }) } if isNetworkIsolated { - dag.Run(g, bindRun(addNetworkIsolatedSettings, cluster)) + dag.Run(g, func(ctx context.Context) error { return addNetworkIsolatedSettings(ctx, cluster) }) } needACR := isNetworkIsolated || attachPrivateAcr acrNonAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, true)) @@ -106,16 +110,6 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag }, nil } -// bind returns func(ctx) → (T, error) by binding arg to fn. -func bind[A, T any](fn func(context.Context, A) (T, error), arg A) func(context.Context) (T, error) { - return func(ctx context.Context) (T, error) { return fn(ctx, arg) } -} - -// bindRun returns func(ctx) → error by binding arg to fn. -func bindRun[A any](fn func(context.Context, A) error, arg A) func(context.Context) error { - return func(ctx context.Context) error { return fn(ctx, arg) } -} - func addACR(cluster *armcontainerservice.ManagedCluster, needACR, isNonAnonymousPull bool) func(context.Context, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { return func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { if !needACR { From 8718b71e0536a50be10093f0e6db624ad399c430 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 09:43:22 +1300 Subject: [PATCH 19/22] docs(e2e): add note to keep prepareCluster minimal --- e2e/cluster.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/e2e/cluster.go b/e2e/cluster.go index 00838feed34..744ee7c0a25 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -61,6 +61,9 @@ func (c *Cluster) MaxPodsPerNode() (int, error) { } // prepareCluster runs all cluster preparation steps as a concurrent DAG. +// This function contains complex concurrent orchestration — keep it as +// minimal as possible and push all non-trivial logic into the individual +// task functions it calls. func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, isNetworkIsolated, attachPrivateAcr bool) (*Cluster, error) { defer toolkit.LogStepCtx(ctx, "preparing cluster")() ctx, cancel := context.WithTimeout(ctx, config.Config.TestTimeoutCluster) From 31cd7478fecb42df0914465b0beef9f484a63d6d Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 10:03:09 +1300 Subject: [PATCH 20/22] fix: address PR review comments on dag package and cluster.go - Fix package doc to reference correct API names (Go/Run not Spawn/Do) - Fix section comments in tests to match current API names - Wait() now surfaces ctx.Err() when parent context is cancelled - TestParentContextCancelled asserts error is context.Canceled - ensureMaintenanceConfiguration 404 path wraps error with cluster context - Firewall/NSG tasks are now dependencies of ensureDebugDaemonsets --- e2e/cluster.go | 12 ++++++++---- e2e/dag/dag.go | 21 +++++++++++++-------- e2e/dag/dag_test.go | 18 +++++++++++------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/e2e/cluster.go b/e2e/cluster.go index 744ee7c0a25..d002befc17c 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -88,16 +88,17 @@ func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.Manag return getClusterKubeletIdentity(ctx, cluster) }) dag.Run(g, func(ctx context.Context) error { return collectGarbageVMSS(ctx, cluster) }) + var networkDeps []dag.Dep if !isNetworkIsolated { - dag.Run(g, func(ctx context.Context) error { return addFirewallRules(ctx, cluster) }) + networkDeps = append(networkDeps, dag.Run(g, func(ctx context.Context) error { return addFirewallRules(ctx, cluster) })) } if isNetworkIsolated { - dag.Run(g, func(ctx context.Context) error { return addNetworkIsolatedSettings(ctx, cluster) }) + networkDeps = append(networkDeps, dag.Run(g, func(ctx context.Context) error { return addNetworkIsolatedSettings(ctx, cluster) })) } needACR := isNetworkIsolated || attachPrivateAcr acrNonAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, true)) acrAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, false)) - dag.Run1(g, kube, ensureDebugDaemonsets(cluster, isNetworkIsolated), acrNonAnon, acrAnon) + dag.Run1(g, kube, ensureDebugDaemonsets(cluster, isNetworkIsolated), append([]dag.Dep{acrNonAnon, acrAnon}, networkDeps...)...) extract := dag.Go1(g, kube, extractClusterParams(cluster)) if err := g.Wait(); err != nil { @@ -421,7 +422,10 @@ func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerse var azErr *azcore.ResponseError if errors.As(err, &azErr) && azErr.StatusCode == 404 { _, err = createNewMaintenanceConfiguration(ctx, cluster) - return err + if err != nil { + return fmt.Errorf("creating maintenance configuration for cluster %q: %w", *cluster.Name, err) + } + return nil } if err != nil { return fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go index 3e538725292..253f99b76fa 100644 --- a/e2e/dag/dag.go +++ b/e2e/dag/dag.go @@ -8,18 +8,19 @@ // There are two kinds of tasks: // // - Value-producing tasks return (T, error) and are represented by [Result][T]. -// Register with [Spawn] (no typed deps) or [Then] / [Then2] / [Then3] (typed deps). +// Register with [Go] (no typed deps) or [Go1] / [Go2] / [Go3] (typed deps). // // - Side-effect tasks return only error and are represented by [Effect]. -// Register with [Do] (no typed deps) or [ThenDo] / [ThenDo2] / [ThenDo3] (typed deps). +// Register with [Run] (no typed deps) or [Run1] / [Run2] / [Run3] (typed deps). // // Both [Result] and [Effect] implement [Dep], so they can be listed as // dependencies of downstream tasks. // -// When a typed dependency is used (Then/ThenDo variants), the dependency's -// value is passed as a function parameter — the compiler enforces correct -// wiring. When untyped dependencies are used (Spawn/Do with variadic deps), -// values are accessed via [Result.MustGet] inside the closure. +// When a typed dependency is used (Go1–Go3 / Run1–Run3 variants), the +// dependency's value is passed as a function parameter — the compiler +// enforces correct wiring. When untyped dependencies are used (Go/Run +// with variadic deps), values are accessed via [Result.MustGet] inside +// the closure. // // On the first task error, the Group cancels its context, causing all pending // and in-flight tasks to observe cancellation and exit. [Group.Wait] blocks @@ -87,16 +88,20 @@ func NewGroup(ctx context.Context) *Group { } // Wait blocks until every task in the group has finished. -// It returns a *[DAGError] if any task failed, or nil on success. +// It returns a *[DAGError] if any task failed, the parent context's error +// if it was cancelled before tasks could run, or nil on success. func (g *Group) Wait() error { g.wg.Wait() + // Capture ctx error before cancel() — after cancel(), ctx.Err() is + // always non-nil regardless of whether the parent was cancelled. + ctxErr := g.ctx.Err() g.cancel() g.mu.Lock() defer g.mu.Unlock() if len(g.errs) > 0 { return &DAGError{Errors: g.errs} } - return nil + return ctxErr } func (g *Group) recordError(err error) { diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go index 1a7cc85ad99..f6191dc5f1d 100644 --- a/e2e/dag/dag_test.go +++ b/e2e/dag/dag_test.go @@ -8,7 +8,7 @@ import ( ) // --------------------------------------------------------------------------- -// Spawn +// Go — value-producing tasks // --------------------------------------------------------------------------- func TestGo(t *testing.T) { @@ -61,7 +61,7 @@ func TestGo_WithDeps(t *testing.T) { } // --------------------------------------------------------------------------- -// Do +// Run — side-effect tasks // --------------------------------------------------------------------------- func TestRun(t *testing.T) { @@ -99,7 +99,7 @@ func TestRun_WithDeps(t *testing.T) { } // --------------------------------------------------------------------------- -// Then chain +// Go1 / Go2 / Go3 chain // --------------------------------------------------------------------------- func TestGo1_Chain(t *testing.T) { @@ -125,7 +125,7 @@ func TestGo1_Chain(t *testing.T) { } // --------------------------------------------------------------------------- -// Then2 / Then3 +// Go2 / Go3 // --------------------------------------------------------------------------- func TestGo2(t *testing.T) { @@ -160,7 +160,7 @@ func TestGo3(t *testing.T) { } // --------------------------------------------------------------------------- -// ThenDo / ThenDo2 / ThenDo3 +// Run1 / Run2 / Run3 // --------------------------------------------------------------------------- func TestRun1(t *testing.T) { @@ -372,8 +372,12 @@ func TestParentContextCancelled(t *testing.T) { return nil }) - // Key invariant: Wait() returns without hanging. - g.Wait() + // Key invariant: Wait() returns without hanging and surfaces the + // parent context's cancellation error. + err := g.Wait() + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } } func TestEffect_AsDep(t *testing.T) { From 8ed5067175ba7ca4b0c44fa8f964d0707ca73791 Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 16:08:39 +1300 Subject: [PATCH 21/22] test(dag): expand coverage to 95.1% with 37 tests Add comprehensive test coverage including: - Go1-Go3/Run1-Run3 success and skip-on-dep-failure paths - Diamond topology, transitive skip, cancellation noise - Parent context cancellation and deadline exceeded - DAGError formatting (single and multiple errors) - Effect as dependency, empty group, Result.Get/MustGet - Cycle behavior documentation (typed API compile-time safety, untyped API deadlock, self-dependency deadlock) Clean up test comments to keep tests concise. --- e2e/dag/dag_test.go | 357 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 307 insertions(+), 50 deletions(-) diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go index f6191dc5f1d..001bbe6dbaf 100644 --- a/e2e/dag/dag_test.go +++ b/e2e/dag/dag_test.go @@ -3,14 +3,12 @@ package dag import ( "context" "errors" + "strings" "sync/atomic" "testing" + "time" ) -// --------------------------------------------------------------------------- -// Go — value-producing tasks -// --------------------------------------------------------------------------- - func TestGo(t *testing.T) { g := NewGroup(context.Background()) r := Go(g, func(ctx context.Context) (int, error) { @@ -60,10 +58,6 @@ func TestGo_WithDeps(t *testing.T) { _ = c.MustGet() } -// --------------------------------------------------------------------------- -// Run — side-effect tasks -// --------------------------------------------------------------------------- - func TestRun(t *testing.T) { g := NewGroup(context.Background()) var called atomic.Bool @@ -98,10 +92,6 @@ func TestRun_WithDeps(t *testing.T) { } } -// --------------------------------------------------------------------------- -// Go1 / Go2 / Go3 chain -// --------------------------------------------------------------------------- - func TestGo1_Chain(t *testing.T) { g := NewGroup(context.Background()) a := Go(g, func(ctx context.Context) (int, error) { @@ -124,10 +114,6 @@ func TestGo1_Chain(t *testing.T) { } } -// --------------------------------------------------------------------------- -// Go2 / Go3 -// --------------------------------------------------------------------------- - func TestGo2(t *testing.T) { g := NewGroup(context.Background()) a := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) @@ -159,10 +145,6 @@ func TestGo3(t *testing.T) { } } -// --------------------------------------------------------------------------- -// Run1 / Run2 / Run3 -// --------------------------------------------------------------------------- - func TestRun1(t *testing.T) { g := NewGroup(context.Background()) a := Go(g, func(ctx context.Context) (int, error) { return 5, nil }) @@ -214,10 +196,6 @@ func TestRun3(t *testing.T) { } } -// --------------------------------------------------------------------------- -// Error propagation — cancel-all behavior -// --------------------------------------------------------------------------- - func TestCancelAll_CancelsRunningTasks(t *testing.T) { g := NewGroup(context.Background()) @@ -235,8 +213,6 @@ func TestCancelAll_CancelsRunningTasks(t *testing.T) { return 0, errors.New("fail") }) - // g.Wait() guarantees all goroutines have returned (via WaitGroup), - // so cancelled is guaranteed to be true here — no sleep needed. g.Wait() if !cancelled.Load() { t.Fatal("expected context to be cancelled for running task") @@ -245,7 +221,6 @@ func TestCancelAll_CancelsRunningTasks(t *testing.T) { func TestSkipsDownstream(t *testing.T) { g := NewGroup(context.Background()) - a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) @@ -264,7 +239,6 @@ func TestSkipsDownstream(t *testing.T) { func TestTransitiveSkip(t *testing.T) { g := NewGroup(context.Background()) - a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) @@ -284,10 +258,6 @@ func TestTransitiveSkip(t *testing.T) { } } -// --------------------------------------------------------------------------- -// DAG topologies -// --------------------------------------------------------------------------- - func TestDiamond(t *testing.T) { // a // / \ @@ -295,7 +265,6 @@ func TestDiamond(t *testing.T) { // \ / // d g := NewGroup(context.Background()) - a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) b := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 10, nil }) c := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 100, nil }) @@ -309,10 +278,6 @@ func TestDiamond(t *testing.T) { } } -// --------------------------------------------------------------------------- -// Result.Get / Result.MustGet safety -// --------------------------------------------------------------------------- - func TestGet_SafeOnError(t *testing.T) { g := NewGroup(context.Background()) r := Go(g, func(ctx context.Context) (int, error) { @@ -329,6 +294,23 @@ func TestGet_SafeOnError(t *testing.T) { } } +func TestGet_SuccessPath(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (string, error) { + return "hello", nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + v, ok := r.Get() + if !ok { + t.Fatal("Get() should return true on success") + } + if v != "hello" { + t.Fatalf("got %q, want %q", v, "hello") + } +} + func TestMustGet_PanicsOnError(t *testing.T) { g := NewGroup(context.Background()) r := Go(g, func(ctx context.Context) (int, error) { @@ -344,10 +326,6 @@ func TestMustGet_PanicsOnError(t *testing.T) { r.MustGet() } -// --------------------------------------------------------------------------- -// Edge cases -// --------------------------------------------------------------------------- - func TestMultipleErrors(t *testing.T) { g := NewGroup(context.Background()) Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("err1") }) @@ -368,12 +346,8 @@ func TestParentContextCancelled(t *testing.T) { cancel() g := NewGroup(ctx) - Run(g, func(ctx context.Context) error { - return nil - }) + Run(g, func(ctx context.Context) error { return nil }) - // Key invariant: Wait() returns without hanging and surfaces the - // parent context's cancellation error. err := g.Wait() if !errors.Is(err, context.Canceled) { t.Fatalf("expected context.Canceled, got %v", err) @@ -383,11 +357,8 @@ func TestParentContextCancelled(t *testing.T) { func TestEffect_AsDep(t *testing.T) { g := NewGroup(context.Background()) - // order is shared between goroutines, but the dependency edge (e) provides - // a happens-before guarantee: close(e.done) in the first goroutine - // happens-before the second goroutine's read via e.wait(). + // The dependency edge provides a happens-before guarantee. var order []int - e := Run(g, func(ctx context.Context) error { order = append(order, 1) return nil @@ -403,3 +374,289 @@ func TestEffect_AsDep(t *testing.T) { t.Fatal(err) } } + +func TestEmptyGroup(t *testing.T) { + g := NewGroup(context.Background()) + if err := g.Wait(); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} + +func TestWait_NilOnSuccess(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + Run(g, func(ctx context.Context) error { return nil }) + + if err := g.Wait(); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} + +// TestWait_DAGErrorPriorityOverCtxErr verifies that when a task fails +// (causing internal cancel), Wait() returns *DAGError not context.Canceled. +func TestWait_DAGErrorPriorityOverCtxErr(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("task failed") + }) + Run(g, func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error") + } + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T: %v", err, err) + } + found := false + for _, e := range dagErr.Errors { + if e.Error() == "task failed" { + found = true + } + } + if !found { + t.Fatalf("expected 'task failed' in DAGError.Errors, got: %v", dagErr.Errors) + } +} + +func TestWait_ParentContextDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + time.Sleep(5 * time.Millisecond) + + g := NewGroup(ctx) + Run(g, func(ctx context.Context) error { return nil }) + + err := g.Wait() + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +// TestCancellationNoise verifies that when task A fails and task B returns +// ctx.Err() as noise, the real error is still present in DAGError. +func TestCancellationNoise(t *testing.T) { + g := NewGroup(context.Background()) + started := make(chan struct{}) + + Go(g, func(ctx context.Context) (int, error) { + <-started + return 0, errors.New("real error") + }) + Run(g, func(ctx context.Context) error { + close(started) + <-ctx.Done() + return ctx.Err() + }) + + err := g.Wait() + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T: %v", err, err) + } + hasRealErr := false + for _, e := range dagErr.Errors { + if e.Error() == "real error" { + hasRealErr = true + } + } + if !hasRealErr { + t.Fatalf("real error not found in DAGError.Errors: %v", dagErr.Errors) + } + if !strings.Contains(dagErr.Error(), "real error") { + t.Fatalf("Error() should mention 'real error': %s", dagErr.Error()) + } +} + +func TestDAGError_Error(t *testing.T) { + dagErr := &DAGError{Errors: []error{ + errors.New("beta"), + errors.New("alpha"), + }} + got := dagErr.Error() + want := "dag execution failed: alpha; beta" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestDAGError_ErrorSingle(t *testing.T) { + dagErr := &DAGError{Errors: []error{errors.New("only")}} + got := dagErr.Error() + want := "dag execution failed: only" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestGo1_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + + var ran atomic.Bool + Go1(g, a, func(ctx context.Context, v int) (string, error) { + ran.Store(true) + return "", nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go1 callback should have been skipped when dep failed") + } +} + +func TestGo2_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + + var ran atomic.Bool + Go2(g, a, b, func(ctx context.Context, x, y int) (int, error) { + ran.Store(true) + return x + y, nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go2 callback should have been skipped when dep failed") + } +} + +func TestGo3_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("b failed") }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + + var ran atomic.Bool + Go3(g, a, b, c, func(ctx context.Context, x, y, z int) (int, error) { + ran.Store(true) + return x + y + z, nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go3 callback should have been skipped when dep failed") + } +} + +func TestRun1_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + + var ran atomic.Bool + Run1(g, a, func(ctx context.Context, v int) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run1 callback should have been skipped when dep failed") + } +} + +func TestRun2_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + b := Go(g, func(ctx context.Context) (string, error) { return "x", nil }) + + var ran atomic.Bool + Run2(g, a, b, func(ctx context.Context, n int, s string) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run2 callback should have been skipped when dep failed") + } +} + +func TestRun3_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("c failed") }) + + var ran atomic.Bool + Run3(g, a, b, c, func(ctx context.Context, x, y, z int) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run3 callback should have been skipped when dep failed") + } +} + +// TestCycle_TypedAPI_Impossible documents that the typed Go1/Go2/Go3 API +// makes cyclic dependencies impossible at compile time: you cannot pass a +// *Result before it is declared. +func TestCycle_TypedAPI_Impossible(t *testing.T) { + // a := Go1(g, b, ...) // b not declared yet — won't compile + // b := Go1(g, a, ...) +} + +// TestCycle_UntypedAPI_Deadlocks verifies that a never-completing dep +// (simulating a cycle) causes Wait() to deadlock. Context cancellation +// does not unblock d.wait() since it's a plain channel read. +// Prefer Go1-Go3 / Run1-Run3 for compile-time cycle safety. +func TestCycle_UntypedAPI_Deadlocks(t *testing.T) { + g := NewGroup(context.Background()) + placeholder := newEffect() + + cycleTask := Go(g, func(ctx context.Context) (int, error) { + return 42, nil + }, placeholder) + + waitDone := make(chan error, 1) + go func() { + waitDone <- g.Wait() + }() + + select { + case <-waitDone: + t.Fatal("Wait() should not have returned — expected deadlock") + case <-time.After(100 * time.Millisecond): + placeholder.finish(errSkipped) + <-waitDone + if _, ok := cycleTask.Get(); ok { + t.Fatal("cyclic task should not have succeeded") + } + } +} + +// TestCycle_SelfDependency verifies that a task depending on a never-completed +// dep (simulating a self-reference) deadlocks Wait(). +func TestCycle_SelfDependency(t *testing.T) { + g := NewGroup(context.Background()) + blocker := newResult[int]() + + var ran atomic.Bool + Run(g, func(ctx context.Context) error { + ran.Store(true) + return nil + }, blocker) + + waitDone := make(chan error, 1) + go func() { + waitDone <- g.Wait() + }() + + select { + case <-waitDone: + t.Fatal("Wait() should not have returned — expected deadlock") + case <-time.After(100 * time.Millisecond): + blocker.finish(0, errSkipped) + <-waitDone + if ran.Load() { + t.Fatal("task should not have run — dep failed") + } + } +} From 8dbaa70dcd6959cf492f0ea015bf4a7509df36ef Mon Sep 17 00:00:00 2001 From: Artur Khantimirov Date: Sun, 22 Mar 2026 17:59:09 +1300 Subject: [PATCH 22/22] dag: recover panics in task goroutines and improve docs - Add panic recovery in launch() so panics in task functions are captured as errors instead of crashing the process. - Unify onSkip/onPanic into a single onFail(error) callback. - Update MustGet docs to reference actual API names. - Document cycle-deadlock limitation in package comment. - Add tests for panic in Go, Run, and downstream skip after panic. --- e2e/dag/dag.go | 69 +++++++++++++++++++++++++++------------------ e2e/dag/dag_test.go | 60 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 28 deletions(-) diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go index 253f99b76fa..e302751cef1 100644 --- a/e2e/dag/dag.go +++ b/e2e/dag/dag.go @@ -25,7 +25,14 @@ // On the first task error, the Group cancels its context, causing all pending // and in-flight tasks to observe cancellation and exit. [Group.Wait] blocks // until every goroutine returns and reports a [DAGError] containing all -// collected errors. +// collected errors. Panics inside task functions are recovered and surfaced +// as errors; they never crash the process. +// +// Note: the package does not perform runtime cycle detection. Dependency +// cycles created via the untyped [Go]/[Run] API (where a task's closure +// captures a [Result] that transitively depends on itself) will cause +// [Group.Wait] to deadlock. The typed variants (Go1–Go3, Run1–Run3) make +// cycles harder to construct but do not prevent them entirely. // // Example: // @@ -117,11 +124,24 @@ func (g *Group) recordError(err error) { var errSkipped = errors.New("skipped: dependency failed") // launch runs fn in a new goroutine after all deps complete. -// If any dep failed or ctx is cancelled, onSkip is called instead of fn. -func (g *Group) launch(deps []Dep, fn func(), onSkip func()) { +// If any dep failed, ctx is cancelled, or fn panics, onFail is called +// with the relevant error instead of (or after) fn. +func (g *Group) launch(deps []Dep, fn func(), onFail func(error)) { g.wg.Add(1) go func() { defer g.wg.Done() + defer func() { + if r := recover(); r != nil { + var err error + if e, ok := r.(error); ok { + err = fmt.Errorf("dag: task panicked: %w", e) + } else { + err = fmt.Errorf("dag: task panicked: %v", r) + } + g.recordError(err) + onFail(err) + } + }() for _, d := range deps { d.wait() @@ -129,13 +149,13 @@ func (g *Group) launch(deps []Dep, fn func(), onSkip func()) { for _, d := range deps { if d.failed() { - onSkip() + onFail(errSkipped) return } } if g.ctx.Err() != nil { - onSkip() + onFail(errSkipped) return } @@ -174,9 +194,10 @@ func (r *Result[T]) Get() (T, bool) { } // MustGet returns the value, panicking if the task failed. Safe to call: -// - Inside Then/ThenDo callbacks (the scheduler guarantees deps succeeded) -// - Inside Spawn/Do callbacks when the Result is listed as a dep -// - After [Group.Wait] returned nil +// - Inside Go1/Go2/Go3 or Run1/Run2/Run3 callbacks (the scheduler +// guarantees deps succeeded before invoking the function). +// - Inside Go/Run callbacks when the Result is listed as a dep. +// - After [Group.Wait] returned nil. func (r *Result[T]) MustGet() T { <-r.done if r.err != nil { @@ -231,9 +252,9 @@ func Go[T any](g *Group, fn func(ctx context.Context) (T, error), deps ...Dep) * g.recordError(err) } r.finish(val, err) - }, func() { + }, func(err error) { var zero T - r.finish(zero, errSkipped) + r.finish(zero, err) }) return r } @@ -248,9 +269,9 @@ func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D g.recordError(err) } r.finish(val, err) - }, func() { + }, func(err error) { var zero T - r.finish(zero, errSkipped) + r.finish(zero, err) }) return r } @@ -265,9 +286,9 @@ func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ct g.recordError(err) } r.finish(val, err) - }, func() { + }, func(err error) { var zero T - r.finish(zero, errSkipped) + r.finish(zero, err) }) return r } @@ -282,9 +303,9 @@ func Go3[T, D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 * g.recordError(err) } r.finish(val, err) - }, func() { + }, func(err error) { var zero T - r.finish(zero, errSkipped) + r.finish(zero, err) }) return r } @@ -305,9 +326,7 @@ func Run(g *Group, fn func(ctx context.Context) error, deps ...Dep) *Effect { g.recordError(err) } e.finish(err) - }, func() { - e.finish(errSkipped) - }) + }, e.finish) return e } @@ -321,9 +340,7 @@ func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) g.recordError(err) } e.finish(err) - }, func() { - e.finish(errSkipped) - }) + }, e.finish) return e } @@ -337,9 +354,7 @@ func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx g.recordError(err) } e.finish(err) - }, func() { - e.finish(errSkipped) - }) + }, e.finish) return e } @@ -353,9 +368,7 @@ func Run3[D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Re g.recordError(err) } e.finish(err) - }, func() { - e.finish(errSkipped) - }) + }, e.finish) return e } diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go index 001bbe6dbaf..23530103693 100644 --- a/e2e/dag/dag_test.go +++ b/e2e/dag/dag_test.go @@ -660,3 +660,63 @@ func TestCycle_SelfDependency(t *testing.T) { } } } + +// TestPanic_GoTask verifies that a panic in a Go task is recovered and +// surfaced as an error via Wait(), not crashing the process. +func TestPanic_GoTask(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + panic("unexpected nil pointer") + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error from panicking task") + } + if !strings.Contains(err.Error(), "task panicked") { + t.Fatalf("expected panic error, got: %v", err) + } + if !strings.Contains(err.Error(), "unexpected nil pointer") { + t.Fatalf("expected panic message in error, got: %v", err) + } + // Result should be marked as failed. + if _, ok := r.Get(); ok { + t.Fatal("Get() should return false on panicked task") + } +} + +// TestPanic_RunTask verifies that a panic in a Run task is recovered. +func TestPanic_RunTask(t *testing.T) { + g := NewGroup(context.Background()) + Run(g, func(ctx context.Context) error { + panic(errors.New("wrapped panic")) + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error from panicking task") + } + if !strings.Contains(err.Error(), "task panicked") { + t.Fatalf("expected panic error, got: %v", err) + } +} + +// TestPanic_SkipsDownstream verifies that a panic in an upstream task +// causes downstream tasks to be skipped (not run). +func TestPanic_SkipsDownstream(t *testing.T) { + g := NewGroup(context.Background()) + upstream := Go(g, func(ctx context.Context) (int, error) { + panic("boom") + }) + + var ran atomic.Bool + Go1(g, upstream, func(ctx context.Context, v int) (string, error) { + ran.Store(true) + return "", nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("downstream task should have been skipped after upstream panic") + } +}