From 2828fa21c0ffd2a32a108b37c0417b01abc42929 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Thu, 26 Jan 2023 21:02:56 -0700 Subject: Avoid doing type switching in inner functions The CPU profiler tells me that the encoder is spending a lot of time on type switches. --- decode.go | 40 ++++++------ encode.go | 14 ++-- encode_string.go | 16 ++--- internal/allwriter.go | 174 ++++++++++++++++++++++++++++++++++++++++++++++++++ internal/base64.go | 9 ++- ioutil.go | 31 --------- reencode.go | 49 ++++++++++++-- 7 files changed, 261 insertions(+), 72 deletions(-) create mode 100644 internal/allwriter.go delete mode 100644 ioutil.go diff --git a/decode.go b/decode.go index 7ae723c..91be865 100644 --- a/decode.go +++ b/decode.go @@ -565,7 +565,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { if dec.disallowUnknownFields { dec.panicType("", typ, fmt.Errorf("json: unknown field %q", name)) } - dec.scan(io.Discard) + dec.scan(internal.Discard) return } field := index.byPos[idx] @@ -749,7 +749,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { dec.decode(mValPtr.Elem(), false) val.Index(i).Set(mValPtr.Elem()) } else { - dec.scan(io.Discard) + dec.scan(internal.Discard) } i++ }) @@ -773,18 +773,18 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { } } -func (dec *Decoder) scan(out io.Writer) { +func (dec *Decoder) scan(out internal.RuneWriter) { limiter := dec.limitingScanner() for { c, _, err := limiter.ReadRune() if err == io.EOF { return } - _, _ = writeRune(out, c) + _, _ = out.WriteRune(c) } } -func (dec *Decoder) scanNumber(gTyp reflect.Type, out io.Writer) { +func (dec *Decoder) scanNumber(gTyp reflect.Type, out internal.RuneWriter) { if t := dec.peekRuneType(); !t.IsNumber() { dec.panicType(t.JSONType(), gTyp, nil) } @@ -991,34 +991,34 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func()) { } } -func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) { +func (dec *Decoder) decodeString(gTyp reflect.Type, out internal.RuneWriter) { dec.expectRuneType('"', internal.RuneTypeStringBeg, gTyp) var uhex [4]byte for { c, t := dec.readRune() switch t { case internal.RuneTypeStringChar: - _, _ = writeRune(out, c) + _, _ = out.WriteRune(c) case internal.RuneTypeStringEsc, internal.RuneTypeStringEscU: // do nothing case internal.RuneTypeStringEsc1: switch c { case '"': - _, _ = writeRune(out, '"') + _, _ = out.WriteRune('"') case '\\': - _, _ = writeRune(out, '\\') + _, _ = out.WriteRune('\\') case '/': - _, _ = writeRune(out, '/') + _, _ = out.WriteRune('/') case 'b': - _, _ = writeRune(out, '\b') + _, _ = out.WriteRune('\b') case 'f': - _, _ = writeRune(out, '\f') + _, _ = out.WriteRune('\f') case 'n': - _, _ = writeRune(out, '\n') + _, _ = out.WriteRune('\n') case 'r': - _, _ = writeRune(out, '\r') + _, _ = out.WriteRune('\r') case 't': - _, _ = writeRune(out, '\t') + _, _ = out.WriteRune('\t') default: panic("should not happen") } @@ -1038,12 +1038,12 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) { handleUnicode: if utf16.IsSurrogate(c) { if dec.peekRuneType() != internal.RuneTypeStringEsc { - _, _ = writeRune(out, utf8.RuneError) + _, _ = out.WriteRune(utf8.RuneError) break } dec.expectRune('\\', internal.RuneTypeStringEsc) if dec.peekRuneType() != internal.RuneTypeStringEscU { - _, _ = writeRune(out, utf8.RuneError) + _, _ = out.WriteRune(utf8.RuneError) break } dec.expectRune('u', internal.RuneTypeStringEscU) @@ -1063,13 +1063,13 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) { rune(uhex[3])<<0 d := utf16.DecodeRune(c, c2) if d == utf8.RuneError { - _, _ = writeRune(out, utf8.RuneError) + _, _ = out.WriteRune(utf8.RuneError) c = c2 goto handleUnicode } - _, _ = writeRune(out, d) + _, _ = out.WriteRune(d) } else { - _, _ = writeRune(out, c) + _, _ = out.WriteRune(c) } case internal.RuneTypeStringEnd: return diff --git a/encode.go b/encode.go index e9c7ac6..c5a29b3 100644 --- a/encode.go +++ b/encode.go @@ -18,6 +18,8 @@ import ( "strconv" "strings" "unsafe" + + "git.lukeshu.com/go/lowmemjson/internal" ) // Encodable is the interface implemented by types that can encode @@ -34,14 +36,14 @@ type encodeError struct { Err error } -func encodeWriteByte(w io.Writer, b byte) { - if err := writeByte(w, b); err != nil { +func encodeWriteByte(w io.ByteWriter, b byte) { + if err := w.WriteByte(b); err != nil { panic(encodeError{err}) } } -func encodeWriteString(w io.Writer, str string) { - if _, err := io.WriteString(w, str); err != nil { +func encodeWriteString(w io.StringWriter, str string) { + if _, err := w.WriteString(str); err != nil { panic(encodeError{err}) } } @@ -115,7 +117,7 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { +func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -436,7 +438,7 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool } } -func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { +func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { diff --git a/encode_string.go b/encode_string.go index c5cb442..831a038 100644 --- a/encode_string.go +++ b/encode_string.go @@ -45,7 +45,7 @@ func writeStringShortEscape(w io.Writer, c rune) (int, error) { return w.Write(buf[:]) } -func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escaper BackslashEscaper) (int, error) { +func writeStringChar(w internal.AllWriter, c rune, wasEscaped BackslashEscapeMode, escaper BackslashEscaper) (int, error) { if escaper == nil { escaper = EscapeDefault } @@ -62,19 +62,19 @@ func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escape case c == '"' || c == '\\': // override, gotta escape these return writeStringShortEscape(w, c) default: // obey - return writeRune(w, c) + return w.WriteRune(c) } case BackslashEscapeShort: switch c { case '"', '\\', '/', '\b', '\f', '\n', '\r', '\t': // obey return writeStringShortEscape(w, c) default: // override, can't short-escape these - return writeRune(w, c) + return w.WriteRune(c) } case BackslashEscapeUnicode: switch { case c > 0xFFFF: // override, can't escape these (TODO: unless we use UTF-16 surrogates?) - return writeRune(w, c) + return w.WriteRune(c) default: // obey return writeStringUnicodeEscape(w, c) } @@ -83,7 +83,7 @@ func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escape } } -func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) { +func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) { encodeWriteByte(w, '"') for _, c := range str { if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil { @@ -93,7 +93,7 @@ func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) { encodeWriteByte(w, '"') } -func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) { +func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) { encodeWriteByte(w, '"') for i := 0; i < len(str); { c, size := utf8.DecodeRune(str[i:]) @@ -106,6 +106,6 @@ func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) { } func init() { - internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(w, nil, s) } - internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(w, nil, s) } + internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(internal.NewAllWriter(w), nil, s) } + internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(internal.NewAllWriter(w), nil, s) } } diff --git a/internal/allwriter.go b/internal/allwriter.go new file mode 100644 index 0000000..187aa8e --- /dev/null +++ b/internal/allwriter.go @@ -0,0 +1,174 @@ +// Copyright (C) 2023 Luke Shumaker +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package internal + +import ( + "io" + "unicode/utf8" +) + +// interfaces ///////////////////////////////////////////////////////////////// + +type RuneWriter interface { + WriteRune(rune) (int, error) +} + +// An AllWriter is the union of several common writer interfaces. +type AllWriter interface { + io.Writer + io.ByteWriter + RuneWriter + io.StringWriter +} + +// implementations //////////////////////////////////////////////////////////// + +func WriteByte(w io.Writer, b byte) error { + var buf [1]byte + buf[0] = b + _, err := w.Write(buf[:]) + return err +} + +func WriteRune(w io.Writer, r rune) (int, error) { + var buf [utf8.UTFMax]byte + n := utf8.EncodeRune(buf[:], r) + return w.Write(buf[:n]) +} + +func WriteString(w io.Writer, s string) (int, error) { + return w.Write([]byte(s)) +} + +// wrappers /////////////////////////////////////////////////////////////////// + +// NNN + +type ( + writerNNN interface{ io.Writer } + writerNNNWrapper struct{ writerNNN } +) + +func (w writerNNNWrapper) WriteByte(b byte) error { return WriteByte(w, b) } +func (w writerNNNWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) } +func (w writerNNNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) } + +// NNY + +type ( + writerNNY interface { + io.Writer + io.StringWriter + } + writerNNYWrapper struct{ writerNNY } +) + +func (w writerNNYWrapper) WriteByte(b byte) error { return WriteByte(w, b) } +func (w writerNNYWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) } + +// NYN + +type ( + writerNYN interface { + io.Writer + RuneWriter + } + writerNYNWrapper struct{ writerNYN } +) + +func (w writerNYNWrapper) WriteByte(b byte) error { return WriteByte(w, b) } +func (w writerNYNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) } + +// NYY + +type ( + writerNYY interface { + io.Writer + RuneWriter + io.StringWriter + } + writerNYYWrapper struct{ writerNYY } +) + +func (w writerNYYWrapper) WriteByte(b byte) error { return WriteByte(w, b) } + +// YNN + +type ( + writerYNN interface { + io.Writer + io.ByteWriter + } + writerYNNWrapper struct{ writerYNN } +) + +func (w writerYNNWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) } +func (w writerYNNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) } + +// YNY + +type ( + writerYNY interface { + io.Writer + io.ByteWriter + io.StringWriter + } + writerYNYWrapper struct{ writerYNY } +) + +func (w writerYNYWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) } + +// YYN + +type ( + writerYYN interface { + io.Writer + io.ByteWriter + RuneWriter + } + writerYYNWrapper struct{ writerYYN } +) + +func (w writerYYNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) } + +// NewAllWriter wraps an io.Writer turning it in to an AllWriter. If +// the io.Writer already has any of the other write methods, then its +// native version of those methods are used. +func NewAllWriter(inner io.Writer) AllWriter { + switch inner := inner.(type) { + // 3 Y bits + case AllWriter: // YYY: + return inner + // 2 Y bits + case writerNYY: + return writerNYYWrapper{writerNYY: inner} + case writerYNY: + return writerYNYWrapper{writerYNY: inner} + case writerYYN: + return writerYYNWrapper{writerYYN: inner} + // 1 Y bit + case writerNNY: + return writerNNYWrapper{writerNNY: inner} + case writerNYN: + return writerNYNWrapper{writerNYN: inner} + case writerYNN: + return writerYNNWrapper{writerYNN: inner} + // 0 Y bits + default: // NNN: + return writerNNNWrapper{writerNNN: inner} + } +} + +// discard ///////////////////////////////////////////////////////////////////// + +// Discard is like io.Discard, but implements AllWriter. +var Discard = discard{} + +type discard struct{} + +func (discard) Write(p []byte) (int, error) { return len(p), nil } +func (discard) WriteByte(b byte) error { return nil } +func (discard) WriteRune(r rune) (int, error) { return 0, nil } +func (discard) WriteString(s string) (int, error) { return len(s), nil } diff --git a/internal/base64.go b/internal/base64.go index 15adbf4..291a229 100644 --- a/internal/base64.go +++ b/internal/base64.go @@ -19,7 +19,10 @@ type base64Decoder struct { bufLen int } -func NewBase64Decoder(w io.Writer) io.WriteCloser { +func NewBase64Decoder(w io.Writer) interface { + io.WriteCloser + RuneWriter +} { return &base64Decoder{ dst: w, } @@ -112,6 +115,10 @@ func (dec *base64Decoder) Write(dat []byte) (int, error) { return len(dat), nil } +func (dec *base64Decoder) WriteRune(r rune) (int, error) { + return WriteRune(dec, r) +} + func (dec *base64Decoder) Close() error { if dec.bufLen == 0 { return nil diff --git a/ioutil.go b/ioutil.go deleted file mode 100644 index a53eac3..0000000 --- a/ioutil.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (C) 2022-2023 Luke Shumaker -// -// SPDX-License-Identifier: GPL-2.0-or-later - -package lowmemjson - -import ( - "io" - "unicode/utf8" -) - -func writeByte(w io.Writer, c byte) error { - if br, ok := w.(interface{ WriteByte(byte) error }); ok { - return br.WriteByte(c) - } - var buf [1]byte - buf[0] = c - if _, err := w.Write(buf[:]); err != nil { - return err - } - return nil -} - -func writeRune(w io.Writer, c rune) (int, error) { - if rw, ok := w.(interface{ WriteRune(rune) (int, error) }); ok { - return rw.WriteRune(c) - } - var buf [utf8.UTFMax]byte - n := utf8.EncodeRune(buf[:], c) - return w.Write(buf[:n]) -} diff --git a/reencode.go b/reencode.go index 876af62..393e8c6 100644 --- a/reencode.go +++ b/reencode.go @@ -71,7 +71,7 @@ type ReEncoderConfig struct { func NewReEncoder(out io.Writer, cfg ReEncoderConfig) *ReEncoder { return &ReEncoder{ ReEncoderConfig: cfg, - out: out, + out: internal.NewAllWriter(out), } } @@ -85,9 +85,9 @@ func NewReEncoder(out io.Writer, cfg ReEncoderConfig) *ReEncoder { // The memory use of a ReEncoder is O( (CompactIfUnder+1)^2 + depth). type ReEncoder struct { ReEncoderConfig - out io.Writer + out internal.AllWriter - // state: .Write's utf8-decoding buffer + // state: .Write's and .WriteString's utf8-decoding buffer buf [utf8.UTFMax]byte bufLen int @@ -119,6 +119,11 @@ type speculation struct { // public API ////////////////////////////////////////////////////////////////// +var ( + _ internal.AllWriter = (*ReEncoder)(nil) + _ io.Closer = (*ReEncoder)(nil) +) + // Write implements io.Writer; it does what you'd expect. // // It is worth noting that Write returns the number of bytes consumed @@ -152,6 +157,38 @@ func (enc *ReEncoder) Write(p []byte) (int, error) { return len(p), nil } +// WriteString implements io.StringWriter; it does what you'd expect, +// but see the notes on the Write method. +func (enc *ReEncoder) WriteString(p string) (int, error) { + if len(p) == 0 { + return 0, nil + } + var n int + if enc.bufLen > 0 { + copy(enc.buf[enc.bufLen:], p) + c, size := utf8.DecodeRune(enc.buf[:]) + n += size - enc.bufLen + enc.bufLen = 0 + if _, err := enc.WriteRune(c); err != nil { + return 0, err + } + } + for utf8.FullRuneInString(p[n:]) { + c, size := utf8.DecodeRuneInString(p[n:]) + if _, err := enc.WriteRune(c); err != nil { + return n, err + } + n += size + } + enc.bufLen = copy(enc.buf[:], p[n:]) + return len(p), nil +} + +// WriteByte implements io.ByteWriter; it does what you'd expect. +func (enc *ReEncoder) WriteByte(b byte) error { + return internal.WriteByte(enc, b) +} + // Close implements io.Closer; it does what you'd expect, mostly. // // The *ReEncoder may continue to be written to with new JSON values @@ -471,7 +508,7 @@ func (enc *ReEncoder) handleRuneMain(c rune, t internal.RuneType) error { } func (enc *ReEncoder) emitByte(c byte) error { - err := writeByte(enc.out, c) + err := enc.out.WriteByte(c) if err == nil { enc.written++ } @@ -488,12 +525,12 @@ func (enc *ReEncoder) emitNlIndent() error { return err } if enc.Prefix != "" { - if err := enc.emit(io.WriteString(enc.out, enc.Prefix)); err != nil { + if err := enc.emit(enc.out.WriteString(enc.Prefix)); err != nil { return err } } for i := 0; i < enc.handleRuneState.curIndent; i++ { - if err := enc.emit(io.WriteString(enc.out, enc.Indent)); err != nil { + if err := enc.emit(enc.out.WriteString(enc.Indent)); err != nil { return err } } -- cgit v1.2.3-2-g168b