diff options
-rw-r--r-- | ReleaseNotes.md | 14 | ||||
-rw-r--r-- | compat/json/compat_test.go | 8 | ||||
-rw-r--r-- | decode.go | 267 | ||||
-rw-r--r-- | decode_test.go | 47 |
4 files changed, 210 insertions, 126 deletions
diff --git a/ReleaseNotes.md b/ReleaseNotes.md index 8f3be1a..20bcd65 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -1,3 +1,17 @@ +# v0.3.8 (TBD) + + Theme: Fixes from fuzzing (part 2/?) + + User-facing changes: + + - Change: Decoder: No longer bails when a type error + (`DecodeTypeError`) is encountered. The part of the output value + with the type error is either unmodified (if already existing) or + set to nil/zero (if not already existing), and decoding + continues. If no later fatal error (syntax, I/O) is encountered, + then the first type error encountered is returned. This is + consistent with the behavior of `encoding/json`. + # v0.3.7 (2023-02-20) Theme: Fixes from fuzzing (part 1?) diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go index df9d387..098ac85 100644 --- a/compat/json/compat_test.go +++ b/compat/json/compat_test.go @@ -181,6 +181,10 @@ func TestCompatUnmarshal(t *testing.T) { "two-objs": {In: `{} {}`, ExpOut: nil, ExpErr: `invalid character '{' after top-level value`}, "two-numbers1": {In: `00`, ExpOut: nil, ExpErr: `invalid character '0' after top-level value`}, "two-numbers2": {In: `1 2`, ExpOut: nil, ExpErr: `invalid character '2' after top-level value`}, + // 2e308 is slightly more than math.MaxFloat64 (~1.79e308) + "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`}, } for tcName, tc := range testcases { tc := tc @@ -219,6 +223,10 @@ func TestCompatDecode(t *testing.T) { "two-objs": {In: `{} {}`, ExpOut: map[string]any{}}, "two-numbers1": {In: `00`, ExpOut: float64(0)}, "two-numbers2": {In: `1 2`, ExpOut: float64(1)}, + // 2e308 is slightly more than math.MaxFloat64 (~1.79e308) + "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`}, } for tcName, tc := range testcases { tc := tc @@ -96,6 +96,7 @@ type Decoder struct { // state posStack []int64 structStack []decodeStackItem + typeErr *DecodeError } const maxNestingDepth = 10000 @@ -241,19 +242,26 @@ func (dec *Decoder) Decode(ptr any) (err error) { } } + dec.typeErr = nil dec.io.Reset() dec.io.PushReadBarrier() if err := dec.decode(ptrVal.Elem(), false); err != nil { return err } dec.io.PopReadBarrier() + if dec.typeErr != nil { + return dec.typeErr + } return nil } // io helpers ////////////////////////////////////////////////////////////////////////////////////// -func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) *DecodeError { - return &DecodeError{ +func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) { + if dec.typeErr != nil { + return + } + dec.typeErr = &DecodeError{ Field: dec.structStackStr(), FieldParent: dec.structStackParent(), FieldName: dec.structStackName(), @@ -296,7 +304,13 @@ func (dec *Decoder) peekRuneType() (jsonparse.RuneType, *DecodeError) { return t, nil } -func (dec *Decoder) expectRune(ec rune, et jsonparse.RuneType) *DecodeError { +// expectRuneOrPanic is for when you *know* what the next +// non-whitespace rune is going to be; for it to be anything else +// would be a syntax error. It will return an error for I/O errors +// and syntax errors, but panic if the result is not what was +// expected; as that would indicate a bug in the agreement between the +// parser and the decoder. +func (dec *Decoder) expectRuneOrPanic(ec rune, et jsonparse.RuneType) *DecodeError { ac, at, err := dec.readRune() if err != nil { return err @@ -307,17 +321,6 @@ func (dec *Decoder) expectRune(ec rune, et jsonparse.RuneType) *DecodeError { return nil } -func (dec *Decoder) expectRuneType(ec rune, et jsonparse.RuneType, gt reflect.Type) *DecodeError { - ac, at, err := dec.readRune() - if err != nil { - return err - } - if ac != ec || at != et { - return dec.newTypeError(at.JSONType(), gt, nil) - } - return nil -} - type decRuneScanner struct { dec *Decoder eof bool @@ -350,7 +353,11 @@ func (sc *decRuneScanner) UnreadRune() error { return sc.dec.io.UnreadRune() } -func (dec *Decoder) withLimitingScanner(fn func(io.RuneScanner) *DecodeError) (err *DecodeError) { +func (dec *Decoder) withLimitingScanner(gTyp reflect.Type, fn func(io.RuneScanner) error) (err *DecodeError) { + t, err := dec.peekRuneType() + if err != nil { + return err + } dec.io.PushReadBarrier() defer func() { if r := recover(); r != nil { @@ -361,8 +368,15 @@ func (dec *Decoder) withLimitingScanner(fn func(io.RuneScanner) *DecodeError) (e } } }() - if err := fn(&decRuneScanner{dec: dec}); err != nil { - return err + l := &decRuneScanner{dec: dec} + if err := fn(l); err != nil { + dec.newTypeError(t.JSONType(), gTyp, err) + } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType())) + for err != io.EOF { + _, _, err = l.ReadRune() + } } return nil } @@ -403,23 +417,11 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): - t, err := dec.peekRuneType() - if err != nil { - return err - } obj := val.Addr().Interface().(Decodable) - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := obj.DecodeJSON(l); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + return dec.withLimitingScanner(reflect.PointerTo(typ), obj.DecodeJSON) case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): t, err := dec.peekRuneType() if err != nil { @@ -431,7 +433,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(jsonUnmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -443,7 +445,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(typ), err) + dec.newTypeError("string", reflect.PointerTo(typ), err) } default: switch kind := typ.Kind(); kind { @@ -460,39 +462,60 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetUint(n) case reflect.Float32, reflect.Float64: if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetFloat(n) case reflect.String: @@ -509,9 +532,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, + dec.newTypeError(t.JSONType(), typ, fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", buf.String())) + return nil } val.SetString(buf.String()) } else { @@ -526,7 +550,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if typ.NumMethod() > 0 { - return dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + return dec.scan(fastio.Discard) } // If the interface stores a pointer, try to use the type information of the pointer. if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { @@ -570,7 +595,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if err != nil { return err } - val.Set(reflect.ValueOf(v)) + if v != nil { + val.Set(reflect.ValueOf(v)) + } } case reflect.Struct: if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -601,7 +628,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if !ok { if dec.disallowUnknownFields { - return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) + dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) } return dec.scan(fastio.Discard) } @@ -610,9 +637,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357 - return dec.newTypeError("", fVal.Type().Elem(), + dec.newTypeError("", fVal.Type().Elem(), fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", fVal.Type().Elem())) + return dec.scan(fastio.Discard) } t, err := dec.peekRuneType() if err != nil { @@ -653,15 +681,16 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if err := NewDecoder(bytes.NewReader(buf.Bytes())).Decode(fVal.Addr().Interface()); err != nil { if str := buf.String(); str != "null" { - return dec.newTypeError("", fVal.Type(), + dec.newTypeError("", fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", str, fVal.Type())) } } default: - return dec.newTypeError(t.JSONType(), fVal.Type(), + dec.newTypeError(t.JSONType(), fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", fVal.Type())) + return dec.scan(fastio.Discard) } return nil } else { @@ -698,7 +727,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): obj := nameValPtr.Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) + dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) } default: switch nameValTyp.Kind() { @@ -707,17 +736,19 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetUint(n) default: - return dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) + dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) } } return nil @@ -736,7 +767,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } case reflect.Slice: t, err := dec.peekRuneType() @@ -775,7 +806,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } } default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } default: switch t { @@ -806,7 +837,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } } case reflect.Array: @@ -857,7 +888,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return dec.decode(val.Elem(), false) } default: - return dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) + dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) } } return nil @@ -879,17 +910,6 @@ func (dec *Decoder) scan(out fastio.RuneWriter) *DecodeError { return nil } -func (dec *Decoder) scanNumber(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { - t, err := dec.peekRuneType() - if err != nil { - return err - } - if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), gTyp, nil) - } - return dec.scan(out) -} - func (dec *Decoder) decodeAny() (any, *DecodeError) { t, err := dec.peekRuneType() if err != nil { @@ -956,7 +976,8 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) { } f64, err := num.Float64() if err != nil { - return nil, dec.newTypeError("number "+buf.String(), float64Type, err) + dec.newTypeError("number "+buf.String(), float64Type, err) + return nil, nil } return f64, nil case jsonparse.RuneTypeTrueT, jsonparse.RuneTypeFalseF: @@ -983,51 +1004,41 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeObject(nil, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeKey(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError("string", nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError("string", nil, fmt.Errorf("did not consume entire string")) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeKey) }, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - t, err := dec.peekRuneType() - if err != nil { - return err - } - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeVal(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeVal) }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('{', jsonparse.RuneTypeObjectBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeObjectBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1042,7 +1053,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() if err := decodeKey(); err != nil { return err } - if err := dec.expectRune(':', jsonparse.RuneTypeObjectColon); err != nil { + if err := dec.expectRuneOrPanic(':', jsonparse.RuneTypeObjectColon); err != nil { return err } if err := decodeVal(); err != nil { @@ -1054,7 +1065,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() } switch t { case jsonparse.RuneTypeObjectComma: - if err := dec.expectRune('"', jsonparse.RuneTypeStringBeg); err != nil { + if err := dec.expectRuneOrPanic('"', jsonparse.RuneTypeStringBeg); err != nil { return err } goto decodeMember @@ -1083,35 +1094,34 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeArray(nil, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - t, err := dec.peekRuneType() - if err != nil { - return err - } - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeMember(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeMember) }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('[', jsonparse.RuneTypeArrayBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeArrayBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1142,8 +1152,12 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeEr } func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { - if err := dec.expectRuneType('"', jsonparse.RuneTypeStringBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeStringBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } var uhex [3]byte for { @@ -1195,7 +1209,7 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *Deco _, _ = out.WriteRune(utf8.RuneError) break } - if err := dec.expectRune('\\', jsonparse.RuneTypeStringEsc); err != nil { + if err := dec.expectRuneOrPanic('\\', jsonparse.RuneTypeStringEsc); err != nil { return err } t, err = dec.peekRuneType() @@ -1206,7 +1220,7 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *Deco _, _ = out.WriteRune(utf8.RuneError) break } - if err := dec.expectRune('u', jsonparse.RuneTypeStringEscU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeStringEscU); err != nil { return err } @@ -1255,46 +1269,47 @@ func (dec *Decoder) decodeBool(gTyp reflect.Type) (bool, *DecodeError) { } switch c { case 't': - if err := dec.expectRune('r', jsonparse.RuneTypeTrueR); err != nil { + if err := dec.expectRuneOrPanic('r', jsonparse.RuneTypeTrueR); err != nil { return false, err } - if err := dec.expectRune('u', jsonparse.RuneTypeTrueU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeTrueU); err != nil { return false, err } - if err := dec.expectRune('e', jsonparse.RuneTypeTrueE); err != nil { + if err := dec.expectRuneOrPanic('e', jsonparse.RuneTypeTrueE); err != nil { return false, err } return true, nil case 'f': - if err := dec.expectRune('a', jsonparse.RuneTypeFalseA); err != nil { + if err := dec.expectRuneOrPanic('a', jsonparse.RuneTypeFalseA); err != nil { return false, err } - if err := dec.expectRune('l', jsonparse.RuneTypeFalseL); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeFalseL); err != nil { return false, err } - if err := dec.expectRune('s', jsonparse.RuneTypeFalseS); err != nil { + if err := dec.expectRuneOrPanic('s', jsonparse.RuneTypeFalseS); err != nil { return false, err } - if err := dec.expectRune('e', jsonparse.RuneTypeFalseE); err != nil { + if err := dec.expectRuneOrPanic('e', jsonparse.RuneTypeFalseE); err != nil { return false, err } return false, nil default: - return false, dec.newTypeError(t.JSONType(), gTyp, nil) + dec.newTypeError(t.JSONType(), gTyp, nil) + return false, nil } } func (dec *Decoder) decodeNull() *DecodeError { - if err := dec.expectRune('n', jsonparse.RuneTypeNullN); err != nil { + if err := dec.expectRuneOrPanic('n', jsonparse.RuneTypeNullN); err != nil { return err } - if err := dec.expectRune('u', jsonparse.RuneTypeNullU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeNullU); err != nil { return err } - if err := dec.expectRune('l', jsonparse.RuneTypeNullL1); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeNullL1); err != nil { return err } - if err := dec.expectRune('l', jsonparse.RuneTypeNullL2); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeNullL2); err != nil { return err } return nil diff --git a/decode_test.go b/decode_test.go index 456f363..c224f3a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -48,3 +48,50 @@ func TestDecodeGrowing(t *testing.T) { assert.NoError(t, dec.Decode(&x)) assert.ErrorIs(t, dec.Decode(&x), io.EOF) } + +type testAry []int + +func (a *testAry) DecodeJSON(r io.RuneScanner) error { + return DecodeArray(r, func(r io.RuneScanner) error { + var x int + if err := NewDecoder(r).Decode(&x); err != nil { + return err + } + *a = append(*a, x) + return nil + }) +} + +type testObj map[string]int + +func (o *testObj) DecodeJSON(r io.RuneScanner) error { + *o = make(testObj) + var key string + return DecodeObject(r, + func(r io.RuneScanner) error { + return NewDecoder(r).Decode(&key) + }, + func(r io.RuneScanner) error { + var val int + if err := NewDecoder(r).Decode(&val); err != nil { + return err + } + (*o)[key] = val + return nil + }, + ) +} + +func TestDecodeTypeError(t *testing.T) { + t.Parallel() + type outType struct { + First int + Second testAry + Third testObj + } + var out outType + err := NewDecoder(strings.NewReader(`{"First": 1.2, "Second": [3], "Third": {"a":4}}`)).Decode(&out) + assert.EqualError(t, err, + `json: v["First"]: cannot decode JSON number 1.2 at input byte 9 into Go int: strconv.ParseInt: parsing "1.2": invalid syntax`) + assert.Equal(t, outType{First: 0, Second: testAry{3}, Third: testObj{"a": 4}}, out) +} |