From d5b1b73eaaa060ef468f20d8b9eed029eb60ce45 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Fri, 27 Jan 2023 01:24:02 -0700 Subject: encode: Don't use panic for flow-control --- encode.go | 308 +++++++++++++++++++++++++++++++++---------------------- encode_string.go | 38 +++++-- 2 files changed, 216 insertions(+), 130 deletions(-) diff --git a/encode.go b/encode.go index c5a29b3..57f3852 100644 --- a/encode.go +++ b/encode.go @@ -32,22 +32,6 @@ type Encodable interface { EncodeJSON(w io.Writer) error } -type encodeError struct { - Err error -} - -func encodeWriteByte(w io.ByteWriter, b byte) { - if err := w.WriteByte(b); err != nil { - panic(encodeError{err}) - } -} - -func encodeWriteString(w io.StringWriter, str string) { - if _, err := w.WriteString(str); err != nil { - panic(encodeError{err}) - } -} - // An Encoder encodes and writes values to a stream of JSON elements. // // Encoder is analogous to, and has a similar API to the standar @@ -93,22 +77,19 @@ func NewEncoder(w io.Writer) *Encoder { // // [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.18#Marshal func (enc *Encoder) Encode(obj any) (err error) { - defer func() { - if r := recover(); r != nil { - if e, ok := r.(encodeError); ok { - err = e.Err - } else { - panic(r) - } - } - }() - encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}) + if err := encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}); err != nil { + return err + } if enc.closeAfterEncode { return enc.w.Close() } return nil } +func discardInt(_ int, err error) error { + return err +} + var ( encodableType = reflect.TypeOf((*Encodable)(nil)).Elem() jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() @@ -117,10 +98,9 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { +func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error { if !val.IsValid() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } switch { @@ -129,29 +109,27 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q fallthrough case val.Type().Implements(encodableType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(Encodable) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } // Use a sub-ReEncoder to check that it's a full element. validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper}) if err := obj.EncodeJSON(validator); err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, - }}) + } } if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, - }}) + } } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): @@ -159,37 +137,35 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q fallthrough case val.Type().Implements(jsonMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(json.Marshaler) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } dat, err := obj.MarshalJSON() if err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } // Use a sub-ReEncoder to check that it's a full element. validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper}) if _, err := validator.Write(dat); err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } if err := validator.Close(); err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): @@ -197,61 +173,86 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q fallthrough case val.Type().Implements(textMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(encoding.TextMarshaler) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } text, err := obj.MarshalText() if err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalText", Err: err, - }}) + } + } + if err := encodeStringFromBytes(w, escaper, text); err != nil { + return err } - encodeStringFromBytes(w, escaper, text) - default: switch val.Kind() { case reflect.Bool: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } if val.Bool() { - encodeWriteString(w, "true") + if _, err := w.WriteString("true"); err != nil { + return err + } } else { - encodeWriteString(w, "false") + if _, err := w.WriteString("false"); err != nil { + return err + } } if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(strconv.FormatInt(val.Int(), 10)); err != nil { + return err } - encodeWriteString(w, strconv.FormatInt(val.Int(), 10)) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(strconv.FormatUint(val.Uint(), 10)); err != nil { + return err } - encodeWriteString(w, strconv.FormatUint(val.Uint(), 10)) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Float32, reflect.Float64: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if err := encodeTODO(w, val); err != nil { + return err } - encodeTODO(w, val) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.String: if val.Type() == numberType { @@ -260,29 +261,47 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q numStr = "0" } if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(numStr); err != nil { + return err } - encodeWriteString(w, numStr) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } } else { if quote { var buf bytes.Buffer - encodeStringFromString(&buf, escaper, val.String()) - encodeStringFromBytes(w, escaper, buf.Bytes()) + if err := encodeStringFromString(&buf, escaper, val.String()); err != nil { + return err + } + if err := encodeStringFromBytes(w, escaper, buf.Bytes()); err != nil { + return err + } } else { - encodeStringFromString(w, escaper, val.String()) + if err := encodeStringFromString(w, escaper, val.String()); err != nil { + return err + } } } case reflect.Interface: if val.IsNil() { - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } } else { - encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) + if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil { + return err + } } case reflect.Struct: - encodeWriteByte(w, '{') + if err := w.WriteByte('{'); err != nil { + return err + } empty := true for _, field := range indexStruct(val.Type()).byPos { fVal, err := val.FieldByIndexErr(field.Path) @@ -293,35 +312,45 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q continue } if !empty { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } } empty = false - encodeStringFromString(w, escaper, field.Name) - encodeWriteByte(w, ':') - encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen) + if err := encodeStringFromString(w, escaper, field.Name); err != nil { + return err + } + if err := w.WriteByte(':'); err != nil { + return err + } + if err := encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen); err != nil { + return err + } + } + if err := w.WriteByte('}'); err != nil { + return err } - encodeWriteByte(w, '}') case reflect.Map: if val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } if val.Len() == 0 { - encodeWriteString(w, "{}") - return + return discardInt(w.WriteString("{}")) } if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encodeWriteByte(w, '{') + if err := w.WriteByte('{'); err != nil { + return err + } type kv struct { K string @@ -332,14 +361,18 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q for i := 0; iter.Next(); i++ { // TODO: Avoid buffering the map key var k strings.Builder - encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen) + if err := encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen); err != nil { + return err + } kStr := k.String() if kStr == "null" { kStr = `""` } if !strings.HasPrefix(kStr, `"`) { k.Reset() - encodeStringFromString(&k, escaper, kStr) + if err := encodeStringFromString(&k, escaper, kStr); err != nil { + return err + } kStr = k.String() } kvs[i].K = kStr @@ -351,17 +384,29 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q for i, kv := range kvs { if i > 0 { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } + } + if _, err := w.WriteString(kv.K); err != nil { + return err + } + if err := w.WriteByte(':'); err != nil { + return err + } + if err := encode(w, kv.V, escaper, false, cycleDepth, cycleSeen); err != nil { + return err } - encodeWriteString(w, kv.K) - encodeWriteByte(w, ':') - encode(w, kv.V, escaper, false, cycleDepth, cycleSeen) } - encodeWriteByte(w, '}') + if err := w.WriteByte('}'); err != nil { + return err + } case reflect.Slice: switch { case val.IsNil(): - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } case val.Type().Elem().Kind() == reflect.Uint8 && !(false || val.Type().Elem().Implements(encodableType) || reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || @@ -369,11 +414,13 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || val.Type().Elem().Implements(textMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { - panic(encodeError{err}) + return err } } else { // TODO: Surely there's a better way. @@ -381,14 +428,16 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q var buf [1]byte buf[0] = val.Index(i).Convert(byteType).Interface().(byte) if _, err := enc.Write(buf[:]); err != nil { - panic(encodeError{err}) + return err } } } if err := enc.Close(); err != nil { - panic(encodeError{err}) + return err + } + if err := w.WriteByte('"'); err != nil { + return err } - encodeWriteByte(w, '"') default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { // For slices, val.UnsafePointer() doesn't return a pointer to the slice header @@ -401,61 +450,80 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q len int }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encodeArray(w, val, escaper, cycleDepth, cycleSeen) + if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil { + return err + } } case reflect.Array: - encodeArray(w, val, escaper, cycleDepth, cycleSeen) + if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil { + return err + } case reflect.Pointer: if val.IsNil() { - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } } else { if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) + if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil { + return err + } } default: - panic(encodeError{&EncodeTypeError{ + return &EncodeTypeError{ Type: val.Type(), - }}) + } } } + return nil } -func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { - encodeWriteByte(w, '[') +func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) error { + if err := w.WriteByte('['); err != nil { + return err + } n := val.Len() for i := 0; i < n; i++ { if i > 0 { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } } - encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen) + if err := encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen); err != nil { + return err + } + } + if err := w.WriteByte(']'); err != nil { + return err } - encodeWriteByte(w, ']') + return nil } -func encodeTODO(w io.Writer, val reflect.Value) { +func encodeTODO(w io.Writer, val reflect.Value) error { bs, err := json.Marshal(val.Interface()) if err != nil { - panic(encodeError{err}) + return err } if _, err := w.Write(bs); err != nil { - panic(encodeError{err}) + return err } + return nil } diff --git a/encode_string.go b/encode_string.go index 831a038..12f934e 100644 --- a/encode_string.go +++ b/encode_string.go @@ -83,29 +83,47 @@ func writeStringChar(w internal.AllWriter, c rune, wasEscaped BackslashEscapeMod } } -func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) { - encodeWriteByte(w, '"') +func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) error { + if err := w.WriteByte('"'); err != nil { + return err + } for _, c := range str { if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil { - panic(encodeError{err}) + return err } } - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + return nil } -func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) { - encodeWriteByte(w, '"') +func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) error { + if err := w.WriteByte('"'); err != nil { + return err + } for i := 0; i < len(str); { c, size := utf8.DecodeRune(str[i:]) if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil { - panic(encodeError{err}) + return err } i += size } - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + return nil } func init() { - internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(internal.NewAllWriter(w), nil, s) } - internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(internal.NewAllWriter(w), nil, s) } + internal.EncodeStringFromString = func(w io.Writer, s string) { + if err := encodeStringFromString(internal.NewAllWriter(w), nil, s); err != nil { + panic(err) + } + } + internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { + if err := encodeStringFromBytes(internal.NewAllWriter(w), nil, s); err != nil { + panic(err) + } + } } -- cgit v1.1-4-g5e80