diff --git a/assert/assertions.go b/assert/assertions.go index a27e70546..381ff1f30 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -13,6 +13,7 @@ import ( "runtime" "runtime/debug" "strings" + "sync/atomic" "time" "unicode" "unicode/utf8" @@ -2046,6 +2047,44 @@ type CollectT struct { // If it's non-nil but len(c.errors) == 0, this is also a failure // obtained by direct c.FailNow() call. errors []error + + // Tells [EventuallyWithT] to halt after the current tick. + halt uint32 +} + +// HaltT wraps a [CollectT] so that any failure recorded through it tells +// [EventuallyWithT] to halt. +type HaltT struct { + inner *CollectT +} + +// Errorf collects an error in the underlying [CollectT], and halts +// [EventuallyWithT] after the current tick. +func (t *HaltT) Errorf(format string, args ...interface{}) { + atomic.StoreUint32(&t.inner.halt, 1) + t.inner.Errorf(format, args...) +} + +// FailNow calls [CollectT.FailNow] and immediately halts [EventuallyWithT]. +func (t *HaltT) FailNow() { + atomic.StoreUint32(&t.inner.halt, 1) + t.inner.FailNow() +} + +// Halt returns a [HaltT] view of this CollectT. Any failure recorded +// through it tells [EventuallyWithT] to halt. +// +// Use it inside a condition when a terminal, non-recovering prerequisite should +// halt [EventuallyWithT] immediately (via require) or after the current tick +// (via assert): +// +// condition := func(c *CollectT) { +// require.True(c.Halt(), socketsOpen(), "socket must be open before proceeding") +// assert.True(c, eventuallyTrue(), "non-fatal checks still run until success/timeout") +// } +// assert.EventuallyWithT(t, condition, time.Second, 10*time.Millisecond) +func (c *CollectT) Halt() *HaltT { + return &HaltT{c} } // Helper is like [testing.T.Helper] but does nothing. @@ -2138,12 +2177,21 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time tickC = nil go checkCond() case collect := <-ch: - if !collect.failed() { + switch { + case atomic.LoadUint32(&collect.halt) == 1: + for _, err := range collect.errors { + t.Errorf("%v", err) + } + return Fail(t, "Halted", msgAndArgs...) + + case !collect.failed(): return true + + default: + // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. + lastFinishedTickErrs = collect.errors + tickC = ticker.C } - // Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached. - lastFinishedTickErrs = collect.errors - tickC = ticker.C } } } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 4975f5e41..ab54b52ce 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -14,6 +14,7 @@ import ( "regexp" "runtime" "strings" + "sync/atomic" "testing" "time" ) @@ -3543,6 +3544,27 @@ func TestEventuallyWithTFailNow(t *testing.T) { Len(t, mockT.errors, 1) } +func TestEventuallyWithTHalts(t *testing.T) { + t.Parallel() + + mockT := new(CollectT) + + var timesCalled int32 + condition := func(collect *CollectT) { + if times := atomic.AddInt32(×Called, 1); times == 1 { + FailNow(collect.Halt(), "This should be captured") + Fail(collect, "FailNow(CollectT.Halt(), ...) didn't diverge") + } + } + + False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond)) + Equal(t, int32(1), atomic.LoadInt32(×Called), "Condition wasn't called exactly once") + if Len(t, mockT.errors, 2) { + ErrorContains(t, mockT.errors[0], "This should be captured") + ErrorContains(t, mockT.errors[1], "Halted") + } +} + // Check that a long running condition doesn't block Eventually. // See issue 805 (and its long tail of following issues) func TestEventuallyTimeout(t *testing.T) {