From 4f05919a0f2695934df2e67399b507896b52c3bc Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Mon, 9 Jan 2023 03:05:50 -0700 Subject: binstruct: Optimize based on the CPU profiler when running scandevices --- lib/binstruct/structs.go | 10 +++--- lib/binstruct/unmarshal.go | 77 +++++++++++++++++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 22 deletions(-) (limited to 'lib/binstruct') diff --git a/lib/binstruct/structs.go b/lib/binstruct/structs.go index 52e5406..91bfec7 100644 --- a/lib/binstruct/structs.go +++ b/lib/binstruct/structs.go @@ -69,7 +69,8 @@ type structHandler struct { } type structField struct { - name string + name string + isUnmarshaler bool tag } @@ -82,7 +83,7 @@ func (sh structHandler) Unmarshal(dat []byte, dst reflect.Value) (int, error) { if field.skip { continue } - _n, err := Unmarshal(dat[n:], dst.Field(i).Addr().Interface()) + _n, err := unmarshal(dat[n:], dst.Field(i), field.isUnmarshaler) if err != nil { if _n >= 0 { n += _n @@ -166,8 +167,9 @@ func genStructHandler(structInfo reflect.Type) (structHandler, error) { curOffset += fieldTag.siz ret.fields = append(ret.fields, structField{ - name: fieldInfo.Name, - tag: fieldTag, + name: fieldInfo.Name, + isUnmarshaler: reflect.PtrTo(fieldInfo.Type).Implements(unmarshalerType), + tag: fieldTag, }) } ret.Size = curOffset diff --git a/lib/binstruct/unmarshal.go b/lib/binstruct/unmarshal.go index 4cb8a59..eae4b84 100644 --- a/lib/binstruct/unmarshal.go +++ b/lib/binstruct/unmarshal.go @@ -32,6 +32,23 @@ func Unmarshal(dat []byte, dstPtr any) (int, error) { return UnmarshalWithoutInterface(dat, dstPtr) } +// unmarshal is like Unmarshal, but for internal use to avoid some +// slow round-tripping between `any` and `reflect.Value`. +func unmarshal(dat []byte, dst reflect.Value, isUnmarshaler bool) (int, error) { + if isUnmarshaler { + n, err := dst.Addr().Interface().(Unmarshaler).UnmarshalBinary(dat) + if err != nil { + err = &UnmarshalError{ + Type: reflect.PtrTo(dst.Type()), + Method: "UnmarshalBinary", + Err: err, + } + } + return n, err + } + return unmarshalWithoutInterface(dat, dst) +} + func UnmarshalWithoutInterface(dat []byte, dstPtr any) (int, error) { _dstPtr := reflect.ValueOf(dstPtr) if _dstPtr.Kind() != reflect.Ptr { @@ -40,46 +57,70 @@ func UnmarshalWithoutInterface(dat []byte, dstPtr any) (int, error) { Err: errors.New("not a pointer"), }) } - dst := _dstPtr.Elem() + return unmarshalWithoutInterface(dat, _dstPtr.Elem()) +} +func unmarshalWithoutInterface(dat []byte, dst reflect.Value) (int, error) { switch dst.Kind() { - case reflect.Uint8, reflect.Int8: + case reflect.Uint8: + if err := binutil.NeedNBytes(dat, sizeof8); err != nil { + return 0, err + } + dst.SetUint(uint64(dat[0])) + return sizeof8, nil + case reflect.Int8: if err := binutil.NeedNBytes(dat, sizeof8); err != nil { return 0, err } - val := reflect.ValueOf(dat[0]) - dst.Set(val.Convert(dst.Type())) + dst.SetInt(int64(dat[0])) return sizeof8, nil - case reflect.Uint16, reflect.Int16: + case reflect.Uint16: if err := binutil.NeedNBytes(dat, sizeof16); err != nil { return 0, err } - val := reflect.ValueOf(binary.LittleEndian.Uint16(dat[:sizeof16])) - dst.Set(val.Convert(dst.Type())) + dst.SetUint(uint64(binary.LittleEndian.Uint16(dat[:sizeof16]))) return sizeof16, nil - case reflect.Uint32, reflect.Int32: + case reflect.Int16: + if err := binutil.NeedNBytes(dat, sizeof16); err != nil { + return 0, err + } + dst.SetInt(int64(binary.LittleEndian.Uint16(dat[:sizeof16]))) + return sizeof16, nil + case reflect.Uint32: + if err := binutil.NeedNBytes(dat, sizeof32); err != nil { + return 0, err + } + dst.SetUint(uint64(binary.LittleEndian.Uint32(dat[:sizeof32]))) + return sizeof32, nil + case reflect.Int32: if err := binutil.NeedNBytes(dat, sizeof32); err != nil { return 0, err } - val := reflect.ValueOf(binary.LittleEndian.Uint32(dat[:sizeof32])) - dst.Set(val.Convert(dst.Type())) + dst.SetInt(int64(binary.LittleEndian.Uint32(dat[:sizeof32]))) return sizeof32, nil - case reflect.Uint64, reflect.Int64: + case reflect.Uint64: if err := binutil.NeedNBytes(dat, sizeof64); err != nil { return 0, err } - val := reflect.ValueOf(binary.LittleEndian.Uint64(dat[:sizeof64])) - dst.Set(val.Convert(dst.Type())) + dst.SetUint(binary.LittleEndian.Uint64(dat[:sizeof64])) + return sizeof64, nil + case reflect.Int64: + if err := binutil.NeedNBytes(dat, sizeof64); err != nil { + return 0, err + } + dst.SetInt(int64(binary.LittleEndian.Uint64(dat[:sizeof64]))) return sizeof64, nil case reflect.Ptr: - elemPtr := reflect.New(dst.Type().Elem()) - n, err := Unmarshal(dat, elemPtr.Interface()) - dst.Set(elemPtr.Convert(dst.Type())) + typ := dst.Type() + elemPtr := reflect.New(typ.Elem()) + n, err := unmarshal(dat, elemPtr.Elem(), typ.Implements(unmarshalerType)) + dst.SetPointer(elemPtr.UnsafePointer()) return n, err case reflect.Array: + isUnmarshaler := dst.Type().Elem().Implements(unmarshalerType) var n int for i := 0; i < dst.Len(); i++ { - _n, err := Unmarshal(dat[n:], dst.Index(i).Addr().Interface()) + _n, err := unmarshal(dat[n:], dst.Index(i), isUnmarshaler) n += _n if err != nil { return n, err @@ -90,7 +131,7 @@ func UnmarshalWithoutInterface(dat []byte, dstPtr any) (int, error) { return getStructHandler(dst.Type()).Unmarshal(dat, dst) default: panic(&InvalidTypeError{ - Type: _dstPtr.Type(), + Type: reflect.PtrTo(dst.Type()), Err: fmt.Errorf("does not implement binfmt.Unmarshaler and kind=%v is not a supported statically-sized kind", dst.Kind()), }) -- cgit v1.2.3-2-g168b