// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package lowmemjson import ( "bytes" "encoding" "encoding/base64" "encoding/json" "errors" "fmt" "io" iofs "io/fs" "reflect" "sort" "strconv" "strings" "unsafe" ) // Encodable is the interface implemented by types that can encode // themselves to JSON. Encodable is a low-memory-overhead replacement // for the json.Marshaler interface. // // The io.Writer passed to EncodeJSON returns an error if invalid JSON // is written to it. type Encodable interface { EncodeJSON(w io.Writer) error } type encodeError struct { Err error } func encodeWriteByte(w io.Writer, b byte) { if err := writeByte(w, b); err != nil { panic(encodeError{err}) } } func encodeWriteString(w io.Writer, str string) { if _, err := io.WriteString(w, str); err != nil { panic(encodeError{err}) } } // An Encoder encodes and writes values to a stream of JSON elements. // // Encoder is analogous to, and has a similar API to the standar // library's encoding/json.Encoder. Differences are that rather than // having .SetEscapeHTML and .SetIndent methods, the io.Writer passed // to it may be a *ReEncoder that has these settings (and more). If // something more similar to a json.Encoder is desired, // lowmemjson/compat/json.Encoder offers those .SetEscapeHTML and // .SetIndent methods. type Encoder struct { w *ReEncoder closeAfterEncode bool } // NewEncoder returns a new Encoder that writes to w. // // If w is an *ReEncoder, then the inner backslash-escaping of // double-encoded ",string" tagged string values obeys the // *ReEncoder's BackslashEscape policy. // // An Encoder tends to make many small writes; if w.Write calls are // syscalls, then you may want to wrap w in a bufio.Writer. func NewEncoder(w io.Writer) *Encoder { re, ok := w.(*ReEncoder) if !ok { re = &ReEncoder{ Out: w, AllowMultipleValues: true, } } return &Encoder{ w: re, closeAfterEncode: re.par.StackIsEmpty(), } } // Encode encodes obj to JSON and writes that JSON to the Encoder's // output stream. // // See the [documentation for encoding/json.Marshal] for details about // the conversion Go values to JSON; Encode behaves identically to // that, with the exception that in addition to the json.Marshaler // interface it also checks for the Encodable interface. // // [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.18#Marshal func (enc *Encoder) Encode(obj any) (err error) { defer func() { if r := recover(); r != nil { if e, ok := r.(encodeError); ok { err = e.Err } else { panic(r) } } }() encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}) if enc.closeAfterEncode { return enc.w.Close() } return nil } var ( encodableType = reflect.TypeOf((*Encodable)(nil)).Elem() jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) const startDetectingCyclesAfter = 1000 func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return } switch { case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType): val = val.Addr() fallthrough case val.Type().Implements(encodableType): if val.Kind() == reflect.Pointer && val.IsNil() { encodeWriteString(w, "null") return } obj, ok := val.Interface().(Encodable) if !ok { encodeWriteString(w, "null") return } // Use a sub-ReEncoder to check that it's a full element. validator := &ReEncoder{Out: w, BackslashEscape: escaper} if err := obj.EncodeJSON(validator); err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, }}) } if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, }}) } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(jsonMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { encodeWriteString(w, "null") return } obj, ok := val.Interface().(json.Marshaler) if !ok { encodeWriteString(w, "null") return } dat, err := obj.MarshalJSON() if err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, }}) } // Use a sub-ReEncoder to check that it's a full element. validator := &ReEncoder{Out: w, BackslashEscape: escaper} if _, err := validator.Write(dat); err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, }}) } if err := validator.Close(); err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, }}) } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(textMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { encodeWriteString(w, "null") return } obj, ok := val.Interface().(encoding.TextMarshaler) if !ok { encodeWriteString(w, "null") return } text, err := obj.MarshalText() if err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalText", Err: err, }}) } encodeStringFromBytes(w, escaper, text) default: switch val.Kind() { case reflect.Bool: if quote { encodeWriteByte(w, '"') } if val.Bool() { encodeWriteString(w, "true") } else { encodeWriteString(w, "false") } if quote { encodeWriteByte(w, '"') } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if quote { encodeWriteByte(w, '"') } encodeWriteString(w, strconv.FormatInt(val.Int(), 10)) if quote { encodeWriteByte(w, '"') } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if quote { encodeWriteByte(w, '"') } encodeWriteString(w, strconv.FormatUint(val.Uint(), 10)) if quote { encodeWriteByte(w, '"') } case reflect.Float32, reflect.Float64: if quote { encodeWriteByte(w, '"') } encodeTODO(w, val) if quote { encodeWriteByte(w, '"') } case reflect.String: if val.Type() == numberType { numStr := val.String() if numStr == "" { numStr = "0" } if quote { encodeWriteByte(w, '"') } encodeWriteString(w, numStr) if quote { encodeWriteByte(w, '"') } } else { if quote { var buf bytes.Buffer encodeStringFromString(&buf, escaper, val.String()) encodeStringFromBytes(w, escaper, buf.Bytes()) } else { encodeStringFromString(w, escaper, val.String()) } } case reflect.Interface: if val.IsNil() { encodeWriteString(w, "null") } else { encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) } case reflect.Struct: encodeWriteByte(w, '{') empty := true for _, field := range indexStruct(val.Type()).byPos { fVal, err := val.FieldByIndexErr(field.Path) if err != nil { continue } if field.OmitEmpty && isEmptyValue(fVal) { continue } if !empty { encodeWriteByte(w, ',') } empty = false encodeStringFromString(w, escaper, field.Name) encodeWriteByte(w, ':') encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Map: if val.IsNil() { encodeWriteString(w, "null") return } if val.Len() == 0 { encodeWriteString(w, "{}") return } if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { panic(encodeError{&EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), }}) } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } encodeWriteByte(w, '{') type kv struct { K string V reflect.Value } kvs := make([]kv, val.Len()) iter := val.MapRange() for i := 0; iter.Next(); i++ { // TODO: Avoid buffering the map key var k strings.Builder encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen) kStr := k.String() if kStr == "null" { kStr = `""` } if !strings.HasPrefix(kStr, `"`) { k.Reset() encodeStringFromString(&k, escaper, kStr) kStr = k.String() } kvs[i].K = kStr kvs[i].V = iter.Value() } sort.Slice(kvs, func(i, j int) bool { return kvs[i].K < kvs[j].K }) for i, kv := range kvs { if i > 0 { encodeWriteByte(w, ',') } encodeWriteString(w, kv.K) encodeWriteByte(w, ':') encode(w, kv.V, escaper, false, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Slice: switch { case val.IsNil(): encodeWriteString(w, "null") case val.Type().Elem().Kind() == reflect.Uint8 && !(false || val.Type().Elem().Implements(encodableType) || reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || val.Type().Elem().Implements(jsonMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || val.Type().Elem().Implements(textMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): encodeWriteByte(w, '"') enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { panic(encodeError{err}) } } else { // TODO: Surely there's a better way. for i, n := 0, val.Len(); i < n; i++ { var buf [1]byte buf[0] = val.Index(i).Convert(byteType).Interface().(byte) if _, err := enc.Write(buf[:]); err != nil { panic(encodeError{err}) } } } if err := enc.Close(); err != nil { panic(encodeError{err}) } encodeWriteByte(w, '"') default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { // For slices, val.UnsafePointer() doesn't return a pointer to the slice header // or anything like that, it returns a pointer *to the first element in the // slice*. That means that the pointer isn't enough to uniquely identify the // slice! So we pair the pointer with the length of the slice, which is // sufficient. ptr := struct { ptr unsafe.Pointer len int }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { panic(encodeError{&EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), }}) } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } encodeArray(w, val, escaper, cycleDepth, cycleSeen) } case reflect.Array: encodeArray(w, val, escaper, cycleDepth, cycleSeen) case reflect.Pointer: if val.IsNil() { encodeWriteString(w, "null") } else { if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { panic(encodeError{&EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), }}) } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) } default: panic(encodeError{&EncodeTypeError{ Type: val.Type(), }}) } } } func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { if i > 0 { encodeWriteByte(w, ',') } encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen) } encodeWriteByte(w, ']') } func encodeTODO(w io.Writer, val reflect.Value) { bs, err := json.Marshal(val.Interface()) if err != nil { panic(encodeError{err}) } if _, err := w.Write(bs); err != nil { panic(encodeError{err}) } }