From d01fa91dcdfd428fb4b1c46b3961a1497c7a1102 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Fri, 17 Feb 2023 19:21:37 -0700 Subject: decode: Don't bail on type errors --- ReleaseNotes.md | 14 +++++ compat/json/compat_test.go | 8 +++ decode.go | 132 ++++++++++++++++++++++++++++++--------------- decode_test.go | 47 ++++++++++++++++ 4 files changed, 157 insertions(+), 44 deletions(-) diff --git a/ReleaseNotes.md b/ReleaseNotes.md index 8f3be1a..20bcd65 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -1,3 +1,17 @@ +# v0.3.8 (TBD) + + Theme: Fixes from fuzzing (part 2/?) + + User-facing changes: + + - Change: Decoder: No longer bails when a type error + (`DecodeTypeError`) is encountered. The part of the output value + with the type error is either unmodified (if already existing) or + set to nil/zero (if not already existing), and decoding + continues. If no later fatal error (syntax, I/O) is encountered, + then the first type error encountered is returned. This is + consistent with the behavior of `encoding/json`. + # v0.3.7 (2023-02-20) Theme: Fixes from fuzzing (part 1?) diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go index df9d387..098ac85 100644 --- a/compat/json/compat_test.go +++ b/compat/json/compat_test.go @@ -181,6 +181,10 @@ func TestCompatUnmarshal(t *testing.T) { "two-objs": {In: `{} {}`, ExpOut: nil, ExpErr: `invalid character '{' after top-level value`}, "two-numbers1": {In: `00`, ExpOut: nil, ExpErr: `invalid character '0' after top-level value`}, "two-numbers2": {In: `1 2`, ExpOut: nil, ExpErr: `invalid character '2' after top-level value`}, + // 2e308 is slightly more than math.MaxFloat64 (~1.79e308) + "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`}, } for tcName, tc := range testcases { tc := tc @@ -219,6 +223,10 @@ func TestCompatDecode(t *testing.T) { "two-objs": {In: `{} {}`, ExpOut: map[string]any{}}, "two-numbers1": {In: `00`, ExpOut: float64(0)}, "two-numbers2": {In: `1 2`, ExpOut: float64(1)}, + // 2e308 is slightly more than math.MaxFloat64 (~1.79e308) + "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`}, + "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`}, } for tcName, tc := range testcases { tc := tc diff --git a/decode.go b/decode.go index 126b904..c987d79 100644 --- a/decode.go +++ b/decode.go @@ -96,6 +96,7 @@ type Decoder struct { // state posStack []int64 structStack []decodeStackItem + typeErr *DecodeError } const maxNestingDepth = 10000 @@ -241,19 +242,26 @@ func (dec *Decoder) Decode(ptr any) (err error) { } } + dec.typeErr = nil dec.io.Reset() dec.io.PushReadBarrier() if err := dec.decode(ptrVal.Elem(), false); err != nil { return err } dec.io.PopReadBarrier() + if dec.typeErr != nil { + return dec.typeErr + } return nil } // io helpers ////////////////////////////////////////////////////////////////////////////////////// -func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) *DecodeError { - return &DecodeError{ +func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) { + if dec.typeErr != nil { + return + } + dec.typeErr = &DecodeError{ Field: dec.structStackStr(), FieldParent: dec.structStackParent(), FieldName: dec.structStackName(), @@ -313,17 +321,6 @@ func (dec *Decoder) expectRuneOrPanic(ec rune, et jsonparse.RuneType) *DecodeErr return nil } -func (dec *Decoder) expectRuneType(ec rune, et jsonparse.RuneType, gt reflect.Type) *DecodeError { - ac, at, err := dec.readRune() - if err != nil { - return err - } - if ac != ec || at != et { - return dec.newTypeError(at.JSONType(), gt, nil) - } - return nil -} - type decRuneScanner struct { dec *Decoder eof bool @@ -373,10 +370,13 @@ func (dec *Decoder) withLimitingScanner(gTyp reflect.Type, fn func(io.RuneScanne }() l := &decRuneScanner{dec: dec} if err := fn(l); err != nil { - return dec.newTypeError(t.JSONType(), gTyp, err) + dec.newTypeError(t.JSONType(), gTyp, err) } if _, _, err := l.ReadRune(); err != io.EOF { - return dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType())) + dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType())) + for err != io.EOF { + _, _, err = l.ReadRune() + } } return nil } @@ -417,7 +417,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): obj := val.Addr().Interface().(Decodable) @@ -433,7 +433,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(jsonUnmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -445,7 +445,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(typ), err) + dec.newTypeError("string", reflect.PointerTo(typ), err) } default: switch kind := typ.Kind(); kind { @@ -465,7 +465,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if t, err := dec.peekRuneType(); err != nil { return err } else if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) } var buf strings.Builder if err := dec.scan(&buf); err != nil { @@ -473,7 +474,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: @@ -483,7 +485,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if t, err := dec.peekRuneType(); err != nil { return err } else if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) } var buf strings.Builder if err := dec.scan(&buf); err != nil { @@ -491,7 +494,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetUint(n) case reflect.Float32, reflect.Float64: @@ -501,7 +505,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if t, err := dec.peekRuneType(); err != nil { return err } else if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) } var buf strings.Builder if err := dec.scan(&buf); err != nil { @@ -509,7 +514,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetFloat(n) case reflect.String: @@ -526,9 +532,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, + dec.newTypeError(t.JSONType(), typ, fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", buf.String())) + return nil } val.SetString(buf.String()) } else { @@ -543,7 +550,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if typ.NumMethod() > 0 { - return dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + return dec.scan(fastio.Discard) } // If the interface stores a pointer, try to use the type information of the pointer. if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { @@ -587,7 +595,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if err != nil { return err } - val.Set(reflect.ValueOf(v)) + if v != nil { + val.Set(reflect.ValueOf(v)) + } } case reflect.Struct: if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -618,7 +628,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if !ok { if dec.disallowUnknownFields { - return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) + dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) } return dec.scan(fastio.Discard) } @@ -627,9 +637,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357 - return dec.newTypeError("", fVal.Type().Elem(), + dec.newTypeError("", fVal.Type().Elem(), fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", fVal.Type().Elem())) + return dec.scan(fastio.Discard) } t, err := dec.peekRuneType() if err != nil { @@ -670,15 +681,16 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if err := NewDecoder(bytes.NewReader(buf.Bytes())).Decode(fVal.Addr().Interface()); err != nil { if str := buf.String(); str != "null" { - return dec.newTypeError("", fVal.Type(), + dec.newTypeError("", fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", str, fVal.Type())) } } default: - return dec.newTypeError(t.JSONType(), fVal.Type(), + dec.newTypeError(t.JSONType(), fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", fVal.Type())) + return dec.scan(fastio.Discard) } return nil } else { @@ -715,7 +727,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): obj := nameValPtr.Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) + dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) } default: switch nameValTyp.Kind() { @@ -724,17 +736,19 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetUint(n) default: - return dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) + dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) } } return nil @@ -753,7 +767,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } case reflect.Slice: t, err := dec.peekRuneType() @@ -792,7 +806,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } } default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } default: switch t { @@ -823,7 +837,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } } case reflect.Array: @@ -874,7 +888,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return dec.decode(val.Elem(), false) } default: - return dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) + dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) } } return nil @@ -962,7 +976,8 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) { } f64, err := num.Float64() if err != nil { - return nil, dec.newTypeError("number "+buf.String(), float64Type, err) + dec.newTypeError("number "+buf.String(), float64Type, err) + return nil, nil } return f64, nil case jsonparse.RuneTypeTrueT, jsonparse.RuneTypeFalseF: @@ -989,6 +1004,11 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeObject(nil, @@ -1006,12 +1026,19 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('{', jsonparse.RuneTypeObjectBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeObjectBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1067,6 +1094,11 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeArray(nil, func() *DecodeError { @@ -1077,12 +1109,19 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('[', jsonparse.RuneTypeArrayBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeArrayBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1113,8 +1152,12 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeEr } func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { - if err := dec.expectRuneType('"', jsonparse.RuneTypeStringBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeStringBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } var uhex [3]byte for { @@ -1251,7 +1294,8 @@ func (dec *Decoder) decodeBool(gTyp reflect.Type) (bool, *DecodeError) { } return false, nil default: - return false, dec.newTypeError(t.JSONType(), gTyp, nil) + dec.newTypeError(t.JSONType(), gTyp, nil) + return false, nil } } diff --git a/decode_test.go b/decode_test.go index 456f363..c224f3a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -48,3 +48,50 @@ func TestDecodeGrowing(t *testing.T) { assert.NoError(t, dec.Decode(&x)) assert.ErrorIs(t, dec.Decode(&x), io.EOF) } + +type testAry []int + +func (a *testAry) DecodeJSON(r io.RuneScanner) error { + return DecodeArray(r, func(r io.RuneScanner) error { + var x int + if err := NewDecoder(r).Decode(&x); err != nil { + return err + } + *a = append(*a, x) + return nil + }) +} + +type testObj map[string]int + +func (o *testObj) DecodeJSON(r io.RuneScanner) error { + *o = make(testObj) + var key string + return DecodeObject(r, + func(r io.RuneScanner) error { + return NewDecoder(r).Decode(&key) + }, + func(r io.RuneScanner) error { + var val int + if err := NewDecoder(r).Decode(&val); err != nil { + return err + } + (*o)[key] = val + return nil + }, + ) +} + +func TestDecodeTypeError(t *testing.T) { + t.Parallel() + type outType struct { + First int + Second testAry + Third testObj + } + var out outType + err := NewDecoder(strings.NewReader(`{"First": 1.2, "Second": [3], "Third": {"a":4}}`)).Decode(&out) + assert.EqualError(t, err, + `json: v["First"]: cannot decode JSON number 1.2 at input byte 9 into Go int: strconv.ParseInt: parsing "1.2": invalid syntax`) + assert.Equal(t, outType{First: 0, Second: testAry{3}, Third: testObj{"a": 4}}, out) +} -- cgit v1.2.3-2-g168b