diff options
Diffstat (limited to 'encode.go')
-rw-r--r-- | encode.go | 69 |
1 files changed, 56 insertions, 13 deletions
@@ -9,11 +9,13 @@ import ( "encoding" "encoding/base64" "encoding/json" + "fmt" "io" "reflect" "sort" "strconv" "strings" + "unsafe" ) type Encodable interface { @@ -46,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) { } } }() - encode(w, reflect.ValueOf(obj), false) + encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{}) if f, ok := w.(interface{ Flush() error }); ok { return f.Flush() } @@ -59,7 +61,9 @@ var ( textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) -func encode(w io.Writer, val reflect.Value, quote bool) { +const startDetectingCyclesAfter = 1000 + +func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -187,7 +191,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { if val.IsNil() { encodeWriteString(w, "null") } else { - encode(w, val.Elem(), quote) + encode(w, val.Elem(), quote, cycleDepth, cycleSeen) } case reflect.Struct: encodeWriteByte(w, '{') @@ -206,7 +210,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { empty = false encodeString(w, field.Name) encodeWriteByte(w, ':') - encode(w, fVal, field.Quote) + encode(w, fVal, field.Quote, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Map: @@ -218,6 +222,17 @@ func encode(w io.Writer, val reflect.Value, quote bool) { encodeWriteString(w, "{}") return } + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } encodeWriteByte(w, '{') type kv struct { @@ -228,7 +243,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { iter := val.MapRange() for i := 0; iter.Next(); i++ { var k strings.Builder - encode(&k, iter.Key(), false) + encode(&k, iter.Key(), false, cycleDepth, cycleSeen) kStr := k.String() if kStr == "null" { kStr = `""` @@ -251,14 +266,20 @@ func encode(w io.Writer, val reflect.Value, quote bool) { } encodeWriteString(w, kv.K) encodeWriteByte(w, ':') - encode(w, kv.V, false) + encode(w, kv.V, false, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Slice: switch { case val.IsNil(): encodeWriteString(w, "null") - case val.Type().Elem().Kind() == reflect.Uint8: + case val.Type().Elem().Kind() == reflect.Uint8 && !(false || + val.Type().Elem().Implements(encodableType) || + reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || + val.Type().Elem().Implements(jsonMarshalerType) || + reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || + val.Type().Elem().Implements(textMarshalerType) || + reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): encodeWriteByte(w, '"') enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { @@ -280,18 +301,40 @@ func encode(w io.Writer, val reflect.Value, quote bool) { } encodeWriteByte(w, '"') default: - encodeArray(w, val) + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } + encodeArray(w, val, cycleDepth, cycleSeen) } case reflect.Array: - encodeArray(w, val) + encodeArray(w, val, cycleDepth, cycleSeen) case reflect.Pointer: if val.IsNil() { encodeWriteString(w, "null") } else { - encode(w, val.Elem(), quote) + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } + encode(w, val.Elem(), quote, cycleDepth, cycleSeen) } default: - panic(encodeError{&json.UnsupportedTypeError{ + panic(encodeError{&EncodeTypeError{ Type: val.Type(), }}) } @@ -310,14 +353,14 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) { encodeWriteByte(w, '"') } -func encodeArray(w io.Writer, val reflect.Value) { +func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { if i > 0 { encodeWriteByte(w, ',') } - encode(w, val.Index(i), false) + encode(w, val.Index(i), false, cycleDepth, cycleSeen) } encodeWriteByte(w, ']') } |