summaryrefslogtreecommitdiff
path: root/lib/lowmemjson/decode.go
diff options
context:
space:
mode:
Diffstat (limited to 'lib/lowmemjson/decode.go')
-rw-r--r--lib/lowmemjson/decode.go221
1 files changed, 138 insertions, 83 deletions
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go
index 4873d43..b2eaacf 100644
--- a/lib/lowmemjson/decode.go
+++ b/lib/lowmemjson/decode.go
@@ -24,7 +24,7 @@ type decodeError struct {
}
type runeBuffer interface {
- *bytes.Buffer | *strings.Builder
+ io.Writer
WriteRune(rune) (int, error)
Reset()
}
@@ -93,7 +93,7 @@ func Decode(r io.RuneScanner, ptr any) (err error) {
}
}()
decodeWS(r)
- decode(r, ptrVal)
+ decode(r, ptrVal.Elem())
return nil
}
@@ -124,39 +124,39 @@ var kind2bits = map[reflect.Kind]int{
reflect.Float64: 64,
}
-func decode(r io.RuneScanner, ptrVal reflect.Value) {
+func decode(r io.RuneScanner, val reflect.Value) {
+ typ := val.Type()
switch {
- case ptrVal.Type() == rawMessagePtrType:
+ case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType:
var buf bytes.Buffer
scan(r, &buf)
- if err := ptrVal.Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil {
+ if err := val.Addr().Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil {
panic(decodeError{err})
}
- case ptrVal.Type().Implements(decoderType):
- obj := ptrVal.Interface().(Decoder)
+ case val.CanAddr() && reflect.PointerTo(typ).Implements(decoderType):
+ obj := val.Addr().Interface().(Decoder)
if err := obj.DecodeJSON(r); err != nil {
panic(decodeError{err})
}
- case ptrVal.Type().Implements(jsonUnmarshalerType):
+ case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType):
var buf bytes.Buffer
scan(r, &buf)
- obj := ptrVal.Interface().(json.Unmarshaler)
+ obj := val.Addr().Interface().(json.Unmarshaler)
if err := obj.UnmarshalJSON(buf.Bytes()); err != nil {
panic(decodeError{err})
}
- case ptrVal.Type().Implements(textUnmarshalerType):
+ case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType):
var buf bytes.Buffer
decodeString(r, &buf)
- obj := ptrVal.Interface().(encoding.TextUnmarshaler)
+ obj := val.Addr().Interface().(encoding.TextUnmarshaler)
if err := obj.UnmarshalText(buf.Bytes()); err != nil {
panic(decodeError{err})
}
default:
- typ := ptrVal.Type().Elem()
kind := typ.Kind()
switch kind {
case reflect.Bool:
- ptrVal.Elem().SetBool(decodeBool(r))
+ val.SetBool(decodeBool(r))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var buf strings.Builder
scanNumber(r, &buf)
@@ -164,7 +164,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) {
if err != nil {
panic(decodeError{err})
}
- ptrVal.Elem().SetInt(n)
+ val.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
var buf strings.Builder
scanNumber(r, &buf)
@@ -172,7 +172,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) {
if err != nil {
panic(decodeError{err})
}
- ptrVal.Elem().SetUint(n)
+ val.SetUint(n)
case reflect.Float32, reflect.Float64:
var buf strings.Builder
scanNumber(r, &buf)
@@ -180,23 +180,27 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) {
if err != nil {
panic(decodeError{err})
}
- ptrVal.Elem().SetFloat(n)
+ val.SetFloat(n)
case reflect.String:
var buf strings.Builder
if typ == numberType {
scanNumber(r, &buf)
- ptrVal.Elem().SetString(buf.String())
+ val.SetString(buf.String())
} else {
decodeString(r, &buf)
- ptrVal.Elem().SetString(buf.String())
+ val.SetString(buf.String())
}
case reflect.Interface:
- if typ == anyType {
- ptrVal.Elem().Set(reflect.ValueOf(decodeAny(r)))
+ if val.IsNil() {
+ if typ == anyType {
+ val.Set(reflect.ValueOf(decodeAny(r)))
+ } else {
+ panic(decodeError{&json.UnsupportedTypeError{
+ Type: typ,
+ }})
+ }
} else {
- panic(decodeError{&json.UnsupportedTypeError{
- Type: typ,
- }})
+ decode(r, val.Elem())
}
case reflect.Struct:
index := indexStruct(typ)
@@ -209,7 +213,7 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) {
return
}
field := index.byPos[idx]
- fVal := ptrVal.Elem()
+ fVal := val
for _, idx := range field.Path {
if fVal.Kind() == reflect.Pointer {
if fVal.IsNil() {
@@ -225,58 +229,91 @@ func decode(r io.RuneScanner, ptrVal reflect.Value) {
decode(r, fVal.Addr())
})
case reflect.Map:
- if ptrVal.Elem().IsNil() {
- ptrVal.Elem().Set(reflect.MakeMap(typ))
- }
- var nameBuf bytes.Buffer
- decodeObject(r, &nameBuf, func() {
- nameValTyp := typ.Key()
- nameValPtr := reflect.New(nameValTyp)
- switch {
- case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType):
- obj := ptrVal.Interface().(encoding.TextUnmarshaler)
- if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil {
- panic(decodeError{err})
- }
- default:
- switch nameValTyp.Kind() {
- case reflect.String:
- nameValPtr.Elem().SetString(nameBuf.String())
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
- n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
- if err != nil {
- panic(decodeError{err})
- }
- nameValPtr.Elem().SetInt(n)
- case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
- n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
- if err != nil {
+ switch peekRune(r) {
+ case 'n':
+ decodeNull(r)
+ val.Set(reflect.Zero(typ))
+ case '{':
+ if val.IsNil() {
+ val.Set(reflect.MakeMap(typ))
+ }
+ var nameBuf bytes.Buffer
+ decodeObject(r, &nameBuf, func() {
+ nameValTyp := typ.Key()
+ nameValPtr := reflect.New(nameValTyp)
+ switch {
+ case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType):
+ obj := nameValPtr.Interface().(encoding.TextUnmarshaler)
+ if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil {
panic(decodeError{err})
}
- nameValPtr.Elem().SetUint(n)
default:
- panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)})
+ switch nameValTyp.Kind() {
+ case reflect.String:
+ nameValPtr.Elem().SetString(nameBuf.String())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ nameValPtr.Elem().SetInt(n)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
+ if err != nil {
+ panic(decodeError{err})
+ }
+ nameValPtr.Elem().SetUint(n)
+ default:
+ panic(decodeError{fmt.Errorf("invalid map key type: %v", nameValTyp)})
+ }
}
- }
- fValPtr := reflect.New(typ.Elem())
- decode(r, fValPtr)
+ fValPtr := reflect.New(typ.Elem())
+ decode(r, fValPtr)
- ptrVal.Elem().SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
- })
+ val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
+ })
+ default:
+ panic(decodeError{fmt.Errorf("invalid character %q for map value", peekRune(r))})
+ }
case reflect.Slice:
- if ptrVal.Elem().IsNil() {
- ptrVal.Elem().Set(reflect.MakeSlice(typ.Elem(), 0, 0))
+ switch {
+ case typ.Elem().Kind() == reflect.Uint8:
+ var buf bytes.Buffer
+ dec := newBase64Decoder(&buf)
+ decodeString(r, dec)
+ val.Set(reflect.ValueOf(buf.Bytes()))
+ default:
+ switch peekRune(r) {
+ case 'n':
+ decodeNull(r)
+ val.Set(reflect.Zero(typ))
+ case '[':
+ if val.IsNil() {
+ val.Set(reflect.MakeSlice(typ, 0, 0))
+ }
+ decodeArray(r, func() {
+ mValPtr := reflect.New(typ.Elem())
+ decode(r, mValPtr)
+ val.Set(reflect.Append(val, mValPtr.Elem()))
+ })
+ default:
+ panic(decodeError{fmt.Errorf("invalid character %q for slice value", peekRune(r))})
+ }
}
+ case reflect.Array:
+ i := 0
decodeArray(r, func() {
mValPtr := reflect.New(typ.Elem())
decode(r, mValPtr)
- ptrVal.Set(reflect.Append(ptrVal.Elem(), mValPtr.Elem()))
+ val.Index(i).Set(mValPtr.Elem())
+ i++
})
case reflect.Pointer:
- val := reflect.New(typ.Elem())
- decode(r, val)
- ptrVal.Elem().Set(val)
+ if val.IsNil() {
+ val.Set(reflect.New(typ.Elem()))
+ }
+ decode(r, val.Elem())
default:
panic(decodeError{&json.UnsupportedTypeError{
Type: typ,
@@ -337,7 +374,7 @@ func scanNumber(r io.RuneScanner, out io.Writer) {
case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
scan(r, out)
default:
- panic(decodeError{fmt.Errorf("expected a nubmer bug got %c", c)})
+ panic(decodeError{fmt.Errorf("expected a nubmer but got %c", c)})
}
}
@@ -368,17 +405,14 @@ func decodeAny(r io.RuneScanner) any {
case 't', 'f':
return decodeBool(r)
case 'n':
- expectRune(r, 'n')
- expectRune(r, 'u')
- expectRune(r, 'l')
- expectRune(r, 'l')
+ decodeNull(r)
return nil
default:
panic(decodeError{fmt.Errorf("unexpected character: %c", c)})
}
}
-func decodeObject[bufT runeBuffer](r io.RuneScanner, nameBuf bufT, decodeKVal func()) {
+func decodeObject(r io.RuneScanner, nameBuf runeBuffer, decodeKVal func()) {
expectRune(r, '{')
decodeWS(r)
c := readRune(r)
@@ -450,40 +484,54 @@ func decodeHex(r io.RuneReader) rune {
}
}
-func decodeString[bufT runeBuffer](r io.RuneScanner, out bufT) {
- // No need to check errors from out.WriteRune because 'out'
- // guaranteed (by the 'runeBuffer' type constraint) to always
- // either a *bytes.Buffer or a *string.Builder, neither of
- // which return errors.
+func decodeString(r io.RuneScanner, out io.Writer) {
expectRune(r, '"')
for {
c := readRune(r)
switch {
case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\':
- _, _ = out.WriteRune(c)
+ if _, err := writeRune(out, c); err != nil {
+ panic(decodeError{err})
+ }
case c == '\\':
c = readRune(r)
switch c {
case '"':
- _, _ = out.WriteRune('"')
+ if _, err := writeRune(out, '"'); err != nil {
+ panic(decodeError{err})
+ }
case '\\':
- _, _ = out.WriteRune('\\')
+ if _, err := writeRune(out, '\\'); err != nil {
+ panic(decodeError{err})
+ }
case 'b':
- _, _ = out.WriteRune('\b')
+ if _, err := writeRune(out, '\b'); err != nil {
+ panic(decodeError{err})
+ }
case 'f':
- _, _ = out.WriteRune('\f')
+ if _, err := writeRune(out, '\f'); err != nil {
+ panic(decodeError{err})
+ }
case 'n':
- _, _ = out.WriteRune('\n')
+ if _, err := writeRune(out, '\n'); err != nil {
+ panic(decodeError{err})
+ }
case 'r':
- _, _ = out.WriteRune('\r')
+ if _, err := writeRune(out, '\r'); err != nil {
+ panic(decodeError{err})
+ }
case 't':
- _, _ = out.WriteRune('\t')
+ if _, err := writeRune(out, '\t'); err != nil {
+ panic(decodeError{err})
+ }
case 'u':
c = decodeHex(r)
c = (c << 4) | decodeHex(r)
c = (c << 4) | decodeHex(r)
c = (c << 4) | decodeHex(r)
- _, _ = out.WriteRune(c)
+ if _, err := writeRune(out, c); err != nil {
+ panic(decodeError{err})
+ }
}
case c == '"':
return
@@ -511,3 +559,10 @@ func decodeBool(r io.RuneReader) bool {
panic(decodeError{fmt.Errorf("unexpected character: %c", c)})
}
}
+
+func decodeNull(r io.RuneReader) {
+ expectRune(r, 'n')
+ expectRune(r, 'u')
+ expectRune(r, 'l')
+ expectRune(r, 'l')
+}