From ab1da5feecf7f05233187424effa10637247c218 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sat, 30 Jul 2022 15:11:13 -0600 Subject: wip decode --- lib/lowmemjson/decode.go | 501 +++++++++++++++++++++++++++++++++++++++++++++ lib/lowmemjson/encode.go | 2 +- lib/lowmemjson/reencode.go | 7 + lib/lowmemjson/struct.go | 20 +- 4 files changed, 524 insertions(+), 6 deletions(-) create mode 100644 lib/lowmemjson/decode.go diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go new file mode 100644 index 0000000..33b4222 --- /dev/null +++ b/lib/lowmemjson/decode.go @@ -0,0 +1,501 @@ +// Copyright (C) 2022 Luke Shumaker +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +type Decoder interface { + DecodeJSON(io.RuneScanner) error +} + +type decodeError struct { + Err error +} + +type runeBuffer interface { + *bytes.Buffer | *strings.Builder + WriteRune(rune) (int, error) + Reset() +} + +func readRune(r io.RuneReader) rune { + c, _, err := r.ReadRune() + if err != nil { + panic(decodeError{err}) + } + return c +} + +func readRuneOrEOF(r io.RuneReader) (c rune, ok bool) { + c, _, err := r.ReadRune() + if err != nil { + if err == io.EOF { + return 0, false + } + panic(decodeError{err}) + } + return c, true +} + +func unreadRune(r io.RuneScanner) { + if err := r.UnreadRune(); err != nil { + panic(decodeError{err}) + } +} + +func peekRune(r io.RuneScanner) rune { + c := readRune(r) + unreadRune(r) + return c +} + +func expectRune(r io.RuneReader, exp rune) { + act := readRune(r) + if act != exp { + panic(decodeError{fmt.Errorf("expected %c but got %c", exp, act)}) + } +} + +func Decode(r io.RuneScanner, ptr any) (err error) { + ptrVal := reflect.ValueOf(ptr) + if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { + return &json.InvalidUnmarshalError{ + Type: ptrVal.Type(), + } + } + + defer func() { + if r := recover(); r != nil { + if e, ok := r.(decodeError); ok { + err = e.Err + } else { + panic(r) + } + } + }() + decodeWS(r) + decode(r, ptrVal) + return nil +} + +var ( + rawMessagePtrType = reflect.TypeOf((*json.RawMessage)(nil)) + anyType = reflect.TypeOf((*any)(nil)).Elem() + decoderType = reflect.TypeOf((*Decoder)(nil)).Elem() + jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +var kind2bits = map[reflect.Kind]int{ + reflect.Int: int(32 << (^uint(0) >> 63)), + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, + + reflect.Uint: int(32 << (^uint(0) >> 63)), + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, + + reflect.Uintptr: int(32 << (^uintptr(0) >> 63)), + + reflect.Float32: 32, + reflect.Float64: 64, +} + +func decode(r io.RuneScanner, ptrVal reflect.Value) { + switch { + case ptrVal.Type() == rawMessagePtrType: + var buf bytes.Buffer + scan(r, &buf) + if err := ptrVal.Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(decoderType): + obj := ptrVal.Interface().(Decoder) + if err := obj.DecodeJSON(r); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(jsonUnmarshalerType): + var buf bytes.Buffer + scan(r, &buf) + obj := ptrVal.Interface().(json.Unmarshaler) + if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + case ptrVal.Type().Implements(textUnmarshalerType): + var buf bytes.Buffer + decodeString(r, &buf) + obj := ptrVal.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(buf.Bytes()); err != nil { + panic(decodeError{err}) + } + default: + typ := ptrVal.Type().Elem() + kind := typ.Kind() + switch kind { + case reflect.Bool: + ptrVal.Elem().SetBool(decodeBool(r)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetUint(n) + case reflect.Float32, reflect.Float64: + var buf strings.Builder + scanNumber(r, &buf) + n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) + if err != nil { + panic(decodeError{err}) + } + ptrVal.Elem().SetFloat(n) + case reflect.String: + var buf strings.Builder + if typ == numberType { + scanNumber(r, &buf) + ptrVal.Elem().SetString(buf.String()) + } else { + decodeString(r, &buf) + ptrVal.Elem().SetString(buf.String()) + } + case reflect.Interface: + if typ == anyType { + ptrVal.Elem().Set(reflect.ValueOf(decodeAny(r))) + } else { + panic(decodeError{&json.UnsupportedTypeError{ + Type: typ, + }}) + } + case reflect.Struct: + index := indexStruct(typ) + var nameBuf strings.Builder + decodeObject(r, &nameBuf, func() { + name := nameBuf.String() + idx, ok := index.byName[name] + if !ok { + scan(r, io.Discard) + return + } + field := index.byPos[idx] + fVal := ptrVal.Elem() + for _, idx := range field.Path { + if fVal.Kind() == reflect.Pointer { + if fVal.IsNil() { + if !fVal.CanSet() { // https://golang.org/issue/21357 + panic(decodeError{fmt.Errorf("cannot set embedded pointer to unexported type %v", fVal.Type().Elem())}) + } + fVal.Set(reflect.New(fVal.Type().Elem())) + } + fVal = fVal.Elem() + } + fVal = fVal.Field(idx) + } + decode(r, fVal.Addr()) + }) + case reflect.Map: + if ptrVal.Elem().IsNil() { + ptrVal.Elem().Set(reflect.MakeMap(typ)) + } + var nameBuf bytes.Buffer + decodeObject(r, &nameBuf, func() { + nameValTyp := typ.Key() + nameValPtr := reflect.New(nameValTyp) + switch { + case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): + obj := ptrVal.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { + panic(decodeError{err}) + } + default: + switch nameValTyp.Kind() { + case reflect.String: + nameValPtr.Elem().SetString(nameBuf.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + panic(decodeError{err}) + } + nameValPtr.Elem().SetUint(n) + default: + panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)}) + } + } + + fValPtr := reflect.New(typ.Elem()) + decode(r, fValPtr) + + ptrVal.Elem().SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) + }) + case reflect.Slice: + if ptrVal.Elem().IsNil() { + ptrVal.Elem().Set(reflect.MakeSlice(typ.Elem(), 0, 0)) + } + decodeArray(r, func() { + mValPtr := reflect.New(typ.Elem()) + decode(r, mValPtr) + ptrVal.Set(reflect.Append(ptrVal.Elem(), mValPtr.Elem())) + }) + case reflect.Pointer: + val := reflect.New(typ.Elem()) + decode(r, val) + ptrVal.Elem().Set(val) + default: + panic(decodeError{&json.UnsupportedTypeError{ + Type: typ, + }}) + } + } +} + +func decodeWS(r io.RuneScanner) { + for { + switch readRune(r) { + // NB: The JSON definition of whitespace is more + // narrow than unicode.IsSpace + case 0x0020, 0x000A, 0x000D, 0x0009: + // do nothing + default: + unreadRune(r) + return + } + } +} + +func scan(r io.RuneScanner, out io.Writer) { + scanner := &ReEncoder{ + Out: out, + Compact: true, + } + if _, err := scanner.WriteRune(readRune(r)); err != nil { + panic(decodeError{err}) + } + scanner.bailAfterCurrent = true + var err error + for err == nil { + c, ok := readRuneOrEOF(r) + if ok { + _, err = scanner.WriteRune(c) + } else { + err = scanner.Flush() + break + } + } + if err != nil { + if err == errBailedAfterCurrent { + unreadRune(r) + } else { + panic(decodeError{err}) + } + } +} + +func scanNumber(r io.RuneScanner, out io.Writer) { + c := peekRune(r) + switch c { + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + scan(r, out) + default: + panic(decodeError{fmt.Errorf("expected a nubmer bug got %c", c)}) + } +} + +func decodeAny(r io.RuneScanner) any { + c := peekRune(r) + switch c { + case '{': + ret := make(map[string]any) + var nameBuf strings.Builder + decodeObject(r, &nameBuf, func() { + ret[nameBuf.String()] = decodeAny(r) + }) + return ret + case '[': + ret := []any{} + decodeArray(r, func() { + ret = append(ret, decodeAny(r)) + }) + return ret + case '"': + var buf strings.Builder + decodeString(r, &buf) + return buf.String() + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + var buf strings.Builder + scanNumber(r, &buf) + return json.Number(buf.String()) + case 't', 'f': + return decodeBool(r) + case 'n': + expectRune(r, 'n') + expectRune(r, 'u') + expectRune(r, 'l') + expectRune(r, 'l') + return nil + default: + panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) + } +} + +func decodeObject[bufT runeBuffer](r io.RuneScanner, nameBuf bufT, decodeKVal func()) { + expectRune(r, '{') + decodeWS(r) + c := readRune(r) + switch c { + case '"': + decodeMember: + unreadRune(r) + nameBuf.Reset() + decodeString(r, nameBuf) + decodeWS(r) + expectRune(r, ':') + decodeWS(r) + decodeKVal() + decodeWS(r) + c := readRune(r) + switch c { + case ',': + decodeWS(r) + expectRune(r, '"') + goto decodeMember + case '}': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', '}', c)}) + } + case '}': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", '"', '}', c)}) + } +} + +func decodeArray(r io.RuneScanner, decodeMember func()) { + expectRune(r, '[') + decodeWS(r) + c := readRune(r) + switch c { + case ']': + return + default: + decodeNextMember: + unreadRune(r) + decodeMember() + decodeWS(r) + c := readRune(r) + switch c { + case ',': + decodeWS(r) + goto decodeNextMember + case ']': + return + default: + panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', ']', c)}) + } + } +} + +func decodeHex(r io.RuneReader) rune { + c := readRune(r) + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + default: + panic(decodeError{fmt.Errorf("unexpected %c in unicode literal", c)}) + } +} + +func decodeString[bufT runeBuffer](r io.RuneScanner, out bufT) { + // No need to check errors from out.WriteRune because 'out' + // guaranteed (by the 'runeBuffer' type constraint) to always + // either a *bytes.Buffer or a *string.Builder, neither of + // which return errors. + expectRune(r, '"') + for { + c := readRune(r) + switch { + case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\': + _, _ = out.WriteRune(c) + case c == '\\': + c = readRune(r) + switch c { + case '"': + _, _ = out.WriteRune('"') + case '\\': + _, _ = out.WriteRune('\\') + case 'b': + _, _ = out.WriteRune('\b') + case 'f': + _, _ = out.WriteRune('\f') + case 'n': + _, _ = out.WriteRune('\n') + case 'r': + _, _ = out.WriteRune('\r') + case 't': + _, _ = out.WriteRune('\t') + case 'u': + c = decodeHex(r) + c = (c << 4) | decodeHex(r) + c = (c << 4) | decodeHex(r) + c = (c << 4) | decodeHex(r) + _, _ = out.WriteRune(c) + } + case c == '"': + return + default: + panic(decodeError{fmt.Errorf("unexpected %c in string", c)}) + } + } +} + +func decodeBool(r io.RuneReader) bool { + c := readRune(r) + switch c { + case 't': + expectRune(r, 'r') + expectRune(r, 'u') + expectRune(r, 'e') + return true + case 'f': + expectRune(r, 'a') + expectRune(r, 'l') + expectRune(r, 's') + expectRune(r, 'e') + return false + default: + panic(decodeError{fmt.Errorf("unexpected character: %c", c)}) + } +} diff --git a/lib/lowmemjson/encode.go b/lib/lowmemjson/encode.go index 107b4da..d22d86d 100644 --- a/lib/lowmemjson/encode.go +++ b/lib/lowmemjson/encode.go @@ -189,7 +189,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { case reflect.Struct: encodeWriteByte(w, '{') empty := true - for _, field := range indexStruct(val.Type()) { + for _, field := range indexStruct(val.Type()).byPos { fVal, err := val.FieldByIndexErr(field.Path) if err != nil { continue diff --git a/lib/lowmemjson/reencode.go b/lib/lowmemjson/reencode.go index 836c5f6..76aedc9 100644 --- a/lib/lowmemjson/reencode.go +++ b/lib/lowmemjson/reencode.go @@ -31,6 +31,8 @@ type ReEncoder struct { // If not set, then EscapeUnicodeDefault is used. UnicodeEscape func(rune, bool) bool + bailAfterCurrent bool + // state: .Write's utf8-decoding buffer buf [utf8.UTFMax]byte bufLen int @@ -160,8 +162,13 @@ func (enc *ReEncoder) popState() { enc.stack = enc.stack[:len(enc.stack)-1] } +var errBailedAfterCurrent = errors.New("bailed after current") + func (enc *ReEncoder) state(c rune) error { if len(enc.stack) == 0 { + if enc.bailAfterCurrent { + return errBailedAfterCurrent + } enc.pushState(enc.stateAny, false) } return enc.stack[len(enc.stack)-1](c) diff --git a/lib/lowmemjson/struct.go b/lib/lowmemjson/struct.go index 434d3dc..c27fb81 100644 --- a/lib/lowmemjson/struct.go +++ b/lib/lowmemjson/struct.go @@ -16,13 +16,20 @@ type structField struct { Quote bool } -func indexStruct(typ reflect.Type) []structField { +type structIndex struct { + byPos []structField + byName map[string]int +} + +func indexStruct(typ reflect.Type) structIndex { byName := make(map[string][]structField) var byPos []string indexStructInner(typ, nil, byName, &byPos) - var ret []structField + ret := structIndex{ + byName: make(map[string]int), + } for _, name := range byPos { fields := byName[name] @@ -31,7 +38,8 @@ func indexStruct(typ reflect.Type) []structField { case 0: // do nothing case 1: - ret = append(ret, fields[0]) + ret.byName[name] = len(ret.byPos) + ret.byPos = append(ret.byPos, fields[0]) default: // To quote the encoding/json docs (version 1.18.4): // @@ -77,10 +85,12 @@ func indexStruct(typ reflect.Type) []structField { case 0: // do nothing case 1: - ret = append(ret, fields[untaggedIdx]) + ret.byName[name] = len(ret.byPos) + ret.byPos = append(ret.byPos, fields[untaggedIdx]) } case 1: - ret = append(ret, fields[taggedIdx]) + ret.byName[name] = len(ret.byPos) + ret.byPos = append(ret.byPos, fields[taggedIdx]) } } } -- cgit v1.1-4-g5e80