summaryrefslogtreecommitdiff
path: root/lib/lowmemjson/decode.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/lowmemjson/decode.go')
-rw-r--r--lib/lowmemjson/decode.go112
1 files changed, 84 insertions, 28 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go
index 2dee59c..f9ea8a2 100644
--- a/lib/lowmemjson/decode.go
+++ b/lib/lowmemjson/decode.go
@@ -78,22 +78,21 @@ func (dec *Decoder) stackPop() {
dec.stack = dec.stack[:len(dec.stack)-1]
}
-type decodeError struct{}
+type decodeError struct {
+ Err error
+}
func (dec *Decoder) panicIO(err error) {
- dec.err = fmt.Errorf("json: I/O error at input byte %v: %s: %w",
- dec.nxtPos, dec.stackStr(), err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w",
+ dec.nxtPos, dec.stackStr(), err)})
}
func (dec *Decoder) panicSyntax(err error) {
- dec.err = fmt.Errorf("json: syntax error at input byte %v: %s: %w",
- dec.curPos, dec.stackStr(), err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w",
+ dec.curPos, dec.stackStr(), err)})
}
func (dec *Decoder) panicType(typ reflect.Type, err error) {
- dec.err = fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w",
- dec.curPos, dec.stackStr(), typ, err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w",
+ dec.curPos, dec.stackStr(), typ, err)})
}
func Decode(r io.Reader, ptr any) error {
@@ -104,7 +103,8 @@ func (dec *Decoder) Decode(ptr any) (err error) {
ptrVal := reflect.ValueOf(ptr)
if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() {
return &json.InvalidUnmarshalError{
- Type: ptrVal.Type(),
+ // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil
+ Type: reflect.TypeOf(ptr),
}
}
@@ -114,7 +114,8 @@ func (dec *Decoder) Decode(ptr any) (err error) {
defer func() {
if r := recover(); r != nil {
- if _, ok := r.(decodeError); ok {
+ if de, ok := r.(decodeError); ok {
+ dec.err = de.Err
err = dec.err
} else {
panic(r)
@@ -122,7 +123,7 @@ func (dec *Decoder) Decode(ptr any) (err error) {
}
}()
dec.decodeWS()
- dec.decode(ptrVal.Elem())
+ dec.decode(ptrVal.Elem(), false)
return nil
}
@@ -226,7 +227,7 @@ var kind2bits = map[reflect.Kind]int{
reflect.Float64: 64,
}
-func (dec *Decoder) decode(val reflect.Value) {
+func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
typ := val.Type()
switch {
case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType:
@@ -248,6 +249,10 @@ func (dec *Decoder) decode(val reflect.Value) {
dec.panicSyntax(err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType):
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf bytes.Buffer
dec.decodeString(&buf)
obj := val.Addr().Interface().(encoding.TextUnmarshaler)
@@ -258,8 +263,16 @@ func (dec *Decoder) decode(val reflect.Value) {
kind := typ.Kind()
switch kind {
case reflect.Bool:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
val.SetBool(dec.decodeBool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind])
@@ -268,6 +281,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind])
@@ -276,6 +293,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetUint(n)
case reflect.Float32, reflect.Float64:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseFloat(buf.String(), kind2bits[kind])
@@ -284,6 +305,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetFloat(n)
case reflect.String:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
if typ == numberType {
dec.scanNumber(&buf)
@@ -301,19 +326,23 @@ func (dec *Decoder) decode(val reflect.Value) {
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer {
// XXX: I can't justify this case, other than "it's what encoding/json does, but
// I don't understand their rationale".
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
} else {
dec.decodeNull()
val.Set(reflect.Zero(typ))
}
default:
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer {
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
} else {
val.Set(reflect.ValueOf(dec.decodeAny()))
}
}
case reflect.Struct:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
index := indexStruct(typ)
var nameBuf strings.Builder
dec.decodeObject(&nameBuf, func() {
@@ -358,13 +387,13 @@ func (dec *Decoder) decode(val reflect.Value) {
subD := *dec // capture the .curPos *before* calling .decodeString
dec.decodeString(&buf)
subD.r = &buf
- subD.decode(fVal)
+ subD.decode(fVal, false)
default:
dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q",
'n', '"', dec.peekRune()))
}
} else {
- dec.decode(fVal)
+ dec.decode(fVal, true)
}
})
case reflect.Map:
@@ -410,7 +439,7 @@ func (dec *Decoder) decode(val reflect.Value) {
defer dec.stackPop()
fValPtr := reflect.New(typ.Elem())
- dec.decode(fValPtr.Elem())
+ dec.decode(fValPtr.Elem(), false)
val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
})
@@ -422,7 +451,16 @@ func (dec *Decoder) decode(val reflect.Value) {
case typ.Elem().Kind() == reflect.Uint8:
var buf bytes.Buffer
dec.decodeString(newBase64Decoder(&buf))
- val.Set(reflect.ValueOf(buf.Bytes()))
+ if typ.Elem() == byteType {
+ val.Set(reflect.ValueOf(buf.Bytes()))
+ } else {
+ bs := buf.Bytes()
+ // TODO: Surely there's a better way.
+ val.Set(reflect.MakeSlice(typ, len(bs), len(bs)))
+ for i := 0; i < len(bs); i++ {
+ val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem()))
+ }
+ }
default:
switch dec.peekRune() {
case 'n':
@@ -432,12 +470,15 @@ func (dec *Decoder) decode(val reflect.Value) {
if val.IsNil() {
val.Set(reflect.MakeSlice(typ, 0, 0))
}
+ if val.Len() > 0 {
+ val.Set(val.Slice(0, 0))
+ }
i := 0
dec.decodeArray(func() {
dec.stackPush(i)
defer dec.stackPop()
mValPtr := reflect.New(typ.Elem())
- dec.decode(mValPtr.Elem())
+ dec.decode(mValPtr.Elem(), false)
val.Set(reflect.Append(val, mValPtr.Elem()))
i++
})
@@ -446,29 +487,44 @@ func (dec *Decoder) decode(val reflect.Value) {
}
}
case reflect.Array:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
i := 0
+ n := val.Len()
dec.decodeArray(func() {
dec.stackPush(i)
defer dec.stackPop()
- mValPtr := reflect.New(typ.Elem())
- dec.decode(mValPtr.Elem())
- val.Index(i).Set(mValPtr.Elem())
+ if i < n {
+ mValPtr := reflect.New(typ.Elem())
+ dec.decode(mValPtr.Elem(), false)
+ val.Index(i).Set(mValPtr.Elem())
+ } else {
+ dec.scan(io.Discard)
+ }
i++
})
+ for ; i < n; i++ {
+ val.Index(i).Set(reflect.Zero(typ.Elem()))
+ }
case reflect.Pointer:
switch dec.peekRune() {
case 'n':
dec.decodeNull()
- for val.IsNil() && typ.Elem().Kind() == reflect.Pointer {
- val.Set(reflect.New(typ.Elem()))
+ for typ.Elem().Kind() == reflect.Pointer {
+ if val.IsNil() || !val.Elem().CanSet() {
+ val.Set(reflect.New(typ.Elem()))
+ }
val = val.Elem()
+ typ = val.Type()
}
- val.Elem().Set(reflect.Zero(val.Type().Elem()))
+ val.Set(reflect.Zero(typ))
default:
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
}
default:
dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))