summaryrefslogtreecommitdiff
path: root/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go')
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go52
1 files changed, 26 insertions, 26 deletions
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
index 910452a..acec9b8 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
@@ -6,7 +6,6 @@ package rebuildmappings
import (
"bytes"
- "io"
"testing"
"github.com/stretchr/testify/assert"
@@ -15,11 +14,28 @@ import (
"git.lukeshu.com/btrfs-progs-ng/lib/diskio"
)
+type bytePattern[K ~int64 | ~int] []byte
+
+var _ kmpPattern[int, byte] = bytePattern[int]{}
+
+// PatLen implements kmpPattern.
+func (s bytePattern[K]) PatLen() K {
+ return K(len(s))
+}
+
+// PatGet implements kmpPattern.
+func (s bytePattern[K]) PatGet(i K) (byte, bool) {
+ chr := s[int(i)]
+ if chr == '.' {
+ return 0, false
+ }
+ return chr, true
+}
+
func TestBuildKMPTable(t *testing.T) {
t.Parallel()
- substr := diskio.SliceSequence[int64, byte]([]byte("ababaa"))
- table, err := buildKMPTable[int64, byte](substr)
- require.NoError(t, err)
+ substr := bytePattern[int64]([]byte("ababaa"))
+ table := buildKMPTable[int64, byte](substr)
require.Equal(t,
[]int64{0, 0, 1, 2, 3, 1},
table)
@@ -33,8 +49,7 @@ func TestBuildKMPTable(t *testing.T) {
func FuzzBuildKMPTable(f *testing.F) {
f.Add([]byte("ababaa"))
f.Fuzz(func(t *testing.T, substr []byte) {
- table, err := buildKMPTable[int64, byte](diskio.SliceSequence[int64, byte](substr))
- require.NoError(t, err)
+ table := buildKMPTable[int64, byte](bytePattern[int64](substr))
require.Equal(t, len(substr), len(table), "length")
for j, val := range table {
matchLen := j + 1
@@ -62,27 +77,13 @@ func FuzzIndexAll(f *testing.F) {
t.Logf("str =%q", str)
t.Logf("substr=%q", substr)
exp := NaiveIndexAll(str, substr)
- act, err := IndexAll[int64, byte](
- &diskio.ByteReaderSequence[int64]{R: bytes.NewReader(str)},
- diskio.SliceSequence[int64, byte](substr))
- assert.NoError(t, err)
+ act := indexAll[int64, byte](
+ diskio.SliceSequence[int64, byte](str),
+ bytePattern[int64](substr))
assert.Equal(t, exp, act)
})
}
-type RESeq string
-
-func (re RESeq) Get(i int64) (byte, error) {
- if i < 0 || i >= int64(len(re)) {
- return 0, io.EOF
- }
- chr := re[int(i)]
- if chr == '.' {
- return 0, ErrWildcard
- }
- return chr, nil
-}
-
func TestKMPWildcard(t *testing.T) {
t.Parallel()
type testcase struct {
@@ -116,10 +117,9 @@ func TestKMPWildcard(t *testing.T) {
tc := tc
t.Run(tcName, func(t *testing.T) {
t.Parallel()
- matches, err := IndexAll[int64, byte](
+ matches := indexAll[int64, byte](
diskio.StringSequence[int64](tc.InStr),
- RESeq(tc.InSubstr))
- assert.NoError(t, err)
+ bytePattern[int64](tc.InSubstr))
assert.Equal(t, tc.ExpMatches, matches)
})
}