summaryrefslogtreecommitdiff
path: root/encode.go
diff options
context:
space:
mode:
Diffstat (limited to 'encode.go')
-rw-r--r--encode.go69
1 files changed, 56 insertions, 13 deletions
diff --git a/encode.go b/encode.go
index 8479785..c881369 100644
--- a/encode.go
+++ b/encode.go
@@ -9,11 +9,13 @@ import (
"encoding"
"encoding/base64"
"encoding/json"
+ "fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
+ "unsafe"
)
type Encodable interface {
@@ -46,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) {
}
}
}()
- encode(w, reflect.ValueOf(obj), false)
+ encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{})
if f, ok := w.(interface{ Flush() error }); ok {
return f.Flush()
}
@@ -59,7 +61,9 @@ var (
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)
-func encode(w io.Writer, val reflect.Value, quote bool) {
+const startDetectingCyclesAfter = 1000
+
+func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
if !val.IsValid() {
encodeWriteString(w, "null")
return
@@ -187,7 +191,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
if val.IsNil() {
encodeWriteString(w, "null")
} else {
- encode(w, val.Elem(), quote)
+ encode(w, val.Elem(), quote, cycleDepth, cycleSeen)
}
case reflect.Struct:
encodeWriteByte(w, '{')
@@ -206,7 +210,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
empty = false
encodeString(w, field.Name)
encodeWriteByte(w, ':')
- encode(w, fVal, field.Quote)
+ encode(w, fVal, field.Quote, cycleDepth, cycleSeen)
}
encodeWriteByte(w, '}')
case reflect.Map:
@@ -218,6 +222,17 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
encodeWriteString(w, "{}")
return
}
+ if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
+ ptr := val.UnsafePointer()
+ if _, seen := cycleSeen[ptr]; seen {
+ panic(encodeError{&EncodeValueError{
+ Value: val,
+ Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
+ }})
+ }
+ cycleSeen[ptr] = struct{}{}
+ defer delete(cycleSeen, ptr)
+ }
encodeWriteByte(w, '{')
type kv struct {
@@ -228,7 +243,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
iter := val.MapRange()
for i := 0; iter.Next(); i++ {
var k strings.Builder
- encode(&k, iter.Key(), false)
+ encode(&k, iter.Key(), false, cycleDepth, cycleSeen)
kStr := k.String()
if kStr == "null" {
kStr = `""`
@@ -251,14 +266,20 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
}
encodeWriteString(w, kv.K)
encodeWriteByte(w, ':')
- encode(w, kv.V, false)
+ encode(w, kv.V, false, cycleDepth, cycleSeen)
}
encodeWriteByte(w, '}')
case reflect.Slice:
switch {
case val.IsNil():
encodeWriteString(w, "null")
- case val.Type().Elem().Kind() == reflect.Uint8:
+ case val.Type().Elem().Kind() == reflect.Uint8 && !(false ||
+ val.Type().Elem().Implements(encodableType) ||
+ reflect.PointerTo(val.Type().Elem()).Implements(encodableType) ||
+ val.Type().Elem().Implements(jsonMarshalerType) ||
+ reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) ||
+ val.Type().Elem().Implements(textMarshalerType) ||
+ reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)):
encodeWriteByte(w, '"')
enc := base64.NewEncoder(base64.StdEncoding, w)
if val.CanConvert(byteSliceType) {
@@ -280,18 +301,40 @@ func encode(w io.Writer, val reflect.Value, quote bool) {
}
encodeWriteByte(w, '"')
default:
- encodeArray(w, val)
+ if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
+ ptr := val.UnsafePointer()
+ if _, seen := cycleSeen[ptr]; seen {
+ panic(encodeError{&EncodeValueError{
+ Value: val,
+ Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
+ }})
+ }
+ cycleSeen[ptr] = struct{}{}
+ defer delete(cycleSeen, ptr)
+ }
+ encodeArray(w, val, cycleDepth, cycleSeen)
}
case reflect.Array:
- encodeArray(w, val)
+ encodeArray(w, val, cycleDepth, cycleSeen)
case reflect.Pointer:
if val.IsNil() {
encodeWriteString(w, "null")
} else {
- encode(w, val.Elem(), quote)
+ if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
+ ptr := val.UnsafePointer()
+ if _, seen := cycleSeen[ptr]; seen {
+ panic(encodeError{&EncodeValueError{
+ Value: val,
+ Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
+ }})
+ }
+ cycleSeen[ptr] = struct{}{}
+ defer delete(cycleSeen, ptr)
+ }
+ encode(w, val.Elem(), quote, cycleDepth, cycleSeen)
}
default:
- panic(encodeError{&json.UnsupportedTypeError{
+ panic(encodeError{&EncodeTypeError{
Type: val.Type(),
}})
}
@@ -310,14 +353,14 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) {
encodeWriteByte(w, '"')
}
-func encodeArray(w io.Writer, val reflect.Value) {
+func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) {
encodeWriteByte(w, '[')
n := val.Len()
for i := 0; i < n; i++ {
if i > 0 {
encodeWriteByte(w, ',')
}
- encode(w, val.Index(i), false)
+ encode(w, val.Index(i), false, cycleDepth, cycleSeen)
}
encodeWriteByte(w, ']')
}