From 8ac23f9c53d4dd6878e2f5814ee7ffd48524f67c Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Mon, 1 Aug 2022 20:28:23 -0600 Subject: decodeConfig --- lib/lowmemjson/adapter_test.go | 7 +++--- lib/lowmemjson/decode.go | 48 ++++++++++++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/lib/lowmemjson/adapter_test.go b/lib/lowmemjson/adapter_test.go index ffdc2ee..954381a 100644 --- a/lib/lowmemjson/adapter_test.go +++ b/lib/lowmemjson/adapter_test.go @@ -73,6 +73,7 @@ func Unmarshal(data []byte, ptr any) error { type _Decoder struct { src *bufio.Reader + cfg decodeConfig } func NewDecoder(r io.Reader) *_Decoder { @@ -81,8 +82,8 @@ func NewDecoder(r io.Reader) *_Decoder { } } -func (dec *_Decoder) DisallowUnknownFields() {} -func (dec *_Decoder) UseNumber() {} +func (dec *_Decoder) DisallowUnknownFields() { dec.cfg.disallowUnknownFields = true } +func (dec *_Decoder) UseNumber() { dec.cfg.useNumber = true } func (dec *_Decoder) Buffered() io.Reader { dat, _ := dec.src.Peek(dec.src.Buffered()) @@ -90,7 +91,7 @@ func (dec *_Decoder) Buffered() io.Reader { } func (dec *_Decoder) Decode(v any) error { - return Decode(dec.src, v) + return _Decode(dec.src, v, dec.cfg) } func (dec *_Decoder) InputOffset() int64 { return 0 } func (dec *_Decoder) More() bool { diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go index 1604e87..c2e7ab1 100644 --- a/lib/lowmemjson/decode.go +++ b/lib/lowmemjson/decode.go @@ -75,7 +75,16 @@ func expectRune(r io.RuneReader, exp rune) { } } +type decodeConfig struct { + disallowUnknownFields bool + useNumber bool +} + func Decode(r io.RuneScanner, ptr any) (err error) { + return _Decode(r, ptr, decodeConfig{}) +} + +func _Decode(r io.RuneScanner, ptr any, cfg decodeConfig) (err error) { ptrVal := reflect.ValueOf(ptr) if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { return &json.InvalidUnmarshalError{ @@ -93,7 +102,7 @@ func Decode(r io.RuneScanner, ptr any) (err error) { } }() decodeWS(r) - decode(r, ptrVal.Elem()) + decode(r, ptrVal.Elem(), cfg) return nil } @@ -124,7 +133,7 @@ var kind2bits = map[reflect.Kind]int{ reflect.Float64: 64, } -func decode(r io.RuneScanner, val reflect.Value) { +func decode(r io.RuneScanner, val reflect.Value, cfg decodeConfig) { typ := val.Type() switch { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: @@ -193,14 +202,14 @@ func decode(r io.RuneScanner, val reflect.Value) { case reflect.Interface: if val.IsNil() { if typ == anyType { - val.Set(reflect.ValueOf(decodeAny(r))) + val.Set(reflect.ValueOf(decodeAny(r, cfg))) } else { panic(decodeError{&json.UnsupportedTypeError{ Type: typ, }}) } } else { - decode(r, val.Elem()) + decode(r, val.Elem(), cfg) } case reflect.Struct: index := indexStruct(typ) @@ -209,6 +218,9 @@ func decode(r io.RuneScanner, val reflect.Value) { name := nameBuf.String() idx, ok := index.byName[name] if !ok { + if cfg.disallowUnknownFields { + panic(decodeError{fmt.Errorf("unknown field %q", name)}) + } scan(r, io.Discard) return } @@ -239,12 +251,12 @@ func decode(r io.RuneScanner, val reflect.Value) { } case '"': decodeString(r, &buf) - decode(&buf, fVal) + decode(&buf, fVal, cfg) default: panic(decodeError{fmt.Errorf("invalid character %q for ,string struct value", peekRune(r))}) } } else { - decode(r, fVal) + decode(r, fVal, cfg) } }) case reflect.Map: @@ -288,7 +300,7 @@ func decode(r io.RuneScanner, val reflect.Value) { } fValPtr := reflect.New(typ.Elem()) - decode(r, fValPtr.Elem()) + decode(r, fValPtr.Elem(), cfg) val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) }) @@ -313,7 +325,7 @@ func decode(r io.RuneScanner, val reflect.Value) { } decodeArray(r, func() { mValPtr := reflect.New(typ.Elem()) - decode(r, mValPtr.Elem()) + decode(r, mValPtr.Elem(), cfg) val.Set(reflect.Append(val, mValPtr.Elem())) }) default: @@ -324,7 +336,7 @@ func decode(r io.RuneScanner, val reflect.Value) { i := 0 decodeArray(r, func() { mValPtr := reflect.New(typ.Elem()) - decode(r, mValPtr.Elem()) + decode(r, mValPtr.Elem(), cfg) val.Index(i).Set(mValPtr.Elem()) i++ }) @@ -332,7 +344,7 @@ func decode(r io.RuneScanner, val reflect.Value) { if val.IsNil() { val.Set(reflect.New(typ.Elem())) } - decode(r, val.Elem()) + decode(r, val.Elem(), cfg) default: panic(decodeError{&json.UnsupportedTypeError{ Type: typ, @@ -401,20 +413,20 @@ func scanNumber(r io.RuneScanner, out io.Writer) { } } -func decodeAny(r io.RuneScanner) any { +func decodeAny(r io.RuneScanner, cfg decodeConfig) any { c := peekRune(r) switch c { case '{': ret := make(map[string]any) var nameBuf strings.Builder decodeObject(r, &nameBuf, func() { - ret[nameBuf.String()] = decodeAny(r) + ret[nameBuf.String()] = decodeAny(r, cfg) }) return ret case '[': ret := []any{} decodeArray(r, func() { - ret = append(ret, decodeAny(r)) + ret = append(ret, decodeAny(r, cfg)) }) return ret case '"': @@ -424,7 +436,15 @@ func decodeAny(r io.RuneScanner) any { case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': var buf strings.Builder scanNumber(r, &buf) - return json.Number(buf.String()) + num := json.Number(buf.String()) + if cfg.useNumber { + return num + } + f64, err := num.Float64() + if err != nil { + panic(decodeError{err}) + } + return f64 case 't', 'f': return decodeBool(r) case 'n': -- cgit v1.2.3-2-g168b