diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2023-02-03 19:50:35 -0700 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2023-02-04 20:53:23 -0700 |
commit | ef60daef395b20b67042c011f5b2a1131049e275 (patch) | |
tree | c70aa1661272e10883bbc57373cf00ab980ef336 | |
parent | 77f3c0d7cd21274d00984b72dfce05394d11bdd0 (diff) |
rebuildmappings: Optimize the KMP search
-rw-r--r-- | lib/btrfs/btrfssum/sumrun.go | 20 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go | 2 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go | 106 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go | 52 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go | 5 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go | 4 | ||||
-rw-r--r-- | lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go | 21 | ||||
-rw-r--r-- | lib/diskio/seq.go | 63 |
8 files changed, 103 insertions, 170 deletions
diff --git a/lib/btrfs/btrfssum/sumrun.go b/lib/btrfs/btrfssum/sumrun.go index 1000e7a..bc2db3f 100644 --- a/lib/btrfs/btrfssum/sumrun.go +++ b/lib/btrfs/btrfssum/sumrun.go @@ -6,9 +6,9 @@ package btrfssum import ( "context" - "io" "git.lukeshu.com/btrfs-progs-ng/lib/btrfs/btrfsvol" + "git.lukeshu.com/btrfs-progs-ng/lib/diskio" ) type SumRun[Addr btrfsvol.IntAddr[Addr]] struct { @@ -20,21 +20,21 @@ type SumRun[Addr btrfsvol.IntAddr[Addr]] struct { Sums ShortSum } -func (run SumRun[Addr]) NumSums() int { +var _ diskio.Sequence[int, ShortSum] = SumRun[btrfsvol.LogicalAddr]{} + +// SeqLen implements diskio.Sequence[int, ShortSum]. +func (run SumRun[Addr]) SeqLen() int { return len(run.Sums) / run.ChecksumSize } func (run SumRun[Addr]) Size() btrfsvol.AddrDelta { - return btrfsvol.AddrDelta(run.NumSums()) * BlockSize + return btrfsvol.AddrDelta(run.SeqLen()) * BlockSize } -// Get implements diskio.Sequence[int, ShortSum] -func (run SumRun[Addr]) Get(sumIdx int64) (ShortSum, error) { - if sumIdx < 0 || int(sumIdx) >= run.NumSums() { - return "", io.EOF - } - off := int(sumIdx) * run.ChecksumSize - return run.Sums[off : off+run.ChecksumSize], nil +// SeqGet implements diskio.Sequence[int, ShortSum]. +func (run SumRun[Addr]) SeqGet(sumIdx int) ShortSum { + off := sumIdx * run.ChecksumSize + return run.Sums[off : off+run.ChecksumSize] } func (run SumRun[Addr]) SumForAddr(addr Addr) (ShortSum, bool) { diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go index b51526b..6b75d84 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/fuzzymatchsums.go @@ -69,7 +69,7 @@ func fuzzyMatchBlockGroupSums(ctx context.Context, blockgroup := blockgroups[bgLAddr] bgRun := SumsForLogicalRegion(logicalSums, blockgroup.LAddr, blockgroup.Size) - d := bgRun.NumSums() + d := bgRun.PatLen() matches := make(map[btrfsvol.QualifiedPhysicalAddr]int) if err := bgRun.Walk(ctx, func(laddr btrfsvol.LogicalAddr, sum btrfssum.ShortSum) error { // O(n*… off := laddr.Sub(bgRun.Addr) diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go index eeaab0c..20772ba 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp.go @@ -6,48 +6,24 @@ package rebuildmappings import ( "errors" - "io" "git.lukeshu.com/btrfs-progs-ng/lib/diskio" ) -var ErrWildcard = errors.New("wildcard") - -func kmpEq2[K ~int64, V comparable](aS diskio.Sequence[K, V], aI K, bS diskio.Sequence[K, V], bI K) bool { - aV, aErr := aS.Get(aI) - bV, bErr := bS.Get(bI) - if aErr != nil { - //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is. - if aErr == ErrWildcard || errors.Is(aErr, ErrWildcard) { - aV = bV - aErr = nil - } else { - panic(aErr) - } - } - if bErr != nil { - //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is. - if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) { - bV = aV - bErr = nil - } else { - panic(bErr) - } - } - if aErr != nil || bErr != nil { - return false - } - return aV == bV +type kmpPattern[K ~int64 | ~int, V comparable] interface { + PatLen() K + // Get the value at 'pos' in the sequence. Positions start at + // 0 and increment naturally. It is invalid to call Get(pos) + // with a pos that is >= Len(). If there is a gap/wildcard at + // pos, ok is false. + PatGet(pos K) (v V, ok bool) } -func kmpEq1[K ~int64, V comparable](aV V, bS diskio.Sequence[K, V], bI K) bool { - bV, bErr := bS.Get(bI) - if bErr != nil { - //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is. - if bErr == ErrWildcard || errors.Is(bErr, ErrWildcard) { - return true - } - panic(bErr) +func kmpSelfEq[K ~int64 | ~int, V comparable](s kmpPattern[K, V], aI K, bI K) bool { + aV, aOK := s.PatGet(aI) + bV, bOK := s.PatGet(bI) + if !aOK || !bOK { + return true } return aV == bV } @@ -55,18 +31,8 @@ func kmpEq1[K ~int64, V comparable](aV V, bS diskio.Sequence[K, V], bI K) bool { // buildKMPTable takes the string 'substr', and returns a table such // that 'table[matchLen-1]' is the largest value 'val' for which 'val < matchLen' and // 'substr[:val] == substr[matchLen-val:matchLen]'. -func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, error) { - var substrLen K - for { - //nolint:errorlint // The == is just a fast-path; we still fall back to errors.Is. - if _, err := substr.Get(substrLen); err != nil && !(err == ErrWildcard || errors.Is(err, ErrWildcard)) { - if errors.Is(err, io.EOF) { - break - } - return nil, err - } - substrLen++ - } +func buildKMPTable[K ~int64 | ~int, V comparable](substr kmpPattern[K, V]) []K { + substrLen := substr.PatLen() table := make([]K, substrLen) for j := K(0); j < substrLen; j++ { @@ -77,26 +43,31 @@ func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, e } val := table[j-1] // not a match; go back - for val > 0 && !kmpEq2(substr, j, substr, val) { + for val > 0 && !kmpSelfEq(substr, j, val) { val = table[val-1] } // is a match; go forward - if kmpEq2(substr, val, substr, j) { + if kmpSelfEq(substr, val, j) { val++ } table[j] = val } - return table, nil + return table } -// IndexAll returns the starting-position of all possibly-overlapping +func kmpEq[K ~int64 | ~int, V comparable](aV V, bS kmpPattern[K, V], bI K) bool { + bV, ok := bS.PatGet(bI) + if !ok { + return true + } + return aV == bV +} + +// indexAll returns the starting-position of all possibly-overlapping // occurrences of 'substr' in the 'str' sequence. // // Will hop around in 'substr', but will only get the natural sequence -// [0...) in order from 'str'. When hopping around in 'substr' it -// assumes that once it has gotten a given index without error, it can -// continue to do so without error; errors appearing later will cause -// panics. +// [0...) in order from 'str'. // // Will panic if the length of 'substr' is 0. // @@ -104,11 +75,8 @@ func buildKMPTable[K ~int64, V comparable](substr diskio.Sequence[K, V]) ([]K, e // ErrWildcard for a position. // // Uses the Knuth-Morris-Pratt algorithm. -func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, error) { - table, err := buildKMPTable(substr) - if err != nil { - return nil, err - } +func indexAll[K ~int64 | ~int, V comparable](str diskio.Sequence[K, V], substr kmpPattern[K, V]) []K { + table := buildKMPTable(substr) substrLen := K(len(table)) if substrLen == 0 { panic(errors.New("rebuildmappings.IndexAll: empty substring")) @@ -118,22 +86,17 @@ func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, e var curMatchBeg K var curMatchLen K - for pos := K(0); ; pos++ { - chr, err := str.Get(pos) - if err != nil { - if errors.Is(err, io.EOF) { - err = nil - } - return matches, err - } + strLen := str.SeqLen() + for pos := K(0); pos < strLen; pos++ { + chr := str.SeqGet(pos) // Consider 'chr' - for curMatchLen > 0 && !kmpEq1(chr, substr, curMatchLen) { // shorten the match + for curMatchLen > 0 && !kmpEq(chr, substr, curMatchLen) { // shorten the match overlap := table[curMatchLen-1] curMatchBeg += curMatchLen - overlap curMatchLen = overlap } - if kmpEq1(chr, substr, curMatchLen) { // lengthen the match + if kmpEq(chr, substr, curMatchLen) { // lengthen the match if curMatchLen == 0 { curMatchBeg = pos } @@ -146,4 +109,5 @@ func IndexAll[K ~int64, V comparable](str, substr diskio.Sequence[K, V]) ([]K, e } } } + return matches } diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go index 910452a..acec9b8 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/kmp_test.go @@ -6,7 +6,6 @@ package rebuildmappings import ( "bytes" - "io" "testing" "github.com/stretchr/testify/assert" @@ -15,11 +14,28 @@ import ( "git.lukeshu.com/btrfs-progs-ng/lib/diskio" ) +type bytePattern[K ~int64 | ~int] []byte + +var _ kmpPattern[int, byte] = bytePattern[int]{} + +// PatLen implements kmpPattern. +func (s bytePattern[K]) PatLen() K { + return K(len(s)) +} + +// PatGet implements kmpPattern. +func (s bytePattern[K]) PatGet(i K) (byte, bool) { + chr := s[int(i)] + if chr == '.' { + return 0, false + } + return chr, true +} + func TestBuildKMPTable(t *testing.T) { t.Parallel() - substr := diskio.SliceSequence[int64, byte]([]byte("ababaa")) - table, err := buildKMPTable[int64, byte](substr) - require.NoError(t, err) + substr := bytePattern[int64]([]byte("ababaa")) + table := buildKMPTable[int64, byte](substr) require.Equal(t, []int64{0, 0, 1, 2, 3, 1}, table) @@ -33,8 +49,7 @@ func TestBuildKMPTable(t *testing.T) { func FuzzBuildKMPTable(f *testing.F) { f.Add([]byte("ababaa")) f.Fuzz(func(t *testing.T, substr []byte) { - table, err := buildKMPTable[int64, byte](diskio.SliceSequence[int64, byte](substr)) - require.NoError(t, err) + table := buildKMPTable[int64, byte](bytePattern[int64](substr)) require.Equal(t, len(substr), len(table), "length") for j, val := range table { matchLen := j + 1 @@ -62,27 +77,13 @@ func FuzzIndexAll(f *testing.F) { t.Logf("str =%q", str) t.Logf("substr=%q", substr) exp := NaiveIndexAll(str, substr) - act, err := IndexAll[int64, byte]( - &diskio.ByteReaderSequence[int64]{R: bytes.NewReader(str)}, - diskio.SliceSequence[int64, byte](substr)) - assert.NoError(t, err) + act := indexAll[int64, byte]( + diskio.SliceSequence[int64, byte](str), + bytePattern[int64](substr)) assert.Equal(t, exp, act) }) } -type RESeq string - -func (re RESeq) Get(i int64) (byte, error) { - if i < 0 || i >= int64(len(re)) { - return 0, io.EOF - } - chr := re[int(i)] - if chr == '.' { - return 0, ErrWildcard - } - return chr, nil -} - func TestKMPWildcard(t *testing.T) { t.Parallel() type testcase struct { @@ -116,10 +117,9 @@ func TestKMPWildcard(t *testing.T) { tc := tc t.Run(tcName, func(t *testing.T) { t.Parallel() - matches, err := IndexAll[int64, byte]( + matches := indexAll[int64, byte]( diskio.StringSequence[int64](tc.InStr), - RESeq(tc.InSubstr)) - assert.NoError(t, err) + bytePattern[int64](tc.InSubstr)) assert.Equal(t, tc.ExpMatches, matches) }) } diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go index eda37bd..a3e724e 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/matchsums.go @@ -36,10 +36,7 @@ func matchBlockGroupSums(ctx context.Context, var matches []btrfsvol.QualifiedPhysicalAddr if err := WalkUnmappedPhysicalRegions(ctx, physicalSums, regions, func(devID btrfsvol.DeviceID, region btrfssum.SumRun[btrfsvol.PhysicalAddr]) error { - rawMatches, err := IndexAll[int64, btrfssum.ShortSum](region, bgRun) - if err != nil { - return err - } + rawMatches := indexAll[int, btrfssum.ShortSum](region, bgRun) for _, match := range rawMatches { matches = append(matches, btrfsvol.QualifiedPhysicalAddr{ Dev: devID, diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go index 665bc96..cdf5e5a 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/rebuildmappings.go @@ -158,10 +158,6 @@ func RebuildMappings(ctx context.Context, fs *btrfs.FS, scanResults btrfsinspect // The fuzzy-search is only fast because the exact-search is so good at getting `physicalBlocks` down. // Empirically: if I remove the exact-search step, then the fuzzy-match step is more than an order of magnitude // slower. - // - // The exact-search probably could be optimized to be substantially faster (by a constant factor; not affecting - // the big-O) by figuring out how to inline function calls and get reduce allocations, but IMO it's "fast - // enough" for now. ctx = dlog.WithField(_ctx, "btrfsinspect.rebuild-mappings.step", "5/6") dlog.Infof(_ctx, "5/6: Searching for %d block groups in checksum map (exact)...", len(bgs)) physicalSums := ExtractPhysicalSums(scanResults) diff --git a/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go b/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go index d1064d8..e574540 100644 --- a/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go +++ b/lib/btrfsprogs/btrfsinspect/rebuildmappings/sumrunwithgaps.go @@ -30,15 +30,16 @@ var ( _ lowmemjson.Decodable = (*SumRunWithGaps[btrfsvol.LogicalAddr])(nil) ) -func (sg SumRunWithGaps[Addr]) NumSums() int { +// PatLen implements kmpPattern[int, ShortSum]. +func (sg SumRunWithGaps[Addr]) PatLen() int { return int(sg.Size / btrfssum.BlockSize) } func (sg SumRunWithGaps[Addr]) PctFull() float64 { - total := sg.NumSums() + total := sg.PatLen() var full int for _, run := range sg.Runs { - full += run.NumSums() + full += run.SeqLen() } return float64(full) / float64(total) } @@ -56,21 +57,21 @@ func (sg SumRunWithGaps[Addr]) RunForAddr(addr Addr) (btrfssum.SumRun[Addr], Add return btrfssum.SumRun[Addr]{}, math.MaxInt64, false } -func (sg SumRunWithGaps[Addr]) SumForAddr(addr Addr) (btrfssum.ShortSum, error) { +func (sg SumRunWithGaps[Addr]) SumForAddr(addr Addr) (btrfssum.ShortSum, bool) { if addr < sg.Addr || addr >= sg.Addr.Add(sg.Size) { - return "", io.EOF + return "", false } for _, run := range sg.Runs { if run.Addr > addr { - return "", ErrWildcard + return "", false } if run.Addr.Add(run.Size()) <= addr { continue } off := int((addr-run.Addr)/btrfssum.BlockSize) * run.ChecksumSize - return run.Sums[off : off+run.ChecksumSize], nil + return run.Sums[off : off+run.ChecksumSize], true } - return "", ErrWildcard + return "", false } func (sg SumRunWithGaps[Addr]) Walk(ctx context.Context, fn func(Addr, btrfssum.ShortSum) error) error { @@ -82,8 +83,8 @@ func (sg SumRunWithGaps[Addr]) Walk(ctx context.Context, fn func(Addr, btrfssum. return nil } -// Get implements diskio.Sequence[int, ShortSum] -func (sg SumRunWithGaps[Addr]) Get(sumIdx int64) (btrfssum.ShortSum, error) { +// PatGet implements kmpPattern[int, ShortSum]. +func (sg SumRunWithGaps[Addr]) PatGet(sumIdx int) (btrfssum.ShortSum, bool) { addr := sg.Addr.Add(btrfsvol.AddrDelta(sumIdx) * btrfssum.BlockSize) return sg.SumForAddr(addr) } diff --git a/lib/diskio/seq.go b/lib/diskio/seq.go index 3c5f4ae..f8e6ea8 100644 --- a/lib/diskio/seq.go +++ b/lib/diskio/seq.go @@ -1,68 +1,43 @@ -// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com> // // SPDX-License-Identifier: GPL-2.0-or-later package diskio -import ( - "fmt" - "io" -) - // interface ///////////////////////////////////////////////////////// -type Sequence[K ~int64, V any] interface { +type Sequence[K ~int64 | ~int, V any] interface { + SeqLen() K // Get the value at 'pos' in the sequence. Positions start at - // 0 and increment naturally. Return an error that is io.EOF - // if 'pos' is past the end of the sequence'. - Get(pos K) (V, error) + // 0 and increment naturally. It is invalid to call + // SeqGet(pos) with a pos that is >= SeqLen(). + SeqGet(pos K) V } // implementation: slice ///////////////////////////////////////////// -type SliceSequence[K ~int64, V any] []V +type SliceSequence[K ~int64 | ~int, V any] []V + +var _ Sequence[assertAddr, byte] = SliceSequence[assertAddr, byte](nil) -var _ Sequence[assertAddr, byte] = SliceSequence[assertAddr, byte]([]byte(nil)) +func (s SliceSequence[K, V]) SeqLen() K { + return K(len(s)) +} -func (s SliceSequence[K, V]) Get(i K) (V, error) { - if i >= K(len(s)) { - var v V - return v, io.EOF - } - return s[int(i)], nil +func (s SliceSequence[K, V]) SeqGet(i K) V { + return s[int(i)] } // implementation: string //////////////////////////////////////////// -type StringSequence[K ~int64] string +type StringSequence[K ~int64 | ~int] string var _ Sequence[assertAddr, byte] = StringSequence[assertAddr]("") -func (s StringSequence[K]) Get(i K) (byte, error) { - if i >= K(len(s)) { - return 0, io.EOF - } - return s[int(i)], nil +func (s StringSequence[K]) SeqLen() K { + return K(len(s)) } -// implementation: io.ByteReader ///////////////////////////////////// - -type ByteReaderSequence[K ~int64] struct { - R io.ByteReader - pos K -} - -var _ Sequence[assertAddr, byte] = &ByteReaderSequence[assertAddr]{R: nil} - -func (s *ByteReaderSequence[K]) Get(i K) (byte, error) { - if i != s.pos { - return 0, fmt.Errorf("%T.Get(%v): can only call .Get(%v)", - s, i, s.pos) - } - chr, err := s.R.ReadByte() - if err != nil { - return chr, err - } - s.pos++ - return chr, nil +func (s StringSequence[K]) SeqGet(i K) byte { + return s[int(i)] } |