// Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package lowmemjson

import (
	"bytes"
	"encoding"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	iofs "io/fs"
	"reflect"
	"sort"
	"strconv"
	"strings"
	"unsafe"
)

// Encodable is the interface implemented by types that can encode
// themselves to JSON.  Encodable is a low-memory-overhead replacement
// for the json.Marshaler interface.
//
// The io.Writer passed to EncodeJSON returns an error if invalid JSON
// is written to it.
type Encodable interface {
	EncodeJSON(w io.Writer) error
}

type encodeError struct {
	Err error
}

func encodeWriteByte(w io.Writer, b byte) {
	if err := writeByte(w, b); err != nil {
		panic(encodeError{err})
	}
}

func encodeWriteString(w io.Writer, str string) {
	if _, err := io.WriteString(w, str); err != nil {
		panic(encodeError{err})
	}
}

// An Encoder encodes and writes values to a stream of JSON elements.
//
// Encoder is analogous to, and has a similar API to the standar
// library's encoding/json.Encoder.  Differences are that rather than
// having .SetEscapeHTML and .SetIndent methods, the io.Writer passed
// to it may be a *ReEncoder that has these settings (and more).  If
// something more similar to a json.Encoder is desired,
// lowmemjson/compat/json.Encoder offers those .SetEscapeHTML and
// .SetIndent methods.
type Encoder struct {
	w                *ReEncoder
	closeAfterEncode bool
}

// NewEncoder returns a new Encoder that writes to w.
//
// If w is an *ReEncoder, then the inner backslash-escaping of
// double-encoded ",string" tagged string values obeys the
// *ReEncoder's BackslashEscape policy.
//
// An Encoder tends to make many small writes; if w.Write calls are
// syscalls, then you may want to wrap w in a bufio.Writer.
func NewEncoder(w io.Writer) *Encoder {
	re, ok := w.(*ReEncoder)
	if !ok {
		re = &ReEncoder{
			Out:                 w,
			AllowMultipleValues: true,
		}
	}
	return &Encoder{
		w:                re,
		closeAfterEncode: re.par.StackIsEmpty(),
	}
}

// Encode encodes obj to JSON and writes that JSON to the Encoder's
// output stream.
//
// See the [documentation for encoding/json.Marshal] for details about
// the conversion Go values to JSON; Encode behaves identically to
// that, with the exception that in addition to the json.Marshaler
// interface it also checks for the Encodable interface.
//
// [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.18#Marshal
func (enc *Encoder) Encode(obj any) (err error) {
	defer func() {
		if r := recover(); r != nil {
			if e, ok := r.(encodeError); ok {
				err = e.Err
			} else {
				panic(r)
			}
		}
	}()
	encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{})
	if enc.closeAfterEncode {
		return enc.w.Close()
	}
	return nil
}

var (
	encodableType     = reflect.TypeOf((*Encodable)(nil)).Elem()
	jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
	textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)

const startDetectingCyclesAfter = 1000

func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
	if !val.IsValid() {
		encodeWriteString(w, "null")
		return
	}
	switch {

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(encodableType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			encodeWriteString(w, "null")
			return
		}
		obj, ok := val.Interface().(Encodable)
		if !ok {
			encodeWriteString(w, "null")
			return
		}
		// Use a sub-ReEncoder to check that it's a full element.
		validator := &ReEncoder{Out: w, BackslashEscape: escaper}
		if err := obj.EncodeJSON(validator); err != nil {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "EncodeJSON",
				Err:        err,
			}})
		}
		if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "EncodeJSON",
				Err:        err,
			}})
		}

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(jsonMarshalerType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			encodeWriteString(w, "null")
			return
		}
		obj, ok := val.Interface().(json.Marshaler)
		if !ok {
			encodeWriteString(w, "null")
			return
		}
		dat, err := obj.MarshalJSON()
		if err != nil {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}})
		}
		// Use a sub-ReEncoder to check that it's a full element.
		validator := &ReEncoder{Out: w, BackslashEscape: escaper}
		if _, err := validator.Write(dat); err != nil {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}})
		}
		if err := validator.Close(); err != nil {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}})
		}

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(textMarshalerType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			encodeWriteString(w, "null")
			return
		}
		obj, ok := val.Interface().(encoding.TextMarshaler)
		if !ok {
			encodeWriteString(w, "null")
			return
		}
		text, err := obj.MarshalText()
		if err != nil {
			panic(encodeError{&EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalText",
				Err:        err,
			}})
		}
		encodeStringFromBytes(w, escaper, text)

	default:
		switch val.Kind() {
		case reflect.Bool:
			if quote {
				encodeWriteByte(w, '"')
			}
			if val.Bool() {
				encodeWriteString(w, "true")
			} else {
				encodeWriteString(w, "false")
			}
			if quote {
				encodeWriteByte(w, '"')
			}
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
			if quote {
				encodeWriteByte(w, '"')
			}
			encodeWriteString(w, strconv.FormatInt(val.Int(), 10))
			if quote {
				encodeWriteByte(w, '"')
			}
		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
			if quote {
				encodeWriteByte(w, '"')
			}
			encodeWriteString(w, strconv.FormatUint(val.Uint(), 10))
			if quote {
				encodeWriteByte(w, '"')
			}
		case reflect.Float32, reflect.Float64:
			if quote {
				encodeWriteByte(w, '"')
			}
			encodeTODO(w, val)
			if quote {
				encodeWriteByte(w, '"')
			}
		case reflect.String:
			if val.Type() == numberType {
				numStr := val.String()
				if numStr == "" {
					numStr = "0"
				}
				if quote {
					encodeWriteByte(w, '"')
				}
				encodeWriteString(w, numStr)
				if quote {
					encodeWriteByte(w, '"')
				}
			} else {
				if quote {
					var buf bytes.Buffer
					encodeStringFromString(&buf, escaper, val.String())
					encodeStringFromBytes(w, escaper, buf.Bytes())
				} else {
					encodeStringFromString(w, escaper, val.String())
				}
			}
		case reflect.Interface:
			if val.IsNil() {
				encodeWriteString(w, "null")
			} else {
				encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
			}
		case reflect.Struct:
			encodeWriteByte(w, '{')
			empty := true
			for _, field := range indexStruct(val.Type()).byPos {
				fVal, err := val.FieldByIndexErr(field.Path)
				if err != nil {
					continue
				}
				if field.OmitEmpty && isEmptyValue(fVal) {
					continue
				}
				if !empty {
					encodeWriteByte(w, ',')
				}
				empty = false
				encodeStringFromString(w, escaper, field.Name)
				encodeWriteByte(w, ':')
				encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen)
			}
			encodeWriteByte(w, '}')
		case reflect.Map:
			if val.IsNil() {
				encodeWriteString(w, "null")
				return
			}
			if val.Len() == 0 {
				encodeWriteString(w, "{}")
				return
			}
			if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
				ptr := val.UnsafePointer()
				if _, seen := cycleSeen[ptr]; seen {
					panic(encodeError{&EncodeValueError{
						Value: val,
						Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
					}})
				}
				cycleSeen[ptr] = struct{}{}
				defer delete(cycleSeen, ptr)
			}
			encodeWriteByte(w, '{')

			type kv struct {
				K string
				V reflect.Value
			}
			kvs := make([]kv, val.Len())
			iter := val.MapRange()
			for i := 0; iter.Next(); i++ {
				// TODO: Avoid buffering the map key
				var k strings.Builder
				encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen)
				kStr := k.String()
				if kStr == "null" {
					kStr = `""`
				}
				if !strings.HasPrefix(kStr, `"`) {
					k.Reset()
					encodeStringFromString(&k, escaper, kStr)
					kStr = k.String()
				}
				kvs[i].K = kStr
				kvs[i].V = iter.Value()
			}
			sort.Slice(kvs, func(i, j int) bool {
				return kvs[i].K < kvs[j].K
			})

			for i, kv := range kvs {
				if i > 0 {
					encodeWriteByte(w, ',')
				}
				encodeWriteString(w, kv.K)
				encodeWriteByte(w, ':')
				encode(w, kv.V, escaper, false, cycleDepth, cycleSeen)
			}
			encodeWriteByte(w, '}')
		case reflect.Slice:
			switch {
			case val.IsNil():
				encodeWriteString(w, "null")
			case val.Type().Elem().Kind() == reflect.Uint8 && !(false ||
				val.Type().Elem().Implements(encodableType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(encodableType) ||
				val.Type().Elem().Implements(jsonMarshalerType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) ||
				val.Type().Elem().Implements(textMarshalerType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)):
				encodeWriteByte(w, '"')
				enc := base64.NewEncoder(base64.StdEncoding, w)
				if val.CanConvert(byteSliceType) {
					if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil {
						panic(encodeError{err})
					}
				} else {
					// TODO: Surely there's a better way.
					for i, n := 0, val.Len(); i < n; i++ {
						var buf [1]byte
						buf[0] = val.Index(i).Convert(byteType).Interface().(byte)
						if _, err := enc.Write(buf[:]); err != nil {
							panic(encodeError{err})
						}
					}
				}
				if err := enc.Close(); err != nil {
					panic(encodeError{err})
				}
				encodeWriteByte(w, '"')
			default:
				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
					// For slices, val.UnsafePointer() doesn't return a pointer to the slice header
					// or anything like that, it returns a pointer *to the first element in the
					// slice*.  That means that the pointer isn't enough to uniquely identify the
					// slice!  So we pair the pointer with the length of the slice, which is
					// sufficient.
					ptr := struct {
						ptr unsafe.Pointer
						len int
					}{val.UnsafePointer(), val.Len()}
					if _, seen := cycleSeen[ptr]; seen {
						panic(encodeError{&EncodeValueError{
							Value: val,
							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
						}})
					}
					cycleSeen[ptr] = struct{}{}
					defer delete(cycleSeen, ptr)
				}
				encodeArray(w, val, escaper, cycleDepth, cycleSeen)
			}
		case reflect.Array:
			encodeArray(w, val, escaper, cycleDepth, cycleSeen)
		case reflect.Pointer:
			if val.IsNil() {
				encodeWriteString(w, "null")
			} else {
				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
					ptr := val.UnsafePointer()
					if _, seen := cycleSeen[ptr]; seen {
						panic(encodeError{&EncodeValueError{
							Value: val,
							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
						}})
					}
					cycleSeen[ptr] = struct{}{}
					defer delete(cycleSeen, ptr)
				}
				encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
			}
		default:
			panic(encodeError{&EncodeTypeError{
				Type: val.Type(),
			}})
		}
	}
}

func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) {
	encodeWriteByte(w, '[')
	n := val.Len()
	for i := 0; i < n; i++ {
		if i > 0 {
			encodeWriteByte(w, ',')
		}
		encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen)
	}
	encodeWriteByte(w, ']')
}

func encodeTODO(w io.Writer, val reflect.Value) {
	bs, err := json.Marshal(val.Interface())
	if err != nil {
		panic(encodeError{err})
	}
	if _, err := w.Write(bs); err != nil {
		panic(encodeError{err})
	}
}