summaryrefslogtreecommitdiff
path: root/compat
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-02-20 12:47:10 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-02-20 12:47:10 -0700
commitf5ca3478c68e47ae20fd12748c1552fdf81f75f9 (patch)
treeb3d3f889ed25084fe33ed9e01554d6ca51104bb5 /compat
parentd240d0b06c7b5711f583d961eddfc37d07d4546e (diff)
parent49ee8be679add0bd3cf08a2669331b3be7a835f8 (diff)
Merge branch 'lukeshu/fixes'
Diffstat (limited to 'compat')
-rw-r--r--compat/json/compat.go150
-rw-r--r--compat/json/compat_test.go241
-rw-r--r--compat/json/testcompat_test.go14
3 files changed, 379 insertions, 26 deletions
diff --git a/compat/json/compat.go b/compat/json/compat.go
index c96470d..695c1a8 100644
--- a/compat/json/compat.go
+++ b/compat/json/compat.go
@@ -11,10 +11,13 @@ import (
"bytes"
"encoding/json"
"errors"
+ "fmt"
"io"
"strconv"
+ "unicode/utf8"
"git.lukeshu.com/go/lowmemjson"
+ "git.lukeshu.com/go/lowmemjson/internal/jsonstring"
)
//nolint:stylecheck // ST1021 False positive; these aren't comments on individual types.
@@ -144,7 +147,23 @@ func convertReEncodeError(err error) error {
}
func HTMLEscape(dst *bytes.Buffer, src []byte) {
- _, _ = lowmemjson.NewReEncoder(dst, lowmemjson.ReEncoderConfig{}).Write(src)
+ for n := 0; n < len(src); {
+ c, size := utf8.DecodeRune(src[n:])
+ if c == utf8.RuneError && size == 1 {
+ dst.WriteByte(src[n])
+ } else {
+ mode := lowmemjson.EscapeHTMLSafe(c, lowmemjson.BackslashEscapeNone)
+ switch mode {
+ case lowmemjson.BackslashEscapeNone:
+ dst.WriteRune(c)
+ case lowmemjson.BackslashEscapeUnicode:
+ _ = jsonstring.WriteStringUnicodeEscape(dst, c, mode)
+ default:
+ panic(fmt.Errorf("lowmemjson.EscapeHTMLSafe returned an unexpected escape mode=%d", mode))
+ }
+ }
+ n += size
+ }
}
func reencode(dst io.Writer, src []byte, cfg lowmemjson.ReEncoderConfig) error {
@@ -157,38 +176,75 @@ func reencode(dst io.Writer, src []byte, cfg lowmemjson.ReEncoderConfig) error {
}
func Compact(dst *bytes.Buffer, src []byte) error {
- return reencode(dst, src, lowmemjson.ReEncoderConfig{
+ start := dst.Len()
+ err := reencode(dst, src, lowmemjson.ReEncoderConfig{
Compact: true,
+ InvalidUTF8: lowmemjson.InvalidUTF8Preserve,
BackslashEscape: lowmemjson.EscapePreserve,
})
+ if err != nil {
+ dst.Truncate(start)
+ }
+ return err
+}
+
+func isSpace(c byte) bool {
+ switch c {
+ case 0x0020, 0x000A, 0x000D, 0x0009:
+ return true
+ default:
+ return false
+ }
}
func Indent(dst *bytes.Buffer, src []byte, prefix, indent string) error {
- return reencode(dst, src, lowmemjson.ReEncoderConfig{
+ start := dst.Len()
+ err := reencode(dst, src, lowmemjson.ReEncoderConfig{
Indent: indent,
Prefix: prefix,
+ InvalidUTF8: lowmemjson.InvalidUTF8Preserve,
BackslashEscape: lowmemjson.EscapePreserve,
})
+ if err != nil {
+ dst.Truncate(start)
+ return err
+ }
+
+ // Preserve trailing whitespace.
+ lastNonWS := len(src) - 1
+ for ; lastNonWS >= 0 && isSpace(src[lastNonWS]); lastNonWS-- {
+ }
+ if _, err := dst.Write(src[lastNonWS+1:]); err != nil {
+ return err
+ }
+
+ return nil
}
func Valid(data []byte) bool {
formatter := lowmemjson.NewReEncoder(io.Discard, lowmemjson.ReEncoderConfig{
- Compact: true,
+ Compact: true,
+ InvalidUTF8: lowmemjson.InvalidUTF8Error,
})
- _, err := formatter.Write(data)
- return err == nil
+ if _, err := formatter.Write(data); err != nil {
+ return false
+ }
+ if err := formatter.Close(); err != nil {
+ return false
+ }
+ return true
}
// Decode wrappers ///////////////////////////////////////////////////
-func convertDecodeError(err error) error {
+func convertDecodeError(err error, isUnmarshal bool) error {
if derr, ok := err.(*lowmemjson.DecodeError); ok {
switch terr := derr.Err.(type) {
case *lowmemjson.DecodeSyntaxError:
switch {
case errors.Is(terr.Err, io.EOF):
err = io.EOF
- case errors.Is(terr.Err, io.ErrUnexpectedEOF):
+ case errors.Is(terr.Err, io.ErrUnexpectedEOF) && isUnmarshal:
err = &SyntaxError{
msg: "unexpected end of JSON input",
Offset: terr.Offset,
@@ -228,13 +284,66 @@ func convertDecodeError(err error) error {
return err
}
+type decodeValidator struct{}
+
+func (*decodeValidator) DecodeJSON(r io.RuneScanner) error {
+ for {
+ if _, _, err := r.ReadRune(); err != nil {
+
+ if err == io.EOF {
+ return nil
+ }
+ return err
+ }
+ }
+}
+
+var _ lowmemjson.Decodable = (*decodeValidator)(nil)
+
func Unmarshal(data []byte, ptr any) error {
- return convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr))
+ if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(&decodeValidator{}), true); err != nil {
+ return err
+ }
+ if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr), true); err != nil {
+ return err
+ }
+ return nil
+}
+
+type teeRuneScanner struct {
+ src io.RuneScanner
+ dst *bytes.Buffer
+ lastSize int
+}
+
+func (tee *teeRuneScanner) ReadRune() (r rune, size int, err error) {
+ r, size, err = tee.src.ReadRune()
+ if err == nil {
+ if _, err := tee.dst.WriteRune(r); err != nil {
+ return 0, 0, err
+ }
+ }
+
+ tee.lastSize = size
+ return
+}
+
+func (tee *teeRuneScanner) UnreadRune() error {
+ if tee.lastSize == 0 {
+ return lowmemjson.ErrInvalidUnreadRune
+ }
+ _ = tee.src.UnreadRune()
+ tee.dst.Truncate(tee.dst.Len() - tee.lastSize)
+ tee.lastSize = 0
+ return nil
}
type Decoder struct {
+ validatorBuf *bufio.Reader
+ validator *lowmemjson.Decoder
+
+ decoderBuf bytes.Buffer
*lowmemjson.Decoder
- buf *bufio.Reader
}
func NewDecoder(r io.Reader) *Decoder {
@@ -242,18 +351,29 @@ func NewDecoder(r io.Reader) *Decoder {
if !ok {
br = bufio.NewReader(r)
}
- return &Decoder{
- Decoder: lowmemjson.NewDecoder(br),
- buf: br,
+ ret := &Decoder{
+ validatorBuf: br,
}
+ ret.validator = lowmemjson.NewDecoder(&teeRuneScanner{
+ src: ret.validatorBuf,
+ dst: &ret.decoderBuf,
+ })
+ ret.Decoder = lowmemjson.NewDecoder(&ret.decoderBuf)
+ return ret
}
func (dec *Decoder) Decode(ptr any) error {
- return convertDecodeError(dec.Decoder.Decode(ptr))
+ if err := convertDecodeError(dec.validator.Decode(&decodeValidator{}), false); err != nil {
+ return err
+ }
+ if err := convertDecodeError(dec.Decoder.Decode(ptr), false); err != nil {
+ return err
+ }
+ return nil
}
func (dec *Decoder) Buffered() io.Reader {
- dat, _ := dec.buf.Peek(dec.buf.Buffered())
+ dat, _ := dec.validatorBuf.Peek(dec.validatorBuf.Buffered())
return bytes.NewReader(dat)
}
diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go
new file mode 100644
index 0000000..df9d387
--- /dev/null
+++ b/compat/json/compat_test.go
@@ -0,0 +1,241 @@
+// Copyright (C) 2023 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package json
+
+import (
+ "bytes"
+ "reflect"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestCompatHTMLEscape(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ Out string
+ }
+ testcases := map[string]testcase{
+ "invalid": {In: `x`, Out: `x`},
+ "hex-lower": {In: `"\uabcd"`, Out: `"\uabcd"`},
+ "hex-upper": {In: `"\uABCD"`, Out: `"\uABCD"`},
+ "hex-mixed": {In: `"\uAbCd"`, Out: `"\uAbCd"`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ t.Logf("in=%q", tc.In)
+ var dst bytes.Buffer
+ HTMLEscape(&dst, []byte(tc.In))
+ assert.Equal(t, tc.Out, dst.String())
+ })
+ }
+}
+
+func TestCompatValid(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ Exp bool
+ }
+ testcases := map[string]testcase{
+ "empty": {In: ``, Exp: false},
+ "num": {In: `1`, Exp: true},
+ "trunc": {In: `{`, Exp: false},
+ "object": {In: `{}`, Exp: true},
+ "non-utf8": {In: "\"\x85\xcd\"", Exp: false}, // https://github.com/golang/go/issues/58517
+ "hex-lower": {In: `"\uabcd"`, Exp: true},
+ "hex-upper": {In: `"\uABCD"`, Exp: true},
+ "hex-mixed": {In: `"\uAbCd"`, Exp: true},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ t.Logf("in=%q", tc.In)
+ act := Valid([]byte(tc.In))
+ assert.Equal(t, tc.Exp, act)
+ })
+ }
+}
+
+func TestCompatCompact(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ Out string
+ Err string
+ }
+ testcases := map[string]testcase{
+ "trunc": {In: `{`, Out: ``, Err: `unexpected end of JSON input`},
+ "object": {In: `{}`, Out: `{}`},
+ "non-utf8": {In: "\"\x85\xcd\"", Out: "\"\x85\xcd\""},
+ "float": {In: `1.200e003`, Out: `1.200e003`},
+ "hex-lower": {In: `"\uabcd"`, Out: `"\uabcd"`},
+ "hex-upper": {In: `"\uABCD"`, Out: `"\uABCD"`},
+ "hex-mixed": {In: `"\uAbCd"`, Out: `"\uAbCd"`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ t.Logf("in=%q", tc.In)
+ var out bytes.Buffer
+ err := Compact(&out, []byte(tc.In))
+ assert.Equal(t, tc.Out, out.String())
+ if tc.Err == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.Err)
+ }
+ })
+ }
+}
+
+func TestCompatIndent(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ Out string
+ Err string
+ }
+ testcases := map[string]testcase{
+ "trunc": {In: `{`, Out: ``, Err: `unexpected end of JSON input`},
+ "object": {In: `{}`, Out: `{}`},
+ "non-utf8": {In: "\"\x85\xcd\"", Out: "\"\x85\xcd\""},
+ "float": {In: `1.200e003`, Out: `1.200e003`},
+ "tailws0": {In: `0`, Out: `0`},
+ "tailws1": {In: `0 `, Out: `0 `},
+ "tailws2": {In: `0 `, Out: `0 `},
+ "tailws3": {In: "0\n", Out: "0\n"},
+ "headws1": {In: ` 0`, Out: `0`},
+ "objws1": {In: `{"a" : 1}`, Out: "{\n>.\"a\": 1\n>}"},
+ "objws2": {In: "{\"a\"\n:\n1}", Out: "{\n>.\"a\": 1\n>}"},
+ "hex-lower": {In: `"\uabcd"`, Out: `"\uabcd"`},
+ "hex-upper": {In: `"\uABCD"`, Out: `"\uABCD"`},
+ "hex-mixed": {In: `"\uAbCd"`, Out: `"\uAbCd"`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ t.Logf("in=%q", tc.In)
+ var out bytes.Buffer
+ err := Indent(&out, []byte(tc.In), ">", ".")
+ assert.Equal(t, tc.Out, out.String())
+ if tc.Err == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.Err)
+ }
+ })
+ }
+}
+
+func TestCompatMarshal(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In any
+ Out string
+ Err string
+ }
+ testcases := map[string]testcase{
+ "non-utf8": {In: "\x85\xcd", Out: "\"\\ufffd\\ufffd\""},
+ "urc": {In: "\ufffd", Out: "\"\ufffd\""},
+ "float": {In: 1.2e3, Out: `1200`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ out, err := Marshal(tc.In)
+ assert.Equal(t, tc.Out, string(out))
+ if tc.Err == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.Err)
+ }
+ })
+ }
+}
+
+func TestCompatUnmarshal(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ InPtr any
+ ExpOut any
+ ExpErr string
+ }
+ testcases := map[string]testcase{
+ "empty-obj": {In: `{}`, ExpOut: map[string]any{}},
+ "partial-obj": {In: `{"foo":"bar",`, ExpOut: nil, ExpErr: `unexpected end of JSON input`},
+ "existing-obj": {In: `{"baz":"quz"}`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar", "baz": "quz"}},
+ "existing-obj-partial": {In: `{"baz":"quz"`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar"}, ExpErr: "unexpected end of JSON input"},
+ "empty-ary": {In: `[]`, ExpOut: []any{}},
+ "two-objs": {In: `{} {}`, ExpOut: nil, ExpErr: `invalid character '{' after top-level value`},
+ "two-numbers1": {In: `00`, ExpOut: nil, ExpErr: `invalid character '0' after top-level value`},
+ "two-numbers2": {In: `1 2`, ExpOut: nil, ExpErr: `invalid character '2' after top-level value`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ ptr := tc.InPtr
+ if ptr == nil {
+ var out any
+ ptr = &out
+ }
+ err := Unmarshal([]byte(tc.In), ptr)
+ assert.Equal(t, tc.ExpOut, reflect.ValueOf(ptr).Elem().Interface())
+ if tc.ExpErr == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.ExpErr)
+ }
+ })
+ }
+}
+
+func TestCompatDecode(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ InPtr any
+ ExpOut any
+ ExpErr string
+ }
+ testcases := map[string]testcase{
+ "empty-obj": {In: `{}`, ExpOut: map[string]any{}},
+ "partial-obj": {In: `{"foo":"bar",`, ExpOut: nil, ExpErr: `unexpected EOF`},
+ "existing-obj": {In: `{"baz":"quz"}`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar", "baz": "quz"}},
+ "existing-obj-partial": {In: `{"baz":"quz"`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar"}, ExpErr: "unexpected EOF"},
+ "empty-ary": {In: `[]`, ExpOut: []any{}},
+ "two-objs": {In: `{} {}`, ExpOut: map[string]any{}},
+ "two-numbers1": {In: `00`, ExpOut: float64(0)},
+ "two-numbers2": {In: `1 2`, ExpOut: float64(1)},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ ptr := tc.InPtr
+ if ptr == nil {
+ var out any
+ ptr = &out
+ }
+ err := NewDecoder(strings.NewReader(tc.In)).Decode(ptr)
+ assert.Equal(t, tc.ExpOut, reflect.ValueOf(ptr).Elem().Interface())
+ if tc.ExpErr == "" {
+ assert.NoError(t, err)
+ } else {
+ assert.EqualError(t, err, tc.ExpErr)
+ }
+ })
+ }
+}
diff --git a/compat/json/testcompat_test.go b/compat/json/testcompat_test.go
index 42cbf5c..73153d9 100644
--- a/compat/json/testcompat_test.go
+++ b/compat/json/testcompat_test.go
@@ -8,6 +8,7 @@ import (
"bytes"
"encoding/json"
"io"
+ "reflect"
_ "unsafe"
"git.lukeshu.com/go/lowmemjson"
@@ -45,27 +46,18 @@ const (
startDetectingCyclesAfter = 1000
)
-func isSpace(c byte) bool {
- switch c {
- case 0x0020, 0x000A, 0x000D, 0x0009:
- return true
- default:
- return false
- }
-}
-
type encodeState struct {
bytes.Buffer
}
func (es *encodeState) string(str string, _ bool) {
- if err := jsonstring.EncodeStringFromString(&es.Buffer, lowmemjson.EscapeDefault, str); err != nil {
+ if err := jsonstring.EncodeStringFromString(&es.Buffer, lowmemjson.EscapeDefault, 0, reflect.Value{}, str); err != nil {
panic(err)
}
}
func (es *encodeState) stringBytes(str []byte, _ bool) {
- if err := jsonstring.EncodeStringFromBytes(&es.Buffer, lowmemjson.EscapeDefault, str); err != nil {
+ if err := jsonstring.EncodeStringFromBytes(&es.Buffer, lowmemjson.EscapeDefault, 0, reflect.Value{}, str); err != nil {
panic(err)
}
}