// Copyright (C) 2022  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package lowmemjson

import (
	"bytes"
	"fmt"
	"io"
	"unicode/utf8"
)

type speculation struct {
	compactFmt ReEncoder
	compactBuf bytes.Buffer
	indentFmt  ReEncoder
	indentBuf  bytes.Buffer
}

// The memory use of a ReEncoder is O( (CompactIfUnder+1)^2 + depth).
type ReEncoder struct {
	Out io.Writer

	AllowMultipleValues bool

	// Whether to minify the JSON.
	//
	// Trims all whitespace, except that it emits a newline
	// between two *number* top-level values (or puts a newline
	// after all top-level values if ForceTrailingNewlines).
	//
	// Trims superflous 0s from numbers.
	Compact bool
	// CompactIfUnder causes the *ReEncoder to behave as if
	// Compact=true for individual elements if doing so would
	// cause that element to be under this number of bytes.
	//
	// Has no affect if Compact is true or Indent is empty.
	//
	// This has O((CompactIfUnder+1)^2) memory overhead, so set
	// with caution.
	CompactIfUnder int
	// String to use to indent; ignored if Compact is true.
	//
	// Newlines are emitted *between* top-level values; a newline is
	// not emitted after the *last* top-level value (unless
	// ForceTrailingNewlines is on).
	Indent string
	// String to put before indents.
	Prefix string
	// Whether to emit a newline after each top-level value.  See
	// the comments on Compact and Indent for discussion of how
	// this is different than the usual behavior.
	ForceTrailingNewlines bool
	// Returns whether a given character in a string should be
	// backslash-escaped.  The bool argument is whether it was
	// \u-escaped in the input.  This does not affect characters
	// that must or must-not be escaped to be valid JSON.
	//
	// If not set, then EscapeDefault is used.
	BackslashEscape BackslashEscaper

	// state: .Write's utf8-decoding buffer
	buf    [utf8.UTFMax]byte
	bufLen int

	// state: .WriteRune
	err      error
	par      Parser
	written  int
	inputPos int64

	// state: .handleRune
	handleRuneState struct {
		lastNonSpace RuneType
		wasNumber    bool
		curIndent    int
		uhex         [4]byte // "\uABCD"-encoded characters in strings
		fracZeros    int64
		expZero      bool

		specu *speculation
	}
}

// public API //////////////////////////////////////////////////////////////////

func (enc *ReEncoder) Write(p []byte) (int, error) {
	if len(p) == 0 {
		return 0, nil
	}
	var n int
	if enc.bufLen > 0 {
		copy(enc.buf[enc.bufLen:], p)
		c, size := utf8.DecodeRune(enc.buf[:])
		n += size - enc.bufLen
		enc.bufLen = 0
		if _, err := enc.WriteRune(c); err != nil {
			return 0, err
		}
	}
	for utf8.FullRune(p[n:]) {
		c, size := utf8.DecodeRune(p[n:])
		if _, err := enc.WriteRune(c); err != nil {
			return n, err
		}
		n += size
	}
	enc.bufLen = copy(enc.buf[:], p[n:])
	return len(p), nil
}

// Close does what you'd expect, mostly.
//
// The *ReEncoder may continue to be written to with new JSON values
// if enc.AllowMultipleValues is set.
func (enc *ReEncoder) Close() error {
	if enc.bufLen > 0 {
		return &ReEncodeSyntaxError{
			Offset: enc.inputPos,
			Err:    fmt.Errorf("%w: unflushed unicode garbage: %q", io.ErrUnexpectedEOF, enc.buf[:enc.bufLen]),
		}
	}
	if _, err := enc.par.HandleEOF(); err != nil {
		enc.err = &ReEncodeSyntaxError{
			Err:    err,
			Offset: enc.inputPos,
		}
		return enc.err
	}
	if err := enc.handleRune(0, RuneTypeError); err != nil {
		enc.err = &ReEncodeSyntaxError{
			Err:    err,
			Offset: enc.inputPos,
		}
		return enc.err
	}
	if enc.AllowMultipleValues {
		enc.par.Reset()
	}
	return nil
}

func (enc *ReEncoder) WriteRune(c rune) (n int, err error) {
	if enc.err != nil {
		return 0, enc.err
	}
	if enc.bufLen > 0 {
		enc.err = fmt.Errorf("lowmemjson.ReEncoder: cannot .WriteRune() when there is a partial rune that has been .Write()en: %q", enc.buf[:enc.bufLen])
		return 0, enc.err
	}

	enc.written = 0

rehandle:
	t, err := enc.par.HandleRune(c)
	if err != nil {
		enc.err = &ReEncodeSyntaxError{
			Err:    err,
			Offset: enc.inputPos,
		}
		return enc.written, enc.err
	}
	enc.err = enc.handleRune(c, t)
	if enc.err == nil && t == RuneTypeEOF {
		if enc.AllowMultipleValues {
			enc.par.Reset()
			goto rehandle
		} else {
			enc.err = &ReEncodeSyntaxError{
				Err:    fmt.Errorf("invalid character %q after top-level value", c),
				Offset: enc.inputPos,
			}
			return enc.written, enc.err
		}
	}

	enc.inputPos += int64(utf8.RuneLen(c))
	return enc.written, enc.err
}

// internal ////////////////////////////////////////////////////////////////////

func (enc *ReEncoder) handleRune(c rune, t RuneType) error {
	if enc.CompactIfUnder == 0 || enc.Compact || enc.Indent == "" {
		return enc.handleRuneNoSpeculation(c, t)
	}

	// main
	if enc.handleRuneState.specu == nil { // not speculating
		switch t {
		case RuneTypeObjectBeg, RuneTypeArrayBeg: // start speculating
			if err, _ := enc.handleRunePre(c, t); err != nil {
				return err
			}
			specu := &speculation{
				compactFmt: *enc,
				indentFmt:  *enc,
			}
			specu.compactFmt.Compact = true
			specu.compactFmt.Out = &specu.compactBuf
			specu.indentFmt.Out = &specu.indentBuf
			enc.handleRuneState.specu = specu
			if err := specu.compactFmt.handleRuneMain(c, t); err != nil {
				return err
			}
			if err := specu.indentFmt.handleRuneMain(c, t); err != nil {
				return err
			}
		default:
			if err := enc.handleRuneNoSpeculation(c, t); err != nil {
				return err
			}
		}
	} else { // speculating

		// conCompress is whether we're 1-up from the leaf;
		// set this *before* the calls to .handleRune.
		canCompress := enc.handleRuneState.specu.indentFmt.handleRuneState.specu == nil

		if err := enc.handleRuneState.specu.compactFmt.handleRune(c, t); err != nil {
			return err
		}
		if err := enc.handleRuneState.specu.indentFmt.handleRune(c, t); err != nil {
			return err
		}

		switch {
		case enc.handleRuneState.specu.compactBuf.Len() >= enc.CompactIfUnder: // stop speculating; use indent
			if _, err := enc.handleRuneState.specu.indentBuf.WriteTo(enc.Out); err != nil {
				return err
			}
			enc.handleRuneState = enc.handleRuneState.specu.indentFmt.handleRuneState
		case canCompress && (t == RuneTypeObjectEnd || t == RuneTypeArrayEnd): // stop speculating; use compact
			if _, err := enc.handleRuneState.specu.compactBuf.WriteTo(enc.Out); err != nil {
				return err
			}
			enc.handleRuneState.lastNonSpace = t
			enc.handleRuneState.curIndent--
			enc.handleRuneState.specu = nil
		}
	}

	return nil
}

func (enc *ReEncoder) handleRuneNoSpeculation(c rune, t RuneType) error {
	err, shouldHandle := enc.handleRunePre(c, t)
	if err != nil {
		return err
	}
	if !shouldHandle {
		return nil
	}
	return enc.handleRuneMain(c, t)
}

// handle buffered things that need to happen before the new rune
// itself is handled.
func (enc *ReEncoder) handleRunePre(c rune, t RuneType) (error, bool) {
	// emit newlines between top-level values
	if enc.handleRuneState.lastNonSpace == RuneTypeEOF {
		switch {
		case enc.handleRuneState.wasNumber && t.IsNumber():
			if err := enc.emitByte('\n'); err != nil {
				return err, false
			}
		case enc.Indent != "" && !enc.Compact:
			if err := enc.emitByte('\n'); err != nil {
				return err, false
			}
		}
	}

	// shorten numbers
	switch t { // trim trailing '0's from the fraction-part, but don't remove all digits
	case RuneTypeNumberFracDot:
		enc.handleRuneState.fracZeros = 0
	case RuneTypeNumberFracDig:
		if c == '0' && enc.handleRuneState.lastNonSpace == RuneTypeNumberFracDig {
			enc.handleRuneState.fracZeros++
			return nil, false
		}
		fallthrough
	default:
		for enc.handleRuneState.fracZeros > 0 {
			if err := enc.emitByte('0'); err != nil {
				return err, false
			}
			enc.handleRuneState.fracZeros--
		}
	}
	switch t { // trim leading '0's from the exponent-part, but don't remove all digits
	case RuneTypeNumberExpE, RuneTypeNumberExpSign:
		enc.handleRuneState.expZero = true
	case RuneTypeNumberExpDig:
		if c == '0' && enc.handleRuneState.expZero {
			return nil, false
		}
		enc.handleRuneState.expZero = false
	default:
		if enc.handleRuneState.expZero {
			if err := enc.emitByte('0'); err != nil {
				return err, false
			}
			enc.handleRuneState.expZero = false
		}
	}

	// whitespace
	switch {
	case enc.Compact:
		if t == RuneTypeSpace {
			return nil, false
		}
	case enc.Indent != "":
		switch t {
		case RuneTypeSpace:
			// let us manage whitespace, don't pass it through
			return nil, false
		case RuneTypeObjectEnd, RuneTypeArrayEnd:
			enc.handleRuneState.curIndent--
			switch enc.handleRuneState.lastNonSpace {
			case RuneTypeObjectBeg, RuneTypeArrayBeg:
				// collapse
			default:
				if err := enc.emitNlIndent(); err != nil {
					return err, false
				}
			}
		default:
			switch enc.handleRuneState.lastNonSpace {
			case RuneTypeObjectBeg, RuneTypeObjectComma, RuneTypeArrayBeg, RuneTypeArrayComma:
				if err := enc.emitNlIndent(); err != nil {
					return err, false
				}
			case RuneTypeObjectColon:
				if err := enc.emitByte(' '); err != nil {
					return err, false
				}
			}
			switch t {
			case RuneTypeObjectBeg, RuneTypeArrayBeg:
				enc.handleRuneState.curIndent++
			}
		}
	}

	return nil, true
}

// handle the new rune itself, not buffered things
func (enc *ReEncoder) handleRuneMain(c rune, t RuneType) error {
	defer func() {
		if t != RuneTypeSpace {
			enc.handleRuneState.lastNonSpace = t
		}
	}()

	switch t {

	case RuneTypeStringChar:
		return enc.emit(writeStringChar(enc.Out, c, BackslashEscapeNone, enc.BackslashEscape))
	case RuneTypeStringEsc, RuneTypeStringEscU:
		return nil
	case RuneTypeStringEsc1:
		switch c {
		case '"':
			return enc.emit(writeStringChar(enc.Out, '"', BackslashEscapeShort, enc.BackslashEscape))
		case '\\':
			return enc.emit(writeStringChar(enc.Out, '\\', BackslashEscapeShort, enc.BackslashEscape))
		case '/':
			return enc.emit(writeStringChar(enc.Out, '/', BackslashEscapeShort, enc.BackslashEscape))
		case 'b':
			return enc.emit(writeStringChar(enc.Out, '\b', BackslashEscapeShort, enc.BackslashEscape))
		case 'f':
			return enc.emit(writeStringChar(enc.Out, '\f', BackslashEscapeShort, enc.BackslashEscape))
		case 'n':
			return enc.emit(writeStringChar(enc.Out, '\n', BackslashEscapeShort, enc.BackslashEscape))
		case 'r':
			return enc.emit(writeStringChar(enc.Out, '\r', BackslashEscapeShort, enc.BackslashEscape))
		case 't':
			return enc.emit(writeStringChar(enc.Out, '\t', BackslashEscapeShort, enc.BackslashEscape))
		default:
			panic("should not happen")
		}
	case RuneTypeStringEscUA:
		enc.handleRuneState.uhex[0], _ = hex2int(c)
		return nil
	case RuneTypeStringEscUB:
		enc.handleRuneState.uhex[1], _ = hex2int(c)
		return nil
	case RuneTypeStringEscUC:
		enc.handleRuneState.uhex[2], _ = hex2int(c)
		return nil
	case RuneTypeStringEscUD:
		enc.handleRuneState.uhex[3], _ = hex2int(c)
		c := 0 |
			rune(enc.handleRuneState.uhex[0])<<12 |
			rune(enc.handleRuneState.uhex[1])<<8 |
			rune(enc.handleRuneState.uhex[2])<<4 |
			rune(enc.handleRuneState.uhex[3])<<0
		return enc.emit(writeStringChar(enc.Out, c, BackslashEscapeUnicode, enc.BackslashEscape))

	case RuneTypeError: // EOF explicitly stated by .Close()
		fallthrough
	case RuneTypeEOF: // EOF implied by the start of the next top-level value
		enc.handleRuneState.wasNumber = enc.handleRuneState.lastNonSpace.IsNumber()
		switch {
		case enc.ForceTrailingNewlines:
			t = RuneTypeError // enc.handleRuneState.lastNonSpace : an NL isn't needed (we already printed one)
			return enc.emitByte('\n')
		default:
			t = RuneTypeEOF // enc.handleRuneState.lastNonSpace : an NL *might* be needed
			return nil
		}
	default:
		return enc.emitByte(byte(c))
	}
}

func (enc *ReEncoder) emitByte(c byte) error {
	err := writeByte(enc.Out, c)
	if err == nil {
		enc.written++
	}
	return err
}

func (enc *ReEncoder) emit(n int, err error) error {
	enc.written += n
	return err
}

func (enc *ReEncoder) emitNlIndent() error {
	if err := enc.emitByte('\n'); err != nil {
		return err
	}
	if enc.Prefix != "" {
		if err := enc.emit(io.WriteString(enc.Out, enc.Prefix)); err != nil {
			return err
		}
	}
	for i := 0; i < enc.handleRuneState.curIndent; i++ {
		if err := enc.emit(io.WriteString(enc.Out, enc.Indent)); err != nil {
			return err
		}
	}
	return nil
}