From 27401b6ea459921a6152ab1744da1618358465f4 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sun, 10 Jul 2022 13:18:30 -0600 Subject: Rename the module, mv pkg lib --- lib/binstruct/binint.go | 35 ++++++ lib/binstruct/binint/builtins.go | 241 +++++++++++++++++++++++++++++++++++++++ lib/binstruct/binstruct_test.go | 60 ++++++++++ lib/binstruct/marshal.go | 42 +++++++ lib/binstruct/size.go | 57 +++++++++ lib/binstruct/structs.go | 186 ++++++++++++++++++++++++++++++ lib/binstruct/unmarshal.go | 54 +++++++++ 7 files changed, 675 insertions(+) create mode 100644 lib/binstruct/binint.go create mode 100644 lib/binstruct/binint/builtins.go create mode 100644 lib/binstruct/binstruct_test.go create mode 100644 lib/binstruct/marshal.go create mode 100644 lib/binstruct/size.go create mode 100644 lib/binstruct/structs.go create mode 100644 lib/binstruct/unmarshal.go (limited to 'lib/binstruct') diff --git a/lib/binstruct/binint.go b/lib/binstruct/binint.go new file mode 100644 index 0000000..89bb4f6 --- /dev/null +++ b/lib/binstruct/binint.go @@ -0,0 +1,35 @@ +package binstruct + +import ( + "reflect" + + "git.lukeshu.com/btrfs-progs-ng/lib/binstruct/binint" +) + +type ( + U8 = binint.U8 + U16le = binint.U16le + U32le = binint.U32le + U64le = binint.U64le + U16be = binint.U16be + U32be = binint.U32be + U64be = binint.U64be + I8 = binint.I8 + I16le = binint.I16le + I32le = binint.I32le + I64le = binint.I64le + I16be = binint.I16be + I32be = binint.I32be + I64be = binint.I64be +) + +var intKind2Type = map[reflect.Kind]reflect.Type{ + reflect.Uint8: reflect.TypeOf(U8(0)), + reflect.Int8: reflect.TypeOf(I8(0)), + reflect.Uint16: reflect.TypeOf(U16le(0)), + reflect.Int16: reflect.TypeOf(I16le(0)), + reflect.Uint32: reflect.TypeOf(U32le(0)), + reflect.Int32: reflect.TypeOf(I32le(0)), + reflect.Uint64: reflect.TypeOf(U64le(0)), + reflect.Int64: reflect.TypeOf(I64le(0)), +} diff --git a/lib/binstruct/binint/builtins.go b/lib/binstruct/binint/builtins.go new file mode 100644 index 0000000..04fc477 --- /dev/null +++ b/lib/binstruct/binint/builtins.go @@ -0,0 +1,241 @@ +package binint + +import ( + "encoding/binary" + "fmt" +) + +func needNBytes(t interface{}, dat []byte, n int) error { + if len(dat) < n { + return fmt.Errorf("%T.UnmarshalBinary: need at least %v bytes, only have %v", t, n, len(dat)) + } + return nil +} + +// unsigned + +type U8 uint8 + +func (U8) BinaryStaticSize() int { return 1 } +func (x U8) MarshalBinary() ([]byte, error) { return []byte{byte(x)}, nil } +func (x *U8) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 1); err != nil { + return 0, err + } + *x = U8(dat[0]) + return 1, nil +} + +// unsigned little endian + +type U16le uint16 + +func (U16le) BinaryStaticSize() int { return 2 } +func (x U16le) MarshalBinary() ([]byte, error) { + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], uint16(x)) + return buf[:], nil +} +func (x *U16le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 2); err != nil { + return 0, err + } + *x = U16le(binary.LittleEndian.Uint16(dat)) + return 2, nil +} + +type U32le uint32 + +func (U32le) BinaryStaticSize() int { return 4 } +func (x U32le) MarshalBinary() ([]byte, error) { + var buf [4]byte + binary.LittleEndian.PutUint32(buf[:], uint32(x)) + return buf[:], nil +} +func (x *U32le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 4); err != nil { + return 0, err + } + *x = U32le(binary.LittleEndian.Uint32(dat)) + return 4, nil +} + +type U64le uint64 + +func (U64le) BinaryStaticSize() int { return 8 } +func (x U64le) MarshalBinary() ([]byte, error) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(x)) + return buf[:], nil +} +func (x *U64le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 8); err != nil { + return 0, err + } + *x = U64le(binary.LittleEndian.Uint64(dat)) + return 8, nil +} + +// unsigned big endian + +type U16be uint16 + +func (U16be) BinaryStaticSize() int { return 2 } +func (x U16be) MarshalBinary() ([]byte, error) { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], uint16(x)) + return buf[:], nil +} +func (x *U16be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 2); err != nil { + return 0, err + } + *x = U16be(binary.BigEndian.Uint16(dat)) + return 2, nil +} + +type U32be uint32 + +func (U32be) BinaryStaticSize() int { return 4 } +func (x U32be) MarshalBinary() ([]byte, error) { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], uint32(x)) + return buf[:], nil +} +func (x *U32be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 4); err != nil { + return 0, err + } + *x = U32be(binary.BigEndian.Uint32(dat)) + return 4, nil +} + +type U64be uint64 + +func (U64be) BinaryStaticSize() int { return 8 } +func (x U64be) MarshalBinary() ([]byte, error) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(x)) + return buf[:], nil +} +func (x *U64be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 8); err != nil { + return 0, err + } + *x = U64be(binary.BigEndian.Uint64(dat)) + return 8, nil +} + +// signed + +type I8 int8 + +func (I8) BinaryStaticSize() int { return 1 } +func (x I8) MarshalBinary() ([]byte, error) { return []byte{byte(x)}, nil } +func (x *I8) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 1); err != nil { + return 0, err + } + *x = I8(dat[0]) + return 1, nil +} + +// signed little endian + +type I16le int16 + +func (I16le) BinaryStaticSize() int { return 2 } +func (x I16le) MarshalBinary() ([]byte, error) { + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], uint16(x)) + return buf[:], nil +} +func (x *I16le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 2); err != nil { + return 0, err + } + *x = I16le(binary.LittleEndian.Uint16(dat)) + return 2, nil +} + +type I32le int32 + +func (I32le) BinaryStaticSize() int { return 4 } +func (x I32le) MarshalBinary() ([]byte, error) { + var buf [4]byte + binary.LittleEndian.PutUint32(buf[:], uint32(x)) + return buf[:], nil +} +func (x *I32le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 4); err != nil { + return 0, err + } + *x = I32le(binary.LittleEndian.Uint32(dat)) + return 4, nil +} + +type I64le int64 + +func (I64le) BinaryStaticSize() int { return 8 } +func (x I64le) MarshalBinary() ([]byte, error) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(x)) + return buf[:], nil +} +func (x *I64le) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 8); err != nil { + return 0, err + } + *x = I64le(binary.LittleEndian.Uint64(dat)) + return 8, nil +} + +// signed big endian + +type I16be int16 + +func (I16be) BinaryStaticSize() int { return 2 } +func (x I16be) MarshalBinary() ([]byte, error) { + var buf [2]byte + binary.BigEndian.PutUint16(buf[:], uint16(x)) + return buf[:], nil +} +func (x *I16be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 2); err != nil { + return 0, err + } + *x = I16be(binary.BigEndian.Uint16(dat)) + return 2, nil +} + +type I32be int32 + +func (I32be) BinaryStaticSize() int { return 4 } +func (x I32be) MarshalBinary() ([]byte, error) { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], uint32(x)) + return buf[:], nil +} +func (x *I32be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 4); err != nil { + return 0, err + } + *x = I32be(binary.BigEndian.Uint32(dat)) + return 4, nil +} + +type I64be int64 + +func (I64be) BinaryStaticSize() int { return 8 } +func (x I64be) MarshalBinary() ([]byte, error) { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(x)) + return buf[:], nil +} +func (x *I64be) UnmarshalBinary(dat []byte) (int, error) { + if err := needNBytes(*x, dat, 8); err != nil { + return 0, err + } + *x = I64be(binary.BigEndian.Uint64(dat)) + return 8, nil +} diff --git a/lib/binstruct/binstruct_test.go b/lib/binstruct/binstruct_test.go new file mode 100644 index 0000000..542746f --- /dev/null +++ b/lib/binstruct/binstruct_test.go @@ -0,0 +1,60 @@ +package binstruct_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "git.lukeshu.com/btrfs-progs-ng/lib/binstruct" +) + +func TestSmoke(t *testing.T) { + type UUID [16]byte + type PhysicalAddr int64 + type DevItem struct { + DeviceID uint64 `bin:"off=0x0, siz=0x8"` // device id + + NumBytes uint64 `bin:"off=0x8, siz=0x8"` // number of bytes + NumBytesUsed uint64 `bin:"off=0x10, siz=0x8"` // number of bytes used + + IOOptimalAlign uint32 `bin:"off=0x18, siz=0x4"` // optimal I/O align + IOOptimalWidth uint32 `bin:"off=0x1c, siz=0x4"` // optimal I/O width + IOMinSize uint32 `bin:"off=0x20, siz=0x4"` // minimal I/O size (sector size) + + Type uint64 `bin:"off=0x24, siz=0x8"` // type + Generation uint64 `bin:"off=0x2c, siz=0x8"` // generation + StartOffset uint64 `bin:"off=0x34, siz=0x8"` // start offset + DevGroup uint32 `bin:"off=0x3c, siz=0x4"` // dev group + SeekSpeed uint8 `bin:"off=0x40, siz=0x1"` // seek speed + Bandwidth uint8 `bin:"off=0x41, siz=0x1"` // bandwidth + + DevUUID UUID `bin:"off=0x42, siz=0x10"` // device UUID + FSUUID UUID `bin:"off=0x52, siz=0x10"` // FS UUID + + binstruct.End `bin:"off=0x62"` + } + type TestType struct { + Magic [5]byte `bin:"off=0x0,siz=0x5"` + Dev DevItem `bin:"off=0x5,siz=0x62"` + Addr PhysicalAddr `bin:"off=0x67, siz=0x8"` + + binstruct.End `bin:"off=0x6F"` + } + + assert.Equal(t, 0x6F, binstruct.StaticSize(TestType{})) + + input := TestType{} + copy(input.Magic[:], "mAgIc") + input.Dev.DeviceID = 12 + input.Addr = 0xBEEF + + bs, err := binstruct.Marshal(input) + assert.NoError(t, err) + assert.Equal(t, 0x6F, len(bs)) + + var output TestType + n, err := binstruct.Unmarshal(bs, &output) + assert.NoError(t, err) + assert.Equal(t, 0x6F, n) + assert.Equal(t, input, output) +} diff --git a/lib/binstruct/marshal.go b/lib/binstruct/marshal.go new file mode 100644 index 0000000..684d2f3 --- /dev/null +++ b/lib/binstruct/marshal.go @@ -0,0 +1,42 @@ +package binstruct + +import ( + "encoding" + "fmt" + "reflect" +) + +type Marshaler = encoding.BinaryMarshaler + +func Marshal(obj any) ([]byte, error) { + if mar, ok := obj.(Marshaler); ok { + return mar.MarshalBinary() + } + return MarshalWithoutInterface(obj) +} + +func MarshalWithoutInterface(obj any) ([]byte, error) { + val := reflect.ValueOf(obj) + switch val.Kind() { + case reflect.Uint8, reflect.Int8, reflect.Uint16, reflect.Int16, reflect.Uint32, reflect.Int32, reflect.Uint64, reflect.Int64: + typ := intKind2Type[val.Kind()] + return val.Convert(typ).Interface().(Marshaler).MarshalBinary() + case reflect.Ptr: + return Marshal(val.Elem().Interface()) + case reflect.Array: + var ret []byte + for i := 0; i < val.Len(); i++ { + bs, err := Marshal(val.Index(i).Interface()) + ret = append(ret, bs...) + if err != nil { + return ret, err + } + } + return ret, nil + case reflect.Struct: + return getStructHandler(val.Type()).Marshal(val) + default: + panic(fmt.Errorf("type=%v does not implement binfmt.Marshaler and kind=%v is not a supported statically-sized kind", + val.Type(), val.Kind())) + } +} diff --git a/lib/binstruct/size.go b/lib/binstruct/size.go new file mode 100644 index 0000000..6563455 --- /dev/null +++ b/lib/binstruct/size.go @@ -0,0 +1,57 @@ +package binstruct + +import ( + "fmt" + "reflect" +) + +type StaticSizer interface { + BinaryStaticSize() int +} + +func StaticSize(obj any) int { + sz, err := staticSize(reflect.TypeOf(obj)) + if err != nil { + panic(err) + } + return sz +} + +var ( + staticSizerType = reflect.TypeOf((*StaticSizer)(nil)).Elem() + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() +) + +func staticSize(typ reflect.Type) (int, error) { + if typ.Implements(staticSizerType) { + return reflect.New(typ).Elem().Interface().(StaticSizer).BinaryStaticSize(), nil + } + switch typ.Kind() { + case reflect.Uint8, reflect.Int8: + return 1, nil + case reflect.Uint16, reflect.Int16: + return 2, nil + case reflect.Uint32, reflect.Int32: + return 4, nil + case reflect.Uint64, reflect.Int64: + return 8, nil + case reflect.Ptr: + return staticSize(typ.Elem()) + case reflect.Array: + elemSize, err := staticSize(typ.Elem()) + if err != nil { + return 0, err + } + return elemSize * typ.Len(), nil + case reflect.Struct: + if !(typ.Implements(marshalerType) || typ.Implements(unmarshalerType)) { + return getStructHandler(typ).Size, nil + } + return 0, fmt.Errorf("type=%v (kind=%v) does not implement binfmt.StaticSizer but does implement binfmt.Marshaler or binfmt.Unmarshaler", + typ, typ.Kind()) + default: + return 0, fmt.Errorf("type=%v does not implement binfmt.StaticSizer and kind=%v is not a supported statically-sized kind", + typ, typ.Kind()) + } +} diff --git a/lib/binstruct/structs.go b/lib/binstruct/structs.go new file mode 100644 index 0000000..ec2bb7d --- /dev/null +++ b/lib/binstruct/structs.go @@ -0,0 +1,186 @@ +package binstruct + +import ( + "fmt" + "reflect" + "strconv" + "strings" +) + +type End struct{} + +var endType = reflect.TypeOf(End{}) + +type tag struct { + skip bool + + off int + siz int +} + +func parseStructTag(str string) (tag, error) { + var ret tag + for _, part := range strings.Split(str, ",") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if part == "-" { + return tag{skip: true}, nil + } + keyval := strings.SplitN(part, "=", 2) + if len(keyval) != 2 { + return tag{}, fmt.Errorf("option is not a key=value pair: %q", part) + } + key := keyval[0] + val := keyval[1] + switch key { + case "off": + vint, err := strconv.ParseInt(val, 0, 0) + if err != nil { + return tag{}, err + } + ret.off = int(vint) + case "siz": + vint, err := strconv.ParseInt(val, 0, 0) + if err != nil { + return tag{}, err + } + ret.siz = int(vint) + default: + return tag{}, fmt.Errorf("unrecognized option %q", key) + } + } + return ret, nil +} + +type structHandler struct { + name string + Size int + fields []structField +} + +type structField struct { + name string + tag +} + +func (sh structHandler) Unmarshal(dat []byte, dst reflect.Value) (int, error) { + var n int + for i, field := range sh.fields { + if field.skip { + continue + } + _n, err := Unmarshal(dat[n:], dst.Field(i).Addr().Interface()) + if err != nil { + if _n >= 0 { + n += _n + } + return n, fmt.Errorf("struct %q field %v %q: %w", + sh.name, i, field.name, err) + } + if _n != field.siz { + return n, fmt.Errorf("struct %q field %v %q: consumed %v bytes but should have consumed %v bytes", + sh.name, i, field.name, _n, field.siz) + } + n += _n + } + return n, nil +} + +func (sh structHandler) Marshal(val reflect.Value) ([]byte, error) { + ret := make([]byte, 0, sh.Size) + for i, field := range sh.fields { + if field.skip { + continue + } + bs, err := Marshal(val.Field(i).Interface()) + ret = append(ret, bs...) + if err != nil { + return ret, fmt.Errorf("struct %q field %v %q: %w", + sh.name, i, field.name, err) + } + } + return ret, nil +} + +func genStructHandler(structInfo reflect.Type) (structHandler, error) { + var ret structHandler + + ret.name = structInfo.String() + + var curOffset, endOffset int + for i := 0; i < structInfo.NumField(); i++ { + var fieldInfo reflect.StructField = structInfo.Field(i) + + if fieldInfo.Anonymous && fieldInfo.Type != endType { + err := fmt.Errorf("binstruct does not support embedded fields") + return ret, fmt.Errorf("struct %q field %v %q: %w", + ret.name, i, fieldInfo.Name, err) + } + + fieldTag, err := parseStructTag(fieldInfo.Tag.Get("bin")) + if err != nil { + return ret, fmt.Errorf("struct %q field %v %q: %w", + ret.name, i, fieldInfo.Name, err) + } + if fieldTag.skip { + ret.fields = append(ret.fields, structField{ + tag: fieldTag, + name: fieldInfo.Name, + }) + continue + } + + if fieldTag.off != curOffset { + err := fmt.Errorf("tag says off=%#x but curOffset=%#x", fieldTag.off, curOffset) + return ret, fmt.Errorf("struct %q field %v %q: %w", + ret.name, i, fieldInfo.Name, err) + } + if fieldInfo.Type == endType { + endOffset = curOffset + } + + fieldSize, err := staticSize(fieldInfo.Type) + if err != nil { + return ret, fmt.Errorf("struct %q field %v %q: %w", + ret.name, i, fieldInfo.Name, err) + } + + if fieldTag.siz != fieldSize { + err := fmt.Errorf("tag says siz=%#x but StaticSize(typ)=%#x", fieldTag.siz, fieldSize) + return ret, fmt.Errorf("struct %q field %v %q: %w", + ret.name, i, fieldInfo.Name, err) + } + curOffset += fieldTag.siz + + ret.fields = append(ret.fields, structField{ + name: fieldInfo.Name, + tag: fieldTag, + }) + } + ret.Size = curOffset + + if ret.Size != endOffset { + return ret, fmt.Errorf("struct %q: .Size=%v but endOffset=%v", + ret.name, ret.Size, endOffset) + } + + return ret, nil +} + +var structCache = make(map[reflect.Type]structHandler) + +func getStructHandler(typ reflect.Type) structHandler { + h, ok := structCache[typ] + if ok { + return h + } + + h, err := genStructHandler(typ) + if err != nil { + panic(err) + } + structCache[typ] = h + return h +} diff --git a/lib/binstruct/unmarshal.go b/lib/binstruct/unmarshal.go new file mode 100644 index 0000000..1959d45 --- /dev/null +++ b/lib/binstruct/unmarshal.go @@ -0,0 +1,54 @@ +package binstruct + +import ( + "fmt" + "reflect" +) + +type Unmarshaler interface { + UnmarshalBinary([]byte) (int, error) +} + +func Unmarshal(dat []byte, dstPtr any) (int, error) { + if unmar, ok := dstPtr.(Unmarshaler); ok { + return unmar.UnmarshalBinary(dat) + } + return UnmarshalWithoutInterface(dat, dstPtr) +} + +func UnmarshalWithoutInterface(dat []byte, dstPtr any) (int, error) { + _dstPtr := reflect.ValueOf(dstPtr) + if _dstPtr.Kind() != reflect.Ptr { + return 0, fmt.Errorf("not a pointer: %v", _dstPtr.Type()) + } + dst := _dstPtr.Elem() + + switch dst.Kind() { + case reflect.Uint8, reflect.Int8, reflect.Uint16, reflect.Int16, reflect.Uint32, reflect.Int32, reflect.Uint64, reflect.Int64: + typ := intKind2Type[dst.Kind()] + newDstPtr := reflect.New(typ) + n, err := Unmarshal(dat, newDstPtr.Interface()) + dst.Set(newDstPtr.Elem().Convert(dst.Type())) + return n, err + case reflect.Ptr: + elemPtr := reflect.New(dst.Type().Elem()) + n, err := Unmarshal(dat, elemPtr.Interface()) + dst.Set(elemPtr.Convert(dst.Type())) + return n, err + case reflect.Array: + var n int + for i := 0; i < dst.Len(); i++ { + _n, err := Unmarshal(dat[n:], dst.Index(i).Addr().Interface()) + n += _n + if err != nil { + return n, err + } + } + return n, nil + case reflect.Struct: + return getStructHandler(dst.Type()).Unmarshal(dat, dst) + default: + panic(fmt.Errorf("type=%v does not implement binfmt.Unmarshaler and kind=%v is not a supported statically-sized kind", + dst.Type(), dst.Kind())) + } +} -- cgit v1.2.3-2-g168b