summaryrefslogtreecommitdiff
path: root/encode.go
diff options
context:
space:
mode:
Diffstat (limited to 'encode.go')
-rw-r--r--encode.go42
1 files changed, 35 insertions, 7 deletions
diff --git a/encode.go b/encode.go
index c881369..af7c57e 100644
--- a/encode.go
+++ b/encode.go
@@ -48,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) {
}
}
}()
- encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{})
+ encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{})
if f, ok := w.(interface{ Flush() error }); ok {
return f.Flush()
}
@@ -63,7 +63,7 @@ var (
const startDetectingCyclesAfter = 1000
-func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
+func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
if !val.IsValid() {
encodeWriteString(w, "null")
return
@@ -83,7 +83,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteString(w, "null")
return
}
- if err := obj.EncodeJSON(w); err != nil {
+ validator := &ReEncoder{Out: w}
+ if err := obj.EncodeJSON(validator); err != nil {
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "EncodeJSON",
+ }})
+ }
+ if err := validator.Close(); err != nil {
panic(encodeError{err})
}
@@ -102,9 +110,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
dat, err := obj.MarshalJSON()
if err != nil {
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "MarshalJSON",
+ }})
+ }
+ validator := &ReEncoder{Out: w}
+ if _, err := validator.Write(dat); err != nil {
panic(encodeError{err})
}
- if _, err := w.Write(dat); err != nil {
+ if err := validator.Close(); err != nil {
panic(encodeError{err})
}
@@ -123,7 +139,11 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
text, err := obj.MarshalText()
if err != nil {
- panic(encodeError{err})
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ Err: err,
+ SourceFunc: "MarshalText",
+ }})
}
encodeString(w, text)
@@ -302,7 +322,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteByte(w, '"')
default:
if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
- ptr := val.UnsafePointer()
+ // For slices, val.UnsafePointer() doesn't return a pointer to the slice header
+ // or anything like that, it returns a pointer *to the first element in the
+ // slice*. That means that the pointer isn't enough to uniquely identify the
+ // slice! So we pair the pointer with the length of the slice, which is
+ // sufficient.
+ ptr := struct {
+ ptr unsafe.Pointer
+ len int
+ }{val.UnsafePointer(), val.Len()}
if _, seen := cycleSeen[ptr]; seen {
panic(encodeError{&EncodeValueError{
Value: val,
@@ -353,7 +381,7 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) {
encodeWriteByte(w, '"')
}
-func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
+func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) {
encodeWriteByte(w, '[')
n := val.Len()
for i := 0; i < n; i++ {