diff options
Diffstat (limited to 'lib/lowmemjson/decode.go')
-rw-r--r-- | lib/lowmemjson/decode.go | 221 |
1 files changed, 138 insertions, 83 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go index 4873d43..b2eaacf 100644 --- a/lib/lowmemjson/decode.go +++ b/lib/lowmemjson/decode.go @@ -24,7 +24,7 @@ type decodeError struct { } type runeBuffer interface { - *bytes.Buffer | *strings.Builder + io.Writer WriteRune(rune) (int, error) Reset() } @@ -93,7 +93,7 @@ func Decode(r io.RuneScanner, ptr any) (err error) { } }() decodeWS(r) - decode(r, ptrVal) + decode(r, ptrVal.Elem()) return nil } @@ -124,39 +124,39 @@ var kind2bits = map[reflect.Kind]int{ reflect.Float64: 64, } -func decode(r io.RuneScanner, ptrVal reflect.Value) { +func decode(r io.RuneScanner, val reflect.Value) { + typ := val.Type() switch { - case ptrVal.Type() == rawMessagePtrType: + case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: var buf bytes.Buffer scan(r, &buf) - if err := ptrVal.Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { + if err := val.Addr().Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { panic(decodeError{err}) } - case ptrVal.Type().Implements(decoderType): - obj := ptrVal.Interface().(Decoder) + case val.CanAddr() && reflect.PointerTo(typ).Implements(decoderType): + obj := val.Addr().Interface().(Decoder) if err := obj.DecodeJSON(r); err != nil { panic(decodeError{err}) } - case ptrVal.Type().Implements(jsonUnmarshalerType): + case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): var buf bytes.Buffer scan(r, &buf) - obj := ptrVal.Interface().(json.Unmarshaler) + obj := val.Addr().Interface().(json.Unmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { panic(decodeError{err}) } - case ptrVal.Type().Implements(textUnmarshalerType): + case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): var buf bytes.Buffer decodeString(r, &buf) - obj := ptrVal.Interface().(encoding.TextUnmarshaler) + obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { panic(decodeError{err}) } default: - typ := ptrVal.Type().Elem() kind := typ.Kind() switch kind { case reflect.Bool: - ptrVal.Elem().SetBool(decodeBool(r)) + val.SetBool(decodeBool(r)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: var buf strings.Builder scanNumber(r, &buf) @@ -164,7 +164,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) { if err != nil { panic(decodeError{err}) } - ptrVal.Elem().SetInt(n) + val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: var buf strings.Builder scanNumber(r, &buf) @@ -172,7 +172,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) { if err != nil { panic(decodeError{err}) } - ptrVal.Elem().SetUint(n) + val.SetUint(n) case reflect.Float32, reflect.Float64: var buf strings.Builder scanNumber(r, &buf) @@ -180,23 +180,27 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) { if err != nil { panic(decodeError{err}) } - ptrVal.Elem().SetFloat(n) + val.SetFloat(n) case reflect.String: var buf strings.Builder if typ == numberType { scanNumber(r, &buf) - ptrVal.Elem().SetString(buf.String()) + val.SetString(buf.String()) } else { decodeString(r, &buf) - ptrVal.Elem().SetString(buf.String()) + val.SetString(buf.String()) } case reflect.Interface: - if typ == anyType { - ptrVal.Elem().Set(reflect.ValueOf(decodeAny(r))) + if val.IsNil() { + if typ == anyType { + val.Set(reflect.ValueOf(decodeAny(r))) + } else { + panic(decodeError{&json.UnsupportedTypeError{ + Type: typ, + }}) + } } else { - panic(decodeError{&json.UnsupportedTypeError{ - Type: typ, - }}) + decode(r, val.Elem()) } case reflect.Struct: index := indexStruct(typ) @@ -209,7 +213,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) { return } field := index.byPos[idx] - fVal := ptrVal.Elem() + fVal := val for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { if fVal.IsNil() { @@ -225,58 +229,91 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) { decode(r, fVal.Addr()) }) case reflect.Map: - if ptrVal.Elem().IsNil() { - ptrVal.Elem().Set(reflect.MakeMap(typ)) - } - var nameBuf bytes.Buffer - decodeObject(r, &nameBuf, func() { - nameValTyp := typ.Key() - nameValPtr := reflect.New(nameValTyp) - switch { - case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): - obj := ptrVal.Interface().(encoding.TextUnmarshaler) - if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { - panic(decodeError{err}) - } - default: - switch nameValTyp.Kind() { - case reflect.String: - nameValPtr.Elem().SetString(nameBuf.String()) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) - if err != nil { - panic(decodeError{err}) - } - nameValPtr.Elem().SetInt(n) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) - if err != nil { + switch peekRune(r) { + case 'n': + decodeNull(r) + val.Set(reflect.Zero(typ)) + case '{': + if val.IsNil() { + val.Set(reflect.MakeMap(typ)) + } + var nameBuf bytes.Buffer + decodeObject(r, &nameBuf, func() { + nameValTyp := typ.Key() + nameValPtr := reflect.New(nameValTyp) + switch { + case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): + obj := nameValPtr.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { panic(decodeError{err}) } - nameValPtr.Elem().SetUint(n) default: - panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)}) + switch nameValTyp.Kind() { + case reflect.String: + nameValPtr.Elem().SetString(nameBuf.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetUint(n) + default: + panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)}) + } } - } - fValPtr := reflect.New(typ.Elem()) - decode(r, fValPtr) + fValPtr := reflect.New(typ.Elem()) + decode(r, fValPtr) - ptrVal.Elem().SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) - }) + val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) + }) + default: + panic(decodeError{fmt.Errorf("invalid character %q for map value", peekRune(r))}) + } case reflect.Slice: - if ptrVal.Elem().IsNil() { - ptrVal.Elem().Set(reflect.MakeSlice(typ.Elem(), 0, 0)) + switch { + case typ.Elem().Kind() == reflect.Uint8: + var buf bytes.Buffer + dec := newBase64Decoder(&buf) + decodeString(r, dec) + val.Set(reflect.ValueOf(buf.Bytes())) + default: + switch peekRune(r) { + case 'n': + decodeNull(r) + val.Set(reflect.Zero(typ)) + case '[': + if val.IsNil() { + val.Set(reflect.MakeSlice(typ, 0, 0)) + } + decodeArray(r, func() { + mValPtr := reflect.New(typ.Elem()) + decode(r, mValPtr) + val.Set(reflect.Append(val, mValPtr.Elem())) + }) + default: + panic(decodeError{fmt.Errorf("invalid character %q for slice value", peekRune(r))}) + } } + case reflect.Array: + i := 0 decodeArray(r, func() { mValPtr := reflect.New(typ.Elem()) decode(r, mValPtr) - ptrVal.Set(reflect.Append(ptrVal.Elem(), mValPtr.Elem())) + val.Index(i).Set(mValPtr.Elem()) + i++ }) case reflect.Pointer: - val := reflect.New(typ.Elem()) - decode(r, val) - ptrVal.Elem().Set(val) + if val.IsNil() { + val.Set(reflect.New(typ.Elem())) + } + decode(r, val.Elem()) default: panic(decodeError{&json.UnsupportedTypeError{ Type: typ, @@ -337,7 +374,7 @@ func scanNumber(r io.RuneScanner, out io.Writer) { case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': scan(r, out) default: - panic(decodeError{fmt.Errorf("expected a nubmer bug got %c", c)}) + panic(decodeError{fmt.Errorf("expected a nubmer but got %c", c)}) } } @@ -368,17 +405,14 @@ func decodeAny(r io.RuneScanner) any { case 't', 'f': return decodeBool(r) case 'n': - expectRune(r, 'n') - expectRune(r, 'u') - expectRune(r, 'l') - expectRune(r, 'l') + decodeNull(r) return nil default: panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) } } -func decodeObject[bufT runeBuffer](r io.RuneScanner, nameBuf bufT, decodeKVal func()) { +func decodeObject(r io.RuneScanner, nameBuf runeBuffer, decodeKVal func()) { expectRune(r, '{') decodeWS(r) c := readRune(r) @@ -450,40 +484,54 @@ func decodeHex(r io.RuneReader) rune { } } -func decodeString[bufT runeBuffer](r io.RuneScanner, out bufT) { - // No need to check errors from out.WriteRune because 'out' - // guaranteed (by the 'runeBuffer' type constraint) to always - // either a *bytes.Buffer or a *string.Builder, neither of - // which return errors. +func decodeString(r io.RuneScanner, out io.Writer) { expectRune(r, '"') for { c := readRune(r) switch { case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\': - _, _ = out.WriteRune(c) + if _, err := writeRune(out, c); err != nil { + panic(decodeError{err}) + } case c == '\\': c = readRune(r) switch c { case '"': - _, _ = out.WriteRune('"') + if _, err := writeRune(out, '"'); err != nil { + panic(decodeError{err}) + } case '\\': - _, _ = out.WriteRune('\\') + if _, err := writeRune(out, '\\'); err != nil { + panic(decodeError{err}) + } case 'b': - _, _ = out.WriteRune('\b') + if _, err := writeRune(out, '\b'); err != nil { + panic(decodeError{err}) + } case 'f': - _, _ = out.WriteRune('\f') + if _, err := writeRune(out, '\f'); err != nil { + panic(decodeError{err}) + } case 'n': - _, _ = out.WriteRune('\n') + if _, err := writeRune(out, '\n'); err != nil { + panic(decodeError{err}) + } case 'r': - _, _ = out.WriteRune('\r') + if _, err := writeRune(out, '\r'); err != nil { + panic(decodeError{err}) + } case 't': - _, _ = out.WriteRune('\t') + if _, err := writeRune(out, '\t'); err != nil { + panic(decodeError{err}) + } case 'u': c = decodeHex(r) c = (c << 4) | decodeHex(r) c = (c << 4) | decodeHex(r) c = (c << 4) | decodeHex(r) - _, _ = out.WriteRune(c) + if _, err := writeRune(out, c); err != nil { + panic(decodeError{err}) + } } case c == '"': return @@ -511,3 +559,10 @@ func decodeBool(r io.RuneReader) bool { panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) } } + +func decodeNull(r io.RuneReader) { + expectRune(r, 'n') + expectRune(r, 'u') + expectRune(r, 'l') + expectRune(r, 'l') +} |