diff options
Diffstat (limited to 'lib/lowmemjson/decode.go')
-rw-r--r-- | lib/lowmemjson/decode.go | 501 |
1 files changed, 501 insertions, 0 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go new file mode 100644 index 0000000..33b4222 --- /dev/null +++ b/lib/lowmemjson/decode.go @@ -0,0 +1,501 @@ +// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +type Decoder interface { + DecodeJSON(io.RuneScanner) error +} + +type decodeError struct { + Err error +} + +type runeBuffer interface { + *bytes.Buffer | *strings.Builder + WriteRune(rune) (int, error) + Reset() +} + +func readRune(r io.RuneReader) rune { + c, _, err := r.ReadRune() + if err != nil { + panic(decodeError{err}) + } + return c +} + +func readRuneOrEOF(r io.RuneReader) (c rune, ok bool) { + c, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + return 0, false + } + panic(decodeError{err}) + } + return c, true +} + +func unreadRune(r io.RuneScanner) { + if err := r.UnreadRune(); err != nil { + panic(decodeError{err}) + } +} + +func peekRune(r io.RuneScanner) rune { + c := readRune(r) + unreadRune(r) + return c +} + +func expectRune(r io.RuneReader, exp rune) { + act := readRune(r) + if act != exp { + panic(decodeError{fmt.Errorf("expected %c but got %c", exp, act)}) + } +} + +func Decode(r io.RuneScanner, ptr any) (err error) { + ptrVal := reflect.ValueOf(ptr) + if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { + return &json.InvalidUnmarshalError{ + Type: ptrVal.Type(), + } + } + + defer func() { + if r := recover(); r != nil { + if e, ok := r.(decodeError); ok { + err = e.Err + } else { + panic(r) + } + } + }() + decodeWS(r) + decode(r, ptrVal) + return nil +} + +var ( + rawMessagePtrType = reflect.TypeOf((*json.RawMessage)(nil)) + anyType = reflect.TypeOf((*any)(nil)).Elem() + decoderType = reflect.TypeOf((*Decoder)(nil)).Elem() + jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +var kind2bits = map[reflect.Kind]int{ + reflect.Int: int(32 << (^uint(0) >> 63)), + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, + + reflect.Uint: int(32 << (^uint(0) >> 63)), + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, + + reflect.Uintptr: int(32 << (^uintptr(0) >> 63)), + + reflect.Float32: 32, + reflect.Float64: 64, +} + +func decode(r io.RuneScanner, ptrVal reflect.Value) { + switch { + case ptrVal.Type() == rawMessagePtrType: + var buf bytes.Buffer + scan(r, &buf) + if err := ptrVal.Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(decoderType): + obj := ptrVal.Interface().(Decoder) + if err := obj.DecodeJSON(r); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(jsonUnmarshalerType): + var buf bytes.Buffer + scan(r, &buf) + obj := ptrVal.Interface().(json.Unmarshaler) + if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(textUnmarshalerType): + var buf bytes.Buffer + decodeString(r, &buf) + obj := ptrVal.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + default: + typ := ptrVal.Type().Elem() + kind := typ.Kind() + switch kind { + case reflect.Bool: + ptrVal.Elem().SetBool(decodeBool(r)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetUint(n) + case reflect.Float32, reflect.Float64: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetFloat(n) + case reflect.String: + var buf strings.Builder + if typ == numberType { + scanNumber(r, &buf) + ptrVal.Elem().SetString(buf.String()) + } else { + decodeString(r, &buf) + ptrVal.Elem().SetString(buf.String()) + } + case reflect.Interface: + if typ == anyType { + ptrVal.Elem().Set(reflect.ValueOf(decodeAny(r))) + } else { + panic(decodeError{&json.UnsupportedTypeError{ + Type: typ, + }}) + } + case reflect.Struct: + index := indexStruct(typ) + var nameBuf strings.Builder + decodeObject(r, &nameBuf, func() { + name := nameBuf.String() + idx, ok := index.byName[name] + if !ok { + scan(r, io.Discard) + return + } + field := index.byPos[idx] + fVal := ptrVal.Elem() + for _, idx := range field.Path { + if fVal.Kind() == reflect.Pointer { + if fVal.IsNil() { + if !fVal.CanSet() { // https://golang.org/issue/21357 + panic(decodeError{fmt.Errorf("cannot set embedded pointer to unexported type %v", fVal.Type().Elem())}) + } + fVal.Set(reflect.New(fVal.Type().Elem())) + } + fVal = fVal.Elem() + } + fVal = fVal.Field(idx) + } + decode(r, fVal.Addr()) + }) + case reflect.Map: + if ptrVal.Elem().IsNil() { + ptrVal.Elem().Set(reflect.MakeMap(typ)) + } + var nameBuf bytes.Buffer + decodeObject(r, &nameBuf, func() { + nameValTyp := typ.Key() + nameValPtr := reflect.New(nameValTyp) + switch { + case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): + obj := ptrVal.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { + panic(decodeError{err}) + } + default: + switch nameValTyp.Kind() { + case reflect.String: + nameValPtr.Elem().SetString(nameBuf.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetUint(n) + default: + panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)}) + } + } + + fValPtr := reflect.New(typ.Elem()) + decode(r, fValPtr) + + ptrVal.Elem().SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) + }) + case reflect.Slice: + if ptrVal.Elem().IsNil() { + ptrVal.Elem().Set(reflect.MakeSlice(typ.Elem(), 0, 0)) + } + decodeArray(r, func() { + mValPtr := reflect.New(typ.Elem()) + decode(r, mValPtr) + ptrVal.Set(reflect.Append(ptrVal.Elem(), mValPtr.Elem())) + }) + case reflect.Pointer: + val := reflect.New(typ.Elem()) + decode(r, val) + ptrVal.Elem().Set(val) + default: + panic(decodeError{&json.UnsupportedTypeError{ + Type: typ, + }}) + } + } +} + +func decodeWS(r io.RuneScanner) { + for { + switch readRune(r) { + // NB: The JSON definition of whitespace is more + // narrow than unicode.IsSpace + case 0x0020, 0x000A, 0x000D, 0x0009: + // do nothing + default: + unreadRune(r) + return + } + } +} + +func scan(r io.RuneScanner, out io.Writer) { + scanner := &ReEncoder{ + Out: out, + Compact: true, + } + if _, err := scanner.WriteRune(readRune(r)); err != nil { + panic(decodeError{err}) + } + scanner.bailAfterCurrent = true + var err error + for err == nil { + c, ok := readRuneOrEOF(r) + if ok { + _, err = scanner.WriteRune(c) + } else { + err = scanner.Flush() + break + } + } + if err != nil { + if err == errBailedAfterCurrent { + unreadRune(r) + } else { + panic(decodeError{err}) + } + } +} + +func scanNumber(r io.RuneScanner, out io.Writer) { + c := peekRune(r) + switch c { + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + scan(r, out) + default: + panic(decodeError{fmt.Errorf("expected a nubmer bug got %c", c)}) + } +} + +func decodeAny(r io.RuneScanner) any { + c := peekRune(r) + switch c { + case '{': + ret := make(map[string]any) + var nameBuf strings.Builder + decodeObject(r, &nameBuf, func() { + ret[nameBuf.String()] = decodeAny(r) + }) + return ret + case '[': + ret := []any{} + decodeArray(r, func() { + ret = append(ret, decodeAny(r)) + }) + return ret + case '"': + var buf strings.Builder + decodeString(r, &buf) + return buf.String() + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + var buf strings.Builder + scanNumber(r, &buf) + return json.Number(buf.String()) + case 't', 'f': + return decodeBool(r) + case 'n': + expectRune(r, 'n') + expectRune(r, 'u') + expectRune(r, 'l') + expectRune(r, 'l') + return nil + default: + panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) + } +} + +func decodeObject[bufT runeBuffer](r io.RuneScanner, nameBuf bufT, decodeKVal func()) { + expectRune(r, '{') + decodeWS(r) + c := readRune(r) + switch c { + case '"': + decodeMember: + unreadRune(r) + nameBuf.Reset() + decodeString(r, nameBuf) + decodeWS(r) + expectRune(r, ':') + decodeWS(r) + decodeKVal() + decodeWS(r) + c := readRune(r) + switch c { + case ',': + decodeWS(r) + expectRune(r, '"') + goto decodeMember + case '}': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', '}', c)}) + } + case '}': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", '"', '}', c)}) + } +} + +func decodeArray(r io.RuneScanner, decodeMember func()) { + expectRune(r, '[') + decodeWS(r) + c := readRune(r) + switch c { + case ']': + return + default: + decodeNextMember: + unreadRune(r) + decodeMember() + decodeWS(r) + c := readRune(r) + switch c { + case ',': + decodeWS(r) + goto decodeNextMember + case ']': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', ']', c)}) + } + } +} + +func decodeHex(r io.RuneReader) rune { + c := readRune(r) + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + default: + panic(decodeError{fmt.Errorf("unexpected %c in unicode literal", c)}) + } +} + +func decodeString[bufT runeBuffer](r io.RuneScanner, out bufT) { + // No need to check errors from out.WriteRune because 'out' + // guaranteed (by the 'runeBuffer' type constraint) to always + // either a *bytes.Buffer or a *string.Builder, neither of + // which return errors. + expectRune(r, '"') + for { + c := readRune(r) + switch { + case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\': + _, _ = out.WriteRune(c) + case c == '\\': + c = readRune(r) + switch c { + case '"': + _, _ = out.WriteRune('"') + case '\\': + _, _ = out.WriteRune('\\') + case 'b': + _, _ = out.WriteRune('\b') + case 'f': + _, _ = out.WriteRune('\f') + case 'n': + _, _ = out.WriteRune('\n') + case 'r': + _, _ = out.WriteRune('\r') + case 't': + _, _ = out.WriteRune('\t') + case 'u': + c = decodeHex(r) + c = (c << 4) | decodeHex(r) + c = (c << 4) | decodeHex(r) + c = (c << 4) | decodeHex(r) + _, _ = out.WriteRune(c) + } + case c == '"': + return + default: + panic(decodeError{fmt.Errorf("unexpected %c in string", c)}) + } + } +} + +func decodeBool(r io.RuneReader) bool { + c := readRune(r) + switch c { + case 't': + expectRune(r, 'r') + expectRune(r, 'u') + expectRune(r, 'e') + return true + case 'f': + expectRune(r, 'a') + expectRune(r, 'l') + expectRune(r, 's') + expectRune(r, 'e') + return false + default: + panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) + } +} |