summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2022-08-02 01:13:45 -0600
committerLuke Shumaker <lukeshu@lukeshu.com>2022-08-02 01:13:45 -0600
commit20e2cf0c4e0ba704455ca6e163bbab9ddde05c80 (patch)
treee26be8b735e8b4ab649257f23df730f4e9659524
parent06d7be736ac93cd548db84d7f898d6ef0b257c1f (diff)
wip
-rw-r--r--lib/lowmemjson/borrowed_decode_test.go10
-rw-r--r--lib/lowmemjson/decode.go112
-rw-r--r--lib/lowmemjson/encode.go17
-rw-r--r--lib/lowmemjson/misc.go8
4 files changed, 112 insertions, 35 deletions
diff --git a/lib/lowmemjson/borrowed_decode_test.go b/lib/lowmemjson/borrowed_decode_test.go
index 804fb87..b555f87 100644
--- a/lib/lowmemjson/borrowed_decode_test.go
+++ b/lib/lowmemjson/borrowed_decode_test.go
@@ -1959,6 +1959,9 @@ func TestByteKind(t *testing.T) {
if err != nil {
t.Error(err)
}
+ if !reflect.DeepEqual(data, []byte(`"aGVsbG8="`)) { // MODIFIED
+ t.Errorf("expected %q == %q", data, `"aGVsbG8="`) // MODIFIED
+ } // MODIFIED
var b byteKind
err = Unmarshal(data, &b)
if err != nil {
@@ -1980,6 +1983,9 @@ func TestSliceOfCustomByte(t *testing.T) {
if err != nil {
t.Fatal(err)
}
+ if !reflect.DeepEqual(data, []byte(`"aGVsbG8="`)) { // MODIFIED
+ t.Errorf("expected %q == %q", data, `"aGVsbG8="`) // MODIFIED
+ } // MODIFIED
var b []Uint8
err = Unmarshal(data, &b)
if err != nil {
@@ -2005,7 +2011,7 @@ var decodeTypeErrorTests = []struct {
func TestUnmarshalTypeError(t *testing.T) {
for _, item := range decodeTypeErrorTests {
err := Unmarshal([]byte(item.src), item.dest)
- if _, ok := err.(*UnmarshalTypeError); !ok {
+ if err == nil { // if _, ok := err.(*UnmarshalTypeError); !ok { // MODIFIED
t.Errorf("expected type error for Unmarshal(%q, type %T): got %T",
item.src, item.dest, err)
}
@@ -2027,7 +2033,7 @@ func TestUnmarshalSyntax(t *testing.T) {
var x any
for _, src := range unmarshalSyntaxTests {
err := Unmarshal([]byte(src), &x)
- if _, ok := err.(*SyntaxError); !ok {
+ if err == nil { // _, ok := err.(*SyntaxError); !ok { // MODIFIED
t.Errorf("expected syntax error for Unmarshal(%q): got %T", src, err)
}
}
diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go
index 2dee59c..f9ea8a2 100644
--- a/lib/lowmemjson/decode.go
+++ b/lib/lowmemjson/decode.go
@@ -78,22 +78,21 @@ func (dec *Decoder) stackPop() {
dec.stack = dec.stack[:len(dec.stack)-1]
}
-type decodeError struct{}
+type decodeError struct {
+ Err error
+}
func (dec *Decoder) panicIO(err error) {
- dec.err = fmt.Errorf("json: I/O error at input byte %v: %s: %w",
- dec.nxtPos, dec.stackStr(), err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w",
+ dec.nxtPos, dec.stackStr(), err)})
}
func (dec *Decoder) panicSyntax(err error) {
- dec.err = fmt.Errorf("json: syntax error at input byte %v: %s: %w",
- dec.curPos, dec.stackStr(), err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w",
+ dec.curPos, dec.stackStr(), err)})
}
func (dec *Decoder) panicType(typ reflect.Type, err error) {
- dec.err = fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w",
- dec.curPos, dec.stackStr(), typ, err)
- panic(decodeError{})
+ panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w",
+ dec.curPos, dec.stackStr(), typ, err)})
}
func Decode(r io.Reader, ptr any) error {
@@ -104,7 +103,8 @@ func (dec *Decoder) Decode(ptr any) (err error) {
ptrVal := reflect.ValueOf(ptr)
if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() {
return &json.InvalidUnmarshalError{
- Type: ptrVal.Type(),
+ // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil
+ Type: reflect.TypeOf(ptr),
}
}
@@ -114,7 +114,8 @@ func (dec *Decoder) Decode(ptr any) (err error) {
defer func() {
if r := recover(); r != nil {
- if _, ok := r.(decodeError); ok {
+ if de, ok := r.(decodeError); ok {
+ dec.err = de.Err
err = dec.err
} else {
panic(r)
@@ -122,7 +123,7 @@ func (dec *Decoder) Decode(ptr any) (err error) {
}
}()
dec.decodeWS()
- dec.decode(ptrVal.Elem())
+ dec.decode(ptrVal.Elem(), false)
return nil
}
@@ -226,7 +227,7 @@ var kind2bits = map[reflect.Kind]int{
reflect.Float64: 64,
}
-func (dec *Decoder) decode(val reflect.Value) {
+func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
typ := val.Type()
switch {
case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType:
@@ -248,6 +249,10 @@ func (dec *Decoder) decode(val reflect.Value) {
dec.panicSyntax(err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType):
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf bytes.Buffer
dec.decodeString(&buf)
obj := val.Addr().Interface().(encoding.TextUnmarshaler)
@@ -258,8 +263,16 @@ func (dec *Decoder) decode(val reflect.Value) {
kind := typ.Kind()
switch kind {
case reflect.Bool:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
val.SetBool(dec.decodeBool())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind])
@@ -268,6 +281,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind])
@@ -276,6 +293,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetUint(n)
case reflect.Float32, reflect.Float64:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
dec.scanNumber(&buf)
n, err := strconv.ParseFloat(buf.String(), kind2bits[kind])
@@ -284,6 +305,10 @@ func (dec *Decoder) decode(val reflect.Value) {
}
val.SetFloat(n)
case reflect.String:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
var buf strings.Builder
if typ == numberType {
dec.scanNumber(&buf)
@@ -301,19 +326,23 @@ func (dec *Decoder) decode(val reflect.Value) {
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer {
// XXX: I can't justify this case, other than "it's what encoding/json does, but
// I don't understand their rationale".
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
} else {
dec.decodeNull()
val.Set(reflect.Zero(typ))
}
default:
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer {
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
} else {
val.Set(reflect.ValueOf(dec.decodeAny()))
}
}
case reflect.Struct:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
index := indexStruct(typ)
var nameBuf strings.Builder
dec.decodeObject(&nameBuf, func() {
@@ -358,13 +387,13 @@ func (dec *Decoder) decode(val reflect.Value) {
subD := *dec // capture the .curPos *before* calling .decodeString
dec.decodeString(&buf)
subD.r = &buf
- subD.decode(fVal)
+ subD.decode(fVal, false)
default:
dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q",
'n', '"', dec.peekRune()))
}
} else {
- dec.decode(fVal)
+ dec.decode(fVal, true)
}
})
case reflect.Map:
@@ -410,7 +439,7 @@ func (dec *Decoder) decode(val reflect.Value) {
defer dec.stackPop()
fValPtr := reflect.New(typ.Elem())
- dec.decode(fValPtr.Elem())
+ dec.decode(fValPtr.Elem(), false)
val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem())
})
@@ -422,7 +451,16 @@ func (dec *Decoder) decode(val reflect.Value) {
case typ.Elem().Kind() == reflect.Uint8:
var buf bytes.Buffer
dec.decodeString(newBase64Decoder(&buf))
- val.Set(reflect.ValueOf(buf.Bytes()))
+ if typ.Elem() == byteType {
+ val.Set(reflect.ValueOf(buf.Bytes()))
+ } else {
+ bs := buf.Bytes()
+ // TODO: Surely there's a better way.
+ val.Set(reflect.MakeSlice(typ, len(bs), len(bs)))
+ for i := 0; i < len(bs); i++ {
+ val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem()))
+ }
+ }
default:
switch dec.peekRune() {
case 'n':
@@ -432,12 +470,15 @@ func (dec *Decoder) decode(val reflect.Value) {
if val.IsNil() {
val.Set(reflect.MakeSlice(typ, 0, 0))
}
+ if val.Len() > 0 {
+ val.Set(val.Slice(0, 0))
+ }
i := 0
dec.decodeArray(func() {
dec.stackPush(i)
defer dec.stackPop()
mValPtr := reflect.New(typ.Elem())
- dec.decode(mValPtr.Elem())
+ dec.decode(mValPtr.Elem(), false)
val.Set(reflect.Append(val, mValPtr.Elem()))
i++
})
@@ -446,29 +487,44 @@ func (dec *Decoder) decode(val reflect.Value) {
}
}
case reflect.Array:
+ if nullOK && dec.peekRune() == 'n' {
+ dec.decodeNull()
+ return
+ }
i := 0
+ n := val.Len()
dec.decodeArray(func() {
dec.stackPush(i)
defer dec.stackPop()
- mValPtr := reflect.New(typ.Elem())
- dec.decode(mValPtr.Elem())
- val.Index(i).Set(mValPtr.Elem())
+ if i < n {
+ mValPtr := reflect.New(typ.Elem())
+ dec.decode(mValPtr.Elem(), false)
+ val.Index(i).Set(mValPtr.Elem())
+ } else {
+ dec.scan(io.Discard)
+ }
i++
})
+ for ; i < n; i++ {
+ val.Index(i).Set(reflect.Zero(typ.Elem()))
+ }
case reflect.Pointer:
switch dec.peekRune() {
case 'n':
dec.decodeNull()
- for val.IsNil() && typ.Elem().Kind() == reflect.Pointer {
- val.Set(reflect.New(typ.Elem()))
+ for typ.Elem().Kind() == reflect.Pointer {
+ if val.IsNil() || !val.Elem().CanSet() {
+ val.Set(reflect.New(typ.Elem()))
+ }
val = val.Elem()
+ typ = val.Type()
}
- val.Elem().Set(reflect.Zero(val.Type().Elem()))
+ val.Set(reflect.Zero(typ))
default:
if val.IsNil() {
val.Set(reflect.New(typ.Elem()))
}
- dec.decode(val.Elem())
+ dec.decode(val.Elem(), false)
}
default:
dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))
diff --git a/lib/lowmemjson/encode.go b/lib/lowmemjson/encode.go
index c09fcc1..4bab7cb 100644
--- a/lib/lowmemjson/encode.go
+++ b/lib/lowmemjson/encode.go
@@ -256,8 +256,19 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
case val.Type().Elem().Kind() == reflect.Uint8:
encodeWriteByte(w, '"')
enc := base64.NewEncoder(base64.StdEncoding, w)
- if _, err := enc.Write(val.Interface().([]byte)); err != nil {
- panic(encodeError{err})
+ if val.CanConvert(byteSliceType) {
+ if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil {
+ panic(encodeError{err})
+ }
+ } else {
+ // TODO: Surely there's a better way.
+ for i, n := 0, val.Len(); i < n; i++ {
+ var buf [1]byte
+ buf[0] = val.Index(i).Convert(byteType).Interface().(byte)
+ if _, err := enc.Write(buf[:]); err != nil {
+ panic(encodeError{err})
+ }
+ }
}
if err := enc.Close(); err != nil {
panic(encodeError{err})
@@ -282,7 +293,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
}
}
-func encodeString[T interface{ ~[]byte | ~string }](w io.Writer, str T) {
+func encodeString[T interface{ []byte | string }](w io.Writer, str T) {
encodeWriteByte(w, '"')
for i := 0; i < len(str); {
c, size := decodeRune(str[i:])
diff --git a/lib/lowmemjson/misc.go b/lib/lowmemjson/misc.go
index 132d441..132b177 100644
--- a/lib/lowmemjson/misc.go
+++ b/lib/lowmemjson/misc.go
@@ -15,11 +15,15 @@ const Tab = "\t"
const hex = "0123456789abcdef"
-var numberType = reflect.TypeOf(json.Number(""))
+var (
+ numberType = reflect.TypeOf(json.Number(""))
+ byteType = reflect.TypeOf(byte(0))
+ byteSliceType = reflect.TypeOf(([]byte)(nil))
+)
// generic I/O /////////////////////////////////////////////////////////////////
-func decodeRune[T interface{ ~[]byte | ~string }](s T) (r rune, size int) {
+func decodeRune[T interface{ []byte | string }](s T) (r rune, size int) {
iface := any(s)
if str, ok := iface.(string); ok {
return utf8.DecodeRuneInString(str)