# lib9p/protogen/c_validate.py - Generate C validation functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import typing

import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c_validate"]


def should_save_offset(parent: idl.Struct, child: idl.StructMember) -> bool:
    if child.val or child.max or isinstance(child.typ, idl.Bitfield):
        return True
    for sibling in parent.members:
        if sibling.val:
            for tok in sibling.val.tokens:
                if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
                    return True
        if sibling.max:
            for tok in sibling.max.tokens:
                if isinstance(tok, idl.ExprSym) and tok.symname == f"&{child.membname}":
                    return True
    return False


def should_save_end_offset(struct: idl.Struct) -> bool:
    for memb in struct.members:
        if memb.val:
            for tok in memb.val.tokens:
                if isinstance(tok, idl.ExprSym) and tok.symname == "end":
                    return True
        if memb.max:
            for tok in memb.max.tokens:
                if isinstance(tok, idl.ExprSym) and tok.symname == "end":
                    return True
    return False


def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* validate_* *****************************************************************/

"""
    ret += cutil.macro(
        "#define VALIDATE_NET_BYTES(n)\n"
        "\tif (__builtin_add_overflow(net_offset, n, &net_offset))\n"
        "\t\t/* If needed-net-size overflowed uint32_t, then\n"
        "\t\t * there's no way that actual-net-size will live up to\n"
        "\t\t * that.  */\n"
        '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
        "\tif (net_offset > net_size)\n"
        '\t\treturn lib9p_errorf(ctx, LINUX_EBADMSG, "message is too short for content (%"PRIu32" > %"PRIu32") @ %d", net_offset, net_size, __LINE__);\n'
    )
    ret += cutil.macro(
        "#define VALIDATE_NET_UTF8(n)\n"
        "\t{\n"
        "\t\tsize_t len = n;\n"
        "\t\tVALIDATE_NET_BYTES(len);\n"
        "\t\tif (!is_valid_utf8_without_nul(&net_bytes[net_offset-len], len))\n"
        '\t\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message contains invalid UTF-8");\n'
        "\t}\n"
    )
    ret += cutil.macro(
        "#define RESERVE_HOST_BYTES(n)\n"
        "\tif (__builtin_add_overflow(host_size, n, &host_size))\n"
        "\t\t/* If needed-host-size overflowed ssize_t, then there's\n"
        "\t\t * no way that actual-net-size will live up to\n"
        "\t\t * that.  */\n"
        '\t\treturn lib9p_error(ctx, LINUX_EBADMSG, "message is too short for content");\n'
    )

    ret += "#define GET_U8LE(off)  (net_bytes[off])\n"
    ret += "#define GET_U16LE(off) uint16le_decode(&net_bytes[off])\n"
    ret += "#define GET_U32LE(off) uint32le_decode(&net_bytes[off])\n"
    ret += "#define GET_U64LE(off) uint64le_decode(&net_bytes[off])\n"

    ret += "#define LAST_U8LE()  GET_U8LE(net_offset-1)\n"
    ret += "#define LAST_U16LE() GET_U16LE(net_offset-2)\n"
    ret += "#define LAST_U32LE() GET_U32LE(net_offset-4)\n"
    ret += "#define LAST_U64LE() GET_U64LE(net_offset-8)\n"

    class IndentLevel(typing.NamedTuple):
        ifdef: bool  # whether this is both `{` and `#if`, or just `{`

    indent_stack: list[IndentLevel]

    def ifdef_lvl() -> int:
        return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)

    def indent_lvl() -> int:
        return len(indent_stack)

    incr_buf: int

    def incr_flush() -> None:
        nonlocal ret
        nonlocal incr_buf
        if incr_buf:
            ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES({incr_buf});\n"
            incr_buf = 0

    def gen_validate_size(path: idlutil.Path) -> None:
        nonlocal ret
        nonlocal incr_buf
        nonlocal indent_stack

        assert path.elems
        child = path.elems[-1]
        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
        assert isinstance(parent, idl.Struct)

        if child.in_versions < parent.in_versions:
            if line := cutil.ifdef_push(
                ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
            ):
                incr_flush()
                ret += line
                ret += (
                    f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
                )
                indent_stack.append(IndentLevel(ifdef=True))
        if should_save_offset(parent, child):
            ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)} = net_offset + {incr_buf};\n"
        if child.cnt:
            assert child.cnt.typ.static_size
            cnt_path = path.parent().add(child.cnt)
            incr_flush()
            if child.membname == "utf8":  # SPECIAL (string)
                # Yes, this is content-validation and "belongs" in
                # gen_validate_content(), not here.  But it's just
                # easier this way.
                ret += f"{'\t'*indent_lvl()}VALIDATE_NET_UTF8(LAST_U{child.cnt.typ.static_size*8}LE());\n"
                return
            if child.typ.static_size == 1:  # SPECIAL (zerocopy)
                ret += f"{'\t'*indent_lvl()}VALIDATE_NET_BYTES(LAST_U{child.cnt.typ.static_size*8}LE());\n"
                return
            loopdepth = sum(1 for elem in path.elems if elem.cnt)
            loopvar = chr(ord("i") + loopdepth - 1)
            ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0, cnt = LAST_U{child.cnt.typ.static_size*8}LE(); {loopvar} < cnt; {loopvar}++) {{\n"
            indent_stack.append(IndentLevel(ifdef=False))
            ret += f"{'\t'*indent_lvl()}RESERVE_HOST_BYTES(sizeof({c9util.typename(child.typ)}));\n"
        if not isinstance(child.typ, idl.Struct):
            incr_buf += child.typ.static_size

    def gen_validate_content(path: idlutil.Path) -> None:
        nonlocal ret
        nonlocal incr_buf
        nonlocal indent_stack

        assert path.elems
        child = path.elems[-1]
        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
        assert isinstance(parent, idl.Struct)

        def lookup_sym(sym: str) -> str:
            if sym.startswith("&"):
                sym = sym[1:]
            return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"

        if child.val:
            incr_flush()
            assert child.typ.static_size
            nbits = child.typ.static_size * 8
            nbits = child.typ.static_size * 8
            if nbits < 32 and any(
                isinstance(tok, idl.ExprSym)
                and (tok.symname == "end" or tok.symname.startswith("&"))
                for tok in child.val.tokens
            ):
                nbits = 32
            act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
            exp = f"(uint{nbits}_t)({c9util.idl_expr(child.val, lookup_sym)})"
            ret += f"{'\t'*indent_lvl()}if ({act} != {exp})\n"
            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is wrong: actual: %"PRIu{nbits}" != correct:%"PRIu{nbits},\n'
            ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
        if child.max:
            incr_flush()
            assert child.typ.static_size
            nbits = child.typ.static_size * 8
            if nbits < 32 and any(
                isinstance(tok, idl.ExprSym)
                and (tok.symname == "end" or tok.symname.startswith("&"))
                for tok in child.max.tokens
            ):
                nbits = 32
            act = f"(uint{nbits}_t)GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
            exp = f"(uint{nbits}_t)({c9util.idl_expr(child.max, lookup_sym)})"
            ret += f"{'\t'*indent_lvl()}if ({act} > {exp})\n"
            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "{path} value is too large: %"PRIu{nbits}" > %"PRIu{nbits},\n'
            ret += f"{'\t'*(indent_lvl()+2)}{act}, {exp});\n"
        if isinstance(child.typ, idl.Bitfield):
            incr_flush()
            nbytes = child.typ.static_size
            nbits = nbytes * 8
            act = f"GET_U{nbits}LE({lookup_sym(f'&{child.membname}')})"
            ret += f"{'\t'*indent_lvl()}if ({act} & ~{child.typ.typname}_masks[ctx->version])\n"
            ret += f'{"\t"*(indent_lvl()+1)}return lib9p_errorf(ctx, LINUX_EBADMSG, "unknown bits in {child.typ.typname} bitfield: %#0{nbytes*2}"PRIx{nbits},\n'
            ret += f"{'\t'*(indent_lvl()+2)}{act} & ~{child.typ.typname}_masks[ctx->version]);\n"

    def handle(
        path: idlutil.Path,
    ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
        nonlocal ret
        nonlocal incr_buf
        nonlocal indent_stack
        indent_stack_len = len(indent_stack)
        pop_struct = path.elems[-1].typ if path.elems else path.root
        pop_path = path
        pop_indent_stack_len: int

        def pop() -> None:
            nonlocal ret
            nonlocal indent_stack
            nonlocal indent_stack_len
            nonlocal pop_struct
            nonlocal pop_path
            nonlocal pop_indent_stack_len
            if isinstance(pop_struct, idl.Struct):
                while len(indent_stack) > pop_indent_stack_len:
                    incr_flush()
                    ret += f"{'\t'*(indent_lvl()-1)}}}\n"
                    if indent_stack.pop().ifdef:
                        ret += cutil.ifdef_pop(ifdef_lvl())
                parent = pop_struct
                path = pop_path
                if should_save_end_offset(parent):
                    ret += f"{'\t'*indent_lvl()}uint32_t offsetof{''.join('_'+m.membname for m in path.elems)}_end = net_offset + {incr_buf};\n"
                for child in parent.members:
                    gen_validate_content(pop_path.add(child))
            while len(indent_stack) > indent_stack_len:
                if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
                    break
                incr_flush()
                ret += f"{'\t'*(indent_lvl()-1)}}}\n"
                if indent_stack.pop().ifdef:
                    ret += cutil.ifdef_pop(ifdef_lvl())

        if path.elems:
            gen_validate_size(path)

        pop_indent_stack_len = len(indent_stack)

        return idlutil.WalkCmd.KEEP_GOING, pop

    for typ in typs:
        if not (
            isinstance(typ, idl.Message) or typ.typname == "stat"
        ):  # SPECIAL (include stat)
            continue
        assert isinstance(typ, idl.Struct)
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        if typ.typname == "stat":  # SPECIAL (stat)
            ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{\n"
        else:
            ret += f"static ssize_t validate_{typ.typname}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes) {{\n"

        ret += "\tuint32_t net_offset = 0;\n"
        ret += f"\tssize_t host_size = sizeof({c9util.typename(typ)});\n"

        incr_buf = 0
        indent_stack = [IndentLevel(ifdef=True)]
        idlutil.walk(typ, handle)
        while len(indent_stack) > 1:
            incr_flush()
            ret += f"{'\t'*(indent_lvl()-1)}}}\n"
            if indent_stack.pop().ifdef:
                ret += cutil.ifdef_pop(ifdef_lvl())

        incr_flush()
        if typ.typname == "stat":  # SPECIAL (stat)
            ret += "\tif (ret_net_size)\n"
            ret += "\t\t*ret_net_size = net_offset;\n"
        ret += "\treturn (ssize_t)host_size;\n"
        ret += "}\n"
    ret += cutil.ifdef_pop(0)
    return ret