diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e268bc2..29d65ac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -80,7 +80,6 @@ jobs: strategy: matrix: go-version: - - "1.17" - "1.18" - "1.19" - "1.20" @@ -88,11 +87,13 @@ jobs: - "1.22" - "1.23" - "1.24" + - "stable" steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} + check-latest: true - name: Run tests run: make test env: diff --git a/README.md b/README.md index 1ade00b..95e192f 100644 --- a/README.md +++ b/README.md @@ -5,8 +5,8 @@

Go package providing a suite of functions that use crypto/rand - to generate cryptographically secure random strings in various formats, as - well as ints and bytes. + to generate cryptographically secure random data in various forms and + formats.

@@ -59,6 +59,9 @@ n, err := rands.Int(2147483647) // => 1334400235 n, err := rands.Int64(int64(9223372036854775807)) // => 8256935979116161233 b, err := rands.Bytes(8) // => [0 220 137 243 135 204 34 63] + +err := rands.Shuffle(len(arr) func(i, j) { arr[i], arr[j] = arr[j], arr[i] }) +err := rands.ShuffleSlice(arr) ``` ## [`randsmust`](https://pkg.go.dev/github.com/jimeh/rands/randsmust) package @@ -102,6 +105,9 @@ n := randsmust.Int(2147483647) // => 1293388115 n := randsmust.Int64(int64(9223372036854775807)) // => 6168113630900161239 b := randsmust.Bytes(8) // => [205 128 54 95 0 95 53 51] + +randsmust.Shuffle(len(arr) func(i, j) { arr[i], arr[j] = arr[j], arr[i] }) +randsmust.ShuffleSlice(arr) ``` ## Documentation diff --git a/go.mod b/go.mod index 921ca0c..0766348 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/jimeh/rands -go 1.17 +go 1.18 require github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 3ebac7b..713a0b4 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,10 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/randsmust/shuffle.go b/randsmust/shuffle.go new file mode 100644 index 0000000..7d42256 --- /dev/null +++ b/randsmust/shuffle.go @@ -0,0 +1,32 @@ +package randsmust + +import "github.com/jimeh/rands" + +// Shuffle randomizes the order of a collection of n elements using +// cryptographically secure random values from crypto/rand. It implements the +// Fisher-Yates shuffle algorithm. +// +// The swap function is called to exchange values at indices i and j. This +// signature is compatible with Shuffle from math/rand and math/rand/v2 for easy +// migration. +// +// If an error occurs during shuffling, this function will panic. +func Shuffle(n int, swap func(i, j int)) { + err := rands.Shuffle(n, swap) + if err != nil { + panic(err) + } +} + +// ShuffleSlice randomizes the order of elements in a slice in-place using +// cryptographically secure random values from crypto/rand. +// +// It implements the Fisher-Yates shuffle algorithm. +// +// If an error occurs during shuffling, this function will panic. +func ShuffleSlice[T any](slice []T) { + err := rands.ShuffleSlice(slice) + if err != nil { + panic(err) + } +} diff --git a/randsmust/shuffle_example_test.go b/randsmust/shuffle_example_test.go new file mode 100644 index 0000000..7bfa142 --- /dev/null +++ b/randsmust/shuffle_example_test.go @@ -0,0 +1,25 @@ +package randsmust_test + +import ( + "fmt" + + "github.com/jimeh/rands/randsmust" +) + +func ExampleShuffle() { + numbers := []int{1, 2, 3, 4, 5} + + randsmust.Shuffle(len(numbers), func(i, j int) { + numbers[i], numbers[j] = numbers[j], numbers[i] + }) + + fmt.Println(numbers) // => [4 2 5 3 1] +} + +func ExampleShuffleSlice() { + mixed := []any{1, "two", 3.14, true, nil} + + randsmust.ShuffleSlice(mixed) + + fmt.Println(mixed) // => [two true 3.14 1] +} diff --git a/randsmust/shuffle_test.go b/randsmust/shuffle_test.go new file mode 100644 index 0000000..287cdad --- /dev/null +++ b/randsmust/shuffle_test.go @@ -0,0 +1,456 @@ +package randsmust + +import ( + "fmt" + "testing" + + "github.com/jimeh/rands" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func factorial(n int) int { + factorial := 1 + for i := 1; i <= n; i++ { + factorial *= i + } + + return factorial +} + +func TestShuffle(t *testing.T) { + t.Parallel() + + t.Run("n < 0", func(t *testing.T) { + t.Parallel() + + p := recoverPanic(func() { + Shuffle(-1, func(_, _ int) {}) + }) + + require.NotNil(t, p, "Expected a panic") + assert.ErrorIs(t, p.(error), rands.ErrInvalidShuffleNegativeN) + assert.ErrorIs(t, p.(error), rands.ErrShuffle) + assert.ErrorIs(t, p.(error), rands.Err) + }) + + t.Run("n == 0", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + p := recoverPanic(func() { + Shuffle(0, func(_, _ int) { + swapCount++ + }) + }) + + require.Nil(t, p, "Did not expect a panic") + assert.Equal(t, 0, swapCount) + }) + + t.Run("n == 1", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + p := recoverPanic(func() { + Shuffle(1, func(_, _ int) { + swapCount++ + }) + }) + + require.Nil(t, p, "Did not expect a panic") + assert.Equal(t, 0, swapCount) + }) + + t.Run("n == 2", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + p := recoverPanic(func() { + Shuffle(2, func(_, _ int) { + swapCount++ + }) + }) + + require.Nil(t, p, "Did not expect a panic") + assert.Equal(t, 1, swapCount) + }) + + t.Run("basic", func(t *testing.T) { + t.Parallel() + + arr := make([]int, 100) + for i := range arr { + arr[i] = i + } + + arrCopy := make([]int, len(arr)) + copy(arrCopy, arr) + + p := recoverPanic(func() { + Shuffle(len(arr), func(i, j int) { + arr[i], arr[j] = arr[j], arr[i] + }) + }) + + require.Nil(t, p, "Did not expect a panic") + assert.NotEqual(t, arrCopy, arr, "Shuffle did not change the array") + assert.ElementsMatch(t, arrCopy, arr, "Shuffle changed elements") + }) + + t.Run("swaps", func(t *testing.T) { + t.Parallel() + + swapSame := 0 + swapDifferent := 0 + arr := make([]int, 100) + for j := range arr { + arr[j] = j + } + + p := recoverPanic(func() { + Shuffle(len(arr), func(i, j int) { + if i == j { + swapSame++ + } else { + swapDifferent++ + } + + arr[i], arr[j] = arr[j], arr[i] + }) + }) + + require.Nil(t, p, "Did not expect a panic") + + // Fisher-Yates with n elements should make exactly n-1 swaps + assert.Equal(t, len(arr)-1, swapSame+swapDifferent, + "Unexpected swaps count", + ) + + // Ensure we have more different-element swaps than self-swaps. The + // lower the input shuffle n value, the more likely this assertion will + // fail. For a n=100 shuffle, this is exceptionally unlikely to fail. + assert.Greater(t, swapDifferent, swapSame, + "Expected more different-element swaps than self-swaps", + ) + }) + + t.Run("swap ranges", func(t *testing.T) { + t.Parallel() + + n := 32 + runs := 1000 + + for run := 0; run < runs; run++ { + called := 0 + p := recoverPanic(func() { + Shuffle(n, func(i, j int) { + called++ + + // Verify indices are in bounds. + assert.True(t, + i >= 0 && i < n, "Out of bounds index i = %d", i, + ) + assert.True(t, + j >= 0 && j < n, "Out of bounds index j = %d", j, + ) + + // For Fisher-Yates, i should be > 0 and j should be in + // range [0,i]. + assert.Greater(t, i, 0, "Expected i > 0, got i=%d", i) + assert.True(t, + j >= 0 && j <= i, + "Expected j in range [0,%d], got j=%d", i, j, + ) + }) + }) + require.Nil(t, p, "Did not expect a panic") + + // Fisher-Yates with n elements should make exactly n-1 swaps + expected := n - 1 + assert.Equal(t, expected, called, + "Expected %d swap calls, got %d", expected, called, + ) + } + }) + + t.Run("all permutations", func(t *testing.T) { + t.Parallel() + + // Use a small array of 5 elements to make it feasible to track all + // permutations. + n := 5 + fact := factorial(n) // 120 + runs := fact * 3000 // 360000 + + permCounts := make(map[string]int) + for i := 0; i < runs; i++ { + arr := make([]int, n) + for i := range arr { + arr[i] = i + } + + p := recoverPanic(func() { + Shuffle(len(arr), func(i, j int) { + arr[i], arr[j] = arr[j], arr[i] + }) + }) + require.Nil(t, p, "Did not expect a panic") + + // Convert the permutation to a string key and count it. + key := fmt.Sprintf("%v", arr) + permCounts[key]++ + } + + assert.Equal(t, fact, len(permCounts), + "Expected %d different permutations", fact, + ) + + wantCount := float64(runs) / float64(fact) + margin := 0.15 + minAcceptable := int(wantCount * (1 - margin)) + maxAcceptable := int(wantCount * (1 + margin)) + + for perm, count := range permCounts { + assert.True(t, + count >= minAcceptable && count <= maxAcceptable, + "Non-uniform distribution for %s: count=%d, expected=%v±%v", + perm, count, wantCount, wantCount*margin, + ) + } + }) + + t.Run("distribution", func(t *testing.T) { + t.Parallel() + // Track which positions received which random indices + n := 100 + posCounts := make([]map[int]int, n) + for i := range posCounts { + posCounts[i] = make(map[int]int) + } + + runs := 3000 + for run := 0; run < runs; run++ { + p := recoverPanic(func() { + Shuffle(n, func(i, j int) { + posCounts[i][j]++ + }) + }) + require.Nil(t, p, "Did not expect a panic") + } + + // For each position, check that it received a reasonable distribution. + for i := n - 1; i >= n-len(posCounts); i-- { + // Calculate how many unique positions we should expect. + // Position i should receive random positions from 0 to i, and + // allow for some statistical variation. + want := int(float64(i+1) * 0.9) + assert.GreaterOrEqual(t, + len(posCounts[i]), want, + "Position %d: expected ~%d unique indices, got %d", + i, want, len(posCounts[i]), + ) + } + }) +} + +func TestShuffleSlice(t *testing.T) { + t.Parallel() + + t.Run("empty slice", func(t *testing.T) { + t.Parallel() + + slice := []int{} + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + assert.Empty(t, slice) + }) + + t.Run("single element", func(t *testing.T) { + t.Parallel() + + slice := []int{42} + origSlice := make([]int, len(slice)) + copy(origSlice, slice) + + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + + assert.Equal(t, + origSlice, slice, "Single element slice should remain unchanged", + ) + }) + + t.Run("two elements", func(t *testing.T) { + t.Parallel() + + slice := []int{1, 2} + origSlice := make([]int, len(slice)) + copy(origSlice, slice) + + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + + // With two elements, the slice might remain the same or be swapped + assert.Len(t, slice, len(origSlice)) + assert.ElementsMatch(t, origSlice, slice) + }) + + t.Run("basic", func(t *testing.T) { + t.Parallel() + + slice := make([]int, 100) + for i := range slice { + slice[i] = i + } + + sliceCopy := make([]int, len(slice)) + copy(sliceCopy, slice) + + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + + assert.NotEqual(t, + sliceCopy, slice, "ShuffleSlice did not change the slice", + ) + assert.ElementsMatch(t, + sliceCopy, slice, "ShuffleSlice changed elements", + ) + }) + + t.Run("string slice", func(t *testing.T) { + t.Parallel() + + strSlice := []string{"a", "b", "c", "d", "e"} + strCopy := make([]string, len(strSlice)) + copy(strCopy, strSlice) + + p := recoverPanic(func() { + ShuffleSlice(strSlice) + }) + require.Nil(t, p, "Did not expect a panic") + + assert.ElementsMatch(t, strCopy, strSlice) + }) + + t.Run("struct slice", func(t *testing.T) { + t.Parallel() + + type testStruct struct { + id int + name string + } + structSlice := []testStruct{ + {1, "one"}, + {2, "two"}, + {3, "three"}, + {4, "four"}, + } + structCopy := make([]testStruct, len(structSlice)) + copy(structCopy, structSlice) + + p := recoverPanic(func() { + ShuffleSlice(structSlice) + }) + require.Nil(t, p, "Did not expect a panic") + assert.ElementsMatch(t, structCopy, structSlice) + }) + + t.Run("all permutations", func(t *testing.T) { + t.Parallel() + + // Use a small slice of 5 elements to make it feasible to track all + // permutations. + n := 5 + fact := factorial(n) // 120 + runs := fact * 3000 // 360000 + + permCounts := make(map[string]int) + for i := 0; i < runs; i++ { + slice := make([]int, n) + for j := range slice { + slice[j] = j + } + + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + + // Convert the permutation to a string key and count it. + key := fmt.Sprintf("%v", slice) + permCounts[key]++ + } + + assert.Equal(t, fact, len(permCounts), + "Expected %d different permutations", fact, + ) + + wantCount := float64(runs) / float64(fact) + margin := 0.15 + minAcceptable := int(wantCount * (1 - margin)) + maxAcceptable := int(wantCount * (1 + margin)) + + for perm, count := range permCounts { + assert.True(t, + count >= minAcceptable && count <= maxAcceptable, + "Non-uniform distribution for %s: count=%d, expected=%v±%v", + perm, count, wantCount, wantCount*margin, + ) + } + }) + + t.Run("distribution", func(t *testing.T) { + t.Parallel() + + // Track where each original index ends up after shuffling + n := 100 + // posCounts[originalPos][newPos] tracks how many times + // the element originally at position i ended up at position j + posCounts := make([]map[int]int, n) + for i := range posCounts { + posCounts[i] = make(map[int]int) + } + + runs := 3000 + for run := 0; run < runs; run++ { + // Create a slice where the value is its original position + slice := make([]int, n) + for i := range slice { + slice[i] = i + } + + p := recoverPanic(func() { + ShuffleSlice(slice) + }) + require.Nil(t, p, "Did not expect a panic") + + // Track where each original position ended up + for newPos, origPos := range slice { + posCounts[origPos][newPos]++ + } + } + + // For each original position, check that it was distributed + // reasonably across all possible new positions + for i := n - 1; i >= n-len(posCounts); i-- { + // Calculate how many unique positions we should expect. + // Position i should receive random positions from 0 to i, and + // allow for some statistical variation. + want := int(float64(i+1) * 0.9) + assert.GreaterOrEqual(t, + len(posCounts[i]), want, + "Original position %d: expected ~%d unique positions, got %d", + i, want, len(posCounts[i]), + ) + } + }) +} diff --git a/shuffle.go b/shuffle.go new file mode 100644 index 0000000..c565043 --- /dev/null +++ b/shuffle.go @@ -0,0 +1,51 @@ +package rands + +import ( + "fmt" +) + +var ( + ErrShuffle = fmt.Errorf("%w: shuffle", Err) + ErrInvalidShuffleNegativeN = fmt.Errorf( + "%w: n must not be negative", ErrShuffle, + ) +) + +// Shuffle randomizes the order of a collection of n elements using +// cryptographically secure random values from crypto/rand. It implements the +// Fisher-Yates shuffle algorithm. +// +// The swap function is called to exchange values at indices i and j. This +// signature is compatible with Shuffle from math/rand and math/rand/v2 for easy +// migration. +func Shuffle(n int, swap func(i, j int)) error { + if n < 0 { + return ErrInvalidShuffleNegativeN + } + + for i := n - 1; i > 0; i-- { + j, err := Int(i + 1) + if err != nil { + return err + } + + swap(i, j) + } + + return nil +} + +// ShuffleSlice randomizes the order of elements in a slice in-place using +// cryptographically secure random values from crypto/rand. +// +// It implements the Fisher-Yates shuffle algorithm. +func ShuffleSlice[T any](slice []T) error { + // If the slice has one or no elements, there's nothing to shuffle. + if len(slice) < 2 { + return nil + } + + return Shuffle(len(slice), func(i, j int) { + slice[i], slice[j] = slice[j], slice[i] + }) +} diff --git a/shuffle_example_test.go b/shuffle_example_test.go new file mode 100644 index 0000000..e43175c --- /dev/null +++ b/shuffle_example_test.go @@ -0,0 +1,32 @@ +package rands_test + +import ( + "fmt" + "log" + + "github.com/jimeh/rands" +) + +func ExampleShuffle() { + numbers := []int{1, 2, 3, 4, 5} + + err := rands.Shuffle(len(numbers), func(i, j int) { + numbers[i], numbers[j] = numbers[j], numbers[i] + }) + if err != nil { + log.Fatal(err) + } + + fmt.Println(numbers) // => [2 4 5 1 3] +} + +func ExampleShuffleSlice() { + mixed := []any{1, "two", 3.14, true, nil} + + err := rands.ShuffleSlice(mixed) + if err != nil { + log.Fatal(err) + } + + fmt.Println(mixed) // => [3.14 true 1 two ] +} diff --git a/shuffle_test.go b/shuffle_test.go new file mode 100644 index 0000000..cb45932 --- /dev/null +++ b/shuffle_test.go @@ -0,0 +1,444 @@ +package rands + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func factorial(n int) int { + factorial := 1 + for i := 1; i <= n; i++ { + factorial *= i + } + + return factorial +} + +func TestShuffle(t *testing.T) { + t.Parallel() + + t.Run("n < 0", func(t *testing.T) { + t.Parallel() + + err := Shuffle(-1, func(_, _ int) {}) + + require.ErrorIs(t, err, ErrInvalidShuffleNegativeN) + require.ErrorIs(t, err, ErrShuffle) + require.ErrorIs(t, err, Err) + }) + + t.Run("n == 0", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + err := Shuffle(0, func(_, _ int) { + swapCount++ + }) + require.NoError(t, err) + + assert.Equal(t, 0, swapCount) + }) + + t.Run("n == 1", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + err := Shuffle(1, func(_, _ int) { + swapCount++ + }) + require.NoError(t, err) + + assert.Equal(t, 0, swapCount) + }) + + t.Run("n == 2", func(t *testing.T) { + t.Parallel() + + swapCount := 0 + err := Shuffle(2, func(_, _ int) { + swapCount++ + }) + require.NoError(t, err) + + assert.Equal(t, 1, swapCount) + }) + + t.Run("basic", func(t *testing.T) { + t.Parallel() + + arr := make([]int, 100) + for i := range arr { + arr[i] = i + } + + arrCopy := make([]int, len(arr)) + copy(arrCopy, arr) + + err := Shuffle(len(arr), func(i, j int) { + arr[i], arr[j] = arr[j], arr[i] + }) + require.NoError(t, err, "Shuffle returned an error") + + assert.NotEqual(t, arrCopy, arr, "Shuffle did not change the array") + assert.ElementsMatch(t, arrCopy, arr, "Shuffle changed elements") + }) + + t.Run("swaps", func(t *testing.T) { + t.Parallel() + + swapSame := 0 + swapDifferent := 0 + arr := make([]int, 100) + for j := range arr { + arr[j] = j + } + + err := Shuffle(len(arr), func(i, j int) { + if i == j { + swapSame++ + } else { + swapDifferent++ + } + + arr[i], arr[j] = arr[j], arr[i] + }) + require.NoError(t, err, "Shuffle returned an error") + + // Fisher-Yates with n elements should make exactly n-1 swaps + assert.Equal(t, len(arr)-1, swapSame+swapDifferent, + "Unexpected swaps count", + ) + + // Ensure we have more different-element swaps than self-swaps. The + // lower the input shuffle n value, the more likely this assertion will + // fail. For a n=100 shuffle, this is exceptionally unlikely to fail. + assert.Greater(t, swapDifferent, swapSame, + "Expected more different-element swaps than self-swaps", + ) + }) + + t.Run("swap ranges", func(t *testing.T) { + t.Parallel() + + n := 32 + runs := 1000 + + for run := 0; run < runs; run++ { + called := 0 + err := Shuffle(n, func(i, j int) { + called++ + + // Verify indices are in bounds. + assert.True(t, i >= 0 && i < n, "Out of bounds index i = %d", i) + assert.True(t, j >= 0 && j < n, "Out of bounds index j = %d", j) + + // For Fisher-Yates, i should be > 0 and j should be in range + // [0,i]. + assert.Greater(t, i, 0, "Expected i > 0, got i=%d", i) + assert.True(t, + j >= 0 && j <= i, + "Expected j in range [0,%d], got j=%d", i, j, + ) + }) + require.NoError(t, err, "Shuffle returned an error") + + // Fisher-Yates with n elements should make exactly n-1 swaps + expected := n - 1 + assert.Equal(t, expected, called, + "Expected %d swap calls, got %d", expected, called, + ) + } + }) + + t.Run("all permutations", func(t *testing.T) { + t.Parallel() + + // Use a small array of 5 elements to make it feasible to track all + // permutations. + n := 5 + fact := factorial(n) // 120 + runs := fact * 3000 // 360000 + + permCounts := make(map[string]int) + for i := 0; i < runs; i++ { + arr := make([]int, n) + for i := range arr { + arr[i] = i + } + + err := Shuffle(len(arr), func(i, j int) { + arr[i], arr[j] = arr[j], arr[i] + }) + require.NoError(t, err, "Shuffle returned an error") + + // Convert the permutation to a string key and count it. + key := fmt.Sprintf("%v", arr) + permCounts[key]++ + } + + assert.Equal(t, fact, len(permCounts), + "Expected %d different permutations", fact, + ) + + wantCount := float64(runs) / float64(fact) + margin := 0.15 + minAcceptable := int(wantCount * (1 - margin)) + maxAcceptable := int(wantCount * (1 + margin)) + + for perm, count := range permCounts { + assert.True(t, + count >= minAcceptable && count <= maxAcceptable, + "Non-uniform distribution for %s: count=%d, expected=%v±%v", + perm, count, wantCount, wantCount*margin, + ) + } + }) + + t.Run("distribution", func(t *testing.T) { + t.Parallel() + // Track which positions received which random indices + n := 100 + posCounts := make([]map[int]int, n) + for i := range posCounts { + posCounts[i] = make(map[int]int) + } + + runs := 3000 + for run := 0; run < runs; run++ { + err := Shuffle(n, func(i, j int) { + posCounts[i][j]++ + }) + require.NoError(t, err, "Shuffle returned an error") + } + + // For each position, check that it received a reasonable distribution. + for i := n - 1; i >= n-len(posCounts); i-- { + // Calculate how many unique positions we should expect. + // Position i should receive random positions from 0 to i, and + // allow for some statistical variation. + want := int(float64(i+1) * 0.9) + assert.GreaterOrEqual(t, + len(posCounts[i]), want, + "Position %d: expected ~%d unique indices, got %d", + i, want, len(posCounts[i]), + ) + } + }) +} + +func BenchmarkShuffle(b *testing.B) { + ranges := []int{32, 64, 128, 1024, 4096} + for _, n := range ranges { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Shuffle(n, func(_, _ int) {}) + } + }) + } +} + +func TestShuffleSlice(t *testing.T) { + t.Parallel() + + t.Run("empty slice", func(t *testing.T) { + t.Parallel() + + slice := []int{} + err := ShuffleSlice(slice) + require.NoError(t, err) + assert.Empty(t, slice) + }) + + t.Run("single element", func(t *testing.T) { + t.Parallel() + + slice := []int{42} + origSlice := make([]int, len(slice)) + copy(origSlice, slice) + + err := ShuffleSlice(slice) + require.NoError(t, err) + + assert.Equal(t, + origSlice, slice, "Single element slice should remain unchanged", + ) + }) + + t.Run("two elements", func(t *testing.T) { + t.Parallel() + + slice := []int{1, 2} + origSlice := make([]int, len(slice)) + copy(origSlice, slice) + + err := ShuffleSlice(slice) + require.NoError(t, err) + + // With two elements, the slice might remain the same or be swapped + assert.Len(t, slice, len(origSlice)) + assert.ElementsMatch(t, origSlice, slice) + }) + + t.Run("basic", func(t *testing.T) { + t.Parallel() + + slice := make([]int, 100) + for i := range slice { + slice[i] = i + } + + sliceCopy := make([]int, len(slice)) + copy(sliceCopy, slice) + + err := ShuffleSlice(slice) + require.NoError(t, err, "ShuffleSlice returned an error") + + assert.NotEqual(t, + sliceCopy, slice, "ShuffleSlice did not change the slice", + ) + assert.ElementsMatch(t, + sliceCopy, slice, "ShuffleSlice changed elements", + ) + }) + + t.Run("string slice", func(t *testing.T) { + t.Parallel() + + strSlice := []string{"a", "b", "c", "d", "e"} + strCopy := make([]string, len(strSlice)) + copy(strCopy, strSlice) + + err := ShuffleSlice(strSlice) + require.NoError(t, err) + + assert.ElementsMatch(t, strCopy, strSlice) + }) + + t.Run("struct slice", func(t *testing.T) { + t.Parallel() + + type testStruct struct { + id int + name string + } + structSlice := []testStruct{ + {1, "one"}, + {2, "two"}, + {3, "three"}, + {4, "four"}, + } + structCopy := make([]testStruct, len(structSlice)) + copy(structCopy, structSlice) + + err := ShuffleSlice(structSlice) + require.NoError(t, err) + assert.ElementsMatch(t, structCopy, structSlice) + }) + + t.Run("all permutations", func(t *testing.T) { + t.Parallel() + + // Use a small slice of 5 elements to make it feasible to track all + // permutations. + n := 5 + fact := factorial(n) // 120 + runs := fact * 3000 // 360000 + + permCounts := make(map[string]int) + for i := 0; i < runs; i++ { + slice := make([]int, n) + for j := range slice { + slice[j] = j + } + + err := ShuffleSlice(slice) + require.NoError(t, err, "ShuffleSlice returned an error") + + // Convert the permutation to a string key and count it. + key := fmt.Sprintf("%v", slice) + permCounts[key]++ + } + + assert.Equal(t, fact, len(permCounts), + "Expected %d different permutations", fact, + ) + + wantCount := float64(runs) / float64(fact) + margin := 0.15 + minAcceptable := int(wantCount * (1 - margin)) + maxAcceptable := int(wantCount * (1 + margin)) + + for perm, count := range permCounts { + assert.True(t, + count >= minAcceptable && count <= maxAcceptable, + "Non-uniform distribution for %s: count=%d, expected=%v±%v", + perm, count, wantCount, wantCount*margin, + ) + } + }) + + t.Run("distribution", func(t *testing.T) { + t.Parallel() + + // Track where each original index ends up after shuffling + n := 100 + // posCounts[originalPos][newPos] tracks how many times + // the element originally at position i ended up at position j + posCounts := make([]map[int]int, n) + for i := range posCounts { + posCounts[i] = make(map[int]int) + } + + runs := 3000 + for run := 0; run < runs; run++ { + // Create a slice where the value is its original position + slice := make([]int, n) + for i := range slice { + slice[i] = i + } + + err := ShuffleSlice(slice) + require.NoError(t, err, "ShuffleSlice returned an error") + + // Track where each original position ended up + for newPos, origPos := range slice { + posCounts[origPos][newPos]++ + } + } + + // For each original position, check that it was distributed + // reasonably across all possible new positions + for i := n - 1; i >= n-len(posCounts); i-- { + // Calculate how many unique positions we should expect. + // Position i should receive random positions from 0 to i, and + // allow for some statistical variation. + want := int(float64(i+1) * 0.9) + assert.GreaterOrEqual(t, + len(posCounts[i]), want, + "Original position %d: expected ~%d unique positions, got %d", + i, want, len(posCounts[i]), + ) + } + }) +} + +func BenchmarkShuffleSlice(b *testing.B) { + ranges := []int{32, 64, 128, 1024, 4096} + for _, n := range ranges { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + b.StopTimer() + slice := make([]int, n) + for i := range slice { + slice[i] = i + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + _ = ShuffleSlice(slice) + } + }) + } +}