summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-07-30 15:11:13 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-07-31 14:49:04 -0600
commitab1da5feecf7f05233187424effa10637247c218 (patch)
tree20c0518d125035a459aecced71c6bae74fe65661
parentd163b167496d914a00355a3f44897d0a43af96df (diff)
wip decode
-rw-r--r--lib/lowmemjson/decode.go501
-rw-r--r--lib/lowmemjson/encode.go2
-rw-r--r--lib/lowmemjson/reencode.go7
-rw-r--r--lib/lowmemjson/struct.go20
4 files changed, 524 insertions, 6 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go
new file mode 100644
index 0000000..33b4222
--- /dev/null
+++ b/lib/lowmemjson/decode.go
@@ -0,0 +1,501 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package lowmemjson
+
+import (
+ "bytes"
+ "encoding"
+ "encoding/json"
+ "fmt"
+ "io"
+ "reflect"
+ "strconv"
+ "strings"
+)
+
+type Decoder interface {
+ DecodeJSON(io.RuneScanner) error
+}
+
+type decodeError struct {
+ Err error
+}
+
+type runeBuffer interface {
+ *bytes.Buffer | *strings.Builder
+ WriteRune(rune) (int, error)
+ Reset()
+}
+
+func readRune(r io.RuneReader) rune {
+ c, _, err := r.ReadRune()
+ if err != nil {
+ panic(decodeError{err})
+ }
+ return c
+}
+
+func readRuneOrEOF(r io.RuneReader) (c rune, ok bool) {
+ c, _, err := r.ReadRune()
+ if err != nil {
+ if err == io.EOF {
+ return 0, false
+ }
+ panic(decodeError{err})
+ }
+ return c, true
+}
+
+func unreadRune(r io.RuneScanner) {
+ if err := r.UnreadRune(); err != nil {
+ panic(decodeError{err})
+ }
+}
+
+func peekRune(r io.RuneScanner) rune {
+ c := readRune(r)
+ unreadRune(r)
+ return c
+}
+
+func expectRune(r io.RuneReader, exp rune) {
+ act := readRune(r)
+ if act != exp {
+ panic(decodeError{fmt.Errorf("expected %c but got %c", exp, act)})
+ }
+}
+
+func Decode(r io.RuneScanner, ptr any) (err error) {
+ ptrVal := reflect.ValueOf(ptr)
+ if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() {
+ return &json.InvalidUnmarshalError{
+ Type: ptrVal.Type(),
+ }
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ if e, ok := r.(decodeError); ok {
+ err = e.Err
+ } else {
+ panic(r)
+ }
+ }
+ }()
+ decodeWS(r)
+ decode(r, ptrVal)
+ return nil
+}
+
+var (
+ rawMessagePtrType = reflect.TypeOf((*json.RawMessage)(nil))
+ anyType = reflect.TypeOf((*any)(nil)).Elem()
+ decoderType = reflect.TypeOf((*Decoder)(nil)).Elem()
+ jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
+ textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
+)
+
+var kind2bits = map[reflect.Kind]int{
+ reflect.Int: int(32 << (^uint(0) >> 63)),
+ reflect.Int8: 8,
+ reflect.Int16: 16,
+ reflect.Int32: 32,
+ reflect.Int64: 64,
+
+ reflect.Uint: int(32 << (^uint(0) >> 63)),
+ reflect.Uint8: 8,
+ reflect.Uint16: 16,
+ reflect.Uint32: 32,
+ reflect.Uint64: 64,
+
+ reflect.Uintptr: int(32 << (^uintptr(0) >> 63)),
+
+ reflect.Float32: 32,
+ reflect.Float64: 64,
+}
+
+func decode(r io.RuneScanner, ptrVal reflect.Value) {
+ switch {
+ case ptrVal.Type() == rawMessagePtrType:
+ var buf bytes.Buffer
+ scan(r, &buf)
+ if err := ptrVal.Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil {
+ panic(decodeError{err})
+ }
+ case ptrVal.Type().Implements(decoderType):
+ obj := ptrVal.Interface().(Decoder)
+ if err := obj.DecodeJSON(r); err != nil {
+ panic(decodeError{err})
+ }
+ case ptrVal.Type().Implements(jsonUnmarshalerType):
+ var buf bytes.Buffer
+ scan(r, &buf)
+ obj := ptrVal.Interface().(json.Unmarshaler)
+ if err := obj.UnmarshalJSON(buf.Bytes()); err != nil {
+ panic(decodeError{err})
+ }
+ case ptrVal.Type().Implements(textUnmarshalerType):
+ var buf bytes.Buffer
+ decodeString(r, &buf)
+ obj := ptrVal.Interface().(encoding.TextUnmarshaler)
+ if err := obj.UnmarshalText(buf.Bytes()); err != nil {
+ panic(decodeError{err})
+ }
+ default:
+ typ := ptrVal.Type().Elem()
+ kind := typ.Kind()
+ switch kind {
+ case reflect.Bool:
+ ptrVal.Elem().SetBool(decodeBool(r))
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ var buf strings.Builder
+ scanNumber(r, &buf)
+ n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ ptrVal.Elem().SetInt(n)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ var buf strings.Builder
+ scanNumber(r, &buf)
+ n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ ptrVal.Elem().SetUint(n)
+ case reflect.Float32, reflect.Float64:
+ var buf strings.Builder
+ scanNumber(r, &buf)
+ n, err := strconv.ParseFloat(buf.String(), kind2bits[kind])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ ptrVal.Elem().SetFloat(n)
+ case reflect.String:
+ var buf strings.Builder
+ if typ == numberType {
+ scanNumber(r, &buf)
+ ptrVal.Elem().SetString(buf.String())
+ } else {
+ decodeString(r, &buf)
+ ptrVal.Elem().SetString(buf.String())
+ }
+ case reflect.Interface:
+ if typ == anyType {
+ ptrVal.Elem().Set(reflect.ValueOf(decodeAny(r)))
+ } else {
+ panic(decodeError{&json.UnsupportedTypeError{
+ Type: typ,
+ }})
+ }
+ case reflect.Struct:
+ index := indexStruct(typ)
+ var nameBuf strings.Builder
+ decodeObject(r, &nameBuf, func() {
+ name := nameBuf.String()
+ idx, ok := index.byName[name]
+ if !ok {
+ scan(r, io.Discard)
+ return
+ }
+ field := index.byPos[idx]
+ fVal := ptrVal.Elem()
+ for _, idx := range field.Path {
+ if fVal.Kind() == reflect.Pointer {
+ if fVal.IsNil() {
+ if !fVal.CanSet() { // https://golang.org/issue/21357
+ panic(decodeError{fmt.Errorf("cannot set embedded pointer to unexported type %v", fVal.Type().Elem())})
+ }
+ fVal.Set(reflect.New(fVal.Type().Elem()))
+ }
+ fVal = fVal.Elem()
+ }
+ fVal = fVal.Field(idx)
+ }
+ decode(r, fVal.Addr())
+ })
+ case reflect.Map:
+ if ptrVal.Elem().IsNil() {
+ ptrVal.Elem().Set(reflect.MakeMap(typ))
+ }
+ var nameBuf bytes.Buffer
+ decodeObject(r, &nameBuf, func() {
+ nameValTyp := typ.Key()
+ nameValPtr := reflect.New(nameValTyp)
+ switch {
+ case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType):
+ obj := ptrVal.Interface().(encoding.TextUnmarshaler)
+ if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil {
+ panic(decodeError{err})
+ }
+ default:
+ switch nameValTyp.Kind() {
+ case reflect.String:
+ nameValPtr.Elem().SetString(nameBuf.String())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ nameValPtr.Elem().SetInt(n)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ nameValPtr.Elem().SetUint(n)
+ default:
+ panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)})
+ }
+ }
+
+ fValPtr := reflect.New(typ.Elem())
+ decode(r, fValPtr)
+
+ ptrVal.Elem().SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
+ })
+ case reflect.Slice:
+ if ptrVal.Elem().IsNil() {
+ ptrVal.Elem().Set(reflect.MakeSlice(typ.Elem(), 0, 0))
+ }
+ decodeArray(r, func() {
+ mValPtr := reflect.New(typ.Elem())
+ decode(r, mValPtr)
+ ptrVal.Set(reflect.Append(ptrVal.Elem(), mValPtr.Elem()))
+ })
+ case reflect.Pointer:
+ val := reflect.New(typ.Elem())
+ decode(r, val)
+ ptrVal.Elem().Set(val)
+ default:
+ panic(decodeError{&json.UnsupportedTypeError{
+ Type: typ,
+ }})
+ }
+ }
+}
+
+func decodeWS(r io.RuneScanner) {
+ for {
+ switch readRune(r) {
+ // NB: The JSON definition of whitespace is more
+ // narrow than unicode.IsSpace
+ case 0x0020, 0x000A, 0x000D, 0x0009:
+ // do nothing
+ default:
+ unreadRune(r)
+ return
+ }
+ }
+}
+
+func scan(r io.RuneScanner, out io.Writer) {
+ scanner := &ReEncoder{
+ Out: out,
+ Compact: true,
+ }
+ if _, err := scanner.WriteRune(readRune(r)); err != nil {
+ panic(decodeError{err})
+ }
+ scanner.bailAfterCurrent = true
+ var err error
+ for err == nil {
+ c, ok := readRuneOrEOF(r)
+ if ok {
+ _, err = scanner.WriteRune(c)
+ } else {
+ err = scanner.Flush()
+ break
+ }
+ }
+ if err != nil {
+ if err == errBailedAfterCurrent {
+ unreadRune(r)
+ } else {
+ panic(decodeError{err})
+ }
+ }
+}
+
+func scanNumber(r io.RuneScanner, out io.Writer) {
+ c := peekRune(r)
+ switch c {
+ case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ scan(r, out)
+ default:
+ panic(decodeError{fmt.Errorf("expected a nubmer bug got %c", c)})
+ }
+}
+
+func decodeAny(r io.RuneScanner) any {
+ c := peekRune(r)
+ switch c {
+ case '{':
+ ret := make(map[string]any)
+ var nameBuf strings.Builder
+ decodeObject(r, &nameBuf, func() {
+ ret[nameBuf.String()] = decodeAny(r)
+ })
+ return ret
+ case '[':
+ ret := []any{}
+ decodeArray(r, func() {
+ ret = append(ret, decodeAny(r))
+ })
+ return ret
+ case '"':
+ var buf strings.Builder
+ decodeString(r, &buf)
+ return buf.String()
+ case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
+ var buf strings.Builder
+ scanNumber(r, &buf)
+ return json.Number(buf.String())
+ case 't', 'f':
+ return decodeBool(r)
+ case 'n':
+ expectRune(r, 'n')
+ expectRune(r, 'u')
+ expectRune(r, 'l')
+ expectRune(r, 'l')
+ return nil
+ default:
+ panic(decodeError{fmt.Errorf("unexpected character: %c", c)})
+ }
+}
+
+func decodeObject[bufT runeBuffer](r io.RuneScanner, nameBuf bufT, decodeKVal func()) {
+ expectRune(r, '{')
+ decodeWS(r)
+ c := readRune(r)
+ switch c {
+ case '"':
+ decodeMember:
+ unreadRune(r)
+ nameBuf.Reset()
+ decodeString(r, nameBuf)
+ decodeWS(r)
+ expectRune(r, ':')
+ decodeWS(r)
+ decodeKVal()
+ decodeWS(r)
+ c := readRune(r)
+ switch c {
+ case ',':
+ decodeWS(r)
+ expectRune(r, '"')
+ goto decodeMember
+ case '}':
+ return
+ default:
+ panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', '}', c)})
+ }
+ case '}':
+ return
+ default:
+ panic(decodeError{fmt.Errorf("expected %c or %c but got %c", '"', '}', c)})
+ }
+}
+
+func decodeArray(r io.RuneScanner, decodeMember func()) {
+ expectRune(r, '[')
+ decodeWS(r)
+ c := readRune(r)
+ switch c {
+ case ']':
+ return
+ default:
+ decodeNextMember:
+ unreadRune(r)
+ decodeMember()
+ decodeWS(r)
+ c := readRune(r)
+ switch c {
+ case ',':
+ decodeWS(r)
+ goto decodeNextMember
+ case ']':
+ return
+ default:
+ panic(decodeError{fmt.Errorf("expected %c or %c but got %c", ',', ']', c)})
+ }
+ }
+}
+
+func decodeHex(r io.RuneReader) rune {
+ c := readRune(r)
+ switch {
+ case '0' <= c && c <= '9':
+ return c - '0'
+ case 'a' <= c && c <= 'f':
+ return c - 'a' + 10
+ case 'A' <= c && c <= 'F':
+ return c - 'A' + 10
+ default:
+ panic(decodeError{fmt.Errorf("unexpected %c in unicode literal", c)})
+ }
+}
+
+func decodeString[bufT runeBuffer](r io.RuneScanner, out bufT) {
+ // No need to check errors from out.WriteRune because 'out'
+ // guaranteed (by the 'runeBuffer' type constraint) to always
+ // either a *bytes.Buffer or a *string.Builder, neither of
+ // which return errors.
+ expectRune(r, '"')
+ for {
+ c := readRune(r)
+ switch {
+ case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\':
+ _, _ = out.WriteRune(c)
+ case c == '\\':
+ c = readRune(r)
+ switch c {
+ case '"':
+ _, _ = out.WriteRune('"')
+ case '\\':
+ _, _ = out.WriteRune('\\')
+ case 'b':
+ _, _ = out.WriteRune('\b')
+ case 'f':
+ _, _ = out.WriteRune('\f')
+ case 'n':
+ _, _ = out.WriteRune('\n')
+ case 'r':
+ _, _ = out.WriteRune('\r')
+ case 't':
+ _, _ = out.WriteRune('\t')
+ case 'u':
+ c = decodeHex(r)
+ c = (c << 4) | decodeHex(r)
+ c = (c << 4) | decodeHex(r)
+ c = (c << 4) | decodeHex(r)
+ _, _ = out.WriteRune(c)
+ }
+ case c == '"':
+ return
+ default:
+ panic(decodeError{fmt.Errorf("unexpected %c in string", c)})
+ }
+ }
+}
+
+func decodeBool(r io.RuneReader) bool {
+ c := readRune(r)
+ switch c {
+ case 't':
+ expectRune(r, 'r')
+ expectRune(r, 'u')
+ expectRune(r, 'e')
+ return true
+ case 'f':
+ expectRune(r, 'a')
+ expectRune(r, 'l')
+ expectRune(r, 's')
+ expectRune(r, 'e')
+ return false
+ default:
+ panic(decodeError{fmt.Errorf("unexpected character: %c", c)})
+ }
+}
diff --git a/lib/lowmemjson/encode.go b/lib/lowmemjson/encode.go
index 107b4da..d22d86d 100644
--- a/lib/lowmemjson/encode.go
+++ b/lib/lowmemjson/encode.go
@@ -189,7 +189,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
case reflect.Struct:
encodeWriteByte(w, '{')
empty := true
- for _, field := range indexStruct(val.Type()) {
+ for _, field := range indexStruct(val.Type()).byPos {
fVal, err := val.FieldByIndexErr(field.Path)
if err != nil {
continue
diff --git a/lib/lowmemjson/reencode.go b/lib/lowmemjson/reencode.go
index 836c5f6..76aedc9 100644
--- a/lib/lowmemjson/reencode.go
+++ b/lib/lowmemjson/reencode.go
@@ -31,6 +31,8 @@ type ReEncoder struct {
// If not set, then EscapeUnicodeDefault is used.
UnicodeEscape func(rune, bool) bool
+ bailAfterCurrent bool
+
// state: .Write's utf8-decoding buffer
buf [utf8.UTFMax]byte
bufLen int
@@ -160,8 +162,13 @@ func (enc *ReEncoder) popState() {
enc.stack = enc.stack[:len(enc.stack)-1]
}
+var errBailedAfterCurrent = errors.New("bailed after current")
+
func (enc *ReEncoder) state(c rune) error {
if len(enc.stack) == 0 {
+ if enc.bailAfterCurrent {
+ return errBailedAfterCurrent
+ }
enc.pushState(enc.stateAny, false)
}
return enc.stack[len(enc.stack)-1](c)
diff --git a/lib/lowmemjson/struct.go b/lib/lowmemjson/struct.go
index 434d3dc..c27fb81 100644
--- a/lib/lowmemjson/struct.go
+++ b/lib/lowmemjson/struct.go
@@ -16,13 +16,20 @@ type structField struct {
Quote bool
}
-func indexStruct(typ reflect.Type) []structField {
+type structIndex struct {
+ byPos []structField
+ byName map[string]int
+}
+
+func indexStruct(typ reflect.Type) structIndex {
byName := make(map[string][]structField)
var byPos []string
indexStructInner(typ, nil, byName, &byPos)
- var ret []structField
+ ret := structIndex{
+ byName: make(map[string]int),
+ }
for _, name := range byPos {
fields := byName[name]
@@ -31,7 +38,8 @@ func indexStruct(typ reflect.Type) []structField {
case 0:
// do nothing
case 1:
- ret = append(ret, fields[0])
+ ret.byName[name] = len(ret.byPos)
+ ret.byPos = append(ret.byPos, fields[0])
default:
// To quote the encoding/json docs (version 1.18.4):
//
@@ -77,10 +85,12 @@ func indexStruct(typ reflect.Type) []structField {
case 0:
// do nothing
case 1:
- ret = append(ret, fields[untaggedIdx])
+ ret.byName[name] = len(ret.byPos)
+ ret.byPos = append(ret.byPos, fields[untaggedIdx])
}
case 1:
- ret = append(ret, fields[taggedIdx])
+ ret.byName[name] = len(ret.byPos)
+ ret.byPos = append(ret.byPos, fields[taggedIdx])
}
}
}