diff --git a/any.go b/any.go new file mode 100644 index 0000000..7f6e876 --- /dev/null +++ b/any.go @@ -0,0 +1,26 @@ +package gomockctx + +import ( + "context" + + "github.com/golang/mock/gomock" +) + +// Any returns a gomock.Matcher which matches any context.Context object. +func Any() gomock.Matcher { + return &anyMatcher{} +} + +type anyMatcher struct{} + +var _ gomock.Matcher = &anyMatcher{} + +func (cm *anyMatcher) Matches(x interface{}) bool { + _, ok := x.(context.Context) + + return ok +} + +func (cm *anyMatcher) String() string { + return "is a context.Context" +} diff --git a/any_test.go b/any_test.go new file mode 100644 index 0000000..380021c --- /dev/null +++ b/any_test.go @@ -0,0 +1,155 @@ +package gomockctx + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAny(t *testing.T) { + tests := []struct { + name string + x interface{} + want bool + }{ + { + name: "nil", + x: nil, + want: false, + }, + { + name: "empty string", + x: "", + want: false, + }, + { + name: "string", + x: "foo bar", + want: false, + }, + { + name: "int", + x: 42, + want: false, + }, + { + name: "int8", + x: int8(64), + want: false, + }, + { + name: "int16", + x: int16(1024), + want: false, + }, + { + name: "int32", + x: int32(1123456789), + want: false, + }, + { + name: "int64", + x: int64(16123456789), + want: false, + }, + { + name: "uint", + x: uint(616), + want: false, + }, + { + name: "uint8", + x: uint8(64), + want: false, + }, + { + name: "uint16", + x: uint16(1024), + want: false, + }, + { + name: "uint32", + x: uint32(1123456789), + want: false, + }, + { + name: "uint64", + x: uint64(16123456789), + want: false, + }, + { + name: "byte", + x: byte('A'), + want: false, + }, + { + name: "rune", + x: rune('A'), + want: false, + }, + { + name: "float32", + x: float32(6.16), + want: false, + }, + { + name: "float64", + x: float64(6.16), + want: false, + }, + { + name: "bool", + x: true, + want: false, + }, + { + name: "slice", + x: []string{"foo", "bar"}, + want: false, + }, + { + name: "array", + x: [2]string{"foo", "bar"}, + want: false, + }, + { + name: "channel", + x: make(chan bool), + want: false, + }, + { + name: "func", + x: func() {}, + want: false, + }, + { + name: "context.Background()", + x: context.Background(), + want: true, + }, + { + name: "context.TODO()", + x: context.TODO(), + want: true, + }, + { + name: "custom context", + x: context.WithValue(context.Background(), ctxKey, "foo"), + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := Any() + require.Implements(t, (*gomock.Matcher)(nil), m) + + got := m.Matches(tt.x) + + assert.Equal(t, tt.want, got) + assert.Equal(t, "is a context.Context", m.String()) + }) + } +} diff --git a/gomockctx.go b/gomockctx.go index a8a63e2..bd07c14 100644 --- a/gomockctx.go +++ b/gomockctx.go @@ -4,8 +4,6 @@ package gomockctx import ( "context" - "fmt" - "reflect" "github.com/golang/mock/gomock" ) @@ -15,76 +13,10 @@ type ( contextValue string ) -var ctxKey contextKey = "gomockctx context ID" - -// New returns a context as a child of the given parent, and with a randomized -// gomockctx ID value set, making it a gomockctx context. This can then be used -// with Is to get a gomock Matcher which returns true for the context from New, -// or any child contexts of it. -// -// If crypto/rand returns an error, this will panic trying to generate the -// gomockctx ID. In practise though, crypto/rand should never return a error. -func New(parent context.Context) context.Context { - return context.WithValue(parent, ctxKey, newCtxID()) -} - -// Is accepts a context with a gomockctx ID value (as returned from New), and -// returns a gomock.Matcher which returns true for the given context, of any -// child contexts of it. -// -// If ctx was not returned from New, the resulting matcher will ALWAYS return -// false. -func Is(ctx context.Context) gomock.Matcher { - return WithValue(ctxKey, value(ctx)) -} - -// ID returns the gomockctx ID value in the given context, or a empty string if -// the context does not have a gomockctx ID value. -func ID(ctx context.Context) string { - if ctx == nil { - return "" - } - - return string(value(ctx)) -} - -// Any returns a gomock.Matcher which matches any context.Context object. -func Any() gomock.Matcher { - return gomock.AssignableToTypeOf( - reflect.TypeOf((*context.Context)(nil)).Elem(), - ) -} - -// WithValue returns a gomock.Matcher which matches any context that has the -// specified key and value. -func WithValue(key interface{}, value interface{}) gomock.Matcher { - return &contextMatcher{ - key: key, - value: value, - } -} - -type contextMatcher struct { - key interface{} - value interface{} -} - -var _ gomock.Matcher = &contextMatcher{} - -func (cm *contextMatcher) Matches(x interface{}) bool { - if ctx, ok := x.(context.Context); ok { - return reflect.DeepEqual(cm.value, ctx.Value(cm.key)) - } - - return false -} - -func (cm *contextMatcher) String() string { - return fmt.Sprintf(`context with "%+v" = "%+v"`, cm.key, cm.value) -} +var ctxKey contextKey = "gomockctx ID" func newCtxID() contextValue { - id, err := randString(64) + id, err := randString(32) if err != nil { panic(err) } @@ -92,7 +24,7 @@ func newCtxID() contextValue { return contextValue(id) } -func value(ctx context.Context) contextValue { +func getValue(ctx context.Context) contextValue { var value contextValue if ctx == nil { return value @@ -105,3 +37,34 @@ func value(ctx context.Context) contextValue { return value } + +// New returns a context as a child of the given parent, which includes a +// randomized gomockctx ID value set, which makes it a gomockctx context. This +// can then be used with Is to get a gomock Matcher which returns true for the +// context from New, or any child contexts of it. +// +// If crypto/rand returns an error, this will panic trying to generate the +// gomockctx ID. In practice though, crypto/rand should never return a error. +func New(parent context.Context) context.Context { + return context.WithValue(parent, ctxKey, newCtxID()) +} + +// Is accepts a context with a gomockctx ID value (as returned from New), and +// returns a gomock.Matcher which returns true for the given context, of any +// child contexts of it. +// +// If ctx was not returned from New, the resulting matcher will ALWAYS return +// false. +func Is(ctx context.Context) gomock.Matcher { + return WithValue(ctxKey, getValue(ctx)) +} + +// ID returns the gomockctx ID value in the given context, or a empty string if +// the context does not have a gomockctx ID value. +func ID(ctx context.Context) string { + if ctx == nil { + return "" + } + + return string(getValue(ctx)) +} diff --git a/gomockctx_test.go b/gomockctx_test.go index 4f603d0..4e59398 100644 --- a/gomockctx_test.go +++ b/gomockctx_test.go @@ -2,11 +2,113 @@ package gomockctx import ( "context" + "regexp" "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func TestNew(t *testing.T) { + type key struct{} + k := key{} + ctxIDRegexp := regexp.MustCompile(`^[A-Za-z0-9]+$`) + parent := context.WithValue(context.Background(), k, "the parent") + + ids := map[contextValue]struct{}{} + limit := 1000 + + for i := 0; i < limit; i++ { + ctx := New(parent) + require.Equal(t, "the parent", ctx.Value(k)) + + v := ctx.Value(ctxKey) + require.IsType(t, contextValue(""), v) + require.Len(t, v, 32) + require.Regexp(t, ctxIDRegexp, v) + cv, _ := v.(contextValue) + ids[cv] = struct{}{} + } + + assert.Len(t, ids, limit) +} + +func TestIs(t *testing.T) { + type strKey string + + type args struct { + ctx context.Context + } + tests := []struct { + name string + args args + want *valueMatcher + }{ + { + name: "nil", + args: args{ + ctx: nil, + }, + want: &valueMatcher{ + key: ctxKey, + value: contextValue(""), + }, + }, + { + name: "context without gomockctx ID", + args: args{ + ctx: context.Background(), + }, + want: &valueMatcher{ + key: ctxKey, + value: contextValue(""), + }, + }, + { + name: "context with gomockctx ID", + args: args{ + ctx: context.WithValue( + context.Background(), + ctxKey, + contextValue("z9KZVcfmA4sWJX0yuIIESVcEARlwiAT2"), + ), + }, + want: &valueMatcher{ + key: ctxKey, + value: contextValue("z9KZVcfmA4sWJX0yuIIESVcEARlwiAT2"), + }, + }, + { + name: "child context of context with gomockctx ID", + args: args{ + ctx: context.WithValue( + context.WithValue( + context.Background(), + ctxKey, + contextValue("hWEKf4Gtj15iLx4R7IFlHc5ooj5tU4UW"), + ), + strKey("foo"), + "bar", + ), + }, + want: &valueMatcher{ + key: ctxKey, + value: contextValue("hWEKf4Gtj15iLx4R7IFlHc5ooj5tU4UW"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := Is(tt.args.ctx) + + assert.Implements(t, (*gomock.Matcher)(nil), got) + + assert.Equal(t, tt.want, got) + }) + } +} + func TestID(t *testing.T) { tests := []struct { name string diff --git a/with_value.go b/with_value.go new file mode 100644 index 0000000..1c49d18 --- /dev/null +++ b/with_value.go @@ -0,0 +1,37 @@ +package gomockctx + +import ( + "context" + "fmt" + "reflect" + + "github.com/golang/mock/gomock" +) + +// WithValue returns a gomock.Matcher which matches any context that has the +// specified key and value. +func WithValue(key interface{}, value interface{}) gomock.Matcher { + return &valueMatcher{ + key: key, + value: value, + } +} + +type valueMatcher struct { + key interface{} + value interface{} +} + +var _ gomock.Matcher = &valueMatcher{} + +func (cm *valueMatcher) Matches(x interface{}) bool { + if ctx, ok := x.(context.Context); ok { + return reflect.DeepEqual(cm.value, ctx.Value(cm.key)) + } + + return false +} + +func (cm *valueMatcher) String() string { + return fmt.Sprintf(`context with "%+v" = "%+v"`, cm.key, cm.value) +} diff --git a/with_value_test.go b/with_value_test.go new file mode 100644 index 0000000..e7f7785 --- /dev/null +++ b/with_value_test.go @@ -0,0 +1,468 @@ +package gomockctx + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestWithValue(t *testing.T) { + type args struct { + key interface{} + value interface{} + } + tests := []struct { + name string + args args + want *valueMatcher + }{ + { + name: "nil", + args: args{ + key: nil, + value: nil, + }, + want: &valueMatcher{ + key: nil, + value: nil, + }, + }, + { + name: "string", + args: args{ + key: "foo", + value: "bar", + }, + want: &valueMatcher{ + key: "foo", + value: "bar", + }, + }, + { + name: "gomockctx ctxKey", + args: args{ + key: ctxKey, + value: "FrAcGnKKpVk1rB3AWC9S8Dnff04svNtN", + }, + want: &valueMatcher{ + key: ctxKey, + value: "FrAcGnKKpVk1rB3AWC9S8Dnff04svNtN", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := WithValue(tt.args.key, tt.args.value) + + assert.Implements(t, (*gomock.Matcher)(nil), got) + + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_valueMatcher_Matches(t *testing.T) { + ctx := context.Background() + + type stringKey string + strKey := stringKey("strKey") + strKeyVal1 := "foo" + strKeyVal2 := "bar" + ctxWithStrKey := context.WithValue(ctx, strKey, strKeyVal1) + + structKey := struct{}{} + structKeyVal1 := struct{ name string }{name: "foo"} + structKeyVal2 := struct{ name string }{name: "bar"} + ctxWithStructKey := context.WithValue(ctx, structKey, structKeyVal1) + + ctxKeyVal1 := "XAzb0Cr7yLuLzO369vTodjxKL3GUpspE" + ctxKeyVal2 := "r11X0FOejbPamvLiWhAuGiSqXzdmGnIm" + ctxWithCtxKey := context.WithValue(ctx, ctxKey, ctxKeyVal1) + + type fields struct { + key interface{} + value interface{} + } + type args struct { + x interface{} + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "nil", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: nil}, + want: false, + }, + { + name: "empty string", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: ""}, + want: false, + }, + { + name: "string", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: "hello world"}, + want: false, + }, + { + name: "int", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: int(616)}, + want: false, + }, + { + name: "int8", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: int8(64)}, + want: false, + }, + { + name: "int16", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: int16(1024)}, + want: false, + }, + { + name: "int32", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: int32(1123456789)}, + want: false, + }, + { + name: "int64", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: int64(16123456789)}, + want: false, + }, + { + name: "uint", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: uint(616)}, + want: false, + }, + { + name: "uint8", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: uint8(64)}, + want: false, + }, + { + name: "uint16", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: uint16(1024)}, + want: false, + }, + { + name: "uint32", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: uint32(1123456789)}, + want: false, + }, + { + name: "uint64", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: uint64(16123456789)}, + want: false, + }, + { + name: "byte", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: byte('A')}, + want: false, + }, + { + name: "rune", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: rune('A')}, + want: false, + }, + { + name: "float32", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: float32(6.16)}, + want: false, + }, + { + name: "float64", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: float64(6.16)}, + want: false, + }, + { + name: "bool", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: 616}, + want: false, + }, + { + name: "slice", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: []string{"foo", "bar"}}, + want: false, + }, + { + name: "array", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: [2]string{"foo", "bar"}}, + want: false, + }, + { + name: "channel", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: make(chan bool)}, + want: false, + }, + { + name: "func", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: func() {}}, + want: false, + }, + { + name: "context with strKey", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: ctxWithStrKey}, + want: true, + }, + { + name: "context with different strKey", + fields: fields{ + key: strKey, + value: strKeyVal2, + }, + args: args{x: ctxWithStrKey}, + want: false, + }, + { + name: "context without strKey", + fields: fields{ + key: strKey, + value: strKeyVal1, + }, + args: args{x: ctx}, + want: false, + }, + { + name: "context with structKey", + fields: fields{ + key: structKey, + value: structKeyVal1, + }, + args: args{x: ctxWithStructKey}, + want: true, + }, + { + name: "context with different structKey", + fields: fields{ + key: structKey, + value: structKeyVal2, + }, + args: args{x: ctxWithStructKey}, + want: false, + }, + { + name: "context without structKey", + fields: fields{ + key: structKey, + value: structKeyVal1, + }, + args: args{x: ctx}, + want: false, + }, + { + name: "context with ctxKey", + fields: fields{ + key: ctxKey, + value: ctxKeyVal1, + }, + args: args{x: ctxWithCtxKey}, + want: true, + }, + { + name: "context with different ctxKey", + fields: fields{ + key: ctxKey, + value: ctxKeyVal2, + }, + args: args{x: ctxWithCtxKey}, + want: false, + }, + { + name: "context without ctxKey", + fields: fields{ + key: ctxKey, + value: ctxKeyVal1, + }, + args: args{x: ctx}, + want: false, + }, + { + name: "context with ctxKey and empty value", + fields: fields{ + key: ctxKey, + value: contextValue(""), + }, + args: args{x: ctxWithCtxKey}, + want: false, + }, + { + name: "context with different ctxKey and empty value", + fields: fields{ + key: ctxKey, + value: contextValue(""), + }, + args: args{x: ctxWithCtxKey}, + want: false, + }, + { + name: "context without ctxKey and empty value", + fields: fields{ + key: ctxKey, + value: contextValue(""), + }, + args: args{x: ctx}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vm := &valueMatcher{ + key: tt.fields.key, + value: tt.fields.value, + } + + got := vm.Matches(tt.args.x) + + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_valueMatcher_String(t *testing.T) { + type stringKey string + type structKey struct{ name string } + + type fields struct { + key interface{} + value interface{} + } + tests := []struct { + name string + fields fields + want string + }{ + { + name: "stringKey", + fields: fields{ + key: stringKey("foo"), + value: "hello world", + }, + want: `context with "foo" = "hello world"`, + }, + { + name: "structKey", + fields: fields{ + key: structKey{name: "bar"}, + value: "okay then", + }, + want: `context with "{name:bar}" = "okay then"`, + }, + { + name: "gomockctx ctxKey", + fields: fields{ + key: ctxKey, + value: "foobar", + }, + want: `context with "gomockctx ID" = "foobar"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + vm := &valueMatcher{ + key: tt.fields.key, + value: tt.fields.value, + } + + got := vm.String() + + assert.Equal(t, tt.want, got) + }) + } +}