summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-07-17 11:54:49 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-07-17 11:54:49 -0600
commit4952fc1880bf0f4286b17cbfbe0c49a132d09ebc (patch)
tree38052a3f6119d70216d796bcce65e8f1fb984ab2 /lib
parent987a4b238e047238bd83384c87b8317afdd45ad8 (diff)
implement wildcards in the KMP IndexAll
Diffstat (limited to 'lib')
-rw-r--r--lib/diskio/kmp.go53
-rw-r--r--lib/diskio/kmp_test.go54
2 files changed, 97 insertions, 10 deletions
diff --git a/lib/diskio/kmp.go b/lib/diskio/kmp.go
index da19e81..15537de 100644
--- a/lib/diskio/kmp.go
+++ b/lib/diskio/kmp.go
@@ -9,12 +9,42 @@ import (
"io"
)
-func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V {
- val, err := seq.Get(i)
- if err != nil {
- panic(err)
+var ErrWildcard = errors.New("wildcard")
+
+func kmpEq2[K ~int64, V comparable](aS Sequence[K, V], aI K, bS Sequence[K, V], bI K) bool {
+ aV, aErr := aS.Get(aI)
+ bV, bErr := bS.Get(bI)
+ if aErr != nil {
+ if errors.Is(aErr, ErrWildcard) {
+ aV = bV
+ aErr = nil
+ } else {
+ panic(aErr)
+ }
+ }
+ if bErr != nil {
+ if errors.Is(bErr, ErrWildcard) {
+ bV = aV
+ bErr = nil
+ } else {
+ panic(bErr)
+ }
+ }
+ if aErr != nil || bErr != nil {
+ return false
}
- return val
+ return aV == bV
+}
+
+func kmpEq1[K ~int64, V comparable](aV V, bS Sequence[K, V], bI K) bool {
+ bV, bErr := bS.Get(bI)
+ if bErr != nil {
+ if errors.Is(bErr, ErrWildcard) {
+ return true
+ }
+ panic(bErr)
+ }
+ return aV == bV
}
// buildKMPTable takes the string 'substr', and returns a table such
@@ -23,7 +53,7 @@ func mustGet[K ~int64, V any](seq Sequence[K, V], i K) V {
func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {
var substrLen K
for {
- if _, err := substr.Get(substrLen); err != nil {
+ if _, err := substr.Get(substrLen); err != nil && !errors.Is(err, ErrWildcard) {
if errors.Is(err, io.EOF) {
break
}
@@ -41,11 +71,11 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {
}
val := table[j-1]
// not a match; go back
- for val > 0 && mustGet(substr, j) != mustGet(substr, val) {
+ for val > 0 && !kmpEq2(substr, j, substr, val) {
val = table[val-1]
}
// is a match; go forward
- if mustGet(substr, val) == mustGet(substr, j) {
+ if kmpEq2(substr, val, substr, j) {
val++
}
table[j] = val
@@ -64,6 +94,9 @@ func buildKMPTable[K ~int64, V comparable](substr Sequence[K, V]) ([]K, error) {
//
// Will panic if the length of 'substr' is 0.
//
+// The 'substr' may include wildcard characters by returning
+// ErrWildcard for a position.
+//
// Uses the Knuth-Morris-Pratt algorithm.
func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) {
table, err := buildKMPTable(substr)
@@ -89,12 +122,12 @@ func IndexAll[K ~int64, V comparable](str, substr Sequence[K, V]) ([]K, error) {
}
// Consider 'chr'
- for curMatchLen > 0 && chr != mustGet(substr, curMatchLen) { // shorten the match
+ for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match
overlap := table[curMatchLen-1]
curMatchBeg += curMatchLen - overlap
curMatchLen = overlap
}
- if chr == mustGet(substr, curMatchLen) { // lengthen the match
+ if kmpEq1(chr, substr, curMatchLen) { // lengthen the match
if curMatchLen == 0 {
curMatchBeg = pos
}
diff --git a/lib/diskio/kmp_test.go b/lib/diskio/kmp_test.go
index 51c7b5e..59b6224 100644
--- a/lib/diskio/kmp_test.go
+++ b/lib/diskio/kmp_test.go
@@ -6,6 +6,7 @@ package diskio
import (
"bytes"
+ "io"
"testing"
"github.com/stretchr/testify/assert"
@@ -65,3 +66,56 @@ func FuzzIndexAll(f *testing.F) {
assert.Equal(t, exp, act)
})
}
+
+type RESeq string
+
+func (re RESeq) Get(i int64) (byte, error) {
+ if i < 0 || i >= int64(len(re)) {
+ return 0, io.EOF
+ }
+ chr := re[int(i)]
+ if chr == '.' {
+ return 0, ErrWildcard
+ }
+ return chr, nil
+}
+
+func TestKMPWildcard(t *testing.T) {
+ type testcase struct {
+ InStr string
+ InSubstr string
+ ExpMatches []int64
+ }
+ testcases := map[string]testcase{
+ "trivial-bar": {
+ InStr: "foo_bar",
+ InSubstr: "foo.ba.",
+ ExpMatches: []int64{0},
+ },
+ "trival-baz": {
+ InStr: "foo-baz",
+ InSubstr: "foo.ba.",
+ ExpMatches: []int64{0},
+ },
+ "suffix": {
+ InStr: "foobarbaz",
+ InSubstr: "...baz",
+ ExpMatches: []int64{3},
+ },
+ "overlap": {
+ InStr: "foobarbar",
+ InSubstr: "...bar",
+ ExpMatches: []int64{0, 3},
+ },
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ matches, err := IndexAll[int64, byte](
+ StringSequence[int64](tc.InStr),
+ RESeq(tc.InSubstr))
+ assert.NoError(t, err)
+ assert.Equal(t, tc.ExpMatches, matches)
+ })
+ }
+}