// Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package lowmemjson

import (
	"bytes"
	"fmt"
	"io"
	"unicode/utf8"

	"git.lukeshu.com/go/lowmemjson/internal"
)

// A ReEncoderConfig controls how a ReEncoder should behave.
type ReEncoderConfig struct {
	// A JSON document is specified to be a single JSON element;
	// but it is often desirable to handle streams of multiple
	// JSON elements.
	AllowMultipleValues bool

	// Whether to minify the JSON.
	//
	// Trims all whitespace, except that it emits a newline
	// between two *number* top-level values (or puts a newline
	// after all top-level values if ForceTrailingNewlines).
	//
	// Trims superflous 0s from numbers.
	Compact bool

	// CompactIfUnder causes the *ReEncoder to behave as if
	// Compact=true for individual elements if doing so would
	// cause that element to be under this number of bytes.
	//
	// Has no affect if Compact is true or Indent is empty.
	//
	// his has O(2^min(CompactIfUnder, depth)) time overhead, so
	// set with caution.
	CompactIfUnder int

	// String to use to indent; ignored if Compact is true.
	//
	// Newlines are emitted *between* top-level values; a newline is
	// not emitted after the *last* top-level value (unless
	// ForceTrailingNewlines is on).
	Indent string

	// String to put before indents.
	Prefix string

	// Whether to emit a newline after each top-level value.  See
	// the comments on Compact and Indent for discussion of how
	// this is different than the usual behavior.
	ForceTrailingNewlines bool

	// Returns whether a given character in a string should be
	// backslash-escaped.  The bool argument is whether it was
	// \u-escaped in the input.  This does not affect characters
	// that must or must-not be escaped to be valid JSON.
	//
	// If not set, then EscapeDefault is used.
	BackslashEscape BackslashEscaper
}

// NewReEncoder returns a new ReEncoder instance.
//
// A ReEncoder tends to make many small writes; if Out.Write
// calls are syscalls, then you may want to wrap Out in a
// bufio.Writer.
func NewReEncoder(out io.Writer, cfg ReEncoderConfig) *ReEncoder {
	return &ReEncoder{
		ReEncoderConfig: cfg,
		out:             internal.NewAllWriter(out),
		specu:           new(speculation),
	}
}

// A ReEncoder takes a stream of JSON elements (by way of implementing
// io.Writer and WriteRune), and re-encodes the JSON, writing it to
// the .Out member.
//
// This is useful for prettifying, minifying, sanitizing, and/or
// validating JSON.
//
// The memory use of a ReEncoder is O(CompactIfUnder+depth).
type ReEncoder struct {
	ReEncoderConfig
	out internal.AllWriter

	// state: .Write's and .WriteString's utf8-decoding buffer
	buf    [utf8.UTFMax]byte
	bufLen int

	// state: .WriteRune
	err      error
	par      internal.Parser
	written  int
	inputPos int64

	// state: .handleRune
	lastNonSpace       internal.RuneType
	lastNonSpaceNonEOF internal.RuneType
	wasNumber          bool
	curIndent          int
	uhex               [4]byte // "\uABCD"-encoded characters in strings
	fracZeros          int64
	expZero            bool
	specu              *speculation

	// state: .pushWriteBarrier and .popWriteBarrier
	barriers []barrier
}

type barrier struct {
	inputPos  int64
	stackSize int
}

type speculation struct {
	speculating      bool
	endWhenStackSize int
	fmt              ReEncoder
	compact          bytes.Buffer
	buf              []inputTuple
}

func (specu *speculation) Reset() {
	specu.speculating = false
	specu.endWhenStackSize = 0
	specu.fmt = ReEncoder{}
	specu.compact.Reset()
	specu.buf = specu.buf[:0]
}

type inputTuple struct {
	c         rune
	t         internal.RuneType
	stackSize int
}

// public API //////////////////////////////////////////////////////////////////

var (
	_ internal.AllWriter = (*ReEncoder)(nil)
	_ io.Closer          = (*ReEncoder)(nil)
)

// Write implements io.Writer; it does what you'd expect.
//
// It is worth noting that Write returns the number of bytes consumed
// from p, not number of bytes written to the output stream.  This
// distinction that most io.Writer implementations don't need to make,
// but *ReEncoder does because it transforms the data written to it,
// and the number of bytes written may be wildly different than the
// number of bytes handled.
func (enc *ReEncoder) Write(p []byte) (int, error) {
	if len(p) == 0 {
		return 0, nil
	}
	var n int
	if enc.bufLen > 0 {
		copy(enc.buf[enc.bufLen:], p)
		c, size := utf8.DecodeRune(enc.buf[:])
		n += size - enc.bufLen
		enc.bufLen = 0
		if _, err := enc.WriteRune(c); err != nil {
			return 0, err
		}
	}
	for utf8.FullRune(p[n:]) {
		c, size := utf8.DecodeRune(p[n:])
		if _, err := enc.WriteRune(c); err != nil {
			return n, err
		}
		n += size
	}
	enc.bufLen = copy(enc.buf[:], p[n:])
	return len(p), nil
}

// WriteString implements io.StringWriter; it does what you'd expect,
// but see the notes on the Write method.
func (enc *ReEncoder) WriteString(p string) (int, error) {
	if len(p) == 0 {
		return 0, nil
	}
	var n int
	if enc.bufLen > 0 {
		copy(enc.buf[enc.bufLen:], p)
		c, size := utf8.DecodeRune(enc.buf[:])
		n += size - enc.bufLen
		enc.bufLen = 0
		if _, err := enc.WriteRune(c); err != nil {
			return 0, err
		}
	}
	for utf8.FullRuneInString(p[n:]) {
		c, size := utf8.DecodeRuneInString(p[n:])
		if _, err := enc.WriteRune(c); err != nil {
			return n, err
		}
		n += size
	}
	enc.bufLen = copy(enc.buf[:], p[n:])
	return len(p), nil
}

// WriteByte implements io.ByteWriter; it does what you'd expect.
func (enc *ReEncoder) WriteByte(b byte) error {
	return internal.WriteByte(enc, b)
}

// Close implements io.Closer; it does what you'd expect, mostly.
//
// The *ReEncoder may continue to be written to with new JSON values
// if enc.AllowMultipleValues is set.
func (enc *ReEncoder) Close() error {
	if enc.bufLen > 0 {
		return &ReEncodeSyntaxError{
			Offset: enc.inputPos,
			Err:    fmt.Errorf("%w: unflushed unicode garbage: %q", io.ErrUnexpectedEOF, enc.buf[:enc.bufLen]),
		}
	}
	if _, err := enc.par.HandleEOF(); err != nil {
		enc.err = &ReEncodeSyntaxError{
			Err:    err,
			Offset: enc.inputPos,
		}
		return enc.err
	}
	if len(enc.barriers) == 0 {
		if err := enc.handleRune(0, internal.RuneTypeError, enc.stackSize()); err != nil {
			enc.err = &ReEncodeSyntaxError{
				Err:    err,
				Offset: enc.inputPos,
			}
			return enc.err
		}
		if enc.AllowMultipleValues {
			enc.par.Reset()
		}
	}
	return nil
}

// WriteRune writes a single Unicode code point, returning the number
// of bytes written to the output stream and any error.
//
// Even when there is no error, the number of bytes written may be
// zero (for example, when the rune is whitespace and the ReEncoder is
// minifying the JSON), or it may be substantially longer than one
// code point's worth (for example, when `\uXXXX` escaping a character
// in a string, or when outputing extra whitespace when the ReEncoder
// is prettifying the JSON).
func (enc *ReEncoder) WriteRune(c rune) (n int, err error) {
	if enc.err != nil {
		return 0, enc.err
	}
	if enc.bufLen > 0 {
		enc.err = fmt.Errorf("lowmemjson.ReEncoder: cannot .WriteRune() when there is a partial rune that has been .Write()en: %q", enc.buf[:enc.bufLen])
		return 0, enc.err
	}

	enc.written = 0

rehandle:
	t, err := enc.par.HandleRune(c)
	if err != nil {
		enc.err = &ReEncodeSyntaxError{
			Err:    err,
			Offset: enc.inputPos,
		}
		return enc.written, enc.err
	}
	enc.err = enc.handleRune(c, t, enc.stackSize())
	if enc.err == nil && t == internal.RuneTypeEOF {
		if enc.AllowMultipleValues && len(enc.barriers) == 0 {
			enc.par.Reset()
			goto rehandle
		} else {
			enc.err = &ReEncodeSyntaxError{
				Err:    fmt.Errorf("invalid character %q after top-level value", c),
				Offset: enc.inputPos,
			}
			return enc.written, enc.err
		}
	}

	enc.inputPos += int64(utf8.RuneLen(c))
	return enc.written, enc.err
}

// semi-public API /////////////////////////////////////////////////////////////

func (enc *ReEncoder) pushWriteBarrier() {
	enc.barriers = append(enc.barriers, barrier{
		inputPos:  enc.inputPos,
		stackSize: enc.stackSize(),
	})
	enc.par.PushWriteBarrier()
	enc.inputPos = 0
}

func (enc *ReEncoder) popWriteBarrier() {
	enc.par.PopBarrier()
	enc.inputPos += enc.barriers[len(enc.barriers)-1].inputPos
	enc.barriers = enc.barriers[:len(enc.barriers)-1]
	enc.lastNonSpace = enc.lastNonSpaceNonEOF
}

// internal ////////////////////////////////////////////////////////////////////

func (enc *ReEncoder) stackSize() int {
	sz := enc.par.StackSize()
	for _, barrier := range enc.barriers {
		sz += barrier.stackSize
	}
	return sz
}

func (enc *ReEncoder) handleRune(c rune, t internal.RuneType, stackSize int) error {
	if enc.CompactIfUnder == 0 || enc.Compact || enc.Indent == "" {
		return enc.handleRuneNoSpeculation(c, t)
	}

	// main
	if !enc.specu.speculating { // not speculating
		switch t {
		case internal.RuneTypeObjectBeg, internal.RuneTypeArrayBeg: // start speculating
			if err, _ := enc.handleRunePre(c, t); err != nil {
				return err
			}
			enc.specu.speculating = true
			enc.specu.endWhenStackSize = stackSize - 1
			enc.specu.fmt = ReEncoder{
				ReEncoderConfig: enc.ReEncoderConfig,
				out:             &enc.specu.compact,
			}
			enc.specu.fmt.Compact = true
			enc.specu.buf = append(enc.specu.buf, inputTuple{
				c:         c,
				t:         t,
				stackSize: stackSize,
			})
			if err := enc.specu.fmt.handleRuneMain(c, t); err != nil {
				return err
			}
		default:
			if err := enc.handleRuneNoSpeculation(c, t); err != nil {
				return err
			}
		}
	} else { // speculating
		enc.specu.buf = append(enc.specu.buf, inputTuple{
			c:         c,
			t:         t,
			stackSize: stackSize,
		})
		if err := enc.specu.fmt.handleRune(c, t, stackSize); err != nil {
			return err
		}
		switch {
		case enc.specu.compact.Len() >= enc.CompactIfUnder: // stop speculating; use indent
			buf := append([]inputTuple(nil), enc.specu.buf...)
			enc.specu.Reset()
			if err := enc.handleRuneMain(buf[0].c, buf[0].t); err != nil {
				return err
			}
			for _, tuple := range buf[1:] {
				if err := enc.handleRune(tuple.c, tuple.t, tuple.stackSize); err != nil {
					return err
				}
			}
		case stackSize == enc.specu.endWhenStackSize: // stop speculating; use compact
			if _, err := enc.specu.compact.WriteTo(enc.out); err != nil {
				return err
			}
			enc.specu.Reset()
			enc.lastNonSpace = t
			enc.curIndent--
		}
	}

	return nil
}

func (enc *ReEncoder) handleRuneNoSpeculation(c rune, t internal.RuneType) error {
	err, shouldHandle := enc.handleRunePre(c, t)
	if err != nil {
		return err
	}
	if !shouldHandle {
		return nil
	}
	return enc.handleRuneMain(c, t)
}

// handleRunePre handles buffered things that need to happen before
// the new rune itself is handled.
func (enc *ReEncoder) handleRunePre(c rune, t internal.RuneType) (error, bool) {
	// emit newlines between top-level values
	if enc.lastNonSpace == internal.RuneTypeEOF {
		switch {
		case enc.wasNumber && t.IsNumber():
			if err := enc.emitByte('\n'); err != nil {
				return err, false
			}
		case enc.Indent != "" && !enc.Compact:
			if err := enc.emitByte('\n'); err != nil {
				return err, false
			}
		}
	}

	// shorten numbers
	switch t { // trim trailing '0's from the fraction-part, but don't remove all digits
	case internal.RuneTypeNumberFracDot:
		enc.fracZeros = 0
	case internal.RuneTypeNumberFracDig:
		if c == '0' && enc.lastNonSpace == internal.RuneTypeNumberFracDig {
			enc.fracZeros++
			return nil, false
		}
		fallthrough
	default:
		for enc.fracZeros > 0 {
			if err := enc.emitByte('0'); err != nil {
				return err, false
			}
			enc.fracZeros--
		}
	}
	switch t { // trim leading '0's from the exponent-part, but don't remove all digits
	case internal.RuneTypeNumberExpE, internal.RuneTypeNumberExpSign:
		enc.expZero = true
	case internal.RuneTypeNumberExpDig:
		if c == '0' && enc.expZero {
			return nil, false
		}
		enc.expZero = false
	default:
		if enc.expZero {
			if err := enc.emitByte('0'); err != nil {
				return err, false
			}
			enc.expZero = false
		}
	}

	// whitespace
	switch {
	case enc.Compact:
		if t == internal.RuneTypeSpace {
			return nil, false
		}
	case enc.Indent != "":
		switch t {
		case internal.RuneTypeSpace:
			// let us manage whitespace, don't pass it through
			return nil, false
		case internal.RuneTypeObjectEnd, internal.RuneTypeArrayEnd:
			enc.curIndent--
			switch enc.lastNonSpace {
			case internal.RuneTypeObjectBeg, internal.RuneTypeArrayBeg:
				// collapse
			default:
				if err := enc.emitNlIndent(); err != nil {
					return err, false
				}
			}
		default:
			switch enc.lastNonSpace {
			case internal.RuneTypeObjectBeg, internal.RuneTypeObjectComma, internal.RuneTypeArrayBeg, internal.RuneTypeArrayComma:
				if err := enc.emitNlIndent(); err != nil {
					return err, false
				}
			case internal.RuneTypeObjectColon:
				if err := enc.emitByte(' '); err != nil {
					return err, false
				}
			}
			switch t {
			case internal.RuneTypeObjectBeg, internal.RuneTypeArrayBeg:
				enc.curIndent++
			}
		}
	}

	return nil, true
}

// handleRuneMain handles the new rune itself, not buffered things.
func (enc *ReEncoder) handleRuneMain(c rune, t internal.RuneType) error {
	var err error
	switch t {

	case internal.RuneTypeStringChar:
		err = enc.emit(writeStringChar(enc.out, c, BackslashEscapeNone, enc.BackslashEscape))
	case internal.RuneTypeStringEsc, internal.RuneTypeStringEscU:
		// do nothing
	case internal.RuneTypeStringEsc1:
		switch c {
		case '"':
			err = enc.emit(writeStringChar(enc.out, '"', BackslashEscapeShort, enc.BackslashEscape))
		case '\\':
			err = enc.emit(writeStringChar(enc.out, '\\', BackslashEscapeShort, enc.BackslashEscape))
		case '/':
			err = enc.emit(writeStringChar(enc.out, '/', BackslashEscapeShort, enc.BackslashEscape))
		case 'b':
			err = enc.emit(writeStringChar(enc.out, '\b', BackslashEscapeShort, enc.BackslashEscape))
		case 'f':
			err = enc.emit(writeStringChar(enc.out, '\f', BackslashEscapeShort, enc.BackslashEscape))
		case 'n':
			err = enc.emit(writeStringChar(enc.out, '\n', BackslashEscapeShort, enc.BackslashEscape))
		case 'r':
			err = enc.emit(writeStringChar(enc.out, '\r', BackslashEscapeShort, enc.BackslashEscape))
		case 't':
			err = enc.emit(writeStringChar(enc.out, '\t', BackslashEscapeShort, enc.BackslashEscape))
		default:
			panic("should not happen")
		}
	case internal.RuneTypeStringEscUA:
		enc.uhex[0], _ = internal.HexToInt(c)
	case internal.RuneTypeStringEscUB:
		enc.uhex[1], _ = internal.HexToInt(c)
	case internal.RuneTypeStringEscUC:
		enc.uhex[2], _ = internal.HexToInt(c)
	case internal.RuneTypeStringEscUD:
		enc.uhex[3], _ = internal.HexToInt(c)
		c := 0 |
			rune(enc.uhex[0])<<12 |
			rune(enc.uhex[1])<<8 |
			rune(enc.uhex[2])<<4 |
			rune(enc.uhex[3])<<0
		err = enc.emit(writeStringChar(enc.out, c, BackslashEscapeUnicode, enc.BackslashEscape))

	case internal.RuneTypeError: // EOF explicitly stated by .Close()
		fallthrough
	case internal.RuneTypeEOF: // EOF implied by the start of the next top-level value
		enc.wasNumber = enc.lastNonSpace.IsNumber()
		switch {
		case enc.ForceTrailingNewlines && len(enc.barriers) == 0:
			t = internal.RuneTypeError // enc.lastNonSpace : an NL isn't needed (we already printed one)
			err = enc.emitByte('\n')
		default:
			t = internal.RuneTypeEOF // enc.lastNonSpace : an NL *might* be needed
		}
	default:
		err = enc.emitByte(byte(c))
	}

	if t != internal.RuneTypeSpace {
		enc.lastNonSpace = t
		if t != internal.RuneTypeEOF {
			enc.lastNonSpaceNonEOF = t
		}
	}
	return err
}

func (enc *ReEncoder) emitByte(c byte) error {
	err := enc.out.WriteByte(c)
	if err == nil {
		enc.written++
	}
	return err
}

func (enc *ReEncoder) emit(n int, err error) error {
	enc.written += n
	return err
}

func (enc *ReEncoder) emitNlIndent() error {
	if err := enc.emitByte('\n'); err != nil {
		return err
	}
	if enc.Prefix != "" {
		if err := enc.emit(enc.out.WriteString(enc.Prefix)); err != nil {
			return err
		}
	}
	for i := 0; i < enc.curIndent; i++ {
		if err := enc.emit(enc.out.WriteString(enc.Indent)); err != nil {
			return err
		}
	}
	return nil
}