From 46f4b0c2a67911b7438621b6181f5888b8be55b6 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Fri, 10 Feb 2023 18:40:08 -0700 Subject: decode: Have .scan() and .scanNumber() return *DecodeError --- decode.go | 45 ++++++++++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 15 deletions(-) (limited to 'decode.go') diff --git a/decode.go b/decode.go index 522e41e..de10370 100644 --- a/decode.go +++ b/decode.go @@ -398,7 +398,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: t := dec.peekRuneType() var buf bytes.Buffer - dec.scan(&buf) + if err := dec.scan(&buf); err != nil { + return err + } if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } @@ -417,7 +419,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): t := dec.peekRuneType() var buf bytes.Buffer - dec.scan(&buf) + if err := dec.scan(&buf); err != nil { + return err + } obj := val.Addr().Interface().(jsonUnmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) @@ -448,7 +452,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { return dec.decodeNull() } var buf strings.Builder - dec.scanNumber(typ, &buf) + if err := dec.scanNumber(typ, &buf); err != nil { + return err + } n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { return dec.newTypeError("number "+buf.String(), typ, err) @@ -459,7 +465,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { return dec.decodeNull() } var buf strings.Builder - dec.scanNumber(typ, &buf) + if err := dec.scanNumber(typ, &buf); err != nil { + return err + } n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { return dec.newTypeError("number "+buf.String(), typ, err) @@ -470,7 +478,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { return dec.decodeNull() } var buf strings.Builder - dec.scanNumber(typ, &buf) + if err := dec.scanNumber(typ, &buf); err != nil { + return err + } n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { return dec.newTypeError("number "+buf.String(), typ, err) @@ -483,7 +493,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { var buf strings.Builder if typ == numberType { t := dec.peekRuneType() - dec.scan(&buf) + if err := dec.scan(&buf); err != nil { + return err + } if !t.IsNumber() { return dec.newTypeError(t.JSONType(), typ, fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", @@ -574,8 +586,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { if dec.disallowUnknownFields { return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) } - dec.scan(fastio.Discard) - return nil + return dec.scan(fastio.Discard) } field := index.ByPos[idx] fVal := val @@ -775,7 +786,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { } val.Index(i).Set(mValPtr.Elem()) } else { - dec.scan(fastio.Discard) + if err := dec.scan(fastio.Discard); err != nil { + return err + } } i++ return nil @@ -805,7 +818,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) (_err *DecodeError) { return nil } -func (dec *Decoder) scan(out fastio.RuneWriter) { +func (dec *Decoder) scan(out fastio.RuneWriter) *DecodeError { dec.io.PushReadBarrier() for { c, t := dec.readRune() @@ -815,14 +828,14 @@ func (dec *Decoder) scan(out fastio.RuneWriter) { _, _ = out.WriteRune(c) } dec.io.PopReadBarrier() + return nil } -func (dec *Decoder) scanNumber(gTyp reflect.Type, out fastio.RuneWriter) { +func (dec *Decoder) scanNumber(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { if t := dec.peekRuneType(); !t.IsNumber() { - err := dec.newTypeError(t.JSONType(), gTyp, nil) - panic(decodeError(*err)) + return dec.newTypeError(t.JSONType(), gTyp, nil) } - dec.scan(out) + return dec.scan(out) } func (dec *Decoder) decodeAny() (any, *DecodeError) { @@ -879,7 +892,9 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) { return buf.String(), nil case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': var buf strings.Builder - dec.scan(&buf) + if err := dec.scan(&buf); err != nil { + return nil, err + } num := Number(buf.String()) if dec.useNumber { return num, nil -- cgit v1.2.3-2-g168b