diff options
-rw-r--r-- | compat/json/borrowed_misc.go | 24 | ||||
-rw-r--r-- | compat/json/compat.go | 44 | ||||
-rw-r--r-- | encode.go | 42 |
3 files changed, 86 insertions, 24 deletions
diff --git a/compat/json/borrowed_misc.go b/compat/json/borrowed_misc.go index 30a3b0e..e6b0162 100644 --- a/compat/json/borrowed_misc.go +++ b/compat/json/borrowed_misc.go @@ -4,6 +4,10 @@ package json +import ( + "reflect" +) + // A SyntaxError is a description of a JSON syntax error. // Unmarshal will return a SyntaxError if the JSON can't be parsed. type SyntaxError struct { @@ -12,3 +16,23 @@ type SyntaxError struct { } func (e *SyntaxError) Error() string { return e.msg } + +// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method. +type MarshalerError struct { + Type reflect.Type + Err error + sourceFunc string +} + +func (e *MarshalerError) Error() string { + srcFunc := e.sourceFunc + if srcFunc == "" { + srcFunc = "MarshalJSON" + } + return "json: error calling " + srcFunc + + " for type " + e.Type.String() + + ": " + e.Err.Error() +} + +// Unwrap returns the underlying error. +func (e *MarshalerError) Unwrap() error { return e.Err } diff --git a/compat/json/compat.go b/compat/json/compat.go index 04dfb24..7145a86 100644 --- a/compat/json/compat.go +++ b/compat/json/compat.go @@ -29,35 +29,45 @@ type ( InvalidUnmarshalError = json.InvalidUnmarshalError // lowmemjson.DecodeArgumentError // marshal errors - InvalidUTF8Error = json.InvalidUTF8Error - MarshalerError = lowmemjson.EncodeMethodError // expose a field + InvalidUTF8Error = json.InvalidUTF8Error + //MarshalerError = lowmemjson.EncodeMethodError // expose a field UnsupportedTypeError = json.UnsupportedTypeError UnsupportedValueError = json.UnsupportedValueError ) -///////////////////////////////////////////////////////////////////// +// Encode wrappers /////////////////////////////////////////////////// -func MarshalIndent(v any, prefix, indent string) ([]byte, error) { +func marshal(v any, formatter *lowmemjson.ReEncoder) ([]byte, error) { var buf bytes.Buffer - formatter := &lowmemjson.ReEncoder{ - Out: &buf, + formatter.Out = &buf + if err := lowmemjson.Encode(formatter, v); err != nil { + if me, ok := err.(*lowmemjson.EncodeMethodError); ok { + err = &MarshalerError{ + Type: me.Type, + Err: me.Err, + sourceFunc: me.SourceFunc, + } + } + return nil, err + } + return buf.Bytes(), nil +} + +func MarshalIndent(v any, prefix, indent string) ([]byte, error) { + return marshal(v, &lowmemjson.ReEncoder{ Indent: indent, Prefix: prefix, - } - err := lowmemjson.Encode(formatter, v) - return buf.Bytes(), err + }) } func Marshal(v any) ([]byte, error) { - var buf bytes.Buffer - formatter := &lowmemjson.ReEncoder{ - Out: &buf, + return marshal(v, &lowmemjson.ReEncoder{ Compact: true, - } - err := lowmemjson.Encode(formatter, v) - return buf.Bytes(), err + }) } +// ReEncode wrappers ///////////////////////////////////////////////// + func HTMLEscape(dst *bytes.Buffer, src []byte) { formatter := &lowmemjson.ReEncoder{ Out: dst, @@ -101,12 +111,12 @@ func Valid(data []byte) bool { return err == nil } +// Decode wrappers /////////////////////////////////////////////////// + func Unmarshal(data []byte, ptr any) error { return NewDecoder(bytes.NewReader(data)).Decode(ptr) } -///////////////////////////////////////////////////////////////////// - type Decoder struct { *lowmemjson.Decoder buf *bufio.Reader @@ -48,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) { } } }() - encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{}) + encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{}) if f, ok := w.(interface{ Flush() error }); ok { return f.Flush() } @@ -63,7 +63,7 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { +func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -83,7 +83,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteString(w, "null") return } - if err := obj.EncodeJSON(w); err != nil { + validator := &ReEncoder{Out: w} + if err := obj.EncodeJSON(validator); err != nil { + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "EncodeJSON", + }}) + } + if err := validator.Close(); err != nil { panic(encodeError{err}) } @@ -102,9 +110,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } dat, err := obj.MarshalJSON() if err != nil { + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "MarshalJSON", + }}) + } + validator := &ReEncoder{Out: w} + if _, err := validator.Write(dat); err != nil { panic(encodeError{err}) } - if _, err := w.Write(dat); err != nil { + if err := validator.Close(); err != nil { panic(encodeError{err}) } @@ -123,7 +139,11 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } text, err := obj.MarshalText() if err != nil { - panic(encodeError{err}) + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "MarshalText", + }}) } encodeString(w, text) @@ -302,7 +322,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteByte(w, '"') default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { - ptr := val.UnsafePointer() + // For slices, val.UnsafePointer() doesn't return a pointer to the slice header + // or anything like that, it returns a pointer *to the first element in the + // slice*. That means that the pointer isn't enough to uniquely identify the + // slice! So we pair the pointer with the length of the slice, which is + // sufficient. + ptr := struct { + ptr unsafe.Pointer + len int + }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { panic(encodeError{&EncodeValueError{ Value: val, @@ -353,7 +381,7 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) { encodeWriteByte(w, '"') } -func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { +func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { |