diff options
Diffstat (limited to 'lib/lowmemjson/decode.go')
-rw-r--r-- | lib/lowmemjson/decode.go | 112 |
1 files changed, 84 insertions, 28 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go index 2dee59c..f9ea8a2 100644 --- a/lib/lowmemjson/decode.go +++ b/lib/lowmemjson/decode.go @@ -78,22 +78,21 @@ func (dec *Decoder) stackPop() { dec.stack = dec.stack[:len(dec.stack)-1] } -type decodeError struct{} +type decodeError struct { + Err error +} func (dec *Decoder) panicIO(err error) { - dec.err = fmt.Errorf("json: I/O error at input byte %v: %s: %w", - dec.nxtPos, dec.stackStr(), err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w", + dec.nxtPos, dec.stackStr(), err)}) } func (dec *Decoder) panicSyntax(err error) { - dec.err = fmt.Errorf("json: syntax error at input byte %v: %s: %w", - dec.curPos, dec.stackStr(), err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w", + dec.curPos, dec.stackStr(), err)}) } func (dec *Decoder) panicType(typ reflect.Type, err error) { - dec.err = fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", - dec.curPos, dec.stackStr(), typ, err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", + dec.curPos, dec.stackStr(), typ, err)}) } func Decode(r io.Reader, ptr any) error { @@ -104,7 +103,8 @@ func (dec *Decoder) Decode(ptr any) (err error) { ptrVal := reflect.ValueOf(ptr) if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { return &json.InvalidUnmarshalError{ - Type: ptrVal.Type(), + // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil + Type: reflect.TypeOf(ptr), } } @@ -114,7 +114,8 @@ func (dec *Decoder) Decode(ptr any) (err error) { defer func() { if r := recover(); r != nil { - if _, ok := r.(decodeError); ok { + if de, ok := r.(decodeError); ok { + dec.err = de.Err err = dec.err } else { panic(r) @@ -122,7 +123,7 @@ func (dec *Decoder) Decode(ptr any) (err error) { } }() dec.decodeWS() - dec.decode(ptrVal.Elem()) + dec.decode(ptrVal.Elem(), false) return nil } @@ -226,7 +227,7 @@ var kind2bits = map[reflect.Kind]int{ reflect.Float64: 64, } -func (dec *Decoder) decode(val reflect.Value) { +func (dec *Decoder) decode(val reflect.Value, nullOK bool) { typ := val.Type() switch { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: @@ -248,6 +249,10 @@ func (dec *Decoder) decode(val reflect.Value) { dec.panicSyntax(err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf bytes.Buffer dec.decodeString(&buf) obj := val.Addr().Interface().(encoding.TextUnmarshaler) @@ -258,8 +263,16 @@ func (dec *Decoder) decode(val reflect.Value) { kind := typ.Kind() switch kind { case reflect.Bool: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } val.SetBool(dec.decodeBool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) @@ -268,6 +281,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) @@ -276,6 +293,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetUint(n) case reflect.Float32, reflect.Float64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) @@ -284,6 +305,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetFloat(n) case reflect.String: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder if typ == numberType { dec.scanNumber(&buf) @@ -301,19 +326,23 @@ func (dec *Decoder) decode(val reflect.Value) { if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer { // XXX: I can't justify this case, other than "it's what encoding/json does, but // I don't understand their rationale". - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } else { dec.decodeNull() val.Set(reflect.Zero(typ)) } default: if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } else { val.Set(reflect.ValueOf(dec.decodeAny())) } } case reflect.Struct: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } index := indexStruct(typ) var nameBuf strings.Builder dec.decodeObject(&nameBuf, func() { @@ -358,13 +387,13 @@ func (dec *Decoder) decode(val reflect.Value) { subD := *dec // capture the .curPos *before* calling .decodeString dec.decodeString(&buf) subD.r = &buf - subD.decode(fVal) + subD.decode(fVal, false) default: dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q", 'n', '"', dec.peekRune())) } } else { - dec.decode(fVal) + dec.decode(fVal, true) } }) case reflect.Map: @@ -410,7 +439,7 @@ func (dec *Decoder) decode(val reflect.Value) { defer dec.stackPop() fValPtr := reflect.New(typ.Elem()) - dec.decode(fValPtr.Elem()) + dec.decode(fValPtr.Elem(), false) val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) }) @@ -422,7 +451,16 @@ func (dec *Decoder) decode(val reflect.Value) { case typ.Elem().Kind() == reflect.Uint8: var buf bytes.Buffer dec.decodeString(newBase64Decoder(&buf)) - val.Set(reflect.ValueOf(buf.Bytes())) + if typ.Elem() == byteType { + val.Set(reflect.ValueOf(buf.Bytes())) + } else { + bs := buf.Bytes() + // TODO: Surely there's a better way. + val.Set(reflect.MakeSlice(typ, len(bs), len(bs))) + for i := 0; i < len(bs); i++ { + val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem())) + } + } default: switch dec.peekRune() { case 'n': @@ -432,12 +470,15 @@ func (dec *Decoder) decode(val reflect.Value) { if val.IsNil() { val.Set(reflect.MakeSlice(typ, 0, 0)) } + if val.Len() > 0 { + val.Set(val.Slice(0, 0)) + } i := 0 dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() mValPtr := reflect.New(typ.Elem()) - dec.decode(mValPtr.Elem()) + dec.decode(mValPtr.Elem(), false) val.Set(reflect.Append(val, mValPtr.Elem())) i++ }) @@ -446,29 +487,44 @@ func (dec *Decoder) decode(val reflect.Value) { } } case reflect.Array: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } i := 0 + n := val.Len() dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() - mValPtr := reflect.New(typ.Elem()) - dec.decode(mValPtr.Elem()) - val.Index(i).Set(mValPtr.Elem()) + if i < n { + mValPtr := reflect.New(typ.Elem()) + dec.decode(mValPtr.Elem(), false) + val.Index(i).Set(mValPtr.Elem()) + } else { + dec.scan(io.Discard) + } i++ }) + for ; i < n; i++ { + val.Index(i).Set(reflect.Zero(typ.Elem())) + } case reflect.Pointer: switch dec.peekRune() { case 'n': dec.decodeNull() - for val.IsNil() && typ.Elem().Kind() == reflect.Pointer { - val.Set(reflect.New(typ.Elem())) + for typ.Elem().Kind() == reflect.Pointer { + if val.IsNil() || !val.Elem().CanSet() { + val.Set(reflect.New(typ.Elem())) + } val = val.Elem() + typ = val.Type() } - val.Elem().Set(reflect.Zero(val.Type().Elem())) + val.Set(reflect.Zero(typ)) default: if val.IsNil() { val.Set(reflect.New(typ.Elem())) } - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } default: dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) |