diff options
author | Luke Shumaker <lukeshu@lukeshu.com> | 2023-01-30 23:00:11 -0700 |
---|---|---|
committer | Luke Shumaker <lukeshu@lukeshu.com> | 2023-01-30 23:00:11 -0700 |
commit | 5d5ac549dc1a25963418fda4b20ed181b7afd8d2 (patch) | |
tree | 8c5c2e9f5bd1c3a2208efe4d11965ecaa95472ec /encode.go | |
parent | 005dfe26e308b965c2f42c81d34a4b48757414a3 (diff) | |
parent | ccf8dc4b21bb1a547f118affab22bca3a02df270 (diff) |
Merge branch 'lukeshu/tune'
Diffstat (limited to 'encode.go')
-rw-r--r-- | encode.go | 326 |
1 files changed, 196 insertions, 130 deletions
@@ -9,10 +9,8 @@ import ( "encoding" "encoding/base64" "encoding/json" - "errors" "fmt" "io" - iofs "io/fs" "reflect" "sort" "strconv" @@ -30,22 +28,6 @@ type Encodable interface { EncodeJSON(w io.Writer) error } -type encodeError struct { - Err error -} - -func encodeWriteByte(w io.Writer, b byte) { - if err := writeByte(w, b); err != nil { - panic(encodeError{err}) - } -} - -func encodeWriteString(w io.Writer, str string) { - if _, err := io.WriteString(w, str); err != nil { - panic(encodeError{err}) - } -} - // An Encoder encodes and writes values to a stream of JSON elements. // // Encoder is analogous to, and has a similar API to the standar @@ -91,22 +73,19 @@ func NewEncoder(w io.Writer) *Encoder { // // [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.18#Marshal func (enc *Encoder) Encode(obj any) (err error) { - defer func() { - if r := recover(); r != nil { - if e, ok := r.(encodeError); ok { - err = e.Err - } else { - panic(r) - } - } - }() - encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}) + if err := encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}); err != nil { + return err + } if enc.closeAfterEncode { return enc.w.Close() } return nil } +func discardInt(_ int, err error) error { + return err +} + var ( encodableType = reflect.TypeOf((*Encodable)(nil)).Elem() jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() @@ -115,10 +94,9 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { +func encode(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error { if !val.IsValid() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } switch { @@ -127,129 +105,150 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool fallthrough case val.Type().Implements(encodableType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(Encodable) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } - // Use a sub-ReEncoder to check that it's a full element. - validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper}) - if err := obj.EncodeJSON(validator); err != nil { - panic(encodeError{&EncodeMethodError{ + w.pushWriteBarrier() + if err := obj.EncodeJSON(w); err != nil { + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, - }}) + } } - if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) { - panic(encodeError{&EncodeMethodError{ + if err := w.Close(); err != nil { + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, - }}) + } } + w.popWriteBarrier() case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(jsonMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(json.Marshaler) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } dat, err := obj.MarshalJSON() if err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } - // Use a sub-ReEncoder to check that it's a full element. - validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper}) - if _, err := validator.Write(dat); err != nil { - panic(encodeError{&EncodeMethodError{ + w.pushWriteBarrier() + if _, err := w.Write(dat); err != nil { + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } - if err := validator.Close(); err != nil { - panic(encodeError{&EncodeMethodError{ + if err := w.Close(); err != nil { + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, - }}) + } } + w.popWriteBarrier() case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(textMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(encoding.TextMarshaler) if !ok { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } text, err := obj.MarshalText() if err != nil { - panic(encodeError{&EncodeMethodError{ + return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalText", Err: err, - }}) + } + } + if err := encodeStringFromBytes(w, escaper, text); err != nil { + return err } - encodeStringFromBytes(w, escaper, text) - default: switch val.Kind() { case reflect.Bool: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } if val.Bool() { - encodeWriteString(w, "true") + if _, err := w.WriteString("true"); err != nil { + return err + } } else { - encodeWriteString(w, "false") + if _, err := w.WriteString("false"); err != nil { + return err + } } if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(strconv.FormatInt(val.Int(), 10)); err != nil { + return err } - encodeWriteString(w, strconv.FormatInt(val.Int(), 10)) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(strconv.FormatUint(val.Uint(), 10)); err != nil { + return err } - encodeWriteString(w, strconv.FormatUint(val.Uint(), 10)) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.Float32, reflect.Float64: if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if err := encodeTODO(w, val); err != nil { + return err } - encodeTODO(w, val) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } case reflect.String: if val.Type() == numberType { @@ -258,29 +257,47 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool numStr = "0" } if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } + } + if _, err := w.WriteString(numStr); err != nil { + return err } - encodeWriteString(w, numStr) if quote { - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } } } else { if quote { var buf bytes.Buffer - encodeStringFromString(&buf, escaper, val.String()) - encodeStringFromBytes(w, escaper, buf.Bytes()) + if err := encodeStringFromString(&buf, escaper, val.String()); err != nil { + return err + } + if err := encodeStringFromBytes(w, escaper, buf.Bytes()); err != nil { + return err + } } else { - encodeStringFromString(w, escaper, val.String()) + if err := encodeStringFromString(w, escaper, val.String()); err != nil { + return err + } } } case reflect.Interface: if val.IsNil() { - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } } else { - encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) + if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil { + return err + } } case reflect.Struct: - encodeWriteByte(w, '{') + if err := w.WriteByte('{'); err != nil { + return err + } empty := true for _, field := range indexStruct(val.Type()).byPos { fVal, err := val.FieldByIndexErr(field.Path) @@ -291,35 +308,45 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool continue } if !empty { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } } empty = false - encodeStringFromString(w, escaper, field.Name) - encodeWriteByte(w, ':') - encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen) + if err := encodeStringFromString(w, escaper, field.Name); err != nil { + return err + } + if err := w.WriteByte(':'); err != nil { + return err + } + if err := encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen); err != nil { + return err + } + } + if err := w.WriteByte('}'); err != nil { + return err } - encodeWriteByte(w, '}') case reflect.Map: if val.IsNil() { - encodeWriteString(w, "null") - return + return discardInt(w.WriteString("null")) } if val.Len() == 0 { - encodeWriteString(w, "{}") - return + return discardInt(w.WriteString("{}")) } if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encodeWriteByte(w, '{') + if err := w.WriteByte('{'); err != nil { + return err + } type kv struct { K string @@ -330,14 +357,18 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool for i := 0; iter.Next(); i++ { // TODO: Avoid buffering the map key var k strings.Builder - encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen) + if err := encode(NewReEncoder(&k, ReEncoderConfig{BackslashEscape: escaper}), iter.Key(), escaper, false, cycleDepth, cycleSeen); err != nil { + return err + } kStr := k.String() if kStr == "null" { kStr = `""` } if !strings.HasPrefix(kStr, `"`) { k.Reset() - encodeStringFromString(&k, escaper, kStr) + if err := encodeStringFromString(&k, escaper, kStr); err != nil { + return err + } kStr = k.String() } kvs[i].K = kStr @@ -349,17 +380,29 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool for i, kv := range kvs { if i > 0 { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } + } + if _, err := w.WriteString(kv.K); err != nil { + return err + } + if err := w.WriteByte(':'); err != nil { + return err + } + if err := encode(w, kv.V, escaper, false, cycleDepth, cycleSeen); err != nil { + return err } - encodeWriteString(w, kv.K) - encodeWriteByte(w, ':') - encode(w, kv.V, escaper, false, cycleDepth, cycleSeen) } - encodeWriteByte(w, '}') + if err := w.WriteByte('}'); err != nil { + return err + } case reflect.Slice: switch { case val.IsNil(): - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } case val.Type().Elem().Kind() == reflect.Uint8 && !(false || val.Type().Elem().Implements(encodableType) || reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || @@ -367,11 +410,13 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || val.Type().Elem().Implements(textMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): - encodeWriteByte(w, '"') + if err := w.WriteByte('"'); err != nil { + return err + } enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { - panic(encodeError{err}) + return err } } else { // TODO: Surely there's a better way. @@ -379,14 +424,16 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool var buf [1]byte buf[0] = val.Index(i).Convert(byteType).Interface().(byte) if _, err := enc.Write(buf[:]); err != nil { - panic(encodeError{err}) + return err } } } if err := enc.Close(); err != nil { - panic(encodeError{err}) + return err + } + if err := w.WriteByte('"'); err != nil { + return err } - encodeWriteByte(w, '"') default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { // For slices, val.UnsafePointer() doesn't return a pointer to the slice header @@ -399,61 +446,80 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool len int }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encodeArray(w, val, escaper, cycleDepth, cycleSeen) + if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil { + return err + } } case reflect.Array: - encodeArray(w, val, escaper, cycleDepth, cycleSeen) + if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil { + return err + } case reflect.Pointer: if val.IsNil() { - encodeWriteString(w, "null") + if _, err := w.WriteString("null"); err != nil { + return err + } } else { if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { - panic(encodeError{&EncodeValueError{ + return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), - }}) + } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) + if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil { + return err + } } default: - panic(encodeError{&EncodeTypeError{ + return &EncodeTypeError{ Type: val.Type(), - }}) + } } } + return nil } -func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { - encodeWriteByte(w, '[') +func encodeArray(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) error { + if err := w.WriteByte('['); err != nil { + return err + } n := val.Len() for i := 0; i < n; i++ { if i > 0 { - encodeWriteByte(w, ',') + if err := w.WriteByte(','); err != nil { + return err + } } - encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen) + if err := encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen); err != nil { + return err + } + } + if err := w.WriteByte(']'); err != nil { + return err } - encodeWriteByte(w, ']') + return nil } -func encodeTODO(w io.Writer, val reflect.Value) { +func encodeTODO(w io.Writer, val reflect.Value) error { bs, err := json.Marshal(val.Interface()) if err != nil { - panic(encodeError{err}) + return err } if _, err := w.Write(bs); err != nil { - panic(encodeError{err}) + return err } + return nil } |