Skip to content

Commit

Permalink
Create errorAssertionFunction creators
Browse files Browse the repository at this point in the history
To clean up table driven test with often repeated functions we can add a
simple function creator that generates a errorAssertionFunction
  • Loading branch information
JERHAV committed Mar 5, 2024
1 parent bb548d0 commit 3f0a0fc
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
34 changes: 34 additions & 0 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,40 @@ type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// Comparison is a custom function that returns true on success and false on failure
type Comparison func() (success bool)

// ErrorIsFor returns an [ErrorAssertionFunc] which tests if the error wraps target.
func ErrorIsFor(target error) ErrorAssertionFunc {
return func(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

return ErrorIs(t, err, target, msgAndArgs...)
}
}

// ErrorAsFor returns an [ErrorAssertionFunc] which tests if the any error in err's tree matches target and if so, assigns it to target.
// The returned function panics if target is not a non-nil pointer to either a type that implements error, or to any interface type.
func ErrorAsFor(target interface{}) ErrorAssertionFunc {
return func(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

return ErrorAs(t, err, target, msgAndArgs...)
}
}

// ErrorTypeFor returns an [ErrorAssertionFunc] which tests whether the type of the error matches the targets type.
func ErrorTypeFor(target error) ErrorAssertionFunc {
return func(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}

return IsType(t, target, err, msgAndArgs...)
}
}

/*
Helper functions
*/
Expand Down
27 changes: 25 additions & 2 deletions assert/assertions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2750,7 +2750,13 @@ func ExampleErrorAssertionFunc() {
t := &testing.T{} // provided by test

dumbParseNum := func(input string, v interface{}) error {
return json.Unmarshal([]byte(input), v)

err := json.Unmarshal([]byte(input), v)
if err != nil {
return testingError{"could not Unmarshal " + input}
}

return nil
}

tests := []struct {
Expand All @@ -2760,8 +2766,9 @@ func ExampleErrorAssertionFunc() {
}{
{"1.2 is number", "1.2", NoError},
{"1.2.3 not number", "1.2.3", Error},
{"true is not number", "true", Error},
{"true is not number", "true", ErrorAsFor(&testingError{})},
{"3 is number", "3", NoError},
{"3% is not a valid number", "3%", ErrorIsFor(testingError{"could not Unmarshal 3%"})},
}

for _, tt := range tests {
Expand All @@ -2772,14 +2779,30 @@ func ExampleErrorAssertionFunc() {
}
}

type testingError struct {
extraInfo string
}

func (t testingError) Error() string {
return t.extraInfo
}

func TestErrorAssertionFunc(t *testing.T) {
var testError = errors.New("test error")
tests := []struct {
name string
err error
assertion ErrorAssertionFunc
}{
{"noError", nil, NoError},
{"error", errors.New("whoops"), Error},
{"errorIs", testError, ErrorIsFor(testError)},
{"errorAs", testingError{extraInfo: "something"}, ErrorAsFor(&testingError{})},
{"errorType", testingError{extraInfo: "something"}, ErrorTypeFor(testingError{})},
{"wrappedErrorAs", fmt.Errorf("This wrapped error: %w", testingError{extraInfo: "something"}),
ErrorAsFor(&testingError{})},
{"wrappedErrorIs", fmt.Errorf("This wrapped error: %w", testError),
ErrorIsFor(testError)},
}

for _, tt := range tests {
Expand Down
34 changes: 34 additions & 0 deletions require/requirements.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,38 @@ type BoolAssertionFunc func(TestingT, bool, ...interface{})
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{})

// ErrorIsFunc returns an [ErrorAssertionFunc] which tests if the error wraps target.
func ErrorIsFor(expectedError error) ErrorAssertionFunc {
return func(t TestingT, err error, msgsAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}

ErrorIs(t, err, expectedError, msgsAndArgs...)
}
}

// ErrorAsFunc returns an [ErrorAssertionFunc] which tests if the any error in err's tree matches target and if so, assigns it to target.
// The returned function panics if target is not a non-nil pointer to either a type that implements error, or to any interface type.
func ErrorAsFor(expectedInterface interface{}) ErrorAssertionFunc {
return func(t TestingT, err error, msgsAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}

ErrorAs(t, err, expectedInterface, msgsAndArgs...)
}
}

// ErrorTypeFor returns an [ErrorAssertionFunc] which tests whether the type of the error matches the targets type.
func ErrorTypeFor(target interface{}) ErrorAssertionFunc {
return func(t TestingT, err error, msgAndArgs ...interface{}) {
if h, ok := t.(tHelper); ok {
h.Helper()
}

IsType(t, target, err, msgAndArgs...)
}
}

//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=require -template=require.go.tmpl -include-format-funcs"
17 changes: 17 additions & 0 deletions require/requirements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package require
import (
"encoding/json"
"errors"
"fmt"
"testing"
"time"
)
Expand Down Expand Up @@ -665,14 +666,30 @@ func ExampleErrorAssertionFunc() {
}
}

type testingError struct {
extraInfo string
}

func (t testingError) Error() string {
return t.extraInfo
}

func TestErrorAssertionFunc(t *testing.T) {
var testError = errors.New("test error")
tests := []struct {
name string
err error
assertion ErrorAssertionFunc
}{
{"noError", nil, NoError},
{"error", errors.New("whoops"), Error},
{"errorIs", testError, ErrorIsFor(testError)},
{"errorAs", testingError{extraInfo: "something"}, ErrorAsFor(&testingError{})},
{"errorType", testingError{extraInfo: "something"}, ErrorTypeFor(testingError{})},
{"wrappedErrorAs", fmt.Errorf("This wrapped error: %w", testingError{extraInfo: "something"}),
ErrorAsFor(&testingError{})},
{"wrappedErrorIs", fmt.Errorf("This wrapped error: %w", testError),
ErrorIsFor(testError)},
}

for _, tt := range tests {
Expand Down

0 comments on commit 3f0a0fc

Please sign in to comment.