diff --git a/e2e/aks_model.go b/e2e/aks_model.go index 7498d92c0d1..648cd732954 100644 --- a/e2e/aks_model.go +++ b/e2e/aks_model.go @@ -299,8 +299,8 @@ func getFirewall(ctx context.Context, location, firewallSubnetID, publicIPID str func addFirewallRules( ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, - location string, ) error { + location := *clusterModel.Location defer toolkit.LogStepCtx(ctx, "adding firewall rules")() routeTableName := "abe2e-fw-rt" rtGetResp, err := config.Azure.RouteTables.Get( @@ -486,10 +486,11 @@ func addFirewallRules( return nil } -func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient, resourceGroupName string, kubeletIdentity *armcontainerservice.UserAssignedIdentity, isNonAnonymousPull bool) error { +func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient, kubeletIdentity *armcontainerservice.UserAssignedIdentity, isNonAnonymousPull bool) error { if cluster == nil || kube == nil || kubeletIdentity == nil { return errors.New("cluster, kubeclient, and kubeletIdentity cannot be nil when adding Private Azure Container Registry") } + resourceGroupName := config.ResourceGroupName(*cluster.Location) if err := createPrivateAzureContainerRegistry(ctx, cluster, resourceGroupName, isNonAnonymousPull); err != nil { return fmt.Errorf("failed to create private acr: %w", err) } @@ -514,7 +515,8 @@ func addPrivateAzureContainerRegistry(ctx context.Context, cluster *armcontainer return nil } -func addNetworkIsolatedSettings(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, location string) error { +func addNetworkIsolatedSettings(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster) error { + location := *clusterModel.Location defer toolkit.LogStepCtx(ctx, fmt.Sprintf("Adding network settings for network isolated cluster %s in rg %s", *clusterModel.Name, *clusterModel.Properties.NodeResourceGroup)) vnet, err := getClusterVNet(ctx, *clusterModel.Properties.NodeResourceGroup) diff --git a/e2e/cluster.go b/e2e/cluster.go index 589371e2d2b..d002befc17c 100644 --- a/e2e/cluster.go +++ b/e2e/cluster.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Azure/agentbaker/e2e/config" + "github.com/Azure/agentbaker/e2e/dag" "github.com/Azure/agentbaker/e2e/toolkit" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" @@ -59,89 +60,82 @@ func (c *Cluster) MaxPodsPerNode() (int, error) { return 0, fmt.Errorf("cluster agentpool profiles were nil or empty: %+v", c.Model) } -func prepareCluster(ctx context.Context, cluster *armcontainerservice.ManagedCluster, isNetworkIsolated, attachPrivateAcr bool) (*Cluster, error) { +// prepareCluster runs all cluster preparation steps as a concurrent DAG. +// This function contains complex concurrent orchestration — keep it as +// minimal as possible and push all non-trivial logic into the individual +// task functions it calls. +func prepareCluster(ctx context.Context, clusterModel *armcontainerservice.ManagedCluster, isNetworkIsolated, attachPrivateAcr bool) (*Cluster, error) { defer toolkit.LogStepCtx(ctx, "preparing cluster")() ctx, cancel := context.WithTimeout(ctx, config.Config.TestTimeoutCluster) defer cancel() - cluster.Name = to.Ptr(fmt.Sprintf("%s-%s", *cluster.Name, hash(cluster))) - cluster, err := getOrCreateCluster(ctx, cluster) - if err != nil { - return nil, fmt.Errorf("get or create cluster: %w", err) - } - - bastion, err := getOrCreateBastion(ctx, cluster) - if err != nil { - return nil, fmt.Errorf("get or create bastion: %w", err) - } - - _, err = getOrCreateMaintenanceConfiguration(ctx, cluster) - if err != nil { - return nil, fmt.Errorf("get or create maintenance configuration: %w", err) - } - - subnetID, err := getClusterSubnetID(ctx, *cluster.Properties.NodeResourceGroup) - if err != nil { - return nil, fmt.Errorf("get cluster subnet: %w", err) - } - resourceGroupName := config.ResourceGroupName(*cluster.Location) + clusterModel.Name = to.Ptr(fmt.Sprintf("%s-%s", *clusterModel.Name, hash(clusterModel))) - kube, err := getClusterKubeClient(ctx, resourceGroupName, *cluster.Name) + cluster, err := getOrCreateCluster(ctx, clusterModel) if err != nil { - return nil, fmt.Errorf("get kube client using cluster %q: %w", *cluster.Name, err) + return nil, fmt.Errorf("get or create cluster: %w", err) } - kubeletIdentity, err := getClusterKubeletIdentity(cluster) - if err != nil { - return nil, fmt.Errorf("getting cluster kubelet identity: %w", err) - } + g := dag.NewGroup(ctx) - if isNetworkIsolated || attachPrivateAcr { - // private acr must be created before we add the debug daemonsets - if err := addPrivateAzureContainerRegistry(ctx, cluster, kube, resourceGroupName, kubeletIdentity, true); err != nil { - return nil, fmt.Errorf("add private azure container registry (true): %w", err) - } - if err := addPrivateAzureContainerRegistry(ctx, cluster, kube, resourceGroupName, kubeletIdentity, false); err != nil { - return nil, fmt.Errorf("add private azure container registry (false): %w", err) - } + bastion := dag.Go(g, func(ctx context.Context) (*Bastion, error) { + return getOrCreateBastion(ctx, cluster) + }) + dag.Run(g, func(ctx context.Context) error { return ensureMaintenanceConfiguration(ctx, cluster) }) + subnet := dag.Go(g, func(ctx context.Context) (string, error) { return getClusterSubnetID(ctx, cluster) }) + kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { return getClusterKubeClient(ctx, cluster) }) + identity := dag.Go(g, func(ctx context.Context) (*armcontainerservice.UserAssignedIdentity, error) { + return getClusterKubeletIdentity(ctx, cluster) + }) + dag.Run(g, func(ctx context.Context) error { return collectGarbageVMSS(ctx, cluster) }) + var networkDeps []dag.Dep + if !isNetworkIsolated { + networkDeps = append(networkDeps, dag.Run(g, func(ctx context.Context) error { return addFirewallRules(ctx, cluster) })) } if isNetworkIsolated { - if err := addNetworkIsolatedSettings(ctx, cluster, *cluster.Location); err != nil { - return nil, fmt.Errorf("add network isolated settings: %w", err) - } - } - if !isNetworkIsolated { // network isolated cluster blocks all egress via NSG - if err := addFirewallRules(ctx, cluster, *cluster.Location); err != nil { - return nil, fmt.Errorf("add firewall rules: %w", err) - } + networkDeps = append(networkDeps, dag.Run(g, func(ctx context.Context) error { return addNetworkIsolatedSettings(ctx, cluster) })) } + needACR := isNetworkIsolated || attachPrivateAcr + acrNonAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, true)) + acrAnon := dag.Run2(g, kube, identity, addACR(cluster, needACR, false)) + dag.Run1(g, kube, ensureDebugDaemonsets(cluster, isNetworkIsolated), append([]dag.Dep{acrNonAnon, acrAnon}, networkDeps...)...) + extract := dag.Go1(g, kube, extractClusterParams(cluster)) - if err := kube.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)); err != nil { - return nil, fmt.Errorf("ensure debug daemonsets for %q: %w", *cluster.Name, err) + if err := g.Wait(); err != nil { + return nil, fmt.Errorf("prepare cluster tasks: %w", err) } + return &Cluster{ + Model: cluster, + Kube: kube.MustGet(), + KubeletIdentity: identity.MustGet(), + SubnetID: subnet.MustGet(), + ClusterParams: extract.MustGet(), + Bastion: bastion.MustGet(), + }, nil +} - // sometimes tests can be interrupted and vmss are left behind - // don't waste resource and delete them - if err := collectGarbageVMSS(ctx, cluster); err != nil { - return nil, fmt.Errorf("collect garbage vmss: %w", err) +func addACR(cluster *armcontainerservice.ManagedCluster, needACR, isNonAnonymousPull bool) func(context.Context, *Kubeclient, *armcontainerservice.UserAssignedIdentity) error { + return func(ctx context.Context, k *Kubeclient, id *armcontainerservice.UserAssignedIdentity) error { + if !needACR { + return nil + } + return addPrivateAzureContainerRegistry(ctx, cluster, k, id, isNonAnonymousPull) } +} - clusterParams, err := extractClusterParameters(ctx, kube, cluster) - if err != nil { - return nil, fmt.Errorf("extracting cluster parameters: %w", err) +func ensureDebugDaemonsets(cluster *armcontainerservice.ManagedCluster, isNetworkIsolated bool) func(context.Context, *Kubeclient) error { + return func(ctx context.Context, k *Kubeclient) error { + return k.EnsureDebugDaemonsets(ctx, isNetworkIsolated, config.GetPrivateACRName(true, *cluster.Location)) } +} - return &Cluster{ - Model: cluster, - Kube: kube, - KubeletIdentity: kubeletIdentity, - SubnetID: subnetID, - ClusterParams: clusterParams, - Bastion: bastion, - }, nil +func extractClusterParams(cluster *armcontainerservice.ManagedCluster) func(context.Context, *Kubeclient) (*ClusterParams, error) { + return func(ctx context.Context, k *Kubeclient) (*ClusterParams, error) { + return extractClusterParameters(ctx, cluster, k) + } } -func getClusterKubeletIdentity(cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { +func getClusterKubeletIdentity(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.UserAssignedIdentity, error) { if cluster == nil || cluster.Properties == nil || cluster.Properties.IdentityProfile == nil { return nil, fmt.Errorf("cannot dereference cluster identity profile to extract kubelet identity ID") } @@ -152,7 +146,7 @@ func getClusterKubeletIdentity(cluster *armcontainerservice.ManagedCluster) (*ar return kubeletIdentity, nil } -func extractClusterParameters(ctx context.Context, kube *Kubeclient, cluster *armcontainerservice.ManagedCluster) (*ClusterParams, error) { +func extractClusterParameters(ctx context.Context, cluster *armcontainerservice.ManagedCluster, kube *Kubeclient) (*ClusterParams, error) { kubeconfig, err := clientcmd.Load(kube.KubeConfig) if err != nil { return nil, fmt.Errorf("loading cluster kubeconfig: %w", err) @@ -423,16 +417,20 @@ func createNewAKSClusterWithRetry(ctx context.Context, cluster *armcontainerserv return nil, fmt.Errorf("failed to create cluster after %d attempts due to persistent 409 Conflict: %w", maxRetries, lastErr) } -func getOrCreateMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.MaintenanceConfiguration, error) { - existingMaintenance, err := config.Azure.Maintenance.Get(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name, "default", nil) +func ensureMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) error { + _, err := config.Azure.Maintenance.Get(ctx, config.ResourceGroupName(*cluster.Location), *cluster.Name, "default", nil) var azErr *azcore.ResponseError if errors.As(err, &azErr) && azErr.StatusCode == 404 { - return createNewMaintenanceConfiguration(ctx, cluster) + _, err = createNewMaintenanceConfiguration(ctx, cluster) + if err != nil { + return fmt.Errorf("creating maintenance configuration for cluster %q: %w", *cluster.Name, err) + } + return nil } if err != nil { - return nil, fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) + return fmt.Errorf("failed to get maintenance configuration 'default' for cluster %q: %w", *cluster.Name, err) } - return &existingMaintenance.MaintenanceConfiguration, nil + return nil } func createNewMaintenanceConfiguration(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*armcontainerservice.MaintenanceConfiguration, error) { diff --git a/e2e/dag/dag.go b/e2e/dag/dag.go new file mode 100644 index 00000000000..e302751cef1 --- /dev/null +++ b/e2e/dag/dag.go @@ -0,0 +1,391 @@ +// Package dag provides a lightweight, type-safe DAG executor for running +// concurrent tasks with dependency tracking. +// +// Tasks are registered against a [Group] and form a directed acyclic graph +// through their declared dependencies. The Group launches each task in its +// own goroutine as soon as all dependencies complete successfully. +// +// There are two kinds of tasks: +// +// - Value-producing tasks return (T, error) and are represented by [Result][T]. +// Register with [Go] (no typed deps) or [Go1] / [Go2] / [Go3] (typed deps). +// +// - Side-effect tasks return only error and are represented by [Effect]. +// Register with [Run] (no typed deps) or [Run1] / [Run2] / [Run3] (typed deps). +// +// Both [Result] and [Effect] implement [Dep], so they can be listed as +// dependencies of downstream tasks. +// +// When a typed dependency is used (Go1–Go3 / Run1–Run3 variants), the +// dependency's value is passed as a function parameter — the compiler +// enforces correct wiring. When untyped dependencies are used (Go/Run +// with variadic deps), values are accessed via [Result.MustGet] inside +// the closure. +// +// On the first task error, the Group cancels its context, causing all pending +// and in-flight tasks to observe cancellation and exit. [Group.Wait] blocks +// until every goroutine returns and reports a [DAGError] containing all +// collected errors. Panics inside task functions are recovered and surfaced +// as errors; they never crash the process. +// +// Note: the package does not perform runtime cycle detection. Dependency +// cycles created via the untyped [Go]/[Run] API (where a task's closure +// captures a [Result] that transitively depends on itself) will cause +// [Group.Wait] to deadlock. The typed variants (Go1–Go3, Run1–Run3) make +// cycles harder to construct but do not prevent them entirely. +// +// Example: +// +// g := dag.NewGroup(ctx) +// +// kube := dag.Go(g, func(ctx context.Context) (*Kubeclient, error) { +// return getKubeClient(ctx) +// }) +// params := dag.Go1(g, kube, func(ctx context.Context, k *Kubeclient) (*Params, error) { +// return extractParams(ctx, k) +// }) +// dag.Run(g, func(ctx context.Context) error { +// return ensureMaintenance(ctx) +// }) +// +// if err := g.Wait(); err != nil { ... } +// fmt.Println(kube.MustGet(), params.MustGet()) +package dag + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + "sync" +) + +// --------------------------------------------------------------------------- +// Dep — the dependency interface +// --------------------------------------------------------------------------- + +// Dep is implemented by [Result] and [Effect]. It represents a dependency +// that must complete before a downstream task starts. +type Dep interface { + wait() + failed() bool +} + +// --------------------------------------------------------------------------- +// Group — the DAG executor +// --------------------------------------------------------------------------- + +// Group manages a set of concurrent tasks with dependency tracking. +// Create one with [NewGroup], register tasks, then call [Group.Wait]. +type Group struct { + ctx context.Context + cancel context.CancelFunc + + mu sync.Mutex + errs []error + wg sync.WaitGroup +} + +// NewGroup creates a Group whose tasks run under ctx. +// On the first task error the Group cancels ctx, signalling all other tasks. +func NewGroup(ctx context.Context) *Group { + ctx, cancel := context.WithCancel(ctx) + return &Group{ctx: ctx, cancel: cancel} +} + +// Wait blocks until every task in the group has finished. +// It returns a *[DAGError] if any task failed, the parent context's error +// if it was cancelled before tasks could run, or nil on success. +func (g *Group) Wait() error { + g.wg.Wait() + // Capture ctx error before cancel() — after cancel(), ctx.Err() is + // always non-nil regardless of whether the parent was cancelled. + ctxErr := g.ctx.Err() + g.cancel() + g.mu.Lock() + defer g.mu.Unlock() + if len(g.errs) > 0 { + return &DAGError{Errors: g.errs} + } + return ctxErr +} + +func (g *Group) recordError(err error) { + g.mu.Lock() + g.errs = append(g.errs, err) + g.mu.Unlock() + g.cancel() +} + +// errSkipped is a sentinel set on tasks that were skipped because a +// dependency failed. It propagates through the graph so that transitive +// dependents are also skipped without running. +var errSkipped = errors.New("skipped: dependency failed") + +// launch runs fn in a new goroutine after all deps complete. +// If any dep failed, ctx is cancelled, or fn panics, onFail is called +// with the relevant error instead of (or after) fn. +func (g *Group) launch(deps []Dep, fn func(), onFail func(error)) { + g.wg.Add(1) + go func() { + defer g.wg.Done() + defer func() { + if r := recover(); r != nil { + var err error + if e, ok := r.(error); ok { + err = fmt.Errorf("dag: task panicked: %w", e) + } else { + err = fmt.Errorf("dag: task panicked: %v", r) + } + g.recordError(err) + onFail(err) + } + }() + + for _, d := range deps { + d.wait() + } + + for _, d := range deps { + if d.failed() { + onFail(errSkipped) + return + } + } + + if g.ctx.Err() != nil { + onFail(errSkipped) + return + } + + fn() + }() +} + +// --------------------------------------------------------------------------- +// Result[T] — a typed task output +// --------------------------------------------------------------------------- + +// Result holds the outcome of a task that produces a value of type T. +// It implements [Dep] so it can be used as a dependency for downstream tasks. +type Result[T any] struct { + done chan struct{} + val T + err error +} + +func newResult[T any]() *Result[T] { + return &Result[T]{done: make(chan struct{})} +} + +func (r *Result[T]) wait() { <-r.done } +func (r *Result[T]) failed() bool { r.wait(); return r.err != nil } + +// Get returns the value and true if the task succeeded, or the zero value +// and false if it failed or was skipped. Blocks until the task completes. +func (r *Result[T]) Get() (T, bool) { + <-r.done + if r.err != nil { + var zero T + return zero, false + } + return r.val, true +} + +// MustGet returns the value, panicking if the task failed. Safe to call: +// - Inside Go1/Go2/Go3 or Run1/Run2/Run3 callbacks (the scheduler +// guarantees deps succeeded before invoking the function). +// - Inside Go/Run callbacks when the Result is listed as a dep. +// - After [Group.Wait] returned nil. +func (r *Result[T]) MustGet() T { + <-r.done + if r.err != nil { + panic("dag: MustGet() called on failed Result") + } + return r.val +} + +func (r *Result[T]) finish(val T, err error) { + r.val = val + r.err = err + close(r.done) +} + +// --------------------------------------------------------------------------- +// Effect — a side-effect-only task +// --------------------------------------------------------------------------- + +// Effect represents a completed side-effect task. It implements [Dep] so +// downstream tasks can depend on it, but it carries no value. +type Effect struct { + done chan struct{} + err error +} + +func newEffect() *Effect { + return &Effect{done: make(chan struct{})} +} + +func (e *Effect) wait() { <-e.done } +func (e *Effect) failed() bool { e.wait(); return e.err != nil } + +func (e *Effect) finish(err error) { + e.err = err + close(e.done) +} + +// --------------------------------------------------------------------------- +// Go / Go1 / Go2 / Go3 — value-producing tasks. +// +// Go = no typed deps (optional untyped deps via variadic Dep args) +// GoN = N typed deps, passed as function parameters +// --------------------------------------------------------------------------- + +// Go launches fn with optional untyped deps. Values from deps are accessed +// via [Result.MustGet] inside fn. +func Go[T any](g *Group, fn func(ctx context.Context) (T, error), deps ...Dep) *Result[T] { + r := newResult[T]() + g.launch(deps, func() { + val, err := fn(g.ctx) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func(err error) { + var zero T + r.finish(zero, err) + }) + return r +} + +// Go1 launches fn after dep completes, passing its value. +// Extra deps are waited on but their values are not passed to fn. +func Go1[T, D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) (T, error), extra ...Dep) *Result[T] { + r := newResult[T]() + g.launch(append([]Dep{dep}, extra...), func() { + val, err := fn(g.ctx, dep.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func(err error) { + var zero T + r.finish(zero, err) + }) + return r +} + +// Go2 launches fn after dep1 and dep2 complete, passing both values. +// Extra deps are waited on but their values are not passed to fn. +func Go2[T, D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) (T, error), extra ...Dep) *Result[T] { + r := newResult[T]() + g.launch(append([]Dep{dep1, dep2}, extra...), func() { + val, err := fn(g.ctx, dep1.val, dep2.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func(err error) { + var zero T + r.finish(zero, err) + }) + return r +} + +// Go3 launches fn after dep1, dep2, and dep3 complete, passing all values. +// Extra deps are waited on but their values are not passed to fn. +func Go3[T, D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) (T, error), extra ...Dep) *Result[T] { + r := newResult[T]() + g.launch(append([]Dep{dep1, dep2, dep3}, extra...), func() { + val, err := fn(g.ctx, dep1.val, dep2.val, dep3.val) + if err != nil { + g.recordError(err) + } + r.finish(val, err) + }, func(err error) { + var zero T + r.finish(zero, err) + }) + return r +} + +// --------------------------------------------------------------------------- +// Run / Run1 / Run2 / Run3 — side-effect tasks (no return value). +// +// Run = no typed deps (optional untyped deps via variadic Dep args) +// RunN = N typed deps, passed as function parameters +// --------------------------------------------------------------------------- + +// Run launches fn with optional untyped deps. +func Run(g *Group, fn func(ctx context.Context) error, deps ...Dep) *Effect { + e := newEffect() + g.launch(deps, func() { + err := fn(g.ctx) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, e.finish) + return e +} + +// Run1 launches fn after dep completes, passing its value. +// Extra deps are waited on but their values are not passed to fn. +func Run1[D1 any](g *Group, dep *Result[D1], fn func(ctx context.Context, d1 D1) error, extra ...Dep) *Effect { + e := newEffect() + g.launch(append([]Dep{dep}, extra...), func() { + err := fn(g.ctx, dep.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, e.finish) + return e +} + +// Run2 launches fn after dep1 and dep2 complete, passing both values. +// Extra deps are waited on but their values are not passed to fn. +func Run2[D1, D2 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], fn func(ctx context.Context, d1 D1, d2 D2) error, extra ...Dep) *Effect { + e := newEffect() + g.launch(append([]Dep{dep1, dep2}, extra...), func() { + err := fn(g.ctx, dep1.val, dep2.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, e.finish) + return e +} + +// Run3 launches fn after dep1, dep2, and dep3 complete, passing all values. +// Extra deps are waited on but their values are not passed to fn. +func Run3[D1, D2, D3 any](g *Group, dep1 *Result[D1], dep2 *Result[D2], dep3 *Result[D3], fn func(ctx context.Context, d1 D1, d2 D2, d3 D3) error, extra ...Dep) *Effect { + e := newEffect() + g.launch(append([]Dep{dep1, dep2, dep3}, extra...), func() { + err := fn(g.ctx, dep1.val, dep2.val, dep3.val) + if err != nil { + g.recordError(err) + } + e.finish(err) + }, e.finish) + return e +} + +// --------------------------------------------------------------------------- +// DAGError +// --------------------------------------------------------------------------- + +// DAGError is returned by [Group.Wait] when one or more tasks failed. +type DAGError struct { + Errors []error +} + +func (e *DAGError) Error() string { + var msgs []string + for _, err := range e.Errors { + msgs = append(msgs, err.Error()) + } + sort.Strings(msgs) + return fmt.Sprintf("dag execution failed: %s", strings.Join(msgs, "; ")) +} diff --git a/e2e/dag/dag_test.go b/e2e/dag/dag_test.go new file mode 100644 index 00000000000..23530103693 --- /dev/null +++ b/e2e/dag/dag_test.go @@ -0,0 +1,722 @@ +package dag + +import ( + "context" + "errors" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestGo(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 42, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := r.MustGet(); v != 42 { + t.Fatalf("got %d, want 42", v) + } +} + +func TestGo_Error(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + err := g.Wait() + if err == nil { + t.Fatal("expected error") + } + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T", err) + } + if len(dagErr.Errors) != 1 { + t.Fatalf("expected 1 error, got %d", len(dagErr.Errors)) + } + if _, ok := r.Get(); ok { + t.Fatal("Get() should return false on failed Result") + } +} + +func TestGo_WithDeps(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 10, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "hello", nil }) + + c := Go(g, func(ctx context.Context) (string, error) { + return b.MustGet() + ":" + string(rune('0'+a.MustGet())), nil + }, a, b) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + _ = c.MustGet() +} + +func TestRun(t *testing.T) { + g := NewGroup(context.Background()) + var called atomic.Bool + Run(g, func(ctx context.Context) error { + called.Store(true) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if !called.Load() { + t.Fatal("Do function was not called") + } +} + +func TestRun_WithDeps(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 10, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "hi", nil }) + + var got atomic.Value + Run(g, func(ctx context.Context) error { + got.Store(b.MustGet() + ":" + string(rune('0'+a.MustGet()))) + return nil + }, a, b) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() == nil { + t.Fatal("Do function was not called") + } +} + +func TestGo1_Chain(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { + return 10, nil + }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { + return v * 2, nil + }) + c := Go1(g, b, func(ctx context.Context, v int) (string, error) { + if v != 20 { + return "", errors.New("bad value") + } + return "ok", nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := c.MustGet(); v != "ok" { + t.Fatalf("got %q, want %q", v, "ok") + } +} + +func TestGo2(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 4, nil }) + c := Go2(g, a, b, func(ctx context.Context, x, y int) (int, error) { + return x + y, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := c.MustGet(); v != 7 { + t.Fatalf("got %d, want 7", v) + } +} + +func TestGo3(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + d := Go3(g, a, b, c, func(ctx context.Context, x, y, z int) (int, error) { + return x + y + z, nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if v := d.MustGet(); v != 6 { + t.Fatalf("got %d, want 6", v) + } +} + +func TestRun1(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 5, nil }) + var got atomic.Int32 + Run1(g, a, func(ctx context.Context, v int) error { + got.Store(int32(v)) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != 5 { + t.Fatalf("got %d, want 5", got.Load()) + } +} + +func TestRun2(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + b := Go(g, func(ctx context.Context) (string, error) { return "x", nil }) + var got atomic.Value + Run2(g, a, b, func(ctx context.Context, n int, s string) error { + got.Store(s + string(rune('0'+n))) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != "x3" { + t.Fatalf("got %v, want x3", got.Load()) + } +} + +func TestRun3(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + var got atomic.Int32 + Run3(g, a, b, c, func(ctx context.Context, x, y, z int) error { + got.Store(int32(x + y + z)) + return nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if got.Load() != 6 { + t.Fatalf("got %d, want 6", got.Load()) + } +} + +func TestCancelAll_CancelsRunningTasks(t *testing.T) { + g := NewGroup(context.Background()) + + started := make(chan struct{}) + var cancelled atomic.Bool + Run(g, func(ctx context.Context) error { + close(started) + <-ctx.Done() + cancelled.Store(true) + return ctx.Err() + }) + + Go(g, func(ctx context.Context) (int, error) { + <-started + return 0, errors.New("fail") + }) + + g.Wait() + if !cancelled.Load() { + t.Fatal("expected context to be cancelled for running task") + } +} + +func TestSkipsDownstream(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("a failed") + }) + + var bRan atomic.Bool + Run1(g, a, func(ctx context.Context, v int) error { + bRan.Store(true) + return nil + }) + + g.Wait() + if bRan.Load() { + t.Fatal("dependent task b should have been skipped") + } +} + +func TestTransitiveSkip(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("a failed") + }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { + return v + 1, nil + }) + + var cRan atomic.Bool + Run1(g, b, func(ctx context.Context, v int) error { + cRan.Store(true) + return nil + }) + + g.Wait() + if cRan.Load() { + t.Fatal("transitive dependent c should have been skipped") + } +} + +func TestDiamond(t *testing.T) { + // a + // / \ + // b c + // \ / + // d + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 10, nil }) + c := Go1(g, a, func(ctx context.Context, v int) (int, error) { return v + 100, nil }) + d := Go2(g, b, c, func(ctx context.Context, x, y int) (int, error) { return x + y, nil }) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + if d.MustGet() != 112 { + t.Fatalf("got %d, want 112", d.MustGet()) + } +} + +func TestGet_SafeOnError(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + g.Wait() + + v, ok := r.Get() + if ok { + t.Fatal("Get() should return false on error") + } + if v != 0 { + t.Fatalf("Get() should return zero value, got %d", v) + } +} + +func TestGet_SuccessPath(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (string, error) { + return "hello", nil + }) + if err := g.Wait(); err != nil { + t.Fatal(err) + } + v, ok := r.Get() + if !ok { + t.Fatal("Get() should return true on success") + } + if v != "hello" { + t.Fatalf("got %q, want %q", v, "hello") + } +} + +func TestMustGet_PanicsOnError(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("boom") + }) + g.Wait() + + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic from MustGet() on failed Result") + } + }() + r.MustGet() +} + +func TestMultipleErrors(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("err1") }) + Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("err2") }) + + err := g.Wait() + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T", err) + } + if len(dagErr.Errors) < 1 { + t.Fatal("expected at least 1 error") + } +} + +func TestParentContextCancelled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + g := NewGroup(ctx) + Run(g, func(ctx context.Context) error { return nil }) + + err := g.Wait() + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestEffect_AsDep(t *testing.T) { + g := NewGroup(context.Background()) + + // The dependency edge provides a happens-before guarantee. + var order []int + e := Run(g, func(ctx context.Context) error { + order = append(order, 1) + return nil + }) + Run(g, func(ctx context.Context) error { + if len(order) == 0 { + return errors.New("expected effect to run first") + } + return nil + }, e) + + if err := g.Wait(); err != nil { + t.Fatal(err) + } +} + +func TestEmptyGroup(t *testing.T) { + g := NewGroup(context.Background()) + if err := g.Wait(); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} + +func TestWait_NilOnSuccess(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + Run(g, func(ctx context.Context) error { return nil }) + + if err := g.Wait(); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} + +// TestWait_DAGErrorPriorityOverCtxErr verifies that when a task fails +// (causing internal cancel), Wait() returns *DAGError not context.Canceled. +func TestWait_DAGErrorPriorityOverCtxErr(t *testing.T) { + g := NewGroup(context.Background()) + Go(g, func(ctx context.Context) (int, error) { + return 0, errors.New("task failed") + }) + Run(g, func(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error") + } + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T: %v", err, err) + } + found := false + for _, e := range dagErr.Errors { + if e.Error() == "task failed" { + found = true + } + } + if !found { + t.Fatalf("expected 'task failed' in DAGError.Errors, got: %v", dagErr.Errors) + } +} + +func TestWait_ParentContextDeadlineExceeded(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + time.Sleep(5 * time.Millisecond) + + g := NewGroup(ctx) + Run(g, func(ctx context.Context) error { return nil }) + + err := g.Wait() + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +// TestCancellationNoise verifies that when task A fails and task B returns +// ctx.Err() as noise, the real error is still present in DAGError. +func TestCancellationNoise(t *testing.T) { + g := NewGroup(context.Background()) + started := make(chan struct{}) + + Go(g, func(ctx context.Context) (int, error) { + <-started + return 0, errors.New("real error") + }) + Run(g, func(ctx context.Context) error { + close(started) + <-ctx.Done() + return ctx.Err() + }) + + err := g.Wait() + var dagErr *DAGError + if !errors.As(err, &dagErr) { + t.Fatalf("expected *DAGError, got %T: %v", err, err) + } + hasRealErr := false + for _, e := range dagErr.Errors { + if e.Error() == "real error" { + hasRealErr = true + } + } + if !hasRealErr { + t.Fatalf("real error not found in DAGError.Errors: %v", dagErr.Errors) + } + if !strings.Contains(dagErr.Error(), "real error") { + t.Fatalf("Error() should mention 'real error': %s", dagErr.Error()) + } +} + +func TestDAGError_Error(t *testing.T) { + dagErr := &DAGError{Errors: []error{ + errors.New("beta"), + errors.New("alpha"), + }} + got := dagErr.Error() + want := "dag execution failed: alpha; beta" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestDAGError_ErrorSingle(t *testing.T) { + dagErr := &DAGError{Errors: []error{errors.New("only")}} + got := dagErr.Error() + want := "dag execution failed: only" + if got != want { + t.Fatalf("got %q, want %q", got, want) + } +} + +func TestGo1_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + + var ran atomic.Bool + Go1(g, a, func(ctx context.Context, v int) (string, error) { + ran.Store(true) + return "", nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go1 callback should have been skipped when dep failed") + } +} + +func TestGo2_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + + var ran atomic.Bool + Go2(g, a, b, func(ctx context.Context, x, y int) (int, error) { + ran.Store(true) + return x + y, nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go2 callback should have been skipped when dep failed") + } +} + +func TestGo3_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("b failed") }) + c := Go(g, func(ctx context.Context) (int, error) { return 3, nil }) + + var ran atomic.Bool + Go3(g, a, b, c, func(ctx context.Context, x, y, z int) (int, error) { + ran.Store(true) + return x + y + z, nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Go3 callback should have been skipped when dep failed") + } +} + +func TestRun1_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + + var ran atomic.Bool + Run1(g, a, func(ctx context.Context, v int) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run1 callback should have been skipped when dep failed") + } +} + +func TestRun2_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("a failed") }) + b := Go(g, func(ctx context.Context) (string, error) { return "x", nil }) + + var ran atomic.Bool + Run2(g, a, b, func(ctx context.Context, n int, s string) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run2 callback should have been skipped when dep failed") + } +} + +func TestRun3_SkipOnDepFailure(t *testing.T) { + g := NewGroup(context.Background()) + a := Go(g, func(ctx context.Context) (int, error) { return 1, nil }) + b := Go(g, func(ctx context.Context) (int, error) { return 2, nil }) + c := Go(g, func(ctx context.Context) (int, error) { return 0, errors.New("c failed") }) + + var ran atomic.Bool + Run3(g, a, b, c, func(ctx context.Context, x, y, z int) error { + ran.Store(true) + return nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("Run3 callback should have been skipped when dep failed") + } +} + +// TestCycle_TypedAPI_Impossible documents that the typed Go1/Go2/Go3 API +// makes cyclic dependencies impossible at compile time: you cannot pass a +// *Result before it is declared. +func TestCycle_TypedAPI_Impossible(t *testing.T) { + // a := Go1(g, b, ...) // b not declared yet — won't compile + // b := Go1(g, a, ...) +} + +// TestCycle_UntypedAPI_Deadlocks verifies that a never-completing dep +// (simulating a cycle) causes Wait() to deadlock. Context cancellation +// does not unblock d.wait() since it's a plain channel read. +// Prefer Go1-Go3 / Run1-Run3 for compile-time cycle safety. +func TestCycle_UntypedAPI_Deadlocks(t *testing.T) { + g := NewGroup(context.Background()) + placeholder := newEffect() + + cycleTask := Go(g, func(ctx context.Context) (int, error) { + return 42, nil + }, placeholder) + + waitDone := make(chan error, 1) + go func() { + waitDone <- g.Wait() + }() + + select { + case <-waitDone: + t.Fatal("Wait() should not have returned — expected deadlock") + case <-time.After(100 * time.Millisecond): + placeholder.finish(errSkipped) + <-waitDone + if _, ok := cycleTask.Get(); ok { + t.Fatal("cyclic task should not have succeeded") + } + } +} + +// TestCycle_SelfDependency verifies that a task depending on a never-completed +// dep (simulating a self-reference) deadlocks Wait(). +func TestCycle_SelfDependency(t *testing.T) { + g := NewGroup(context.Background()) + blocker := newResult[int]() + + var ran atomic.Bool + Run(g, func(ctx context.Context) error { + ran.Store(true) + return nil + }, blocker) + + waitDone := make(chan error, 1) + go func() { + waitDone <- g.Wait() + }() + + select { + case <-waitDone: + t.Fatal("Wait() should not have returned — expected deadlock") + case <-time.After(100 * time.Millisecond): + blocker.finish(0, errSkipped) + <-waitDone + if ran.Load() { + t.Fatal("task should not have run — dep failed") + } + } +} + +// TestPanic_GoTask verifies that a panic in a Go task is recovered and +// surfaced as an error via Wait(), not crashing the process. +func TestPanic_GoTask(t *testing.T) { + g := NewGroup(context.Background()) + r := Go(g, func(ctx context.Context) (int, error) { + panic("unexpected nil pointer") + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error from panicking task") + } + if !strings.Contains(err.Error(), "task panicked") { + t.Fatalf("expected panic error, got: %v", err) + } + if !strings.Contains(err.Error(), "unexpected nil pointer") { + t.Fatalf("expected panic message in error, got: %v", err) + } + // Result should be marked as failed. + if _, ok := r.Get(); ok { + t.Fatal("Get() should return false on panicked task") + } +} + +// TestPanic_RunTask verifies that a panic in a Run task is recovered. +func TestPanic_RunTask(t *testing.T) { + g := NewGroup(context.Background()) + Run(g, func(ctx context.Context) error { + panic(errors.New("wrapped panic")) + }) + + err := g.Wait() + if err == nil { + t.Fatal("expected error from panicking task") + } + if !strings.Contains(err.Error(), "task panicked") { + t.Fatalf("expected panic error, got: %v", err) + } +} + +// TestPanic_SkipsDownstream verifies that a panic in an upstream task +// causes downstream tasks to be skipped (not run). +func TestPanic_SkipsDownstream(t *testing.T) { + g := NewGroup(context.Background()) + upstream := Go(g, func(ctx context.Context) (int, error) { + panic("boom") + }) + + var ran atomic.Bool + Go1(g, upstream, func(ctx context.Context, v int) (string, error) { + ran.Store(true) + return "", nil + }) + + g.Wait() + if ran.Load() { + t.Fatal("downstream task should have been skipped after upstream panic") + } +} diff --git a/e2e/kube.go b/e2e/kube.go index 3b551f6f370..150d58fc31f 100644 --- a/e2e/kube.go +++ b/e2e/kube.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/agentbaker/e2e/config" "github.com/Azure/agentbaker/e2e/toolkit" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v8" "github.com/stretchr/testify/require" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" @@ -40,7 +41,9 @@ const ( podNetworkDebugAppLabel = "debugnonhost-mariner-tolerated" ) -func getClusterKubeClient(ctx context.Context, resourceGroupName, clusterName string) (*Kubeclient, error) { +func getClusterKubeClient(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (*Kubeclient, error) { + resourceGroupName := config.ResourceGroupName(*cluster.Location) + clusterName := *cluster.Name data, err := getClusterKubeconfigBytes(ctx, resourceGroupName, clusterName) if err != nil { return nil, fmt.Errorf("get cluster kubeconfig bytes: %w", err) @@ -448,7 +451,8 @@ func daemonsetDebug(ctx context.Context, deploymentName, targetNodeLabel, privat } } -func getClusterSubnetID(ctx context.Context, mcResourceGroupName string) (string, error) { +func getClusterSubnetID(ctx context.Context, cluster *armcontainerservice.ManagedCluster) (string, error) { + mcResourceGroupName := *cluster.Properties.NodeResourceGroup pager := config.Azure.VNet.NewListPager(mcResourceGroupName, nil) for pager.More() { nextResult, err := pager.NextPage(ctx)