// Copyright (C) 2022 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package lowmemjson import ( "bufio" "bytes" "encoding" "encoding/json" "fmt" "io" "reflect" "strconv" "strings" ) type Decodable interface { DecodeJSON(io.RuneScanner) error } type runeBuffer interface { io.Writer WriteRune(rune) (int, error) Reset() } type Decoder struct { r io.RuneScanner // config disallowUnknownFields bool useNumber bool // state err error curPos int64 nxtPos int64 stack []any } var forceBufio bool func NewDecoder(r io.Reader) *Decoder { rs, ok := r.(io.RuneScanner) if forceBufio || !ok { rs = bufio.NewReader(r) } return &Decoder{ r: rs, } } func (dec *Decoder) DisallowUnknownFields() { dec.disallowUnknownFields = true } func (dec *Decoder) UseNumber() { dec.useNumber = true } func (dec *Decoder) InputOffset() int64 { return dec.curPos } func (dec *Decoder) More() bool { dec.decodeWS() _, ok := dec.peekRuneOrEOF() return ok } func (dec *Decoder) stackStr() string { var buf strings.Builder buf.WriteString("v") for _, item := range dec.stack { fmt.Fprintf(&buf, "[%#v]", item) } return buf.String() } func (dec *Decoder) stackPush(idx any) { dec.stack = append(dec.stack, idx) } func (dec *Decoder) stackPop() { dec.stack = dec.stack[:len(dec.stack)-1] } type decodeError struct { Err error } func (dec *Decoder) panicIO(err error) { panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w", dec.nxtPos, dec.stackStr(), err)}) } func (dec *Decoder) panicSyntax(err error) { panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w", dec.curPos, dec.stackStr(), err)}) } func (dec *Decoder) panicType(typ reflect.Type, err error) { panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", dec.curPos, dec.stackStr(), typ, err)}) } func Decode(r io.Reader, ptr any) error { return NewDecoder(r).Decode(ptr) } func (dec *Decoder) Decode(ptr any) (err error) { ptrVal := reflect.ValueOf(ptr) if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { return &json.InvalidUnmarshalError{ // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil Type: reflect.TypeOf(ptr), } } if dec.err != nil { return dec.err } defer func() { if r := recover(); r != nil { if de, ok := r.(decodeError); ok { dec.err = de.Err err = dec.err } else { panic(r) } } }() dec.decodeWS() dec.decode(ptrVal.Elem(), false) return nil } func (dec *Decoder) readRune() rune { c, size, err := dec.r.ReadRune() if err != nil { if err == io.EOF { dec.panicSyntax(io.ErrUnexpectedEOF) } dec.panicIO(err) } dec.curPos = dec.nxtPos dec.nxtPos = dec.curPos + int64(size) return c } func (dec *Decoder) readRuneOrEOF() (c rune, ok bool) { c, size, err := dec.r.ReadRune() if err != nil { if err == io.EOF { return 0, false } dec.panicIO(err) } dec.curPos = dec.nxtPos dec.nxtPos = dec.curPos + int64(size) return c, true } func (dec *Decoder) unreadRune() { if err := dec.r.UnreadRune(); err != nil { // .UnreadRune() must succeed if the previous call was // .ReadRune(), which it always is for this code. panic(err) } dec.nxtPos = dec.curPos } func (dec *Decoder) peekRune() rune { c, _, err := dec.r.ReadRune() if err != nil { if err == io.EOF { dec.panicSyntax(io.ErrUnexpectedEOF) } dec.panicIO(err) } if err := dec.r.UnreadRune(); err != nil { // .UnreadRune() must succeed if the previous call was // .ReadRune(), which it always is for this code. panic(err) } return c } func (dec *Decoder) peekRuneOrEOF() (rune, bool) { c, _, err := dec.r.ReadRune() if err != nil { if err == io.EOF { return 0, false } dec.panicIO(err) } if err := dec.r.UnreadRune(); err != nil { // .UnreadRune() must succeed if the previous call was // .ReadRune(), which it always is for this code. panic(err) } return c, true } func (dec *Decoder) expectRune(exp rune) { act := dec.readRune() if act != exp { dec.panicSyntax(fmt.Errorf("expected %q but got %q", exp, act)) } } var ( rawMessagePtrType = reflect.TypeOf((*json.RawMessage)(nil)) decodableType = reflect.TypeOf((*Decodable)(nil)).Elem() jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() ) var kind2bits = map[reflect.Kind]int{ reflect.Int: int(32 << (^uint(0) >> 63)), reflect.Int8: 8, reflect.Int16: 16, reflect.Int32: 32, reflect.Int64: 64, reflect.Uint: int(32 << (^uint(0) >> 63)), reflect.Uint8: 8, reflect.Uint16: 16, reflect.Uint32: 32, reflect.Uint64: 64, reflect.Uintptr: int(32 << (^uintptr(0) >> 63)), reflect.Float32: 32, reflect.Float64: 64, } func (dec *Decoder) decode(val reflect.Value, nullOK bool) { typ := val.Type() switch { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: var buf bytes.Buffer dec.scan(&buf) if err := val.Addr().Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { dec.panicSyntax(err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): obj := val.Addr().Interface().(Decodable) if err := obj.DecodeJSON(dec.r); err != nil { dec.panicSyntax(err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): var buf bytes.Buffer dec.scan(&buf) obj := val.Addr().Interface().(json.Unmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { dec.panicSyntax(err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } var buf bytes.Buffer dec.decodeString(&buf) obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { dec.panicSyntax(err) } default: kind := typ.Kind() switch kind { case reflect.Bool: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } val.SetBool(dec.decodeBool()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { dec.panicSyntax(err) } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { dec.panicSyntax(err) } val.SetUint(n) case reflect.Float32, reflect.Float64: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } var buf strings.Builder dec.scanNumber(&buf) n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { dec.panicSyntax(err) } val.SetFloat(n) case reflect.String: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } var buf strings.Builder if typ == numberType { dec.scanNumber(&buf) val.SetString(buf.String()) } else { dec.decodeString(&buf) val.SetString(buf.String()) } case reflect.Interface: if typ.NumMethod() > 0 { dec.panicType(typ, fmt.Errorf("cannot decode in to non-empty interface")) } switch dec.peekRune() { case 'n': if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer { // XXX: I can't justify this case, other than "it's what encoding/json does, but // I don't understand their rationale". dec.decode(val.Elem(), false) } else { dec.decodeNull() val.Set(reflect.Zero(typ)) } default: if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { dec.decode(val.Elem(), false) } else { val.Set(reflect.ValueOf(dec.decodeAny())) } } case reflect.Struct: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } index := indexStruct(typ) var nameBuf strings.Builder dec.decodeObject(&nameBuf, func() { name := nameBuf.String() dec.stackPush(name) defer dec.stackPop() idx, ok := index.byName[name] if !ok { if dec.disallowUnknownFields { dec.panicType(typ, fmt.Errorf("unknown field %q", name)) } dec.scan(io.Discard) return } field := index.byPos[idx] fVal := val for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { if fVal.IsNil() { if !fVal.CanSet() { // https://golang.org/issue/21357 dec.panicType(fVal.Type().Elem(), fmt.Errorf("cannot set embedded pointer to unexported type")) } fVal.Set(reflect.New(fVal.Type().Elem())) } fVal = fVal.Elem() } fVal = fVal.Field(idx) } if field.Quote { switch dec.peekRune() { case 'n': dec.decodeNull() switch fVal.Kind() { // XXX: I can't justify this list, other than "it's what encoding/json // does, but I don't understand their rationale". case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: fVal.Set(reflect.Zero(fVal.Type())) } case '"': // TODO: Figure out how to do this without buffering. var buf bytes.Buffer subD := *dec // capture the .curPos *before* calling .decodeString dec.decodeString(&buf) subD.r = &buf subD.decode(fVal, false) default: dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q", 'n', '"', dec.peekRune())) } } else { dec.decode(fVal, true) } }) case reflect.Map: switch dec.peekRune() { case 'n': dec.decodeNull() val.Set(reflect.Zero(typ)) case '{': if val.IsNil() { val.Set(reflect.MakeMap(typ)) } var nameBuf bytes.Buffer dec.decodeObject(&nameBuf, func() { nameValTyp := typ.Key() nameValPtr := reflect.New(nameValTyp) switch { case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): obj := nameValPtr.Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { dec.panicSyntax(err) } default: switch nameValTyp.Kind() { case reflect.String: nameValPtr.Elem().SetString(nameBuf.String()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { dec.panicSyntax(err) } nameValPtr.Elem().SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { dec.panicSyntax(err) } nameValPtr.Elem().SetUint(n) default: dec.panicType(typ, fmt.Errorf("invalid map key type: %v", nameValTyp)) } } dec.stackPush(nameValPtr.Elem()) defer dec.stackPop() fValPtr := reflect.New(typ.Elem()) dec.decode(fValPtr.Elem(), false) val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) }) default: dec.panicSyntax(fmt.Errorf("map: expected %q or %q bug got %q", 'n', '{', dec.peekRune())) } case reflect.Slice: switch { case typ.Elem().Kind() == reflect.Uint8: switch dec.peekRune() { case 'n': dec.decodeNull() val.Set(reflect.Zero(typ)) case '"': var buf bytes.Buffer dec.decodeString(newBase64Decoder(&buf)) if typ.Elem() == byteType { val.Set(reflect.ValueOf(buf.Bytes())) } else { bs := buf.Bytes() // TODO: Surely there's a better way. val.Set(reflect.MakeSlice(typ, len(bs), len(bs))) for i := 0; i < len(bs); i++ { val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem())) } } default: dec.panicSyntax(fmt.Errorf("byte slice: expected %q or %q but got %q", 'n', '"', dec.peekRune())) } default: switch dec.peekRune() { case 'n': dec.decodeNull() val.Set(reflect.Zero(typ)) case '[': if val.IsNil() { val.Set(reflect.MakeSlice(typ, 0, 0)) } if val.Len() > 0 { val.Set(val.Slice(0, 0)) } i := 0 dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() mValPtr := reflect.New(typ.Elem()) dec.decode(mValPtr.Elem(), false) val.Set(reflect.Append(val, mValPtr.Elem())) i++ }) default: dec.panicSyntax(fmt.Errorf("slice: expected %q or %q but got %q", 'n', '[', dec.peekRune())) } } case reflect.Array: if nullOK && dec.peekRune() == 'n' { dec.decodeNull() return } i := 0 n := val.Len() dec.decodeArray(func() { dec.stackPush(i) defer dec.stackPop() if i < n { mValPtr := reflect.New(typ.Elem()) dec.decode(mValPtr.Elem(), false) val.Index(i).Set(mValPtr.Elem()) } else { dec.scan(io.Discard) } i++ }) for ; i < n; i++ { val.Index(i).Set(reflect.Zero(typ.Elem())) } case reflect.Pointer: switch dec.peekRune() { case 'n': dec.decodeNull() /* for typ.Elem().Kind() == reflect.Pointer { if val.IsNil() || !val.Elem().CanSet() { val.Set(reflect.New(typ.Elem())) } val = val.Elem() typ = val.Type() } */ val.Set(reflect.Zero(typ)) default: if val.IsNil() { val.Set(reflect.New(typ.Elem())) } dec.decode(val.Elem(), false) } default: dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) } } } func (dec *Decoder) decodeWS() { for { c, ok := dec.readRuneOrEOF() if !ok { return } switch c { // NB: The JSON definition of whitespace is more // narrow than unicode.IsSpace case 0x0020, 0x000A, 0x000D, 0x0009: // do nothing default: dec.unreadRune() return } } } func (dec *Decoder) scan(out io.Writer) { scanner := &ReEncoder{ Out: out, Compact: true, } if _, err := scanner.WriteRune(dec.readRune()); err != nil { dec.panicSyntax(err) } scanner.bailAfterCurrent = true var err error var eof bool for err == nil { c, ok := dec.readRuneOrEOF() if ok { _, err = scanner.WriteRune(c) } else { eof = true err = scanner.Flush() break } } if err != nil { if err == errBailedAfterCurrent { if !eof { dec.unreadRune() } } else { dec.panicSyntax(err) } } } func (dec *Decoder) scanNumber(out io.Writer) { c := dec.peekRune() switch c { case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': dec.scan(out) default: dec.panicSyntax(fmt.Errorf("number: expected %q or a digit, but got %q", '-', c)) } } func (dec *Decoder) decodeAny() any { c := dec.peekRune() switch c { case '{': ret := make(map[string]any) var nameBuf strings.Builder dec.decodeObject(&nameBuf, func() { name := nameBuf.String() dec.stackPush(name) defer dec.stackPop() ret[name] = dec.decodeAny() }) return ret case '[': ret := []any{} dec.decodeArray(func() { dec.stackPush(len(ret)) defer dec.stackPop() ret = append(ret, dec.decodeAny()) }) return ret case '"': var buf strings.Builder dec.decodeString(&buf) return buf.String() case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': var buf strings.Builder dec.scanNumber(&buf) num := json.Number(buf.String()) if dec.useNumber { return num } f64, err := num.Float64() if err != nil { dec.panicSyntax(err) } return f64 case 't', 'f': return dec.decodeBool() case 'n': dec.decodeNull() return nil default: dec.panicSyntax(fmt.Errorf("any: unexpected character: %c", c)) panic("not reached") } } func (dec *Decoder) decodeObject(nameBuf runeBuffer, decodeKVal func()) { dec.expectRune('{') dec.decodeWS() c := dec.readRune() switch c { case '"': decodeMember: dec.unreadRune() nameBuf.Reset() dec.decodeString(nameBuf) dec.decodeWS() dec.expectRune(':') dec.decodeWS() decodeKVal() dec.decodeWS() c := dec.readRune() switch c { case ',': dec.decodeWS() dec.expectRune('"') goto decodeMember case '}': return default: dec.panicSyntax(fmt.Errorf("object: expected %q or %q but got %q", ',', '}', c)) } case '}': return default: dec.panicSyntax(fmt.Errorf("object: expected %q or %q but got %q", '"', '}', c)) } } func (dec *Decoder) decodeArray(decodeMember func()) { dec.expectRune('[') dec.decodeWS() c := dec.readRune() switch c { case ']': return default: dec.unreadRune() decodeNextMember: decodeMember() dec.decodeWS() c := dec.readRune() switch c { case ',': dec.decodeWS() goto decodeNextMember case ']': return default: dec.panicSyntax(fmt.Errorf("array: expected %c or %c but got %c", ',', ']', c)) } } } func (dec *Decoder) decodeHex() rune { c := dec.readRune() switch { case '0' <= c && c <= '9': return c - '0' case 'a' <= c && c <= 'f': return c - 'a' + 10 case 'A' <= c && c <= 'F': return c - 'A' + 10 default: dec.panicSyntax(fmt.Errorf("string: expected a hex digit but got %q", c)) panic("not reached") } } func (dec *Decoder) decodeString(out io.Writer) { dec.expectRune('"') for { c := dec.readRune() switch { case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\': if _, err := writeRune(out, c); err != nil { dec.panicSyntax(err) } case c == '\\': c = dec.readRune() switch c { case '"': if _, err := writeRune(out, '"'); err != nil { dec.panicSyntax(err) } case '\\': if _, err := writeRune(out, '\\'); err != nil { dec.panicSyntax(err) } case '/': if _, err := writeRune(out, '/'); err != nil { dec.panicSyntax(err) } case 'b': if _, err := writeRune(out, '\b'); err != nil { dec.panicSyntax(err) } case 'f': if _, err := writeRune(out, '\f'); err != nil { dec.panicSyntax(err) } case 'n': if _, err := writeRune(out, '\n'); err != nil { dec.panicSyntax(err) } case 'r': if _, err := writeRune(out, '\r'); err != nil { dec.panicSyntax(err) } case 't': if _, err := writeRune(out, '\t'); err != nil { dec.panicSyntax(err) } case 'u': c = dec.decodeHex() c = (c << 4) | dec.decodeHex() c = (c << 4) | dec.decodeHex() c = (c << 4) | dec.decodeHex() if _, err := writeRune(out, c); err != nil { dec.panicSyntax(err) } } case c == '"': return default: dec.panicSyntax(fmt.Errorf("string: unexpected %c", c)) } } } func (dec *Decoder) decodeBool() bool { c := dec.readRune() switch c { case 't': dec.expectRune('r') dec.expectRune('u') dec.expectRune('e') return true case 'f': dec.expectRune('a') dec.expectRune('l') dec.expectRune('s') dec.expectRune('e') return false default: dec.panicSyntax(fmt.Errorf("bool: expected %q or %q but got %q", 't', 'f', c)) panic("not reached") } } func (dec *Decoder) decodeNull() { dec.expectRune('n') dec.expectRune('u') dec.expectRune('l') dec.expectRune('l') }