From 913ebafaf8c9f1341d9e28fc679ad7abf5607daf Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sat, 18 Feb 2023 16:00:10 -0500 Subject: [PATCH 1/3] Make BoolOrComparison a type union This was always the intent. With generics we can express the intent in a way the compiler understands. --- assert/assert.go | 8 +++--- assert/assert_test.go | 54 ++++++++++++++++++++++++--------------- go.mod | 2 +- internal/assert/assert.go | 4 +-- 4 files changed, 42 insertions(+), 26 deletions(-) diff --git a/assert/assert.go b/assert/assert.go index 76a64fa..bfdc639 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -100,7 +100,9 @@ import ( // BoolOrComparison can be a bool, cmp.Comparison, or error. See Assert for // details about how this type is used. -type BoolOrComparison interface{} +type BoolOrComparison interface { + bool | func() (bool, string) | ~func() cmp.Result +} // TestingT is the subset of testing.T used by the assert package. type TestingT interface { @@ -138,7 +140,7 @@ type helperT interface { // Assert uses t.FailNow to fail the test. Like t.FailNow, Assert must be called // from the goroutine running the test function, not from other // goroutines created during the test. Use Check from other goroutines. -func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) { +func Assert[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interface{}) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -152,7 +154,7 @@ func Assert(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) // is successful Check returns true. Check may be called from any goroutine. // // See Assert for details about the comparison arg and failure messages. -func Check(t TestingT, comparison BoolOrComparison, msgAndArgs ...interface{}) bool { +func Check[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interface{}) bool { if ht, ok := t.(helperT); ok { ht.Helper() } diff --git a/assert/assert_test.go b/assert/assert_test.go index 97ec8dd..3b13abb 100644 --- a/assert/assert_test.go +++ b/assert/assert_test.go @@ -125,28 +125,42 @@ func (c exampleComparison) Compare() (bool, string) { return c.success, c.message } -func TestAssertWithComparisonSuccess(t *testing.T) { - fakeT := &fakeTestingT{} - - cmp := exampleComparison{success: true} - Assert(fakeT, cmp.Compare) - expectSuccess(t, fakeT) -} - -func TestAssertWithComparisonFailure(t *testing.T) { - fakeT := &fakeTestingT{} - - cmp := exampleComparison{message: "oops, not good"} - Assert(fakeT, cmp.Compare) - expectFailNowed(t, fakeT, "assertion failed: oops, not good") -} +func TestAssert_ArgumentTypes(t *testing.T) { + t.Run("compare function success", func(t *testing.T) { + fakeT := &fakeTestingT{} + cmp := exampleComparison{success: true} + Assert(fakeT, cmp.Compare) + expectSuccess(t, fakeT) + }) + t.Run("compare function failure", func(t *testing.T) { + fakeT := &fakeTestingT{} + cmp := exampleComparison{message: "oops, not good"} + Assert(fakeT, cmp.Compare) + expectFailNowed(t, fakeT, "assertion failed: oops, not good") + }) + t.Run("compare function failure with extra message", func(t *testing.T) { + fakeT := &fakeTestingT{} + cmp := exampleComparison{message: "oops, not good"} + Assert(fakeT, cmp.Compare, "extra stuff %v", true) + expectFailNowed(t, fakeT, "assertion failed: oops, not good: extra stuff true") + }) -func TestAssertWithComparisonAndExtraMessage(t *testing.T) { - fakeT := &fakeTestingT{} + t.Run("bool", func(t *testing.T) { + fakeT := &fakeTestingT{} + Assert(fakeT, true) + expectSuccess(t, fakeT) + Assert(fakeT, false) + expectFailNowed(t, fakeT, "assertion failed: false is false") + }) - cmp := exampleComparison{message: "oops, not good"} - Assert(fakeT, cmp.Compare, "extra stuff %v", true) - expectFailNowed(t, fakeT, "assertion failed: oops, not good: extra stuff true") + t.Run("result function", func(t *testing.T) { + fn := func() cmp.Result { + return cmp.ResultSuccess + } + fakeT := &fakeTestingT{} + Assert(fakeT, fn) + expectSuccess(t, fakeT) + }) } type customError struct { diff --git a/go.mod b/go.mod index 3837812..061023f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module gotest.tools/v3 -go 1.17 +go 1.18 require ( github.com/google/go-cmp v0.5.9 diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 2dd8025..dcba41a 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -27,8 +27,8 @@ const failureMessage = "assertion failed: " func Eval( t LogT, argSelector argSelector, - comparison interface{}, - msgAndArgs ...interface{}, + comparison any, + msgAndArgs ...any, ) bool { if ht, ok := t.(helperT); ok { ht.Helper() From 72a1942051b1b1cf114ed1f5cfa123a7c3911866 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sat, 18 Feb 2023 16:04:41 -0500 Subject: [PATCH 2/3] Use any instead of interface --- assert/assert.go | 18 +++++++++--------- internal/assert/assert.go | 2 +- internal/format/format.go | 6 +++--- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/assert/assert.go b/assert/assert.go index bfdc639..bb3e0e0 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -108,7 +108,7 @@ type BoolOrComparison interface { type TestingT interface { FailNow() Fail() - Log(args ...interface{}) + Log(args ...any) } type helperT interface { @@ -140,7 +140,7 @@ type helperT interface { // Assert uses t.FailNow to fail the test. Like t.FailNow, Assert must be called // from the goroutine running the test function, not from other // goroutines created during the test. Use Check from other goroutines. -func Assert[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interface{}) { +func Assert[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -154,7 +154,7 @@ func Assert[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interfac // is successful Check returns true. Check may be called from any goroutine. // // See Assert for details about the comparison arg and failure messages. -func Check[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interface{}) bool { +func Check[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...any) bool { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -171,7 +171,7 @@ func Check[C BoolOrComparison](t TestingT, comparison C, msgAndArgs ...interface // NilError uses t.FailNow to fail the test. Like t.FailNow, NilError must be // called from the goroutine running the test function, not from other // goroutines created during the test. Use Check from other goroutines. -func NilError(t TestingT, err error, msgAndArgs ...interface{}) { +func NilError(t TestingT, err error, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -199,7 +199,7 @@ func NilError(t TestingT, err error, msgAndArgs ...interface{}) { // called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.Equal from other // goroutines. -func Equal(t TestingT, x, y interface{}, msgAndArgs ...interface{}) { +func Equal(t TestingT, x, y interface{}, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -237,7 +237,7 @@ func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) { // called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.Error from other // goroutines. -func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) { +func Error(t TestingT, err error, expected string, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -254,7 +254,7 @@ func Error(t TestingT, err error, expected string, msgAndArgs ...interface{}) { // must be called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.ErrorContains from other // goroutines. -func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interface{}) { +func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -288,7 +288,7 @@ func ErrorContains(t TestingT, err error, substring string, msgAndArgs ...interf // goroutines. // // Deprecated: Use ErrorIs -func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interface{}) { +func ErrorType(t TestingT, err error, expected any, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -305,7 +305,7 @@ func ErrorType(t TestingT, err error, expected interface{}, msgAndArgs ...interf // must be called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.ErrorIs from other // goroutines. -func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...interface{}) { +func ErrorIs(t TestingT, err error, expected error, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } diff --git a/internal/assert/assert.go b/internal/assert/assert.go index dcba41a..b31ffe7 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -79,7 +79,7 @@ func runCompareFunc( return true } -func logFailureFromBool(t LogT, msgAndArgs ...interface{}) { +func logFailureFromBool(t LogT, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } diff --git a/internal/format/format.go b/internal/format/format.go index 5097e4b..ffa25d0 100644 --- a/internal/format/format.go +++ b/internal/format/format.go @@ -1,9 +1,9 @@ -package format // import "gotest.tools/v3/internal/format" +package format import "fmt" // Message accepts a msgAndArgs varargs and formats it using fmt.Sprintf -func Message(msgAndArgs ...interface{}) string { +func Message(msgAndArgs ...any) string { switch len(msgAndArgs) { case 0: return "" @@ -15,7 +15,7 @@ func Message(msgAndArgs ...interface{}) string { } // WithCustomMessage accepts one or two messages and formats them appropriately -func WithCustomMessage(source string, msgAndArgs ...interface{}) string { +func WithCustomMessage(source string, msgAndArgs ...any) string { custom := Message(msgAndArgs...) switch { case custom == "": From d39110ebf2951b37471097d37c2043aa011f1c92 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Sat, 18 Feb 2023 16:25:25 -0500 Subject: [PATCH 3/3] Use type constraints for assertions --- assert/assert.go | 4 ++-- assert/assert_test.go | 7 ------- assert/cmp/compare.go | 30 ++++++++++++++++-------------- assert/cmp/compare_test.go | 26 +++++++++++--------------- fs/example_test.go | 2 +- 5 files changed, 30 insertions(+), 39 deletions(-) diff --git a/assert/assert.go b/assert/assert.go index bb3e0e0..1451623 100644 --- a/assert/assert.go +++ b/assert/assert.go @@ -199,7 +199,7 @@ func NilError(t TestingT, err error, msgAndArgs ...any) { // called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.Equal from other // goroutines. -func Equal(t TestingT, x, y interface{}, msgAndArgs ...any) { +func Equal[ANY any](t TestingT, x, y ANY, msgAndArgs ...any) { if ht, ok := t.(helperT); ok { ht.Helper() } @@ -218,7 +218,7 @@ func Equal(t TestingT, x, y interface{}, msgAndArgs ...any) { // called from the goroutine running the test function, not from other // goroutines created during the test. Use Check with cmp.DeepEqual from other // goroutines. -func DeepEqual(t TestingT, x, y interface{}, opts ...gocmp.Option) { +func DeepEqual[ANY any](t TestingT, x, y ANY, opts ...gocmp.Option) { if ht, ok := t.(helperT); ok { ht.Helper() } diff --git a/assert/assert_test.go b/assert/assert_test.go index 3b13abb..5f7f6fa 100644 --- a/assert/assert_test.go +++ b/assert/assert_test.go @@ -283,13 +283,6 @@ func TestEqualFailure(t *testing.T) { expectFailNowed(t, fakeT, "assertion failed: 1 (actual int) != 3 (expected int)") } -func TestEqualFailureTypes(t *testing.T) { - fakeT := &fakeTestingT{} - - Equal(fakeT, 3, uint(3)) - expectFailNowed(t, fakeT, `assertion failed: 3 (int) != 3 (uint)`) -} - func TestEqualFailureWithSelectorArgument(t *testing.T) { fakeT := &fakeTestingT{} diff --git a/assert/cmp/compare.go b/assert/cmp/compare.go index 2bb9e8e..c59ff0b 100644 --- a/assert/cmp/compare.go +++ b/assert/cmp/compare.go @@ -24,7 +24,7 @@ type Comparison func() Result // The comparison can be customized using comparison Options. // Package http://pkg.go.dev/gotest.tools/v3/assert/opt provides some additional // commonly used Options. -func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison { +func DeepEqual[ANY any](x, y ANY, opts ...cmp.Option) Comparison { return func() (result Result) { defer func() { if panicmsg, handled := handleCmpPanic(recover()); handled { @@ -63,7 +63,9 @@ func toResult(success bool, msg string) Result { // RegexOrPattern may be either a *regexp.Regexp or a string that is a valid // regexp pattern. -type RegexOrPattern interface{} +type RegexOrPattern interface { + ~string | *regexp.Regexp +} // Regexp succeeds if value v matches regular expression re. // @@ -72,7 +74,7 @@ type RegexOrPattern interface{} // assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str)) // r := regexp.MustCompile("^[0-9a-f]{32}$") // assert.Assert(t, cmp.Regexp(r, str)) -func Regexp(re RegexOrPattern, v string) Comparison { +func Regexp[R RegexOrPattern](re R, v string) Comparison { match := func(re *regexp.Regexp) Result { return toResult( re.MatchString(v), @@ -80,7 +82,7 @@ func Regexp(re RegexOrPattern, v string) Comparison { } return func() Result { - switch regex := re.(type) { + switch regex := any(re).(type) { case *regexp.Regexp: return match(regex) case string: @@ -96,13 +98,13 @@ func Regexp(re RegexOrPattern, v string) Comparison { } // Equal succeeds if x == y. See assert.Equal for full documentation. -func Equal(x, y interface{}) Comparison { +func Equal[ANY any](x, y ANY) Comparison { return func() Result { switch { - case x == y: + case any(x) == any(y): return ResultSuccess case isMultiLineStringCompare(x, y): - diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)}) + diff := format.UnifiedDiff(format.DiffConfig{A: any(x).(string), B: any(y).(string)}) return multiLineDiffResult(diff, x, y) } return ResultFailureTemplate(` @@ -117,7 +119,7 @@ func Equal(x, y interface{}) Comparison { } } -func isMultiLineStringCompare(x, y interface{}) bool { +func isMultiLineStringCompare(x, y any) bool { strX, ok := x.(string) if !ok { return false @@ -129,7 +131,7 @@ func isMultiLineStringCompare(x, y interface{}) bool { return strings.Contains(strX, "\n") || strings.Contains(strY, "\n") } -func multiLineDiffResult(diff string, x, y interface{}) Result { +func multiLineDiffResult(diff string, x, y any) Result { return ResultFailureTemplate(` --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}} +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}} @@ -138,7 +140,7 @@ func multiLineDiffResult(diff string, x, y interface{}) Result { } // Len succeeds if the sequence has the expected length. -func Len(seq interface{}, expected int) Comparison { +func Len(seq any, expected int) Comparison { return func() (result Result) { defer func() { if e := recover(); e != nil { @@ -163,7 +165,7 @@ func Len(seq interface{}, expected int) Comparison { // If collection is a Map, contains will succeed if item is a key in the map. // If collection is a slice or array, item is compared to each item in the // sequence using reflect.DeepEqual(). -func Contains(collection interface{}, item interface{}) Comparison { +func Contains(collection any, item any) Comparison { return func() Result { colValue := reflect.ValueOf(collection) if !colValue.IsValid() { @@ -261,14 +263,14 @@ func formatErrorMessage(err error) string { // // Use NilError() for comparing errors. Use Len(obj, 0) for comparing slices, // maps, and channels. -func Nil(obj interface{}) Comparison { +func Nil(obj any) Comparison { msgFunc := func(value reflect.Value) string { return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type()) } return isNil(obj, msgFunc) } -func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison { +func isNil(obj any, msgFunc func(reflect.Value) string) Comparison { return func() Result { if obj == nil { return ResultSuccess @@ -309,7 +311,7 @@ func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison { // Fails if err does not implement the reflect.Type. // // Deprecated: Use ErrorIs -func ErrorType(err error, expected interface{}) Comparison { +func ErrorType(err error, expected any) Comparison { return func() Result { switch expectedType := expected.(type) { case func(error) bool: diff --git a/assert/cmp/compare_test.go b/assert/cmp/compare_test.go index e4546ea..b315a21 100644 --- a/assert/cmp/compare_test.go +++ b/assert/cmp/compare_test.go @@ -45,13 +45,15 @@ func TestDeepEqualWithUnexported(t *testing.T) { } func TestRegexp(t *testing.T) { - var testcases = []struct { + type testCase struct { name string - regex interface{} + regex string value string match bool expErr string - }{ + } + + var testcases = []testCase{ { name: "pattern string match", regex: "^[0-9]+$", @@ -70,24 +72,12 @@ func TestRegexp(t *testing.T) { value: "2123423456", expErr: `value "2123423456" does not match regexp "^1"`, }, - { - name: "regexp match", - regex: regexp.MustCompile("^d[0-9a-f]{8}$"), - value: "d1632beef", - match: true, - }, { name: "invalid regexp", regex: "^1(", value: "2", expErr: "error parsing regexp: missing closing ): `^1(`", }, - { - name: "invalid type", - regex: struct{}{}, - value: "some string", - expErr: "invalid type struct {} for regex pattern", - }, } for _, tc := range testcases { @@ -100,6 +90,12 @@ func TestRegexp(t *testing.T) { } }) } + + t.Run("regexp match", func(t *testing.T) { + regex := regexp.MustCompile("^d[0-9a-f]{8}$") + res := Regexp(regex, "d1632beef")() + assertSuccess(t, res) + }) } func TestLen(t *testing.T) { diff --git a/fs/example_test.go b/fs/example_test.go index ece5b4b..1ec61f5 100644 --- a/fs/example_test.go +++ b/fs/example_test.go @@ -29,7 +29,7 @@ func ExampleNewFile() { content, err := os.ReadFile(file.Path()) assert.NilError(t, err) - assert.Equal(t, "content\n", content) + assert.Equal(t, "content\n", string(content)) } // Create a directory and subdirectory with files