// Copyright (C) 2022-2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package lowmemjson

import (
	"bytes"
	"encoding"
	"encoding/base64"
	"encoding/json"
	"fmt"
	"io"
	"reflect"
	"sort"
	"strconv"
	"strings"
	"unsafe"

	"git.lukeshu.com/go/lowmemjson/internal/jsonstruct"
)

// Encodable is the interface implemented by types that can encode
// themselves to JSON.  Encodable is a low-memory-overhead replacement
// for the json.Marshaler interface.
//
// The io.Writer passed to EncodeJSON returns an error if invalid JSON
// is written to it.
type Encodable interface {
	EncodeJSON(w io.Writer) error
}

// An Encoder encodes and writes values to a stream of JSON elements.
//
// Encoder is analogous to, and has a similar API to the standar
// library's encoding/json.Encoder.  Differences are that rather than
// having .SetEscapeHTML and .SetIndent methods, the io.Writer passed
// to it may be a *ReEncoder that has these settings (and more).  If
// something more similar to a json.Encoder is desired,
// lowmemjson/compat/json.Encoder offers those .SetEscapeHTML and
// .SetIndent methods.
type Encoder struct {
	w      *ReEncoder
	isRoot bool
}

// NewEncoder returns a new Encoder that writes to w.
//
// If w is an *ReEncoder, then the inner backslash-escaping of
// double-encoded ",string" tagged string values obeys the
// *ReEncoder's BackslashEscape policy.
//
// An Encoder tends to make many small writes; if w.Write calls are
// syscalls, then you may want to wrap w in a bufio.Writer.
func NewEncoder(w io.Writer) *Encoder {
	re, ok := w.(*ReEncoder)
	if !ok {
		re = NewReEncoder(w, ReEncoderConfig{
			AllowMultipleValues: true,
		})
	}
	return &Encoder{
		w:      re,
		isRoot: re.par.StackIsEmpty(),
	}
}

// Encode encodes obj to JSON and writes that JSON to the Encoder's
// output stream.
//
// See the [documentation for encoding/json.Marshal] for details about
// the conversion Go values to JSON; Encode behaves identically to
// that, with the exception that in addition to the json.Marshaler
// interface it also checks for the Encodable interface.
//
// Unlike encoding/json.Encoder.Encode, lowmemjson.Encoder.Encode does
// not buffer its output; if a encode-error is encountered, lowmemjson
// may write partial output, whereas encodin/json would not have
// written anything.
//
// [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.20#Marshal
func (enc *Encoder) Encode(obj any) (err error) {
	if enc.isRoot {
		enc.w.par.Reset()
	}
	if err := encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}); err != nil {
		return err
	}
	if enc.isRoot {
		return enc.w.Close()
	}
	return nil
}

func discardInt(_ int, err error) error {
	return err
}

var (
	encodableType     = reflect.TypeOf((*Encodable)(nil)).Elem()
	jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
	textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)

const startDetectingCyclesAfter = 1000

func encode(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error {
	if !val.IsValid() {
		return discardInt(w.WriteString("null"))
	}
	switch {

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(encodableType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			return discardInt(w.WriteString("null"))
		}
		obj, ok := val.Interface().(Encodable)
		if !ok {
			return discardInt(w.WriteString("null"))
		}
		w.pushWriteBarrier()
		if err := obj.EncodeJSON(w); err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "EncodeJSON",
				Err:        err,
			}
		}
		if err := w.Close(); err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "EncodeJSON",
				Err:        err,
			}
		}
		w.popWriteBarrier()

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(jsonMarshalerType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			return discardInt(w.WriteString("null"))
		}
		obj, ok := val.Interface().(json.Marshaler)
		if !ok {
			return discardInt(w.WriteString("null"))
		}
		dat, err := obj.MarshalJSON()
		if err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}
		}
		w.pushWriteBarrier()
		if _, err := w.Write(dat); err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}
		}
		if err := w.Close(); err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalJSON",
				Err:        err,
			}
		}
		w.popWriteBarrier()

	case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType):
		val = val.Addr()
		fallthrough
	case val.Type().Implements(textMarshalerType):
		if val.Kind() == reflect.Pointer && val.IsNil() {
			return discardInt(w.WriteString("null"))
		}
		obj, ok := val.Interface().(encoding.TextMarshaler)
		if !ok {
			return discardInt(w.WriteString("null"))
		}
		text, err := obj.MarshalText()
		if err != nil {
			return &EncodeMethodError{
				Type:       val.Type(),
				SourceFunc: "MarshalText",
				Err:        err,
			}
		}
		if err := encodeStringFromBytes(w, escaper, text); err != nil {
			return err
		}
	default:
		switch val.Kind() {
		case reflect.Bool:
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
			if val.Bool() {
				if _, err := w.WriteString("true"); err != nil {
					return err
				}
			} else {
				if _, err := w.WriteString("false"); err != nil {
					return err
				}
			}
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
			if _, err := w.WriteString(strconv.FormatInt(val.Int(), 10)); err != nil {
				return err
			}
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
			if _, err := w.WriteString(strconv.FormatUint(val.Uint(), 10)); err != nil {
				return err
			}
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
		case reflect.Float32:
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
			if err := encodeFloat(w, 32, val); err != nil {
				return err
			}
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
		case reflect.Float64:
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
			if err := encodeFloat(w, 64, val); err != nil {
				return err
			}
			if quote {
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			}
		case reflect.String:
			if val.Type() == numberType {
				numStr := val.String()
				if numStr == "" {
					numStr = "0"
				}
				if quote {
					if err := w.WriteByte('"'); err != nil {
						return err
					}
				}
				if _, err := w.WriteString(numStr); err != nil {
					return err
				}
				if quote {
					if err := w.WriteByte('"'); err != nil {
						return err
					}
				}
			} else {
				if quote {
					var buf bytes.Buffer
					if err := encodeStringFromString(&buf, escaper, val.String()); err != nil {
						return err
					}
					if err := encodeStringFromBytes(w, escaper, buf.Bytes()); err != nil {
						return err
					}
				} else {
					if err := encodeStringFromString(w, escaper, val.String()); err != nil {
						return err
					}
				}
			}
		case reflect.Interface:
			if val.IsNil() {
				if _, err := w.WriteString("null"); err != nil {
					return err
				}
			} else {
				if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil {
					return err
				}
			}
		case reflect.Struct:
			if err := w.WriteByte('{'); err != nil {
				return err
			}
			empty := true
			for _, field := range jsonstruct.IndexStruct(val.Type()).ByPos {
				fVal, err := val.FieldByIndexErr(field.Path)
				if err != nil {
					continue
				}
				if field.OmitEmpty && isEmptyValue(fVal) {
					continue
				}
				if !empty {
					if err := w.WriteByte(','); err != nil {
						return err
					}
				}
				empty = false
				if err := encodeStringFromString(w, escaper, field.Name); err != nil {
					return err
				}
				if err := w.WriteByte(':'); err != nil {
					return err
				}
				if err := encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen); err != nil {
					return err
				}
			}
			if err := w.WriteByte('}'); err != nil {
				return err
			}
		case reflect.Map:
			if val.IsNil() {
				return discardInt(w.WriteString("null"))
			}
			if val.Len() == 0 {
				return discardInt(w.WriteString("{}"))
			}
			if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
				ptr := val.UnsafePointer()
				if _, seen := cycleSeen[ptr]; seen {
					return &EncodeValueError{
						Value: val,
						Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
					}
				}
				cycleSeen[ptr] = struct{}{}
				defer delete(cycleSeen, ptr)
			}
			if err := w.WriteByte('{'); err != nil {
				return err
			}

			type kv struct {
				K string
				V reflect.Value
			}
			kvs := make([]kv, val.Len())
			iter := val.MapRange()
			for i := 0; iter.Next(); i++ {
				// TODO: Avoid buffering the map key
				var k strings.Builder
				if err := encode(NewReEncoder(&k, ReEncoderConfig{BackslashEscape: escaper}), iter.Key(), escaper, false, cycleDepth, cycleSeen); err != nil {
					return err
				}
				kStr := k.String()
				if kStr == "null" {
					kStr = `""`
				}
				if !strings.HasPrefix(kStr, `"`) {
					k.Reset()
					if err := encodeStringFromString(&k, escaper, kStr); err != nil {
						return err
					}
					kStr = k.String()
				}
				kvs[i].K = kStr
				kvs[i].V = iter.Value()
			}
			sort.Slice(kvs, func(i, j int) bool {
				return kvs[i].K < kvs[j].K
			})

			for i, kv := range kvs {
				if i > 0 {
					if err := w.WriteByte(','); err != nil {
						return err
					}
				}
				if _, err := w.WriteString(kv.K); err != nil {
					return err
				}
				if err := w.WriteByte(':'); err != nil {
					return err
				}
				if err := encode(w, kv.V, escaper, false, cycleDepth, cycleSeen); err != nil {
					return err
				}
			}
			if err := w.WriteByte('}'); err != nil {
				return err
			}
		case reflect.Slice:
			switch {
			case val.IsNil():
				if _, err := w.WriteString("null"); err != nil {
					return err
				}
			case val.Type().Elem().Kind() == reflect.Uint8 && !(false ||
				val.Type().Elem().Implements(encodableType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(encodableType) ||
				val.Type().Elem().Implements(jsonMarshalerType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) ||
				val.Type().Elem().Implements(textMarshalerType) ||
				reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)):
				if err := w.WriteByte('"'); err != nil {
					return err
				}
				enc := base64.NewEncoder(base64.StdEncoding, w)
				if val.CanConvert(byteSliceType) {
					if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil {
						return err
					}
				} else {
					// TODO: Surely there's a better way.
					for i, n := 0, val.Len(); i < n; i++ {
						var buf [1]byte
						buf[0] = val.Index(i).Convert(byteType).Interface().(byte)
						if _, err := enc.Write(buf[:]); err != nil {
							return err
						}
					}
				}
				if err := enc.Close(); err != nil {
					return err
				}
				if err := w.WriteByte('"'); err != nil {
					return err
				}
			default:
				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
					// For slices, val.UnsafePointer() doesn't return a pointer to the slice header
					// or anything like that, it returns a pointer *to the first element in the
					// slice*.  That means that the pointer isn't enough to uniquely identify the
					// slice!  So we pair the pointer with the length of the slice, which is
					// sufficient.
					ptr := struct {
						ptr unsafe.Pointer
						len int
					}{val.UnsafePointer(), val.Len()}
					if _, seen := cycleSeen[ptr]; seen {
						return &EncodeValueError{
							Value: val,
							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
						}
					}
					cycleSeen[ptr] = struct{}{}
					defer delete(cycleSeen, ptr)
				}
				if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil {
					return err
				}
			}
		case reflect.Array:
			if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil {
				return err
			}
		case reflect.Pointer:
			if val.IsNil() {
				if _, err := w.WriteString("null"); err != nil {
					return err
				}
			} else {
				if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
					ptr := val.UnsafePointer()
					if _, seen := cycleSeen[ptr]; seen {
						return &EncodeValueError{
							Value: val,
							Str:   fmt.Sprintf("encountered a cycle via %s", val.Type()),
						}
					}
					cycleSeen[ptr] = struct{}{}
					defer delete(cycleSeen, ptr)
				}
				if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil {
					return err
				}
			}
		default:
			return &EncodeTypeError{
				Type: val.Type(),
			}
		}
	}
	return nil
}

func encodeArray(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) error {
	if err := w.WriteByte('['); err != nil {
		return err
	}
	n := val.Len()
	for i := 0; i < n; i++ {
		if i > 0 {
			if err := w.WriteByte(','); err != nil {
				return err
			}
		}
		if err := encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen); err != nil {
			return err
		}
	}
	if err := w.WriteByte(']'); err != nil {
		return err
	}
	return nil
}