summaryrefslogtreecommitdiff
path: root/decode.go
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-17 19:21:37 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-25 00:47:52 -0700
commitd01fa91dcdfd428fb4b1c46b3961a1497c7a1102 (patch)
tree0724da76664e893e7ecfad2d8ce4ea1b05918598 /decode.go
parentf369aff688697a881833d86c13b18156e8376f08 (diff)
decode: Don't bail on type errors
Diffstat (limited to 'decode.go')
-rw-r--r--decode.go132
1 files changed, 88 insertions, 44 deletions
diff --git a/decode.go b/decode.go
index 126b904..c987d79 100644
--- a/decode.go
+++ b/decode.go
@@ -96,6 +96,7 @@ type Decoder struct {
// state
posStack []int64
structStack []decodeStackItem
+ typeErr *DecodeError
}
const maxNestingDepth = 10000
@@ -241,19 +242,26 @@ func (dec *Decoder) Decode(ptr any) (err error) {
}
}
+ dec.typeErr = nil
dec.io.Reset()
dec.io.PushReadBarrier()
if err := dec.decode(ptrVal.Elem(), false); err != nil {
return err
}
dec.io.PopReadBarrier()
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
// io helpers //////////////////////////////////////////////////////////////////////////////////////
-func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) *DecodeError {
- return &DecodeError{
+func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) {
+ if dec.typeErr != nil {
+ return
+ }
+ dec.typeErr = &DecodeError{
Field: dec.structStackStr(),
FieldParent: dec.structStackParent(),
FieldName: dec.structStackName(),
@@ -313,17 +321,6 @@ func (dec *Decoder) expectRuneOrPanic(ec rune, et jsonparse.RuneType) *DecodeErr
return nil
}
-func (dec *Decoder) expectRuneType(ec rune, et jsonparse.RuneType, gt reflect.Type) *DecodeError {
- ac, at, err := dec.readRune()
- if err != nil {
- return err
- }
- if ac != ec || at != et {
- return dec.newTypeError(at.JSONType(), gt, nil)
- }
- return nil
-}
-
type decRuneScanner struct {
dec *Decoder
eof bool
@@ -373,10 +370,13 @@ func (dec *Decoder) withLimitingScanner(gTyp reflect.Type, fn func(io.RuneScanne
}()
l := &decRuneScanner{dec: dec}
if err := fn(l); err != nil {
- return dec.newTypeError(t.JSONType(), gTyp, err)
+ dec.newTypeError(t.JSONType(), gTyp, err)
}
if _, _, err := l.ReadRune(); err != io.EOF {
- return dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType()))
+ dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType()))
+ for err != io.EOF {
+ _, _, err = l.ReadRune()
+ }
}
return nil
}
@@ -417,7 +417,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil {
- return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
+ dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType):
obj := val.Addr().Interface().(Decodable)
@@ -433,7 +433,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
obj := val.Addr().Interface().(jsonUnmarshaler)
if err := obj.UnmarshalJSON(buf.Bytes()); err != nil {
- return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
+ dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err)
}
case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType):
if ok, err := dec.maybeDecodeNull(nullOK); ok {
@@ -445,7 +445,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
obj := val.Addr().Interface().(encoding.TextUnmarshaler)
if err := obj.UnmarshalText(buf.Bytes()); err != nil {
- return dec.newTypeError("string", reflect.PointerTo(typ), err)
+ dec.newTypeError("string", reflect.PointerTo(typ), err)
}
default:
switch kind := typ.Kind(); kind {
@@ -465,7 +465,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -473,7 +474,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
@@ -483,7 +485,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -491,7 +494,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetUint(n)
case reflect.Float32, reflect.Float64:
@@ -501,7 +505,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if t, err := dec.peekRuneType(); err != nil {
return err
} else if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
+ return dec.scan(fastio.Discard)
}
var buf strings.Builder
if err := dec.scan(&buf); err != nil {
@@ -509,7 +514,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
n, err := strconv.ParseFloat(buf.String(), kind2bits[kind])
if err != nil {
- return dec.newTypeError("number "+buf.String(), typ, err)
+ dec.newTypeError("number "+buf.String(), typ, err)
+ return nil
}
val.SetFloat(n)
case reflect.String:
@@ -526,9 +532,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if !t.IsNumber() {
- return dec.newTypeError(t.JSONType(), typ,
+ dec.newTypeError(t.JSONType(), typ,
fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number",
buf.String()))
+ return nil
}
val.SetString(buf.String())
} else {
@@ -543,7 +550,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return err
}
if typ.NumMethod() > 0 {
- return dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface)
+ dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface)
+ return dec.scan(fastio.Discard)
}
// If the interface stores a pointer, try to use the type information of the pointer.
if !val.IsNil() && val.Elem().Kind() == reflect.Pointer {
@@ -587,7 +595,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
if err != nil {
return err
}
- val.Set(reflect.ValueOf(v))
+ if v != nil {
+ val.Set(reflect.ValueOf(v))
+ }
}
case reflect.Struct:
if ok, err := dec.maybeDecodeNull(nullOK); ok {
@@ -618,7 +628,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
if !ok {
if dec.disallowUnknownFields {
- return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name))
+ dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name))
}
return dec.scan(fastio.Discard)
}
@@ -627,9 +637,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
for _, idx := range field.Path {
if fVal.Kind() == reflect.Pointer {
if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357
- return dec.newTypeError("", fVal.Type().Elem(),
+ dec.newTypeError("", fVal.Type().Elem(),
fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v",
fVal.Type().Elem()))
+ return dec.scan(fastio.Discard)
}
t, err := dec.peekRuneType()
if err != nil {
@@ -670,15 +681,16 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
if err := NewDecoder(bytes.NewReader(buf.Bytes())).Decode(fVal.Addr().Interface()); err != nil {
if str := buf.String(); str != "null" {
- return dec.newTypeError("", fVal.Type(),
+ dec.newTypeError("", fVal.Type(),
fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v",
str, fVal.Type()))
}
}
default:
- return dec.newTypeError(t.JSONType(), fVal.Type(),
+ dec.newTypeError(t.JSONType(), fVal.Type(),
fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v",
fVal.Type()))
+ return dec.scan(fastio.Discard)
}
return nil
} else {
@@ -715,7 +727,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType):
obj := nameValPtr.Interface().(encoding.TextUnmarshaler)
if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil {
- return dec.newTypeError("string", reflect.PointerTo(nameValTyp), err)
+ dec.newTypeError("string", reflect.PointerTo(nameValTyp), err)
}
default:
switch nameValTyp.Kind() {
@@ -724,17 +736,19 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
if err != nil {
- return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ return nil
}
nameValPtr.Elem().SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()])
if err != nil {
- return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ dec.newTypeError("number "+nameBuf.String(), nameValTyp, err)
+ return nil
}
nameValPtr.Elem().SetUint(n)
default:
- return dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp})
+ dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp})
}
}
return nil
@@ -753,7 +767,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return nil
})
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
case reflect.Slice:
t, err := dec.peekRuneType()
@@ -792,7 +806,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
}
}
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
default:
switch t {
@@ -823,7 +837,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return nil
})
default:
- return dec.newTypeError(t.JSONType(), typ, nil)
+ dec.newTypeError(t.JSONType(), typ, nil)
}
}
case reflect.Array:
@@ -874,7 +888,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError {
return dec.decode(val.Elem(), false)
}
default:
- return dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))
+ dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind()))
}
}
return nil
@@ -962,7 +976,8 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) {
}
f64, err := num.Float64()
if err != nil {
- return nil, dec.newTypeError("number "+buf.String(), float64Type, err)
+ dec.newTypeError("number "+buf.String(), float64Type, err)
+ return nil, nil
}
return f64, nil
case jsonparse.RuneTypeTrueT, jsonparse.RuneTypeFalseF:
@@ -989,6 +1004,11 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er
} else {
dec = NewDecoder(r)
}
+ if dec.typeErr != nil {
+ oldTypeErr := dec.typeErr
+ dec.typeErr = nil
+ defer func() { dec.typeErr = oldTypeErr }()
+ }
dec.posStackPush()
defer dec.posStackPop()
if err := dec.decodeObject(nil,
@@ -1006,12 +1026,19 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er
}); err != nil {
return err
}
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() *DecodeError) *DecodeError {
- if err := dec.expectRuneType('{', jsonparse.RuneTypeObjectBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeObjectBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
_, t, err := dec.readRune()
if err != nil {
@@ -1067,6 +1094,11 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er
} else {
dec = NewDecoder(r)
}
+ if dec.typeErr != nil {
+ oldTypeErr := dec.typeErr
+ dec.typeErr = nil
+ defer func() { dec.typeErr = oldTypeErr }()
+ }
dec.posStackPush()
defer dec.posStackPop()
if err := dec.decodeArray(nil, func() *DecodeError {
@@ -1077,12 +1109,19 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er
}); err != nil {
return err
}
+ if dec.typeErr != nil {
+ return dec.typeErr
+ }
return nil
}
func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeError) *DecodeError {
- if err := dec.expectRuneType('[', jsonparse.RuneTypeArrayBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeArrayBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
_, t, err := dec.readRune()
if err != nil {
@@ -1113,8 +1152,12 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeEr
}
func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError {
- if err := dec.expectRuneType('"', jsonparse.RuneTypeStringBeg, gTyp); err != nil {
+ if _, t, err := dec.readRune(); err != nil {
return err
+ } else if t != jsonparse.RuneTypeStringBeg {
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.unreadRune()
+ return dec.scan(fastio.Discard)
}
var uhex [3]byte
for {
@@ -1251,7 +1294,8 @@ func (dec *Decoder) decodeBool(gTyp reflect.Type) (bool, *DecodeError) {
}
return false, nil
default:
- return false, dec.newTypeError(t.JSONType(), gTyp, nil)
+ dec.newTypeError(t.JSONType(), gTyp, nil)
+ return false, nil
}
}