From 20e2cf0c4e0ba704455ca6e163bbab9ddde05c80 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Tue, 2 Aug 2022 01:13:45 -0600 Subject: wip --- lib/lowmemjson/borrowed_decode_test.go | 10 ++- lib/lowmemjson/decode.go | 112 ++++++++++++++++++++++++--------- lib/lowmemjson/encode.go | 17 ++++- lib/lowmemjson/misc.go | 8 ++- 4 files changed, 112 insertions(+), 35 deletions(-) diff --git a/lib/lowmemjson/borrowed_decode_test.go b/lib/lowmemjson/borrowed_decode_test.go index 804fb87..b555f87 100644 --- a/lib/lowmemjson/borrowed_decode_test.go +++ b/lib/lowmemjson/borrowed_decode_test.go @@ -1959,6 +1959,9 @@ func TestByteKind(t *testing.T) { if err != nil { t.Error(err) } + if !reflect.DeepEqual(data, []byte(`"aGVsbG8="`)) { // MODIFIED + t.Errorf("expected %q == %q", data, `"aGVsbG8="`) // MODIFIED + } // MODIFIED var b byteKind err = Unmarshal(data, &b) if err != nil { @@ -1980,6 +1983,9 @@ func TestSliceOfCustomByte(t *testing.T) { if err != nil { t.Fatal(err) } + if !reflect.DeepEqual(data, []byte(`"aGVsbG8="`)) { // MODIFIED + t.Errorf("expected %q == %q", data, `"aGVsbG8="`) // MODIFIED + } // MODIFIED var b []Uint8 err = Unmarshal(data, &b) if err != nil { @@ -2005,7 +2011,7 @@ var decodeTypeErrorTests = []struct { func TestUnmarshalTypeError(t *testing.T) { for _, item := range decodeTypeErrorTests { err := Unmarshal([]byte(item.src), item.dest) - if _, ok := err.(*UnmarshalTypeError); !ok { + if err == nil { // if _, ok := err.(*UnmarshalTypeError); !ok { // MODIFIED t.Errorf("expected type error for Unmarshal(%q, type %T): got %T", item.src, item.dest, err) } @@ -2027,7 +2033,7 @@ func TestUnmarshalSyntax(t *testing.T) { var x any for _, src := range unmarshalSyntaxTests { err := Unmarshal([]byte(src), &x) - if _, ok := err.(*SyntaxError); !ok { + if err == nil { // _, ok := err.(*SyntaxError); !ok { // MODIFIED t.Errorf("expected syntax error for Unmarshal(%q): got %T", src, err) } } diff --git a/lib/lowmemjson/decode.go b/lib/lowmemjson/decode.go index 2dee59c..f9ea8a2 100644 --- a/lib/lowmemjson/decode.go +++ b/lib/lowmemjson/decode.go @@ -78,22 +78,21 @@ func (dec *Decoder) stackPop() { dec.stack = dec.stack[:len(dec.stack)-1] } -type decodeError struct{} +type decodeError struct { + Err error +} func (dec *Decoder) panicIO(err error) { - dec.err = fmt.Errorf("json: I/O error at input byte %v: %s: %w", - dec.nxtPos, dec.stackStr(), err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w", + dec.nxtPos, dec.stackStr(), err)}) } func (dec *Decoder) panicSyntax(err error) { - dec.err = fmt.Errorf("json: syntax error at input byte %v: %s: %w", - dec.curPos, dec.stackStr(), err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w", + dec.curPos, dec.stackStr(), err)}) } func (dec *Decoder) panicType(typ reflect.Type, err error) { - dec.err = fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", - dec.curPos, dec.stackStr(), typ, err) - panic(decodeError{}) + panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", + dec.curPos, dec.stackStr(), typ, err)}) } func Decode(r io.Reader, ptr any) error { @@ -104,7 +103,8 @@ func (dec *Decoder) Decode(ptr any) (err error) { ptrVal := reflect.ValueOf(ptr) if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { return &json.InvalidUnmarshalError{ - Type: ptrVal.Type(), + // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil + Type: reflect.TypeOf(ptr), } } @@ -114,7 +114,8 @@ func (dec *Decoder) Decode(ptr any) (err error) { defer func() { if r := recover(); r != nil { - if _, ok := r.(decodeError); ok { + if de, ok := r.(decodeError); ok { + dec.err = de.Err err = dec.err } else { panic(r) @@ -122,7 +123,7 @@ func (dec *Decoder) Decode(ptr any) (err error) { } }() dec.decodeWS() - dec.decode(ptrVal.Elem()) + dec.decode(ptrVal.Elem(), false) return nil } @@ -226,7 +227,7 @@ var kind2bits = map[reflect.Kind]int{ reflect.Float64: 64, } -func (dec *Decoder) decode(val reflect.Value) { +func (dec *Decoder) decode(val reflect.Value, nullOK bool) { typ := val.Type() switch { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: @@ -248,6 +249,10 @@ func (dec *Decoder) decode(val reflect.Value) { dec.panicSyntax(err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf bytes.Buffer dec.decodeString(&buf) obj := val.Addr().Interface().(encoding.TextUnmarshaler) @@ -258,8 +263,16 @@ func (dec *Decoder) decode(val reflect.Value) { kind := typ.Kind() switch kind { case reflect.Bool: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } val.SetBool(dec.decodeBool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) @@ -268,6 +281,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) @@ -276,6 +293,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetUint(n) case reflect.Float32, reflect.Float64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) @@ -284,6 +305,10 @@ func (dec *Decoder) decode(val reflect.Value) { } val.SetFloat(n) case reflect.String: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } var buf strings.Builder if typ == numberType { dec.scanNumber(&buf) @@ -301,19 +326,23 @@ func (dec *Decoder) decode(val reflect.Value) { if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer { // XXX: I can't justify this case, other than "it's what encoding/json does, but // I don't understand their rationale". - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } else { dec.decodeNull() val.Set(reflect.Zero(typ)) } default: if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } else { val.Set(reflect.ValueOf(dec.decodeAny())) } } case reflect.Struct: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } index := indexStruct(typ) var nameBuf strings.Builder dec.decodeObject(&nameBuf, func() { @@ -358,13 +387,13 @@ func (dec *Decoder) decode(val reflect.Value) { subD := *dec // capture the .curPos *before* calling .decodeString dec.decodeString(&buf) subD.r = &buf - subD.decode(fVal) + subD.decode(fVal, false) default: dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q", 'n', '"', dec.peekRune())) } } else { - dec.decode(fVal) + dec.decode(fVal, true) } }) case reflect.Map: @@ -410,7 +439,7 @@ func (dec *Decoder) decode(val reflect.Value) { defer dec.stackPop() fValPtr := reflect.New(typ.Elem()) - dec.decode(fValPtr.Elem()) + dec.decode(fValPtr.Elem(), false) val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) }) @@ -422,7 +451,16 @@ func (dec *Decoder) decode(val reflect.Value) { case typ.Elem().Kind() == reflect.Uint8: var buf bytes.Buffer dec.decodeString(newBase64Decoder(&buf)) - val.Set(reflect.ValueOf(buf.Bytes())) + if typ.Elem() == byteType { + val.Set(reflect.ValueOf(buf.Bytes())) + } else { + bs := buf.Bytes() + // TODO: Surely there's a better way. + val.Set(reflect.MakeSlice(typ, len(bs), len(bs))) + for i := 0; i < len(bs); i++ { + val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem())) + } + } default: switch dec.peekRune() { case 'n': @@ -432,12 +470,15 @@ func (dec *Decoder) decode(val reflect.Value) { if val.IsNil() { val.Set(reflect.MakeSlice(typ, 0, 0)) } + if val.Len() > 0 { + val.Set(val.Slice(0, 0)) + } i := 0 dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() mValPtr := reflect.New(typ.Elem()) - dec.decode(mValPtr.Elem()) + dec.decode(mValPtr.Elem(), false) val.Set(reflect.Append(val, mValPtr.Elem())) i++ }) @@ -446,29 +487,44 @@ func (dec *Decoder) decode(val reflect.Value) { } } case reflect.Array: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } i := 0 + n := val.Len() dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() - mValPtr := reflect.New(typ.Elem()) - dec.decode(mValPtr.Elem()) - val.Index(i).Set(mValPtr.Elem()) + if i < n { + mValPtr := reflect.New(typ.Elem()) + dec.decode(mValPtr.Elem(), false) + val.Index(i).Set(mValPtr.Elem()) + } else { + dec.scan(io.Discard) + } i++ }) + for ; i < n; i++ { + val.Index(i).Set(reflect.Zero(typ.Elem())) + } case reflect.Pointer: switch dec.peekRune() { case 'n': dec.decodeNull() - for val.IsNil() && typ.Elem().Kind() == reflect.Pointer { - val.Set(reflect.New(typ.Elem())) + for typ.Elem().Kind() == reflect.Pointer { + if val.IsNil() || !val.Elem().CanSet() { + val.Set(reflect.New(typ.Elem())) + } val = val.Elem() + typ = val.Type() } - val.Elem().Set(reflect.Zero(val.Type().Elem())) + val.Set(reflect.Zero(typ)) default: if val.IsNil() { val.Set(reflect.New(typ.Elem())) } - dec.decode(val.Elem()) + dec.decode(val.Elem(), false) } default: dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) diff --git a/lib/lowmemjson/encode.go b/lib/lowmemjson/encode.go index c09fcc1..4bab7cb 100644 --- a/lib/lowmemjson/encode.go +++ b/lib/lowmemjson/encode.go @@ -256,8 +256,19 @@ func encode(w io.Writer, val reflect.Value, quote bool) { case val.Type().Elem().Kind() == reflect.Uint8: encodeWriteByte(w, '"') enc := base64.NewEncoder(base64.StdEncoding, w) - if _, err := enc.Write(val.Interface().([]byte)); err != nil { - panic(encodeError{err}) + if val.CanConvert(byteSliceType) { + if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { + panic(encodeError{err}) + } + } else { + // TODO: Surely there's a better way. + for i, n := 0, val.Len(); i < n; i++ { + var buf [1]byte + buf[0] = val.Index(i).Convert(byteType).Interface().(byte) + if _, err := enc.Write(buf[:]); err != nil { + panic(encodeError{err}) + } + } } if err := enc.Close(); err != nil { panic(encodeError{err}) @@ -282,7 +293,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { } } -func encodeString[T interface{ ~[]byte | ~string }](w io.Writer, str T) { +func encodeString[T interface{ []byte | string }](w io.Writer, str T) { encodeWriteByte(w, '"') for i := 0; i < len(str); { c, size := decodeRune(str[i:]) diff --git a/lib/lowmemjson/misc.go b/lib/lowmemjson/misc.go index 132d441..132b177 100644 --- a/lib/lowmemjson/misc.go +++ b/lib/lowmemjson/misc.go @@ -15,11 +15,15 @@ const Tab = "\t" const hex = "0123456789abcdef" -var numberType = reflect.TypeOf(json.Number("")) +var ( + numberType = reflect.TypeOf(json.Number("")) + byteType = reflect.TypeOf(byte(0)) + byteSliceType = reflect.TypeOf(([]byte)(nil)) +) // generic I/O ///////////////////////////////////////////////////////////////// -func decodeRune[T interface{ ~[]byte | ~string }](s T) (r rune, size int) { +func decodeRune[T interface{ []byte | string }](s T) (r rune, size int) { iface := any(s) if str, ok := iface.(string); ok { return utf8.DecodeRuneInString(str) -- cgit v1.1-4-g5e80