summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-03 19:50:35 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-04 20:53:23 -0700
commitef60daef395b20b67042c011f5b2a1131049e275 (patch)
treec70aa1661272e10883bbc57373cf00ab980ef336
parent77f3c0d7cd21274d00984b72dfce05394d11bdd0 (diff)
rebuildmappings: Optimize the KMP search
-rw-r--r--lib/btrfs/btrfssum/sumrun.go20
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go2
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go106
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go52
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go5
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go4
-rw-r--r--lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go21
-rw-r--r--lib/diskio/seq.go63
8 files changed, 103 insertions, 170 deletions
diff --git a/lib/btrfs/btrfssum/sumrun.go b/lib/btrfs/btrfssum/sumrun.go
index 1000e7a..bc2db3f 100644
--- a/lib/btrfs/btrfssum/sumrun.go
+++ b/lib/btrfs/btrfssum/sumrun.go
@@ -6,9 +6,9 @@ package btrfssum
import (
"context"
- "io"
"git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfsvol"
+ "git.lukeshu.com/btrfs-progs-ng/lib/diskio"
)
type SumRun[Addr btrfsvol.IntAddr[Addr]] struct {
@@ -20,21 +20,21 @@ type SumRun[Addr btrfsvol.IntAddr[Addr]] struct {
Sums ShortSum
}
-func (run SumRun[Addr]) NumSums() int {
+var _ diskio.Sequence[int, ShortSum] = SumRun[btrfsvol.LogicalAddr]{}
+
+// SeqLen implements diskio.Sequence[int, ShortSum].
+func (run SumRun[Addr]) SeqLen() int {
return len(run.Sums) / run.ChecksumSize
}
func (run SumRun[Addr]) Size() btrfsvol.AddrDelta {
- return btrfsvol.AddrDelta(run.NumSums()) * BlockSize
+ return btrfsvol.AddrDelta(run.SeqLen()) * BlockSize
}
-// Get implements diskio.Sequence[int, ShortSum]
-func (run SumRun[Addr]) Get(sumIdx int64) (ShortSum, error) {
- if sumIdx < 0 || int(sumIdx) >= run.NumSums() {
- return "", io.EOF
- }
- off := int(sumIdx) * run.ChecksumSize
- return run.Sums[off : off+run.ChecksumSize], nil
+// SeqGet implements diskio.Sequence[int, ShortSum].
+func (run SumRun[Addr]) SeqGet(sumIdx int) ShortSum {
+ off := sumIdx * run.ChecksumSize
+ return run.Sums[off : off+run.ChecksumSize]
}
func (run SumRun[Addr]) SumForAddr(addr Addr) (ShortSum, bool) {
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
index b51526b..6b75d84 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go
@@ -69,7 +69,7 @@ func fuzzyMatchBlockGroupSums(ctx context.Context,
blockgroup := blockgroups[bgLAddr]
bgRun := SumsForLogicalRegion(logicalSums, blockgroup.LAddr, blockgroup.Size)
- d := bgRun.NumSums()
+ d := bgRun.PatLen()
matches := make(map[btrfsvol.QualifiedPhysicalAddr]int)
if err := bgRun.Walk(ctx, func(laddr btrfsvol.LogicalAddr, sum btrfssum.ShortSum) error { // O(n*…
off := laddr.Sub(bgRun.Addr)
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
index eeaab0c..20772ba 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go
@@ -6,48 +6,24 @@ package rebuildmappings
import (
"errors"
- "io"
"git.lukeshu.com/btrfs-progs-ng/lib/diskio"
)
-var ErrWildcard = errors.New("wildcard")
-
-func kmpEq2[K ~int64, V comparable](aS diskio.Sequence[K, V], aI K, bS diskio.Sequence[K, V], bI K) bool {
- aV, aErr := aS.Get(aI)
- bV, bErr := bS.Get(bI)
- if aErr != nil {
- //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
- if aErr == ErrWildcard || errors.Is(aErr, ErrWildcard) {
- aV = bV
- aErr = nil
- } else {
- panic(aErr)
- }
- }
- if bErr != nil {
- //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
- if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) {
- bV = aV
- bErr = nil
- } else {
- panic(bErr)
- }
- }
- if aErr != nil || bErr != nil {
- return false
- }
- return aV == bV
+type kmpPattern[K ~int64 | ~int, V comparable] interface {
+ PatLen() K
+ // Get the value at 'pos' in the sequence. Positions start at
+ // 0 and increment naturally. It is invalid to call Get(pos)
+ // with a pos that is >= Len(). If there is a gap/wildcard at
+ // pos, ok is false.
+ PatGet(pos K) (v V, ok bool)
}
-func kmpEq1[K ~int64, V comparable](aV V, bS diskio.Sequence[K, V], bI K) bool {
- bV, bErr := bS.Get(bI)
- if bErr != nil {
- //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
- if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) {
- return true
- }
- panic(bErr)
+func kmpSelfEq[K ~int64 | ~int, V comparable](s kmpPattern[K, V], aI K, bI K) bool {
+ aV, aOK := s.PatGet(aI)
+ bV, bOK := s.PatGet(bI)
+ if !aOK || !bOK {
+ return true
}
return aV == bV
}
@@ -55,18 +31,8 @@ func kmpEq1[K ~int64, V comparable](aV V, bS diskio.Sequence[K, V], bI K) bool {
// buildKMPTable takes the string 'substr', and returns a table such
// that 'table[matchLen-1]' is the largest value 'val' for which 'val < matchLen' and
// 'substr[:val] == substr[matchLen-val:matchLen]'.
-func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, error) {
- var substrLen K
- for {
- //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is.
- if _, err := substr.Get(substrLen); err != nil && !(err == ErrWildcard || errors.Is(err, ErrWildcard)) {
- if errors.Is(err, io.EOF) {
- break
- }
- return nil, err
- }
- substrLen++
- }
+func buildKMPTable[K ~int64 | ~int, V comparable](substr kmpPattern[K, V]) []K {
+ substrLen := substr.PatLen()
table := make([]K, substrLen)
for j := K(0); j < substrLen; j++ {
@@ -77,26 +43,31 @@ func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, e
}
val := table[j-1]
// not a match; go back
- for val > 0 && !kmpEq2(substr, j, substr, val) {
+ for val > 0 && !kmpSelfEq(substr, j, val) {
val = table[val-1]
}
// is a match; go forward
- if kmpEq2(substr, val, substr, j) {
+ if kmpSelfEq(substr, val, j) {
val++
}
table[j] = val
}
- return table, nil
+ return table
}
-// IndexAll returns the starting-position of all possibly-overlapping
+func kmpEq[K ~int64 | ~int, V comparable](aV V, bS kmpPattern[K, V], bI K) bool {
+ bV, ok := bS.PatGet(bI)
+ if !ok {
+ return true
+ }
+ return aV == bV
+}
+
+// indexAll returns the starting-position of all possibly-overlapping
// occurrences of 'substr' in the 'str' sequence.
//
// Will hop around in 'substr', but will only get the natural sequence
-// [0...) in order from 'str'. When hopping around in 'substr' it
-// assumes that once it has gotten a given index without error, it can
-// continue to do so without error; errors appearing later will cause
-// panics.
+// [0...) in order from 'str'.
//
// Will panic if the length of 'substr' is 0.
//
@@ -104,11 +75,8 @@ func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, e
// ErrWildcard for a position.
//
// Uses the Knuth-Morris-Pratt algorithm.
-func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, error) {
- table, err := buildKMPTable(substr)
- if err != nil {
- return nil, err
- }
+func indexAll[K ~int64 | ~int, V comparable](str diskio.Sequence[K, V], substr kmpPattern[K, V]) []K {
+ table := buildKMPTable(substr)
substrLen := K(len(table))
if substrLen == 0 {
panic(errors.New("rebuildmappings.IndexAll: empty substring"))
@@ -118,22 +86,17 @@ func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, e
var curMatchBeg K
var curMatchLen K
- for pos := K(0); ; pos++ {
- chr, err := str.Get(pos)
- if err != nil {
- if errors.Is(err, io.EOF) {
- err = nil
- }
- return matches, err
- }
+ strLen := str.SeqLen()
+ for pos := K(0); pos < strLen; pos++ {
+ chr := str.SeqGet(pos)
// Consider 'chr'
- for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match
+ for curMatchLen > 0 && !kmpEq(chr, substr, curMatchLen) { // shorten the match
overlap := table[curMatchLen-1]
curMatchBeg += curMatchLen - overlap
curMatchLen = overlap
}
- if kmpEq1(chr, substr, curMatchLen) { // lengthen the match
+ if kmpEq(chr, substr, curMatchLen) { // lengthen the match
if curMatchLen == 0 {
curMatchBeg = pos
}
@@ -146,4 +109,5 @@ func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, e
}
}
}
+ return matches
}
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
index 910452a..acec9b8 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go
@@ -6,7 +6,6 @@ package rebuildmappings
import (
"bytes"
- "io"
"testing"
"github.com/stretchr/testify/assert"
@@ -15,11 +14,28 @@ import (
"git.lukeshu.com/btrfs-progs-ng/lib/diskio"
)
+type bytePattern[K ~int64 | ~int] []byte
+
+var _ kmpPattern[int, byte] = bytePattern[int]{}
+
+// PatLen implements kmpPattern.
+func (s bytePattern[K]) PatLen() K {
+ return K(len(s))
+}
+
+// PatGet implements kmpPattern.
+func (s bytePattern[K]) PatGet(i K) (byte, bool) {
+ chr := s[int(i)]
+ if chr == '.' {
+ return 0, false
+ }
+ return chr, true
+}
+
func TestBuildKMPTable(t *testing.T) {
t.Parallel()
- substr := diskio.SliceSequence[int64, byte]([]byte("ababaa"))
- table, err := buildKMPTable[int64, byte](substr)
- require.NoError(t, err)
+ substr := bytePattern[int64]([]byte("ababaa"))
+ table := buildKMPTable[int64, byte](substr)
require.Equal(t,
[]int64{0, 0, 1, 2, 3, 1},
table)
@@ -33,8 +49,7 @@ func TestBuildKMPTable(t *testing.T) {
func FuzzBuildKMPTable(f *testing.F) {
f.Add([]byte("ababaa"))
f.Fuzz(func(t *testing.T, substr []byte) {
- table, err := buildKMPTable[int64, byte](diskio.SliceSequence[int64, byte](substr))
- require.NoError(t, err)
+ table := buildKMPTable[int64, byte](bytePattern[int64](substr))
require.Equal(t, len(substr), len(table), "length")
for j, val := range table {
matchLen := j + 1
@@ -62,27 +77,13 @@ func FuzzIndexAll(f *testing.F) {
t.Logf("str =%q", str)
t.Logf("substr=%q", substr)
exp := NaiveIndexAll(str, substr)
- act, err := IndexAll[int64, byte](
- &diskio.ByteReaderSequence[int64]{R: bytes.NewReader(str)},
- diskio.SliceSequence[int64, byte](substr))
- assert.NoError(t, err)
+ act := indexAll[int64, byte](
+ diskio.SliceSequence[int64, byte](str),
+ bytePattern[int64](substr))
assert.Equal(t, exp, act)
})
}
-type RESeq string
-
-func (re RESeq) Get(i int64) (byte, error) {
- if i < 0 || i >= int64(len(re)) {
- return 0, io.EOF
- }
- chr := re[int(i)]
- if chr == '.' {
- return 0, ErrWildcard
- }
- return chr, nil
-}
-
func TestKMPWildcard(t *testing.T) {
t.Parallel()
type testcase struct {
@@ -116,10 +117,9 @@ func TestKMPWildcard(t *testing.T) {
tc := tc
t.Run(tcName, func(t *testing.T) {
t.Parallel()
- matches, err := IndexAll[int64, byte](
+ matches := indexAll[int64, byte](
diskio.StringSequence[int64](tc.InStr),
- RESeq(tc.InSubstr))
- assert.NoError(t, err)
+ bytePattern[int64](tc.InSubstr))
assert.Equal(t, tc.ExpMatches, matches)
})
}
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go
index eda37bd..a3e724e 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go
@@ -36,10 +36,7 @@ func matchBlockGroupSums(ctx context.Context,
var matches []btrfsvol.QualifiedPhysicalAddr
if err := WalkUnmappedPhysicalRegions(ctx, physicalSums, regions, func(devID btrfsvol.DeviceID, region btrfssum.SumRun[btrfsvol.PhysicalAddr]) error {
- rawMatches, err := IndexAll[int64, btrfssum.ShortSum](region, bgRun)
- if err != nil {
- return err
- }
+ rawMatches := indexAll[int, btrfssum.ShortSum](region, bgRun)
for _, match := range rawMatches {
matches = append(matches, btrfsvol.QualifiedPhysicalAddr{
Dev: devID,
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go
index 665bc96..cdf5e5a 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go
@@ -158,10 +158,6 @@ func RebuildMappings(ctx context.Context, fs *btrfs.FS, scanResults btrfsinspect
// The fuzzy-search is only fast because the exact-search is so good at getting `physicalBlocks` down.
// Empirically: if I remove the exact-search step, then the fuzzy-match step is more than an order of magnitude
// slower.
- //
- // The exact-search probably could be optimized to be substantially faster (by a constant factor; not affecting
- // the big-O) by figuring out how to inline function calls and get reduce allocations, but IMO it's "fast
- // enough" for now.
ctx = dlog.WithField(_ctx, "btrfsinspect.rebuild-mappings.step", "5/6")
dlog.Infof(_ctx, "5/6: Searching for %d block groups in checksum map (exact)...", len(bgs))
physicalSums := ExtractPhysicalSums(scanResults)
diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go
index d1064d8..e574540 100644
--- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go
+++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go
@@ -30,15 +30,16 @@ var (
_ lowmemjson.Decodable = (*SumRunWithGaps[btrfsvol.LogicalAddr])(nil)
)
-func (sg SumRunWithGaps[Addr]) NumSums() int {
+// PatLen implements kmpPattern[int, ShortSum].
+func (sg SumRunWithGaps[Addr]) PatLen() int {
return int(sg.Size / btrfssum.BlockSize)
}
func (sg SumRunWithGaps[Addr]) PctFull() float64 {
- total := sg.NumSums()
+ total := sg.PatLen()
var full int
for _, run := range sg.Runs {
- full += run.NumSums()
+ full += run.SeqLen()
}
return float64(full) / float64(total)
}
@@ -56,21 +57,21 @@ func (sg SumRunWithGaps[Addr]) RunForAddr(addr Addr) (btrfssum.SumRun[Addr], Add
return btrfssum.SumRun[Addr]{}, math.MaxInt64, false
}
-func (sg SumRunWithGaps[Addr]) SumForAddr(addr Addr) (btrfssum.ShortSum, error) {
+func (sg SumRunWithGaps[Addr]) SumForAddr(addr Addr) (btrfssum.ShortSum, bool) {
if addr < sg.Addr || addr >= sg.Addr.Add(sg.Size) {
- return "", io.EOF
+ return "", false
}
for _, run := range sg.Runs {
if run.Addr > addr {
- return "", ErrWildcard
+ return "", false
}
if run.Addr.Add(run.Size()) <= addr {
continue
}
off := int((addr-run.Addr)/btrfssum.BlockSize) * run.ChecksumSize
- return run.Sums[off : off+run.ChecksumSize], nil
+ return run.Sums[off : off+run.ChecksumSize], true
}
- return "", ErrWildcard
+ return "", false
}
func (sg SumRunWithGaps[Addr]) Walk(ctx context.Context, fn func(Addr, btrfssum.ShortSum) error) error {
@@ -82,8 +83,8 @@ func (sg SumRunWithGaps[Addr]) Walk(ctx context.Context, fn func(Addr, btrfssum.
return nil
}
-// Get implements diskio.Sequence[int, ShortSum]
-func (sg SumRunWithGaps[Addr]) Get(sumIdx int64) (btrfssum.ShortSum, error) {
+// PatGet implements kmpPattern[int, ShortSum].
+func (sg SumRunWithGaps[Addr]) PatGet(sumIdx int) (btrfssum.ShortSum, bool) {
addr := sg.Addr.Add(btrfsvol.AddrDelta(sumIdx) * btrfssum.BlockSize)
return sg.SumForAddr(addr)
}
diff --git a/lib/diskio/seq.go b/lib/diskio/seq.go
index 3c5f4ae..f8e6ea8 100644
--- a/lib/diskio/seq.go
+++ b/lib/diskio/seq.go
@@ -1,68 +1,43 @@
-// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later
package diskio
-import (
- "fmt"
- "io"
-)
-
// interface /////////////////////////////////////////////////////////
-type Sequence[K ~int64, V any] interface {
+type Sequence[K ~int64 | ~int, V any] interface {
+ SeqLen() K
// Get the value at 'pos' in the sequence. Positions start at
- // 0 and increment naturally. Return an error that is io.EOF
- // if 'pos' is past the end of the sequence'.
- Get(pos K) (V, error)
+ // 0 and increment naturally. It is invalid to call
+ // SeqGet(pos) with a pos that is >= SeqLen().
+ SeqGet(pos K) V
}
// implementation: slice /////////////////////////////////////////////
-type SliceSequence[K ~int64, V any] []V
+type SliceSequence[K ~int64 | ~int, V any] []V
+
+var _ Sequence[assertAddr, byte] = SliceSequence[assertAddr, byte](nil)
-var _ Sequence[assertAddr, byte] = SliceSequence[assertAddr, byte]([]byte(nil))
+func (s SliceSequence[K, V]) SeqLen() K {
+ return K(len(s))
+}
-func (s SliceSequence[K, V]) Get(i K) (V, error) {
- if i >= K(len(s)) {
- var v V
- return v, io.EOF
- }
- return s[int(i)], nil
+func (s SliceSequence[K, V]) SeqGet(i K) V {
+ return s[int(i)]
}
// implementation: string ////////////////////////////////////////////
-type StringSequence[K ~int64] string
+type StringSequence[K ~int64 | ~int] string
var _ Sequence[assertAddr, byte] = StringSequence[assertAddr]("")
-func (s StringSequence[K]) Get(i K) (byte, error) {
- if i >= K(len(s)) {
- return 0, io.EOF
- }
- return s[int(i)], nil
+func (s StringSequence[K]) SeqLen() K {
+ return K(len(s))
}
-// implementation: io.ByteReader /////////////////////////////////////
-
-type ByteReaderSequence[K ~int64] struct {
- R io.ByteReader
- pos K
-}
-
-var _ Sequence[assertAddr, byte] = &ByteReaderSequence[assertAddr]{R: nil}
-
-func (s *ByteReaderSequence[K]) Get(i K) (byte, error) {
- if i != s.pos {
- return 0, fmt.Errorf("%T.Get(%v): can only call .Get(%v)",
- s, i, s.pos)
- }
- chr, err := s.R.ReadByte()
- if err != nil {
- return chr, err
- }
- s.pos++
- return chr, nil
+func (s StringSequence[K]) SeqGet(i K) byte {
+ return s[int(i)]
}