summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@datawire.io>2022-08-15 22:29:52 -0600
committerLuke Shumaker <lukeshu@datawire.io>2022-08-16 00:05:39 -0600
commit83ec1924ae051b60f911aa8b53b741c5371faaf8 (patch)
treef014234317559fe0d5c64a2df79d1a12efae6e38
parent67b78f25f76b8ca43d837fb8055ca8e2b06c7d02 (diff)
Get borrowed_encode_test.go passing [ci-skip]
-rw-r--r--compat/json/borrowed_misc.go24
-rw-r--r--compat/json/compat.go44
-rw-r--r--encode.go42
3 files changed, 86 insertions, 24 deletions
diff --git a/compat/json/borrowed_misc.go b/compat/json/borrowed_misc.go
index 30a3b0e..e6b0162 100644
--- a/compat/json/borrowed_misc.go
+++ b/compat/json/borrowed_misc.go
@@ -4,6 +4,10 @@
package json
+import (
+ "reflect"
+)
+
// A SyntaxError is a description of a JSON syntax error.
// Unmarshal will return a SyntaxError if the JSON can't be parsed.
type SyntaxError struct {
@@ -12,3 +16,23 @@ type SyntaxError struct {
}
func (e *SyntaxError) Error() string { return e.msg }
+
+// A MarshalerError represents an error from calling a MarshalJSON or MarshalText method.
+type MarshalerError struct {
+ Type reflect.Type
+ Err error
+ sourceFunc string
+}
+
+func (e *MarshalerError) Error() string {
+ srcFunc := e.sourceFunc
+ if srcFunc == "" {
+ srcFunc = "MarshalJSON"
+ }
+ return "json: error calling " + srcFunc +
+ " for type " + e.Type.String() +
+ ": " + e.Err.Error()
+}
+
+// Unwrap returns the underlying error.
+func (e *MarshalerError) Unwrap() error { return e.Err }
diff --git a/compat/json/compat.go b/compat/json/compat.go
index 04dfb24..7145a86 100644
--- a/compat/json/compat.go
+++ b/compat/json/compat.go
@@ -29,35 +29,45 @@ type (
InvalidUnmarshalError = json.InvalidUnmarshalError // lowmemjson.DecodeArgumentError
// marshal errors
- InvalidUTF8Error = json.InvalidUTF8Error
- MarshalerError = lowmemjson.EncodeMethodError // expose a field
+ InvalidUTF8Error = json.InvalidUTF8Error
+ //MarshalerError = lowmemjson.EncodeMethodError // expose a field
UnsupportedTypeError = json.UnsupportedTypeError
UnsupportedValueError = json.UnsupportedValueError
)
-/////////////////////////////////////////////////////////////////////
+// Encode wrappers ///////////////////////////////////////////////////
-func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
+func marshal(v any, formatter *lowmemjson.ReEncoder) ([]byte, error) {
var buf bytes.Buffer
- formatter := &lowmemjson.ReEncoder{
- Out: &buf,
+ formatter.Out = &buf
+ if err := lowmemjson.Encode(formatter, v); err != nil {
+ if me, ok := err.(*lowmemjson.EncodeMethodError); ok {
+ err = &MarshalerError{
+ Type: me.Type,
+ Err: me.Err,
+ sourceFunc: me.SourceFunc,
+ }
+ }
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+func MarshalIndent(v any, prefix, indent string) ([]byte, error) {
+ return marshal(v, &lowmemjson.ReEncoder{
Indent: indent,
Prefix: prefix,
- }
- err := lowmemjson.Encode(formatter, v)
- return buf.Bytes(), err
+ })
}
func Marshal(v any) ([]byte, error) {
- var buf bytes.Buffer
- formatter := &lowmemjson.ReEncoder{
- Out: &buf,
+ return marshal(v, &lowmemjson.ReEncoder{
Compact: true,
- }
- err := lowmemjson.Encode(formatter, v)
- return buf.Bytes(), err
+ })
}
+// ReEncode wrappers /////////////////////////////////////////////////
+
func HTMLEscape(dst *bytes.Buffer, src []byte) {
formatter := &lowmemjson.ReEncoder{
Out: dst,
@@ -101,12 +111,12 @@ func Valid(data []byte) bool {
return err == nil
}
+// Decode wrappers ///////////////////////////////////////////////////
+
func Unmarshal(data []byte, ptr any) error {
return NewDecoder(bytes.NewReader(data)).Decode(ptr)
}
-/////////////////////////////////////////////////////////////////////
-
type Decoder struct {
*lowmemjson.Decoder
buf *bufio.Reader
diff --git a/encode.go b/encode.go
index c881369..af7c57e 100644
--- a/encode.go
+++ b/encode.go
@@ -48,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) {
}
}
}()
- encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{})
+ encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{})
if f, ok := w.(interface{ Flush() error }); ok {
return f.Flush()
}
@@ -63,7 +63,7 @@ var (
const startDetectingCyclesAfter = 1000
-func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
+func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
if !val.IsValid() {
encodeWriteString(w, "null")
return
@@ -83,7 +83,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteString(w, "null")
return
}
- if err := obj.EncodeJSON(w); err != nil {
+ validator := &ReEncoder{Out: w}
+ if err := obj.EncodeJSON(validator); err != nil {
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "EncodeJSON",
+ }})
+ }
+ if err := validator.Close(); err != nil {
panic(encodeError{err})
}
@@ -102,9 +110,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
dat, err := obj.MarshalJSON()
if err != nil {
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "MarshalJSON",
+ }})
+ }
+ validator := &ReEncoder{Out: w}
+ if _, err := validator.Write(dat); err != nil {
panic(encodeError{err})
}
- if _, err := w.Write(dat); err != nil {
+ if err := validator.Close(); err != nil {
panic(encodeError{err})
}
@@ -123,7 +139,11 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
text, err := obj.MarshalText()
if err != nil {
- panic(encodeError{err})
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "MarshalText",
+ }})
}
encodeString(w, text)
@@ -302,7 +322,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteByte(w, '"')
default:
if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
- ptr := val.UnsafePointer()
+ // For slices, val.UnsafePointer() doesn't return a pointer to the slice header
+ // or anything like that, it returns a pointer *to the first element in the
+ // slice*. That means that the pointer isn't enough to uniquely identify the
+ // slice! So we pair the pointer with the length of the slice, which is
+ // sufficient.
+ ptr := struct {
+ ptr unsafe.Pointer
+ len int
+ }{val.UnsafePointer(), val.Len()}
if _, seen := cycleSeen[ptr]; seen {
panic(encodeError{&EncodeValueError{
Value: val,
@@ -353,7 +381,7 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) {
encodeWriteByte(w, '"')
}
-func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
+func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) {
encodeWriteByte(w, '[')
n := val.Len()
for i := 0; i < n; i++ {