diff --git a/go.mod b/go.mod index 54784f1..181cdd4 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,13 @@ module github.com/romdo/gomockctx go 1.17 -require github.com/golang/mock v1.6.0 +require ( + github.com/golang/mock v1.6.0 + github.com/stretchr/testify v1.7.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) diff --git a/go.sum b/go.sum index d067127..50a3a75 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,13 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -23,3 +31,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/gomockctx.go b/gomockctx.go index d12099d..a8a63e2 100644 --- a/gomockctx.go +++ b/gomockctx.go @@ -17,25 +17,6 @@ type ( var ctxKey contextKey = "gomockctx context ID" -func newCtxID() contextValue { - id, err := randString(64) - if err != nil { - panic(err) - } - - return contextValue(id) -} - -func value(ctx context.Context) contextValue { - var value contextValue - v := ctx.Value(ctxKey) - if s, ok := v.(contextValue); ok { - value = s - } - - return value -} - // New returns a context as a child of the given parent, and with a randomized // gomockctx ID value set, making it a gomockctx context. This can then be used // with Is to get a gomock Matcher which returns true for the context from New, @@ -57,8 +38,25 @@ func Is(ctx context.Context) gomock.Matcher { return WithValue(ctxKey, value(ctx)) } -// WithValue creates a generic gomock context matcher which returns true for any -// context which has the specified key/value. +// ID returns the gomockctx ID value in the given context, or a empty string if +// the context does not have a gomockctx ID value. +func ID(ctx context.Context) string { + if ctx == nil { + return "" + } + + return string(value(ctx)) +} + +// Any returns a gomock.Matcher which matches any context.Context object. +func Any() gomock.Matcher { + return gomock.AssignableToTypeOf( + reflect.TypeOf((*context.Context)(nil)).Elem(), + ) +} + +// WithValue returns a gomock.Matcher which matches any context that has the +// specified key and value. func WithValue(key interface{}, value interface{}) gomock.Matcher { return &contextMatcher{ key: key, @@ -66,12 +64,6 @@ func WithValue(key interface{}, value interface{}) gomock.Matcher { } } -// ID returns the gomockctx ID value in the given context, or a empty string if -// it not a gomockctx context. -func ID(ctx context.Context) string { - return string(value(ctx)) -} - type contextMatcher struct { key interface{} value interface{} @@ -79,15 +71,37 @@ type contextMatcher struct { var _ gomock.Matcher = &contextMatcher{} -func (e *contextMatcher) Matches(x interface{}) bool { - ctx, ok := x.(context.Context) - if !ok { - return false +func (cm *contextMatcher) Matches(x interface{}) bool { + if ctx, ok := x.(context.Context); ok { + return reflect.DeepEqual(cm.value, ctx.Value(cm.key)) } - return reflect.DeepEqual(e.value, ctx.Value(e.key)) + return false } -func (e *contextMatcher) String() string { - return fmt.Sprintf(`context with "%+v" = "%+v"`, e.key, e.value) +func (cm *contextMatcher) String() string { + return fmt.Sprintf(`context with "%+v" = "%+v"`, cm.key, cm.value) +} + +func newCtxID() contextValue { + id, err := randString(64) + if err != nil { + panic(err) + } + + return contextValue(id) +} + +func value(ctx context.Context) contextValue { + var value contextValue + if ctx == nil { + return value + } + + v := ctx.Value(ctxKey) + if s, ok := v.(contextValue); ok { + value = s + } + + return value } diff --git a/gomockctx_test.go b/gomockctx_test.go new file mode 100644 index 0000000..4f603d0 --- /dev/null +++ b/gomockctx_test.go @@ -0,0 +1,41 @@ +package gomockctx + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestID(t *testing.T) { + tests := []struct { + name string + ctx context.Context + want string + }{ + { + name: "nil", + ctx: nil, + want: "", + }, + { + name: "without ID", + ctx: context.Background(), + want: "", + }, + { + name: "with ID", + ctx: context.WithValue( + context.Background(), ctxKey, contextValue("xI2UWC8MvdYcU22B"), + ), + want: "xI2UWC8MvdYcU22B", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ID(tt.ctx) + + assert.Equal(t, tt.want, got) + }) + } +}