summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-17 19:21:37 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-25 00:47:52 -0700
commitd01fa91dcdfd428fb4b1c46b3961a1497c7a1102 (patch)
tree0724da76664e893e7ecfad2d8ce4ea1b05918598
parentf369aff688697a881833d86c13b18156e8376f08 (diff)
decode: Don't bail on type errors
-rw-r--r--ReleaseNotes.md14
-rw-r--r--compat/json/compat_test.go8
-rw-r--r--decode.go132
-rw-r--r--decode_test.go47
4 files changed, 157 insertions, 44 deletions
diff --git a/ReleaseNotes.md b/ReleaseNotes.md
index 8f3be1a..20bcd65 100644
--- a/ReleaseNotes.md
+++ b/ReleaseNotes.md
@@ -1,3 +1,17 @@
+# v0.3.8 (TBD)
+
+ Theme: Fixes from fuzzing (part 2/?)
+
+ User-facing changes:
+
+ - Change: Decoder: No longer bails when a type error
+ (`DecodeTypeError`) is encountered. The part of the output value
+ with the type error is either unmodified (if already existing) or
+ set to nil/zero (if not already existing), and decoding
+ continues. If no later fatal error (syntax, I/O) is encountered,
+ then the first type error encountered is returned. This is
+ consistent with the behavior of `encoding/json`.
+
# v0.3.7 (2023-02-20)
Theme: Fixes from fuzzing (part 1?)
diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go
index df9d387..098ac85 100644
--- a/compat/json/compat_test.go
+++ b/compat/json/compat_test.go
@@ -181,6 +181,10 @@ func TestCompatUnmarshal(t *testing.T) {
"two-objs": {In: `{} {}`, ExpOut: nil, ExpErr: `invalid character '{' after top-level value`},
"two-numbers1": {In: `00`, ExpOut: nil, ExpErr: `invalid character '0' after top-level value`},
"two-numbers2": {In: `1 2`, ExpOut: nil, ExpErr: `invalid character '2' after top-level value`},
+ // 2e308 is slightly more than math.MaxFloat64 (~1.79e308)
+ "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`},
+ "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`},
+ "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`},
}
for tcName, tc := range testcases {
tc := tc
@@ -219,6 +223,10 @@ func TestCompatDecode(t *testing.T) {
"two-objs": {In: `{} {}`, ExpOut: map[string]any{}},
"two-numbers1": {In: `00`, ExpOut: float64(0)},
"two-numbers2": {In: `1 2`, ExpOut: float64(1)},
+ // 2e308 is slightly more than math.MaxFloat64 (~1.79e308)
+ "obj-overflow": {In: `{"foo":"bar", "baz":2e308, "qux": "orb"}`, ExpOut: map[string]any{"foo": "bar", "baz": nil, "qux": "orb"}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`},
+ "ary-overflow": {In: `["foo",2e308,"bar",3e308]`, ExpOut: []any{"foo", nil, "bar", nil}, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type float64`},
+ "existing-overflow": {In: `2e308`, InPtr: func() any { x := 4; return &x }(), ExpOut: 4, ExpErr: `json: cannot unmarshal number 2e308 into Go value of type int`},
}
for tcName, tc := range testcases {
tc := tc
diff --git a/decode.go b/decode.go
index 126b904..c987d79 100644
--- a/decode.go
+++ b/decode.go
@@ -96,6 +96,7 @@ type Decoder struct {
// state
posStack []int64
structStack []decodeStackItem
+ typeErr *DecodeError
}
const maxNestingDepth = 10000
@@ -241,19 +242,26 @@ func (dec *Decoder) Decode(ptr any) (err error) {
}
}
+ dec.typeErr = nil
dec.io.Reset()
dec.io.PushReadBarrier()
if err := dec.decode(ptrVal.Elem(), false); err != nil {
return err
}
dec.io.PopReadBarrier()
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
// io helpers //////////////////////////////////////////////////////////////////////////////////////
-func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) *DecodeError {
- return &DecodeError{
+func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) {
+ if dec.typeErr != nil {
+ return
+ }
+ dec.typeErr = &DecodeError{
Field: dec.structStackStr(),
FieldParent: dec.structStackParent(),
FieldName: dec.structStackName(),
@@ -313,17 +321,6 @@ func (dec *Decoder) expectRuneOrPanic(ec rune, et jsonparse.RuneType) *DecodeErr
return nil
}
-func (dec *Decoder) expectRuneType(ec rune, et jsonparse.RuneType, gt reflect.Type) *DecodeError {
- ac, at, err := dec.readRune()
- if err != nil {
- return err
- }
- if ac != ec || at != et {
- return dec.newTypeError(at.JSONType(), gt, nil)
- }
- return nil
-}
-
type decRuneScanner struct {
dec *Decoder
eof bool
@@ -373,10 +370,13 @@ func (dec *Decoder) withLimitingScanner(gTyp reflect.Type, fn func(io.RuneScanne
}()
l := &decRuneScanner{dec: dec}
if err := fn(l); err != nil {
- return dec.newTypeError(t.JSONType(), gTyp, err)
+ dec.newTypeError(t.JSONType(), gTyp, err)
}
if _, _, err := l.ReadRune(); err != io.EOF {
- return dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType()))
+ dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType()))
+ for err != io.EOF {
+ _, _, err = l.ReadRune()
+ }
}
return nil
}
@@ -417,7 +417,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil {
- return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
+ dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType):
obj := val.Addr().Interface().(Decodable)
@@ -433,7 +433,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
obj := val.Addr().Interface().(jsonUnmarshaler)
if err := obj.UnmarshalJSON(buf.Bytes()); err != nil {
- return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
+ dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType):
if ok, err := dec.maybeDecodeNull(nullOK); ok {
@@ -445,7 +445,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
obj := val.Addr().Interface().(encoding.TextUnmarshaler)
if err := obj.UnmarshalText(buf.Bytes()); err != nil {
- return dec.newTypeError("string", reflect.PointerTo(typ), err)
+ dec.newTypeError("string", reflect.PointerTo(typ), err)
}
default:
switch kind := typ.Kind(); kind {
@@ -465,7 +465,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -473,7 +474,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
@@ -483,7 +485,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -491,7 +494,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetUint(n)
case reflect.Float32, reflect.Float64:
@@ -501,7 +505,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -509,7 +514,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseFloat(buf.String(), kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetFloat(n)
case reflect.String:
@@ -526,9 +532,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ,
+ dec.newTypeError(t.JSONType(), typ,
fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number",
buf.String()))
+ return nil
}
val.SetString(buf.String())
} else {
@@ -543,7 +550,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if typ.NumMethod() > 0 {
- return dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface)
+ dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface)
+ return dec.scan(fastio.Discard)
}
// If the interface stores a pointer, try to use the type information of the pointer.
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer {
@@ -587,7 +595,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if err != nil {
return err
}
- val.Set(reflect.ValueOf(v))
+ if v != nil {
+ val.Set(reflect.ValueOf(v))
+ }
}
case reflect.Struct:
if ok, err := dec.maybeDecodeNull(nullOK); ok {
@@ -618,7 +628,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
if !ok {
if dec.disallowUnknownFields {
- return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name))
+ dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name))
}
return dec.scan(fastio.Discard)
}
@@ -627,9 +637,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
for _, idx := range field.Path {
if fVal.Kind() == reflect.Pointer {
if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357
- return dec.newTypeError("", fVal.Type().Elem(),
+ dec.newTypeError("", fVal.Type().Elem(),
fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v",
fVal.Type().Elem()))
+ return dec.scan(fastio.Discard)
}
t, err := dec.peekRuneType()
if err != nil {
@@ -670,15 +681,16 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
if err := NewDecoder(bytes.NewReader(buf.Bytes())).Decode(fVal.Addr().Interface()); err != nil {
if str := buf.String(); str != "null" {
- return dec.newTypeError("", fVal.Type(),
+ dec.newTypeError("", fVal.Type(),
fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v",
str, fVal.Type()))
}
}
default:
- return dec.newTypeError(t.JSONType(), fVal.Type(),
+ dec.newTypeError(t.JSONType(), fVal.Type(),
fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v",
fVal.Type()))
+ return dec.scan(fastio.Discard)
}
return nil
} else {
@@ -715,7 +727,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType):
obj := nameValPtr.Interface().(encoding.TextUnmarshaler)
if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil {
- return dec.newTypeError("string", reflect.PointerTo(nameValTyp), err)
+ dec.newTypeError("string", reflect.PointerTo(nameValTyp), err)
}
default:
switch nameValTyp.Kind() {
@@ -724,17 +736,19 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
if err != nil {
- return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ return nil
}
nameValPtr.Elem().SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
if err != nil {
- return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ return nil
}
nameValPtr.Elem().SetUint(n)
default:
- return dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp})
+ dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp})
}
}
return nil
@@ -753,7 +767,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return nil
})
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
case reflect.Slice:
t, err := dec.peekRuneType()
@@ -792,7 +806,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
}
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
default:
switch t {
@@ -823,7 +837,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return nil
})
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
}
case reflect.Array:
@@ -874,7 +888,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return dec.decode(val.Elem(), false)
}
default:
- return dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))
+ dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))
}
}
return nil
@@ -962,7 +976,8 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) {
}
f64, err := num.Float64()
if err != nil {
- return nil, dec.newTypeError("number "+buf.String(), float64Type, err)
+ dec.newTypeError("number "+buf.String(), float64Type, err)
+ return nil, nil
}
return f64, nil
case jsonparse.RuneTypeTrueT, jsonparse.RuneTypeFalseF:
@@ -989,6 +1004,11 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er
} else {
dec = NewDecoder(r)
}
+ if dec.typeErr != nil {
+ oldTypeErr := dec.typeErr
+ dec.typeErr = nil
+ defer func() { dec.typeErr = oldTypeErr }()
+ }
dec.posStackPush()
defer dec.posStackPop()
if err := dec.decodeObject(nil,
@@ -1006,12 +1026,19 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er
}); err != nil {
return err
}
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() *DecodeError) *DecodeError {
- if err := dec.expectRuneType('{', jsonparse.RuneTypeObjectBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeObjectBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
_, t, err := dec.readRune()
if err != nil {
@@ -1067,6 +1094,11 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er
} else {
dec = NewDecoder(r)
}
+ if dec.typeErr != nil {
+ oldTypeErr := dec.typeErr
+ dec.typeErr = nil
+ defer func() { dec.typeErr = oldTypeErr }()
+ }
dec.posStackPush()
defer dec.posStackPop()
if err := dec.decodeArray(nil, func() *DecodeError {
@@ -1077,12 +1109,19 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er
}); err != nil {
return err
}
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeError) *DecodeError {
- if err := dec.expectRuneType('[', jsonparse.RuneTypeArrayBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeArrayBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
_, t, err := dec.readRune()
if err != nil {
@@ -1113,8 +1152,12 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeEr
}
func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError {
- if err := dec.expectRuneType('"', jsonparse.RuneTypeStringBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeStringBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
var uhex [3]byte
for {
@@ -1251,7 +1294,8 @@ func (dec *Decoder) decodeBool(gTyp reflect.Type) (bool, *DecodeError) {
}
return false, nil
default:
- return false, dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ return false, nil
}
}
diff --git a/decode_test.go b/decode_test.go
index 456f363..c224f3a 100644
--- a/decode_test.go
+++ b/decode_test.go
@@ -48,3 +48,50 @@ func TestDecodeGrowing(t *testing.T) {
assert.NoError(t, dec.Decode(&x))
assert.ErrorIs(t, dec.Decode(&x), io.EOF)
}
+
+type testAry []int
+
+func (a *testAry) DecodeJSON(r io.RuneScanner) error {
+ return DecodeArray(r, func(r io.RuneScanner) error {
+ var x int
+ if err := NewDecoder(r).Decode(&x); err != nil {
+ return err
+ }
+ *a = append(*a, x)
+ return nil
+ })
+}
+
+type testObj map[string]int
+
+func (o *testObj) DecodeJSON(r io.RuneScanner) error {
+ *o = make(testObj)
+ var key string
+ return DecodeObject(r,
+ func(r io.RuneScanner) error {
+ return NewDecoder(r).Decode(&key)
+ },
+ func(r io.RuneScanner) error {
+ var val int
+ if err := NewDecoder(r).Decode(&val); err != nil {
+ return err
+ }
+ (*o)[key] = val
+ return nil
+ },
+ )
+}
+
+func TestDecodeTypeError(t *testing.T) {
+ t.Parallel()
+ type outType struct {
+ First int
+ Second testAry
+ Third testObj
+ }
+ var out outType
+ err := NewDecoder(strings.NewReader(`{"First": 1.2, "Second": [3], "Third": {"a":4}}`)).Decode(&out)
+ assert.EqualError(t, err,
+ `json: v["First"]: cannot decode JSON number 1.2 at input byte 9 into Go int: strconv.ParseInt: parsing "1.2": invalid syntax`)
+ assert.Equal(t, outType{First: 0, Second: testAry{3}, Third: testObj{"a": 4}}, out)
+}