summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-01-26 21:02:56 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-01-30 22:00:25 -0700
commit2828fa21c0ffd2a32a108b37c0417b01abc42929 (patch)
treeae671b894fa952e01a410c94fe27e1d0fec37e80
parent8aa12d3cb043859229810947da6c52e600d34b55 (diff)
Avoid doing type switching in inner functions
The CPU profiler tells me that the encoder is spending a lot of time on type switches.
-rw-r--r--decode.go40
-rw-r--r--encode.go14
-rw-r--r--encode_string.go16
-rw-r--r--internal/allwriter.go174
-rw-r--r--internal/base64.go9
-rw-r--r--ioutil.go31
-rw-r--r--reencode.go49
7 files changed, 261 insertions, 72 deletions
diff --git a/decode.go b/decode.go
index 7ae723c..91be865 100644
--- a/decode.go
+++ b/decode.go
@@ -565,7 +565,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
if dec.disallowUnknownFields {
dec.panicType("", typ, fmt.Errorf("json: unknown field %q", name))
}
- dec.scan(io.Discard)
+ dec.scan(internal.Discard)
return
}
field := index.byPos[idx]
@@ -749,7 +749,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
dec.decode(mValPtr.Elem(), false)
val.Index(i).Set(mValPtr.Elem())
} else {
- dec.scan(io.Discard)
+ dec.scan(internal.Discard)
}
i++
})
@@ -773,18 +773,18 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
}
}
-func (dec *Decoder) scan(out io.Writer) {
+func (dec *Decoder) scan(out internal.RuneWriter) {
limiter := dec.limitingScanner()
for {
c, _, err := limiter.ReadRune()
if err == io.EOF {
return
}
- _, _ = writeRune(out, c)
+ _, _ = out.WriteRune(c)
}
}
-func (dec *Decoder) scanNumber(gTyp reflect.Type, out io.Writer) {
+func (dec *Decoder) scanNumber(gTyp reflect.Type, out internal.RuneWriter) {
if t := dec.peekRuneType(); !t.IsNumber() {
dec.panicType(t.JSONType(), gTyp, nil)
}
@@ -991,34 +991,34 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func()) {
}
}
-func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) {
+func (dec *Decoder) decodeString(gTyp reflect.Type, out internal.RuneWriter) {
dec.expectRuneType('"', internal.RuneTypeStringBeg, gTyp)
var uhex [4]byte
for {
c, t := dec.readRune()
switch t {
case internal.RuneTypeStringChar:
- _, _ = writeRune(out, c)
+ _, _ = out.WriteRune(c)
case internal.RuneTypeStringEsc, internal.RuneTypeStringEscU:
// do nothing
case internal.RuneTypeStringEsc1:
switch c {
case '"':
- _, _ = writeRune(out, '"')
+ _, _ = out.WriteRune('"')
case '\\':
- _, _ = writeRune(out, '\\')
+ _, _ = out.WriteRune('\\')
case '/':
- _, _ = writeRune(out, '/')
+ _, _ = out.WriteRune('/')
case 'b':
- _, _ = writeRune(out, '\b')
+ _, _ = out.WriteRune('\b')
case 'f':
- _, _ = writeRune(out, '\f')
+ _, _ = out.WriteRune('\f')
case 'n':
- _, _ = writeRune(out, '\n')
+ _, _ = out.WriteRune('\n')
case 'r':
- _, _ = writeRune(out, '\r')
+ _, _ = out.WriteRune('\r')
case 't':
- _, _ = writeRune(out, '\t')
+ _, _ = out.WriteRune('\t')
default:
panic("should not happen")
}
@@ -1038,12 +1038,12 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) {
handleUnicode:
if utf16.IsSurrogate(c) {
if dec.peekRuneType() != internal.RuneTypeStringEsc {
- _, _ = writeRune(out, utf8.RuneError)
+ _, _ = out.WriteRune(utf8.RuneError)
break
}
dec.expectRune('\\', internal.RuneTypeStringEsc)
if dec.peekRuneType() != internal.RuneTypeStringEscU {
- _, _ = writeRune(out, utf8.RuneError)
+ _, _ = out.WriteRune(utf8.RuneError)
break
}
dec.expectRune('u', internal.RuneTypeStringEscU)
@@ -1063,13 +1063,13 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) {
rune(uhex[3])<<0
d := utf16.DecodeRune(c, c2)
if d == utf8.RuneError {
- _, _ = writeRune(out, utf8.RuneError)
+ _, _ = out.WriteRune(utf8.RuneError)
c = c2
goto handleUnicode
}
- _, _ = writeRune(out, d)
+ _, _ = out.WriteRune(d)
} else {
- _, _ = writeRune(out, c)
+ _, _ = out.WriteRune(c)
}
case internal.RuneTypeStringEnd:
return
diff --git a/encode.go b/encode.go
index e9c7ac6..c5a29b3 100644
--- a/encode.go
+++ b/encode.go
@@ -18,6 +18,8 @@ import (
"strconv"
"strings"
"unsafe"
+
+ "git.lukeshu.com/go/lowmemjson/internal"
)
// Encodable is the interface implemented by types that can encode
@@ -34,14 +36,14 @@ type encodeError struct {
Err error
}
-func encodeWriteByte(w io.Writer, b byte) {
- if err := writeByte(w, b); err != nil {
+func encodeWriteByte(w io.ByteWriter, b byte) {
+ if err := w.WriteByte(b); err != nil {
panic(encodeError{err})
}
}
-func encodeWriteString(w io.Writer, str string) {
- if _, err := io.WriteString(w, str); err != nil {
+func encodeWriteString(w io.StringWriter, str string) {
+ if _, err := w.WriteString(str); err != nil {
panic(encodeError{err})
}
}
@@ -115,7 +117,7 @@ var (
const startDetectingCyclesAfter = 1000
-func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
+func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
if !val.IsValid() {
encodeWriteString(w, "null")
return
@@ -436,7 +438,7 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool
}
}
-func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) {
+func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) {
encodeWriteByte(w, '[')
n := val.Len()
for i := 0; i < n; i++ {
diff --git a/encode_string.go b/encode_string.go
index c5cb442..831a038 100644
--- a/encode_string.go
+++ b/encode_string.go
@@ -45,7 +45,7 @@ func writeStringShortEscape(w io.Writer, c rune) (int, error) {
return w.Write(buf[:])
}
-func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escaper BackslashEscaper) (int, error) {
+func writeStringChar(w internal.AllWriter, c rune, wasEscaped BackslashEscapeMode, escaper BackslashEscaper) (int, error) {
if escaper == nil {
escaper = EscapeDefault
}
@@ -62,19 +62,19 @@ func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escape
case c == '"' || c == '\\': // override, gotta escape these
return writeStringShortEscape(w, c)
default: // obey
- return writeRune(w, c)
+ return w.WriteRune(c)
}
case BackslashEscapeShort:
switch c {
case '"', '\\', '/', '\b', '\f', '\n', '\r', '\t': // obey
return writeStringShortEscape(w, c)
default: // override, can't short-escape these
- return writeRune(w, c)
+ return w.WriteRune(c)
}
case BackslashEscapeUnicode:
switch {
case c > 0xFFFF: // override, can't escape these (TODO: unless we use UTF-16 surrogates?)
- return writeRune(w, c)
+ return w.WriteRune(c)
default: // obey
return writeStringUnicodeEscape(w, c)
}
@@ -83,7 +83,7 @@ func writeStringChar(w io.Writer, c rune, wasEscaped BackslashEscapeMode, escape
}
}
-func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) {
+func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) {
encodeWriteByte(w, '"')
for _, c := range str {
if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil {
@@ -93,7 +93,7 @@ func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) {
encodeWriteByte(w, '"')
}
-func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) {
+func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) {
encodeWriteByte(w, '"')
for i := 0; i < len(str); {
c, size := utf8.DecodeRune(str[i:])
@@ -106,6 +106,6 @@ func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) {
}
func init() {
- internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(w, nil, s) }
- internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(w, nil, s) }
+ internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(internal.NewAllWriter(w), nil, s) }
+ internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(internal.NewAllWriter(w), nil, s) }
}
diff --git a/internal/allwriter.go b/internal/allwriter.go
new file mode 100644
index 0000000..187aa8e
--- /dev/null
+++ b/internal/allwriter.go
@@ -0,0 +1,174 @@
+// Copyright (C) 2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package internal
+
+import (
+ "io"
+ "unicode/utf8"
+)
+
+// interfaces /////////////////////////////////////////////////////////////////
+
+type RuneWriter interface {
+ WriteRune(rune) (int, error)
+}
+
+// An AllWriter is the union of several common writer interfaces.
+type AllWriter interface {
+ io.Writer
+ io.ByteWriter
+ RuneWriter
+ io.StringWriter
+}
+
+// implementations ////////////////////////////////////////////////////////////
+
+func WriteByte(w io.Writer, b byte) error {
+ var buf [1]byte
+ buf[0] = b
+ _, err := w.Write(buf[:])
+ return err
+}
+
+func WriteRune(w io.Writer, r rune) (int, error) {
+ var buf [utf8.UTFMax]byte
+ n := utf8.EncodeRune(buf[:], r)
+ return w.Write(buf[:n])
+}
+
+func WriteString(w io.Writer, s string) (int, error) {
+ return w.Write([]byte(s))
+}
+
+// wrappers ///////////////////////////////////////////////////////////////////
+
+// NNN
+
+type (
+ writerNNN interface{ io.Writer }
+ writerNNNWrapper struct{ writerNNN }
+)
+
+func (w writerNNNWrapper) WriteByte(b byte) error { return WriteByte(w, b) }
+func (w writerNNNWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) }
+func (w writerNNNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) }
+
+// NNY
+
+type (
+ writerNNY interface {
+ io.Writer
+ io.StringWriter
+ }
+ writerNNYWrapper struct{ writerNNY }
+)
+
+func (w writerNNYWrapper) WriteByte(b byte) error { return WriteByte(w, b) }
+func (w writerNNYWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) }
+
+// NYN
+
+type (
+ writerNYN interface {
+ io.Writer
+ RuneWriter
+ }
+ writerNYNWrapper struct{ writerNYN }
+)
+
+func (w writerNYNWrapper) WriteByte(b byte) error { return WriteByte(w, b) }
+func (w writerNYNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) }
+
+// NYY
+
+type (
+ writerNYY interface {
+ io.Writer
+ RuneWriter
+ io.StringWriter
+ }
+ writerNYYWrapper struct{ writerNYY }
+)
+
+func (w writerNYYWrapper) WriteByte(b byte) error { return WriteByte(w, b) }
+
+// YNN
+
+type (
+ writerYNN interface {
+ io.Writer
+ io.ByteWriter
+ }
+ writerYNNWrapper struct{ writerYNN }
+)
+
+func (w writerYNNWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) }
+func (w writerYNNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) }
+
+// YNY
+
+type (
+ writerYNY interface {
+ io.Writer
+ io.ByteWriter
+ io.StringWriter
+ }
+ writerYNYWrapper struct{ writerYNY }
+)
+
+func (w writerYNYWrapper) WriteRune(r rune) (int, error) { return WriteRune(w, r) }
+
+// YYN
+
+type (
+ writerYYN interface {
+ io.Writer
+ io.ByteWriter
+ RuneWriter
+ }
+ writerYYNWrapper struct{ writerYYN }
+)
+
+func (w writerYYNWrapper) WriteString(s string) (int, error) { return WriteString(w, s) }
+
+// NewAllWriter wraps an io.Writer turning it in to an AllWriter. If
+// the io.Writer already has any of the other write methods, then its
+// native version of those methods are used.
+func NewAllWriter(inner io.Writer) AllWriter {
+ switch inner := inner.(type) {
+ // 3 Y bits
+ case AllWriter: // YYY:
+ return inner
+ // 2 Y bits
+ case writerNYY:
+ return writerNYYWrapper{writerNYY: inner}
+ case writerYNY:
+ return writerYNYWrapper{writerYNY: inner}
+ case writerYYN:
+ return writerYYNWrapper{writerYYN: inner}
+ // 1 Y bit
+ case writerNNY:
+ return writerNNYWrapper{writerNNY: inner}
+ case writerNYN:
+ return writerNYNWrapper{writerNYN: inner}
+ case writerYNN:
+ return writerYNNWrapper{writerYNN: inner}
+ // 0 Y bits
+ default: // NNN:
+ return writerNNNWrapper{writerNNN: inner}
+ }
+}
+
+// discard /////////////////////////////////////////////////////////////////////
+
+// Discard is like io.Discard, but implements AllWriter.
+var Discard = discard{}
+
+type discard struct{}
+
+func (discard) Write(p []byte) (int, error) { return len(p), nil }
+func (discard) WriteByte(b byte) error { return nil }
+func (discard) WriteRune(r rune) (int, error) { return 0, nil }
+func (discard) WriteString(s string) (int, error) { return len(s), nil }
diff --git a/internal/base64.go b/internal/base64.go
index 15adbf4..291a229 100644
--- a/internal/base64.go
+++ b/internal/base64.go
@@ -19,7 +19,10 @@ type base64Decoder struct {
bufLen int
}
-func NewBase64Decoder(w io.Writer) io.WriteCloser {
+func NewBase64Decoder(w io.Writer) interface {
+ io.WriteCloser
+ RuneWriter
+} {
return &base64Decoder{
dst: w,
}
@@ -112,6 +115,10 @@ func (dec *base64Decoder) Write(dat []byte) (int, error) {
return len(dat), nil
}
+func (dec *base64Decoder) WriteRune(r rune) (int, error) {
+ return WriteRune(dec, r)
+}
+
func (dec *base64Decoder) Close() error {
if dec.bufLen == 0 {
return nil
diff --git a/ioutil.go b/ioutil.go
deleted file mode 100644
index a53eac3..0000000
--- a/ioutil.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright (C) 2022-2023 Luke Shumaker <lukeshu@lukeshu.com>
-//
-// SPDX-License-Identifier: GPL-2.0-or-later
-
-package lowmemjson
-
-import (
- "io"
- "unicode/utf8"
-)
-
-func writeByte(w io.Writer, c byte) error {
- if br, ok := w.(interface{ WriteByte(byte) error }); ok {
- return br.WriteByte(c)
- }
- var buf [1]byte
- buf[0] = c
- if _, err := w.Write(buf[:]); err != nil {
- return err
- }
- return nil
-}
-
-func writeRune(w io.Writer, c rune) (int, error) {
- if rw, ok := w.(interface{ WriteRune(rune) (int, error) }); ok {
- return rw.WriteRune(c)
- }
- var buf [utf8.UTFMax]byte
- n := utf8.EncodeRune(buf[:], c)
- return w.Write(buf[:n])
-}
diff --git a/reencode.go b/reencode.go
index 876af62..393e8c6 100644
--- a/reencode.go
+++ b/reencode.go
@@ -71,7 +71,7 @@ type ReEncoderConfig struct {
func NewReEncoder(out io.Writer, cfg ReEncoderConfig) *ReEncoder {
return &ReEncoder{
ReEncoderConfig: cfg,
- out: out,
+ out: internal.NewAllWriter(out),
}
}
@@ -85,9 +85,9 @@ func NewReEncoder(out io.Writer, cfg ReEncoderConfig) *ReEncoder {
// The memory use of a ReEncoder is O( (CompactIfUnder+1)^2 + depth).
type ReEncoder struct {
ReEncoderConfig
- out io.Writer
+ out internal.AllWriter
- // state: .Write's utf8-decoding buffer
+ // state: .Write's and .WriteString's utf8-decoding buffer
buf [utf8.UTFMax]byte
bufLen int
@@ -119,6 +119,11 @@ type speculation struct {
// public API //////////////////////////////////////////////////////////////////
+var (
+ _ internal.AllWriter = (*ReEncoder)(nil)
+ _ io.Closer = (*ReEncoder)(nil)
+)
+
// Write implements io.Writer; it does what you'd expect.
//
// It is worth noting that Write returns the number of bytes consumed
@@ -152,6 +157,38 @@ func (enc *ReEncoder) Write(p []byte) (int, error) {
return len(p), nil
}
+// WriteString implements io.StringWriter; it does what you'd expect,
+// but see the notes on the Write method.
+func (enc *ReEncoder) WriteString(p string) (int, error) {
+ if len(p) == 0 {
+ return 0, nil
+ }
+ var n int
+ if enc.bufLen > 0 {
+ copy(enc.buf[enc.bufLen:], p)
+ c, size := utf8.DecodeRune(enc.buf[:])
+ n += size - enc.bufLen
+ enc.bufLen = 0
+ if _, err := enc.WriteRune(c); err != nil {
+ return 0, err
+ }
+ }
+ for utf8.FullRuneInString(p[n:]) {
+ c, size := utf8.DecodeRuneInString(p[n:])
+ if _, err := enc.WriteRune(c); err != nil {
+ return n, err
+ }
+ n += size
+ }
+ enc.bufLen = copy(enc.buf[:], p[n:])
+ return len(p), nil
+}
+
+// WriteByte implements io.ByteWriter; it does what you'd expect.
+func (enc *ReEncoder) WriteByte(b byte) error {
+ return internal.WriteByte(enc, b)
+}
+
// Close implements io.Closer; it does what you'd expect, mostly.
//
// The *ReEncoder may continue to be written to with new JSON values
@@ -471,7 +508,7 @@ func (enc *ReEncoder) handleRuneMain(c rune, t internal.RuneType) error {
}
func (enc *ReEncoder) emitByte(c byte) error {
- err := writeByte(enc.out, c)
+ err := enc.out.WriteByte(c)
if err == nil {
enc.written++
}
@@ -488,12 +525,12 @@ func (enc *ReEncoder) emitNlIndent() error {
return err
}
if enc.Prefix != "" {
- if err := enc.emit(io.WriteString(enc.out, enc.Prefix)); err != nil {
+ if err := enc.emit(enc.out.WriteString(enc.Prefix)); err != nil {
return err
}
}
for i := 0; i < enc.handleRuneState.curIndent; i++ {
- if err := enc.emit(io.WriteString(enc.out, enc.Indent)); err != nil {
+ if err := enc.emit(enc.out.WriteString(enc.Indent)); err != nil {
return err
}
}