summaryrefslogtreecommitdiff
path: root/lib/util
diff options
context:
space:
mode:
Diffstat (limited to 'lib/util')
-rw-r--r--lib/util/bitfield.go50
-rw-r--r--lib/util/fmt.go67
-rw-r--r--lib/util/fmt_test.go99
-rw-r--r--lib/util/generic.go122
-rw-r--r--lib/util/int.go3
-rw-r--r--lib/util/lru.go73
-rw-r--r--lib/util/ref.go54
-rw-r--r--lib/util/uuid.go72
-rw-r--r--lib/util/uuid_test.go63
9 files changed, 603 insertions, 0 deletions
diff --git a/lib/util/bitfield.go b/lib/util/bitfield.go
new file mode 100644
index 0000000..23da17a
--- /dev/null
+++ b/lib/util/bitfield.go
@@ -0,0 +1,50 @@
+package util
+
+import (
+ "fmt"
+ "strings"
+)
+
+type BitfieldFormat uint8
+
+const (
+ HexNone = BitfieldFormat(iota)
+ HexLower
+ HexUpper
+)
+
+func BitfieldString[T ~uint8 | ~uint16 | ~uint32 | ~uint64](bitfield T, bitnames []string, cfg BitfieldFormat) string {
+ var out strings.Builder
+ switch cfg {
+ case HexNone:
+ // do nothing
+ case HexLower:
+ fmt.Fprintf(&out, "0x%0x(", uint64(bitfield))
+ case HexUpper:
+ fmt.Fprintf(&out, "0x%0X(", uint64(bitfield))
+ }
+ if bitfield == 0 {
+ out.WriteString("none")
+ } else {
+ rest := bitfield
+ first := true
+ for i := 0; rest != 0; i++ {
+ if rest&(1<<i) != 0 {
+ if !first {
+ out.WriteRune('|')
+ }
+ if i < len(bitnames) {
+ out.WriteString(bitnames[i])
+ } else {
+ fmt.Fprintf(&out, "(1<<%v)", i)
+ }
+ first = false
+ }
+ rest &^= 1 << i
+ }
+ }
+ if cfg != HexNone {
+ out.WriteRune(')')
+ }
+ return out.String()
+}
diff --git a/lib/util/fmt.go b/lib/util/fmt.go
new file mode 100644
index 0000000..af7404c
--- /dev/null
+++ b/lib/util/fmt.go
@@ -0,0 +1,67 @@
+package util
+
+import (
+ "fmt"
+ "strings"
+)
+
+// FmtStateString returns the fmt.Printf string that produced a given
+// fmt.State and verb.
+func FmtStateString(st fmt.State, verb rune) string {
+ var ret strings.Builder
+ ret.WriteByte('%')
+ for _, flag := range []int{'-', '+', '#', ' ', '0'} {
+ if st.Flag(flag) {
+ ret.WriteByte(byte(flag))
+ }
+ }
+ if width, ok := st.Width(); ok {
+ fmt.Fprintf(&ret, "%v", width)
+ }
+ if prec, ok := st.Precision(); ok {
+ if prec == 0 {
+ ret.WriteByte('.')
+ } else {
+ fmt.Fprintf(&ret, ".%v", prec)
+ }
+ }
+ ret.WriteRune(verb)
+ return ret.String()
+}
+
+// FormatByteArrayStringer is function for helping to implement
+// fmt.Formatter for []byte or [n]byte types that have a custom string
+// representation. Use it like:
+//
+// type MyType [16]byte
+//
+// func (val MyType) String() string {
+// …
+// }
+//
+// func (val MyType) Format(f fmt.State, verb rune) {
+// util.FormatByteArrayStringer(val, val[:], f, verb)
+// }
+func FormatByteArrayStringer(
+ obj interface {
+ fmt.Stringer
+ fmt.Formatter
+ },
+ objBytes []byte,
+ f fmt.State, verb rune) {
+ switch verb {
+ case 'v':
+ if !f.Flag('#') {
+ FormatByteArrayStringer(obj, objBytes, f, 's') // as a string
+ } else {
+ byteStr := fmt.Sprintf("%#v", objBytes)
+ objType := fmt.Sprintf("%T", obj)
+ objStr := objType + strings.TrimPrefix(byteStr, "[]byte")
+ fmt.Fprintf(f, FmtStateString(f, 's'), objStr)
+ }
+ case 's', 'q': // string
+ fmt.Fprintf(f, FmtStateString(f, verb), obj.String())
+ default:
+ fmt.Fprintf(f, FmtStateString(f, verb), objBytes)
+ }
+}
diff --git a/lib/util/fmt_test.go b/lib/util/fmt_test.go
new file mode 100644
index 0000000..4251ecf
--- /dev/null
+++ b/lib/util/fmt_test.go
@@ -0,0 +1,99 @@
+package util_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/util"
+)
+
+type FmtState struct {
+ MWidth int
+ MPrec int
+ MFlagMinus bool
+ MFlagPlus bool
+ MFlagSharp bool
+ MFlagSpace bool
+ MFlagZero bool
+}
+
+func (st FmtState) Width() (int, bool) {
+ if st.MWidth < 1 {
+ return 0, false
+ }
+ return st.MWidth, true
+}
+
+func (st FmtState) Precision() (int, bool) {
+ if st.MPrec < 1 {
+ return 0, false
+ }
+ return st.MPrec, true
+}
+
+func (st FmtState) Flag(b int) bool {
+ switch b {
+ case '-':
+ return st.MFlagMinus
+ case '+':
+ return st.MFlagPlus
+ case '#':
+ return st.MFlagSharp
+ case ' ':
+ return st.MFlagSpace
+ case '0':
+ return st.MFlagZero
+ }
+ return false
+}
+
+func (st FmtState) Write([]byte) (int, error) {
+ panic("not implemented")
+}
+
+func (dst *FmtState) Format(src fmt.State, verb rune) {
+ if width, ok := src.Width(); ok {
+ dst.MWidth = width
+ }
+ if prec, ok := src.Precision(); ok {
+ dst.MPrec = prec
+ }
+ dst.MFlagMinus = src.Flag('-')
+ dst.MFlagPlus = src.Flag('+')
+ dst.MFlagSharp = src.Flag('#')
+ dst.MFlagSpace = src.Flag(' ')
+ dst.MFlagZero = src.Flag('0')
+}
+
+// letters only? No 'p', 'T', or 'w'.
+const verbs = "abcdefghijklmnoqrstuvxyzABCDEFGHIJKLMNOPQRSUVWXYZ"
+
+func FuzzFmtStateString(f *testing.F) {
+ f.Fuzz(func(t *testing.T,
+ width, prec uint8,
+ flagMinus, flagPlus, flagSharp, flagSpace, flagZero bool,
+ verbIdx uint8,
+ ) {
+ if flagMinus {
+ flagZero = false
+ }
+ input := FmtState{
+ MWidth: int(width),
+ MPrec: int(prec),
+ MFlagMinus: flagMinus,
+ MFlagPlus: flagPlus,
+ MFlagSharp: flagSharp,
+ MFlagSpace: flagSpace,
+ MFlagZero: flagZero,
+ }
+ verb := rune(verbs[int(verbIdx)%len(verbs)])
+
+ t.Logf("(%#v, %c) => %q", input, verb, util.FmtStateString(input, verb))
+
+ var output FmtState
+ assert.Equal(t, "", fmt.Sprintf(util.FmtStateString(input, verb), &output))
+ assert.Equal(t, input, output)
+ })
+}
diff --git a/lib/util/generic.go b/lib/util/generic.go
new file mode 100644
index 0000000..6882724
--- /dev/null
+++ b/lib/util/generic.go
@@ -0,0 +1,122 @@
+package util
+
+import (
+ "sort"
+ "sync"
+
+ "golang.org/x/exp/constraints"
+)
+
+func InSlice[T comparable](needle T, haystack []T) bool {
+ for _, straw := range haystack {
+ if needle == straw {
+ return true
+ }
+ }
+ return false
+}
+
+func RemoveAllFromSlice[T comparable](haystack []T, needle T) []T {
+ for i, straw := range haystack {
+ if needle == straw {
+ return append(
+ haystack[:i],
+ RemoveAllFromSlice(haystack[i+1:], needle)...)
+ }
+ }
+ return haystack
+}
+
+func RemoveAllFromSliceFunc[T any](haystack []T, f func(T) bool) []T {
+ for i, straw := range haystack {
+ if f(straw) {
+ return append(
+ haystack[:i],
+ RemoveAllFromSliceFunc(haystack[i+1:], f)...)
+ }
+ }
+ return haystack
+}
+
+func ReverseSlice[T any](slice []T) {
+ for i := 0; i < len(slice)/2; i++ {
+ j := (len(slice) - 1) - i
+ slice[i], slice[j] = slice[j], slice[i]
+ }
+}
+
+func Max[T constraints.Ordered](a, b T) T {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+func Min[T constraints.Ordered](a, b T) T {
+ if a < b {
+ return a
+ }
+ return b
+}
+
+func MapKeys[K comparable, V any](m map[K]V) []K {
+ ret := make([]K, 0, len(m))
+ for k := range m {
+ ret = append(ret, k)
+ }
+ return ret
+}
+
+func SortSlice[T constraints.Ordered](slice []T) {
+ sort.Slice(slice, func(i, j int) bool {
+ return slice[i] < slice[j]
+ })
+}
+
+func SortedMapKeys[K constraints.Ordered, V any](m map[K]V) []K {
+ ret := MapKeys(m)
+ SortSlice(ret)
+ return ret
+}
+
+func CmpUint[T constraints.Unsigned](a, b T) int {
+ switch {
+ case a < b:
+ return -1
+ case a == b:
+ return 0
+ default:
+ return 1
+ }
+}
+
+type SyncMap[K comparable, V any] struct {
+ inner sync.Map
+}
+
+func (m *SyncMap[K, V]) Delete(key K) { m.inner.Delete(key) }
+func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) {
+ _value, ok := m.inner.Load(key)
+ if ok {
+ value = _value.(V)
+ }
+ return value, ok
+}
+func (m *SyncMap[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
+ _value, ok := m.inner.LoadAndDelete(key)
+ if ok {
+ value = _value.(V)
+ }
+ return value, ok
+}
+func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
+ _actual, loaded := m.inner.LoadOrStore(key, value)
+ actual = _actual.(V)
+ return actual, loaded
+}
+func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) {
+ m.inner.Range(func(key, value any) bool {
+ return f(key.(K), value.(V))
+ })
+}
+func (m *SyncMap[K, V]) Store(key K, value V) { m.inner.Store(key, value) }
diff --git a/lib/util/int.go b/lib/util/int.go
new file mode 100644
index 0000000..fab553d
--- /dev/null
+++ b/lib/util/int.go
@@ -0,0 +1,3 @@
+package util
+
+const MaxUint64pp = 0x1_00000000_00000000
diff --git a/lib/util/lru.go b/lib/util/lru.go
new file mode 100644
index 0000000..2b62e69
--- /dev/null
+++ b/lib/util/lru.go
@@ -0,0 +1,73 @@
+package util
+
+import (
+ "sync"
+
+ lru "github.com/hashicorp/golang-lru"
+)
+
+type LRUCache[K comparable, V any] struct {
+ initOnce sync.Once
+ inner *lru.ARCCache
+}
+
+func (c *LRUCache[K, V]) init() {
+ c.initOnce.Do(func() {
+ c.inner, _ = lru.NewARC(128)
+ })
+}
+
+func (c *LRUCache[K, V]) Add(key K, value V) {
+ c.init()
+ c.inner.Add(key, value)
+}
+func (c *LRUCache[K, V]) Contains(key K) bool {
+ c.init()
+ return c.inner.Contains(key)
+}
+func (c *LRUCache[K, V]) Get(key K) (value V, ok bool) {
+ c.init()
+ _value, ok := c.inner.Get(key)
+ if ok {
+ value = _value.(V)
+ }
+ return value, ok
+}
+func (c *LRUCache[K, V]) Keys() []K {
+ c.init()
+ untyped := c.inner.Keys()
+ typed := make([]K, len(untyped))
+ for i := range untyped {
+ typed[i] = untyped[i].(K)
+ }
+ return typed
+}
+func (c *LRUCache[K, V]) Len() int {
+ c.init()
+ return c.inner.Len()
+}
+func (c *LRUCache[K, V]) Peek(key K) (value V, ok bool) {
+ c.init()
+ _value, ok := c.inner.Peek(key)
+ if ok {
+ value = _value.(V)
+ }
+ return value, ok
+}
+func (c *LRUCache[K, V]) Purge() {
+ c.init()
+ c.inner.Purge()
+}
+func (c *LRUCache[K, V]) Remove(key K) {
+ c.init()
+ c.inner.Remove(key)
+}
+
+func (c *LRUCache[K, V]) GetOrElse(key K, fn func() V) V {
+ var value V
+ var ok bool
+ for value, ok = c.Get(key); !ok; value, ok = c.Get(key) {
+ c.Add(key, fn())
+ }
+ return value
+}
diff --git a/lib/util/ref.go b/lib/util/ref.go
new file mode 100644
index 0000000..1ac48c9
--- /dev/null
+++ b/lib/util/ref.go
@@ -0,0 +1,54 @@
+package util
+
+import (
+ "fmt"
+ "io"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/binstruct"
+)
+
+type File[A ~int64] interface {
+ Name() string
+ Size() (A, error)
+ ReadAt(p []byte, off A) (n int, err error)
+ WriteAt(p []byte, off A) (n int, err error)
+}
+
+var (
+ _ io.WriterAt = File[int64](nil)
+ _ io.ReaderAt = File[int64](nil)
+)
+
+type Ref[A ~int64, T any] struct {
+ File File[A]
+ Addr A
+ Data T
+}
+
+func (r *Ref[A, T]) Read() error {
+ size := binstruct.StaticSize(r.Data)
+ buf := make([]byte, size)
+ if _, err := r.File.ReadAt(buf, r.Addr); err != nil {
+ return err
+ }
+ n, err := binstruct.Unmarshal(buf, &r.Data)
+ if err != nil {
+ return err
+ }
+ if n != size {
+ return fmt.Errorf("util.Ref[%T].Read: left over data: read %v bytes but only consumed %v",
+ r.Data, size, n)
+ }
+ return nil
+}
+
+func (r *Ref[A, T]) Write() error {
+ buf, err := binstruct.Marshal(r.Data)
+ if err != nil {
+ return err
+ }
+ if _, err = r.File.WriteAt(buf, r.Addr); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/lib/util/uuid.go b/lib/util/uuid.go
new file mode 100644
index 0000000..3b4cbaf
--- /dev/null
+++ b/lib/util/uuid.go
@@ -0,0 +1,72 @@
+package util
+
+import (
+ "encoding/hex"
+ "fmt"
+ "strings"
+)
+
+type UUID [16]byte
+
+func (uuid UUID) String() string {
+ str := hex.EncodeToString(uuid[:])
+ return strings.Join([]string{
+ str[:8],
+ str[8:12],
+ str[12:16],
+ str[16:20],
+ str[20:32],
+ }, "-")
+}
+
+func (a UUID) Cmp(b UUID) int {
+ for i := range a {
+ if d := int(a[i]) - int(b[i]); d != 0 {
+ return d
+ }
+ }
+ return 0
+}
+
+func (uuid UUID) Format(f fmt.State, verb rune) {
+ FormatByteArrayStringer(uuid, uuid[:], f, verb)
+}
+
+func ParseUUID(str string) (UUID, error) {
+ var ret UUID
+ j := 0
+ for i := 0; i < len(str); i++ {
+ if j >= len(ret)*2 {
+ return UUID{}, fmt.Errorf("too long to be a UUID: %q|%q", str[:i], str[i:])
+ }
+ c := str[i]
+ var v byte
+ switch {
+ case '0' <= c && c <= '9':
+ v = c - '0'
+ case 'a' <= c && c <= 'f':
+ v = c - 'a' + 10
+ case 'A' <= c && c <= 'F':
+ v = c - 'A' + 10
+ case c == '-':
+ continue
+ default:
+ return UUID{}, fmt.Errorf("illegal byte in UUID: %q|%q|%q", str[:i], str[i:i+1], str[i+1:])
+ }
+ if j%2 == 0 {
+ ret[j/2] = v << 4
+ } else {
+ ret[j/2] = (ret[j/2] & 0xf0) | (v & 0x0f)
+ }
+ j++
+ }
+ return ret, nil
+}
+
+func MustParseUUID(str string) UUID {
+ ret, err := ParseUUID(str)
+ if err != nil {
+ panic(err)
+ }
+ return ret
+}
diff --git a/lib/util/uuid_test.go b/lib/util/uuid_test.go
new file mode 100644
index 0000000..7e0e07a
--- /dev/null
+++ b/lib/util/uuid_test.go
@@ -0,0 +1,63 @@
+package util_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "git.lukeshu.com/btrfs-progs-ng/lib/util"
+)
+
+func TestParseUUID(t *testing.T) {
+ t.Parallel()
+ type TestCase struct {
+ Input string
+ OutputVal util.UUID
+ OutputErr string
+ }
+ testcases := map[string]TestCase{
+ "basic": {Input: "a0dd94ed-e60c-42e8-8632-64e8d4765a43", OutputVal: util.UUID{0xa0, 0xdd, 0x94, 0xed, 0xe6, 0x0c, 0x42, 0xe8, 0x86, 0x32, 0x64, 0xe8, 0xd4, 0x76, 0x5a, 0x43}},
+ "too-long": {Input: "a0dd94ed-e60c-42e8-8632-64e8d4765a43a", OutputErr: `too long to be a UUID: "a0dd94ed-e60c-42e8-8632-64e8d4765a43"|"a"`},
+ "bad char": {Input: "a0dd94ej-e60c-42e8-8632-64e8d4765a43a", OutputErr: `illegal byte in UUID: "a0dd94e"|"j"|"-e60c-42e8-8632-64e8d4765a43a"`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ val, err := util.ParseUUID(tc.Input)
+ assert.Equal(t, tc.OutputVal, val)
+ if tc.OutputErr == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.OutputErr)
+ }
+ })
+ }
+}
+
+func TestUUIDFormat(t *testing.T) {
+ t.Parallel()
+ type TestCase struct {
+ InputUUID util.UUID
+ InputFmt string
+ Output string
+ }
+ uuid := util.MustParseUUID("a0dd94ed-e60c-42e8-8632-64e8d4765a43")
+ testcases := map[string]TestCase{
+ "s": {InputUUID: uuid, InputFmt: "%s", Output: "a0dd94ed-e60c-42e8-8632-64e8d4765a43"},
+ "x": {InputUUID: uuid, InputFmt: "%x", Output: "a0dd94ede60c42e8863264e8d4765a43"},
+ "X": {InputUUID: uuid, InputFmt: "%X", Output: "A0DD94EDE60C42E8863264E8D4765A43"},
+ "v": {InputUUID: uuid, InputFmt: "%v", Output: "a0dd94ed-e60c-42e8-8632-64e8d4765a43"},
+ "40s": {InputUUID: uuid, InputFmt: "|% 40s", Output: "| a0dd94ed-e60c-42e8-8632-64e8d4765a43"},
+ "#115v": {InputUUID: uuid, InputFmt: "|%#115v", Output: "| util.UUID{0xa0, 0xdd, 0x94, 0xed, 0xe6, 0xc, 0x42, 0xe8, 0x86, 0x32, 0x64, 0xe8, 0xd4, 0x76, 0x5a, 0x43}"},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ actual := fmt.Sprintf(tc.InputFmt, tc.InputUUID)
+ assert.Equal(t, tc.Output, actual)
+ })
+ }
+}