diff options
Diffstat (limited to 'encode.go')
-rw-r--r-- | encode.go | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/encode.go b/encode.go new file mode 100644 index 0000000..377b9b9 --- /dev/null +++ b/encode.go @@ -0,0 +1,333 @@ +// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson + +import ( + "bytes" + "encoding" + "encoding/base64" + "encoding/json" + "io" + "reflect" + "sort" + "strconv" + "strings" +) + +type Encodable interface { + EncodeJSON(w io.Writer) error +} + +type encodeError struct { + Err error +} + +func encodeWriteByte(w io.Writer, b byte) { + if err := writeByte(w, b); err != nil { + panic(encodeError{err}) + } +} + +func encodeWriteString(w io.Writer, str string) { + if _, err := io.WriteString(w, str); err != nil { + panic(encodeError{err}) + } +} + +func Encode(w io.Writer, obj any) (err error) { + defer func() { + if r := recover(); r != nil { + if e, ok := r.(encodeError); ok { + err = e.Err + } else { + panic(r) + } + } + }() + encode(w, reflect.ValueOf(obj), false) + if f, ok := w.(interface{ Flush() error }); ok { + return f.Flush() + } + return nil +} + +var ( + encodableType = reflect.TypeOf((*Encodable)(nil)).Elem() + jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() +) + +func encode(w io.Writer, val reflect.Value, quote bool) { + if !val.IsValid() { + encodeWriteString(w, "null") + return + } + switch { + + case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType): + val = val.Addr() + fallthrough + case val.Type().Implements(encodableType): + if val.Kind() == reflect.Pointer && val.IsNil() { + encodeWriteString(w, "null") + return + } + obj, ok := val.Interface().(Encodable) + if !ok { + encodeWriteString(w, "null") + return + } + if err := obj.EncodeJSON(w); err != nil { + panic(encodeError{err}) + } + + case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): + val = val.Addr() + fallthrough + case val.Type().Implements(jsonMarshalerType): + if val.Kind() == reflect.Pointer && val.IsNil() { + encodeWriteString(w, "null") + return + } + obj, ok := val.Interface().(json.Marshaler) + if !ok { + encodeWriteString(w, "null") + return + } + dat, err := obj.MarshalJSON() + if err != nil { + panic(encodeError{err}) + } + if _, err := w.Write(dat); err != nil { + panic(encodeError{err}) + } + + case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): + val = val.Addr() + fallthrough + case val.Type().Implements(textMarshalerType): + if val.Kind() == reflect.Pointer && val.IsNil() { + encodeWriteString(w, "null") + return + } + obj, ok := val.Interface().(encoding.TextMarshaler) + if !ok { + encodeWriteString(w, "null") + return + } + text, err := obj.MarshalText() + if err != nil { + panic(encodeError{err}) + } + encodeString(w, text) + + default: + switch val.Kind() { + case reflect.Bool: + if quote { + encodeWriteByte(w, '"') + } + if val.Bool() { + encodeWriteString(w, "true") + } else { + encodeWriteString(w, "false") + } + if quote { + encodeWriteByte(w, '"') + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if quote { + encodeWriteByte(w, '"') + } + encodeWriteString(w, strconv.FormatInt(val.Int(), 10)) + if quote { + encodeWriteByte(w, '"') + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if quote { + encodeWriteByte(w, '"') + } + encodeWriteString(w, strconv.FormatUint(val.Uint(), 10)) + if quote { + encodeWriteByte(w, '"') + } + case reflect.Float32, reflect.Float64: + if quote { + encodeWriteByte(w, '"') + } + encodeTODO(w, val) + if quote { + encodeWriteByte(w, '"') + } + case reflect.String: + if val.Type() == numberType { + numStr := val.String() + if numStr == "" { + numStr = "0" + } + if quote { + encodeWriteByte(w, '"') + } + encodeWriteString(w, numStr) + if quote { + encodeWriteByte(w, '"') + } + } else { + if quote { + var buf bytes.Buffer + encodeString(&buf, val.String()) + encodeString(w, buf.Bytes()) + } else { + encodeString(w, val.String()) + } + } + case reflect.Interface: + if val.IsNil() { + encodeWriteString(w, "null") + } else { + encode(w, val.Elem(), quote) + } + case reflect.Struct: + encodeWriteByte(w, '{') + empty := true + for _, field := range indexStruct(val.Type()).byPos { + fVal, err := val.FieldByIndexErr(field.Path) + if err != nil { + continue + } + if field.OmitEmpty && isEmptyValue(fVal) { + continue + } + if !empty { + encodeWriteByte(w, ',') + } + empty = false + encodeString(w, field.Name) + encodeWriteByte(w, ':') + encode(w, fVal, field.Quote) + } + encodeWriteByte(w, '}') + case reflect.Map: + if val.IsNil() { + encodeWriteString(w, "null") + return + } + if val.Len() == 0 { + encodeWriteString(w, "{}") + return + } + encodeWriteByte(w, '{') + + type kv struct { + K string + V reflect.Value + } + kvs := make([]kv, val.Len()) + iter := val.MapRange() + for i := 0; iter.Next(); i++ { + var k strings.Builder + encode(&k, iter.Key(), false) + kStr := k.String() + if kStr == "null" { + kStr = `""` + } + if !strings.HasPrefix(kStr, `"`) { + k.Reset() + encodeString(&k, kStr) + kStr = k.String() + } + kvs[i].K = kStr + kvs[i].V = iter.Value() + } + sort.Slice(kvs, func(i, j int) bool { + return kvs[i].K < kvs[j].K + }) + + for i, kv := range kvs { + if i > 0 { + encodeWriteByte(w, ',') + } + encodeWriteString(w, kv.K) + encodeWriteByte(w, ':') + encode(w, kv.V, false) + } + encodeWriteByte(w, '}') + case reflect.Slice: + switch { + case val.IsNil(): + encodeWriteString(w, "null") + case val.Type().Elem().Kind() == reflect.Uint8: + encodeWriteByte(w, '"') + enc := base64.NewEncoder(base64.StdEncoding, w) + if val.CanConvert(byteSliceType) { + if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { + panic(encodeError{err}) + } + } else { + // TODO: Surely there's a better way. + for i, n := 0, val.Len(); i < n; i++ { + var buf [1]byte + buf[0] = val.Index(i).Convert(byteType).Interface().(byte) + if _, err := enc.Write(buf[:]); err != nil { + panic(encodeError{err}) + } + } + } + if err := enc.Close(); err != nil { + panic(encodeError{err}) + } + encodeWriteByte(w, '"') + default: + encodeArray(w, val) + } + case reflect.Array: + encodeArray(w, val) + case reflect.Pointer: + if val.IsNil() { + encodeWriteString(w, "null") + } else { + encode(w, val.Elem(), quote) + } + default: + panic(encodeError{&json.UnsupportedTypeError{ + Type: val.Type(), + }}) + } + } +} + +func encodeString[T interface{ []byte | string }](w io.Writer, str T) { + encodeWriteByte(w, '"') + for i := 0; i < len(str); { + c, size := decodeRune(str[i:]) + if _, err := writeStringChar(w, c, false, nil); err != nil { + panic(encodeError{err}) + } + i += size + } + encodeWriteByte(w, '"') +} + +func encodeArray(w io.Writer, val reflect.Value) { + encodeWriteByte(w, '[') + n := val.Len() + for i := 0; i < n; i++ { + if i > 0 { + encodeWriteByte(w, ',') + } + encode(w, val.Index(i), false) + } + encodeWriteByte(w, ']') +} + +func encodeTODO(w io.Writer, val reflect.Value) { + bs, err := json.Marshal(val.Interface()) + if err != nil { + panic(encodeError{err}) + } + if _, err := w.Write(bs); err != nil { + panic(encodeError{err}) + } +} |