// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package binstruct import ( "fmt" "reflect" "strconv" "strings" "git.lukeshu.com/btrfs-progs-ng/lib/binstruct/binutil" ) type End struct{} var endType = reflect.TypeOf(End{}) type tag struct { skip bool off int siz int } func parseStructTag(str string) (tag, error) { var ret tag for _, part := range strings.Split(str, ",") { part = strings.TrimSpace(part) if part == "" { continue } if part == "-" { return tag{skip: true}, nil } keyval := strings.SplitN(part, "=", 2) if len(keyval) != 2 { return tag{}, fmt.Errorf("option is not a key=value pair: %q", part) } key := keyval[0] val := keyval[1] switch key { case "off": vint, err := strconv.ParseInt(val, 0, 0) if err != nil { return tag{}, err } ret.off = int(vint) case "siz": vint, err := strconv.ParseInt(val, 0, 0) if err != nil { return tag{}, err } ret.siz = int(vint) default: return tag{}, fmt.Errorf("unrecognized option %q", key) } } return ret, nil } type structHandler struct { name string Size int fields []structField } type structField struct { name string tag } func (sh structHandler) Unmarshal(dat []byte, dst reflect.Value) (int, error) { if err := binutil.NeedNBytes(dat, sh.Size); err != nil { return 0, fmt.Errorf("struct %q %w", sh.name, err) } var n int for i, field := range sh.fields { if field.skip { continue } _n, err := Unmarshal(dat[n:], dst.Field(i).Addr().Interface()) if err != nil { if _n >= 0 { n += _n } return n, fmt.Errorf("struct %q field %v %q: %w", sh.name, i, field.name, err) } if _n != field.siz { return n, fmt.Errorf("struct %q field %v %q: consumed %v bytes but should have consumed %v bytes", sh.name, i, field.name, _n, field.siz) } n += _n } return n, nil } func (sh structHandler) Marshal(val reflect.Value) ([]byte, error) { ret := make([]byte, 0, sh.Size) for i, field := range sh.fields { if field.skip { continue } bs, err := Marshal(val.Field(i).Interface()) ret = append(ret, bs...) if err != nil { return ret, fmt.Errorf("struct %q field %v %q: %w", sh.name, i, field.name, err) } } return ret, nil } func genStructHandler(structInfo reflect.Type) (structHandler, error) { var ret structHandler ret.name = structInfo.String() var curOffset, endOffset int for i := 0; i < structInfo.NumField(); i++ { fieldInfo := structInfo.Field(i) if fieldInfo.Anonymous && fieldInfo.Type != endType { err := fmt.Errorf("binstruct does not support embedded fields") return ret, fmt.Errorf("struct %q field %v %q: %w", ret.name, i, fieldInfo.Name, err) } fieldTag, err := parseStructTag(fieldInfo.Tag.Get("bin")) if err != nil { return ret, fmt.Errorf("struct %q field %v %q: %w", ret.name, i, fieldInfo.Name, err) } if fieldTag.skip { ret.fields = append(ret.fields, structField{ tag: fieldTag, name: fieldInfo.Name, }) continue } if fieldTag.off != curOffset { err := fmt.Errorf("tag says off=%#x but curOffset=%#x", fieldTag.off, curOffset) return ret, fmt.Errorf("struct %q field %v %q: %w", ret.name, i, fieldInfo.Name, err) } if fieldInfo.Type == endType { endOffset = curOffset } fieldSize, err := staticSize(fieldInfo.Type) if err != nil { return ret, fmt.Errorf("struct %q field %v %q: %w", ret.name, i, fieldInfo.Name, err) } if fieldTag.siz != fieldSize { err := fmt.Errorf("tag says siz=%#x but StaticSize(typ)=%#x", fieldTag.siz, fieldSize) return ret, fmt.Errorf("struct %q field %v %q: %w", ret.name, i, fieldInfo.Name, err) } curOffset += fieldTag.siz ret.fields = append(ret.fields, structField{ name: fieldInfo.Name, tag: fieldTag, }) } ret.Size = curOffset if ret.Size != endOffset { return ret, fmt.Errorf("struct %q: .Size=%v but endOffset=%v", ret.name, ret.Size, endOffset) } return ret, nil } var structCache = make(map[reflect.Type]structHandler) func getStructHandler(typ reflect.Type) structHandler { h, ok := structCache[typ] if ok { return h } h, err := genStructHandler(typ) if err != nil { panic(&InvalidTypeError{ Type: typ, Err: err, }) } structCache[typ] = h return h }