summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-08-01 20:28:23 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-08-01 20:30:34 -0600
commit8ac23f9c53d4dd6878e2f5814ee7ffd48524f67c (patch)
tree15cff0afa275b616d574313a29130b558b35267f
parent90c2f468869b0ef7e14955f63dca9c8dd8723b89 (diff)
decodeConfig
-rw-r--r--lib/lowmemjson/adapter_test.go7
-rw-r--r--lib/lowmemjson/decode.go48
2 files changed, 38 insertions, 17 deletions
diff --git a/lib/lowmemjson/adapter_test.go b/lib/lowmemjson/adapter_test.go
index ffdc2ee..954381a 100644
--- a/lib/lowmemjson/adapter_test.go
+++ b/lib/lowmemjson/adapter_test.go
@@ -73,6 +73,7 @@ func Unmarshal(data []byte, ptr any) error {
type _Decoder struct {
src *bufio.Reader
+ cfg decodeConfig
}
func NewDecoder(r io.Reader) *_Decoder {
@@ -81,8 +82,8 @@ func NewDecoder(r io.Reader) *_Decoder {
}
}
-func (dec *_Decoder) DisallowUnknownFields() {}
-func (dec *_Decoder) UseNumber() {}
+func (dec *_Decoder) DisallowUnknownFields() { dec.cfg.disallowUnknownFields = true }
+func (dec *_Decoder) UseNumber() { dec.cfg.useNumber = true }
func (dec *_Decoder) Buffered() io.Reader {
dat, _ := dec.src.Peek(dec.src.Buffered())
@@ -90,7 +91,7 @@ func (dec *_Decoder) Buffered() io.Reader {
}
func (dec *_Decoder) Decode(v any) error {
- return Decode(dec.src, v)
+ return _Decode(dec.src, v, dec.cfg)
}
func (dec *_Decoder) InputOffset() int64 { return 0 }
func (dec *_Decoder) More() bool {
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go
index 1604e87..c2e7ab1 100644
--- a/lib/lowmemjson/decode.go
+++ b/lib/lowmemjson/decode.go
@@ -75,7 +75,16 @@ func expectRune(r io.RuneReader, exp rune) {
}
}
+type decodeConfig struct {
+ disallowUnknownFields bool
+ useNumber bool
+}
+
func Decode(r io.RuneScanner, ptr any) (err error) {
+ return _Decode(r, ptr, decodeConfig{})
+}
+
+func _Decode(r io.RuneScanner, ptr any, cfg decodeConfig) (err error) {
ptrVal := reflect.ValueOf(ptr)
if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() {
return &json.InvalidUnmarshalError{
@@ -93,7 +102,7 @@ func Decode(r io.RuneScanner, ptr any) (err error) {
}
}()
decodeWS(r)
- decode(r, ptrVal.Elem())
+ decode(r, ptrVal.Elem(), cfg)
return nil
}
@@ -124,7 +133,7 @@ var kind2bits = map[reflect.Kind]int{
reflect.Float64: 64,
}
-func decode(r io.RuneScanner, val reflect.Value) {
+func decode(r io.RuneScanner, val reflect.Value, cfg decodeConfig) {
typ := val.Type()
switch {
case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType:
@@ -193,14 +202,14 @@ func decode(r io.RuneScanner, val reflect.Value) {
case reflect.Interface:
if val.IsNil() {
if typ == anyType {
- val.Set(reflect.ValueOf(decodeAny(r)))
+ val.Set(reflect.ValueOf(decodeAny(r, cfg)))
} else {
panic(decodeError{&json.UnsupportedTypeError{
Type: typ,
}})
}
} else {
- decode(r, val.Elem())
+ decode(r, val.Elem(), cfg)
}
case reflect.Struct:
index := indexStruct(typ)
@@ -209,6 +218,9 @@ func decode(r io.RuneScanner, val reflect.Value) {
name := nameBuf.String()
idx, ok := index.byName[name]
if !ok {
+ if cfg.disallowUnknownFields {
+ panic(decodeError{fmt.Errorf("unknown field %q", name)})
+ }
scan(r, io.Discard)
return
}
@@ -239,12 +251,12 @@ func decode(r io.RuneScanner, val reflect.Value) {
}
case '"':
decodeString(r, &buf)
- decode(&buf, fVal)
+ decode(&buf, fVal, cfg)
default:
panic(decodeError{fmt.Errorf("invalid character %q for ,string struct value", peekRune(r))})
}
} else {
- decode(r, fVal)
+ decode(r, fVal, cfg)
}
})
case reflect.Map:
@@ -288,7 +300,7 @@ func decode(r io.RuneScanner, val reflect.Value) {
}
fValPtr := reflect.New(typ.Elem())
- decode(r, fValPtr.Elem())
+ decode(r, fValPtr.Elem(), cfg)
val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
})
@@ -313,7 +325,7 @@ func decode(r io.RuneScanner, val reflect.Value) {
}
decodeArray(r, func() {
mValPtr := reflect.New(typ.Elem())
- decode(r, mValPtr.Elem())
+ decode(r, mValPtr.Elem(), cfg)
val.Set(reflect.Append(val, mValPtr.Elem()))
})
default:
@@ -324,7 +336,7 @@ func decode(r io.RuneScanner, val reflect.Value) {
i := 0
decodeArray(r, func() {
mValPtr := reflect.New(typ.Elem())
- decode(r, mValPtr.Elem())
+ decode(r, mValPtr.Elem(), cfg)
val.Index(i).Set(mValPtr.Elem())
i++
})
@@ -332,7 +344,7 @@ func decode(r io.RuneScanner, val reflect.Value) {
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}
- decode(r, val.Elem())
+ decode(r, val.Elem(), cfg)
default:
panic(decodeError{&json.UnsupportedTypeError{
Type: typ,
@@ -401,20 +413,20 @@ func scanNumber(r io.RuneScanner, out io.Writer) {
}
}
-func decodeAny(r io.RuneScanner) any {
+func decodeAny(r io.RuneScanner, cfg decodeConfig) any {
c := peekRune(r)
switch c {
case '{':
ret := make(map[string]any)
var nameBuf strings.Builder
decodeObject(r, &nameBuf, func() {
- ret[nameBuf.String()] = decodeAny(r)
+ ret[nameBuf.String()] = decodeAny(r, cfg)
})
return ret
case '[':
ret := []any{}
decodeArray(r, func() {
- ret = append(ret, decodeAny(r))
+ ret = append(ret, decodeAny(r, cfg))
})
return ret
case '"':
@@ -424,7 +436,15 @@ func decodeAny(r io.RuneScanner) any {
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
var buf strings.Builder
scanNumber(r, &buf)
- return json.Number(buf.String())
+ num := json.Number(buf.String())
+ if cfg.useNumber {
+ return num
+ }
+ f64, err := num.Float64()
+ if err != nil {
+ panic(decodeError{err})
+ }
+ return f64
case 't', 'f':
return decodeBool(r)
case 'n':