# lib9p/protogen/c_marshal.py - Generate C marshal functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import typing

import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c_marshal"]

# get_offset_expr() ############################################################


class OffsetExpr:
    static: int
    cond: dict[frozenset[str], "OffsetExpr"]
    rep: list[tuple[idlutil.Path, "OffsetExpr"]]

    def __init__(self) -> None:
        self.static = 0
        self.rep = []
        self.cond = {}

    def add(self, other: "OffsetExpr") -> None:
        self.static += other.static
        self.rep += other.rep
        for k, v in other.cond.items():
            if k in self.cond:
                self.cond[k].add(v)
            else:
                self.cond[k] = v

    def gen_c(
        self,
        dsttyp: str,
        dstvar: str,
        root: str,
        indent_depth: int,
        loop_depth: int,
    ) -> str:
        oneline: list[str] = []
        multiline = ""
        if self.static:
            oneline.append(str(self.static))
        for cnt, sub in self.rep:
            if not sub.cond and not sub.rep:
                if sub.static == 1:
                    oneline.append(cnt.c_str(root))
                else:
                    oneline.append(f"({cnt.c_str(root)})*{sub.static}")
                continue
            loopvar = chr(ord("i") + loop_depth)
            multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
            multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth + 1)
            multiline += f"{'\t'*indent_depth}}}\n"
        for vers, sub in self.cond.items():
            multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
            multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
            multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
            multiline += f"{'\t'*indent_depth}}}\n"
            multiline += cutil.ifdef_pop(indent_depth)
        ret = ""
        if dsttyp:
            if not oneline:
                oneline.append("0")
            ret += f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
        elif oneline:
            ret += f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
        ret += multiline
        return ret


type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]


def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
    if not isinstance(typ, idl.Struct):
        assert typ.static_size
        ret = OffsetExpr()
        ret.static = typ.static_size
        return ret

    class ExprStackItem(typing.NamedTuple):
        path: idlutil.Path
        expr: OffsetExpr
        pop: typing.Callable[[], None]

    expr_stack: list[ExprStackItem]

    def pop_root() -> None:
        assert False

    def pop_cond() -> None:
        nonlocal expr_stack
        key = frozenset(expr_stack[-1].path.elems[-1].in_versions)
        if key in expr_stack[-2].expr.cond:
            expr_stack[-2].expr.cond[key].add(expr_stack[-1].expr)
        else:
            expr_stack[-2].expr.cond[key] = expr_stack[-1].expr
        expr_stack = expr_stack[:-1]

    def pop_rep() -> None:
        nonlocal expr_stack
        member_path = expr_stack[-1].path
        member = member_path.elems[-1]
        assert member.cnt
        cnt_path = member_path.parent().add(member.cnt)
        expr_stack[-2].expr.rep.append((cnt_path, expr_stack[-1].expr))
        expr_stack = expr_stack[:-1]

    def handle(
        path: idlutil.Path,
    ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
        nonlocal recurse

        ret = recurse(path)
        if ret != idlutil.WalkCmd.KEEP_GOING:
            return ret, None

        nonlocal expr_stack
        expr_stack_len = len(expr_stack)

        def pop() -> None:
            nonlocal expr_stack
            nonlocal expr_stack_len
            while len(expr_stack) > expr_stack_len:
                expr_stack[-1].pop()

        if path.elems:
            child = path.elems[-1]
            parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
            if child.in_versions < parent.in_versions:
                expr_stack.append(
                    ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_cond)
                )
            if child.cnt:
                expr_stack.append(
                    ExprStackItem(path=path, expr=OffsetExpr(), pop=pop_rep)
                )
            if not isinstance(child.typ, idl.Struct):
                assert child.typ.static_size
                expr_stack[-1].expr.static += child.typ.static_size
        return ret, pop

    expr_stack = [
        ExprStackItem(path=idlutil.Path(typ), expr=OffsetExpr(), pop=pop_root)
    ]
    idlutil.walk(typ, handle)
    return expr_stack[0].expr


def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
    return idlutil.WalkCmd.KEEP_GOING


def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
    def ret(path: idlutil.Path) -> idlutil.WalkCmd:
        if len(path.elems) == 1 and path.elems[0].membname == name:
            return idlutil.WalkCmd.ABORT
        return idlutil.WalkCmd.KEEP_GOING

    return ret


# Generate .c ##################################################################


def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* marshal_* ******************************************************************/

"""
    ret += cutil.macro(
        "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
        "\tif (ret->net_iov[ret->net_iov_cnt-1].iov_len)\n"
        "\t\tret->net_iov_cnt++;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_base = data;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len = len;\n"
        "\tret->net_iov_cnt++;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_BYTES(ctx, data, len)\n"
        "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
        "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
        "\tmemcpy(&ret->net_copied[ret->net_copied_size], data, len);\n"
        "\tret->net_copied_size += len;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len += len;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U8LE(ctx, val)\n"
        "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
        "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
        "\tret->net_copied[ret->net_copied_size] = val;\n"
        "\tret->net_copied_size += 1;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 1;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U16LE(ctx, val)\n"
        "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
        "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
        "\tuint16le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
        "\tret->net_copied_size += 2;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 2;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U32LE(ctx, val)\n"
        "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
        "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
        "\tuint32le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
        "\tret->net_copied_size += 4;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 4;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U64LE(ctx, val)\n"
        "\tif (!ret->net_iov[ret->net_iov_cnt-1].iov_base)\n"
        "\t\tret->net_iov[ret->net_iov_cnt-1].iov_base = &ret->net_copied[ret->net_copied_size];\n"
        "\tuint64le_encode(&ret->net_copied[ret->net_copied_size], val);\n"
        "\tret->net_copied_size += 8;\n"
        "\tret->net_iov[ret->net_iov_cnt-1].iov_len += 8;\n"
    )

    class IndentLevel(typing.NamedTuple):
        ifdef: bool  # whether this is both `{` and `#if`, or just `{`

    indent_stack: list[IndentLevel]

    def ifdef_lvl() -> int:
        return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)

    def indent_lvl() -> int:
        return len(indent_stack)

    max_size: int

    def handle(
        path: idlutil.Path,
    ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
        nonlocal ret
        nonlocal indent_stack
        nonlocal max_size
        indent_stack_len = len(indent_stack)

        def pop() -> None:
            nonlocal ret
            nonlocal indent_stack
            nonlocal indent_stack_len
            while len(indent_stack) > indent_stack_len:
                if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
                    break
                ret += f"{'\t'*(indent_lvl()-1)}}}\n"
                if indent_stack.pop().ifdef:
                    ret += cutil.ifdef_pop(ifdef_lvl())

        loopdepth = sum(1 for elem in path.elems if elem.cnt)
        struct = path.elems[-1].typ if path.elems else path.root
        if isinstance(struct, idl.Struct):
            offsets: list[str] = []
            for member in struct.members:
                if not member.val:
                    continue
                for tok in member.val.tokens:
                    if not isinstance(tok, idl.ExprSym):
                        continue
                    if tok.symname == "end" or tok.symname.startswith("&"):
                        if tok.symname not in offsets:
                            offsets.append(tok.symname)
            for name in offsets:
                name_prefix = f"offsetof{''.join('_'+m.membname for m in path.elems)}_"
                if name == "end":
                    if not path.elems:
                        if max_size > cutil.UINT32_MAX:
                            ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
                        else:
                            ret += f"{'\t'*indent_lvl()}uint32_t {name_prefix}end = needed_size;\n"
                        continue
                    recurse: OffsetExprRecursion = go_to_end
                else:
                    assert name.startswith("&")
                    name = name[1:]
                    recurse = go_to_tok(name)
                expr = get_offset_expr(struct, recurse)
                expr_prefix = path.c_str("val->", loopdepth)
                if not expr_prefix.endswith(">"):
                    expr_prefix += "."
                ret += expr.gen_c(
                    "uint32_t",
                    name_prefix + name,
                    expr_prefix,
                    indent_lvl(),
                    loopdepth,
                )
        if not path.elems:
            return idlutil.WalkCmd.KEEP_GOING, pop

        child = path.elems[-1]
        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
        if child.in_versions < parent.in_versions:
            if line := cutil.ifdef_push(
                ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
            ):
                ret += line
                ret += (
                    f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
                )
                indent_stack.append(IndentLevel(ifdef=True))
        if child.cnt:
            cnt_path = path.parent().add(child.cnt)
            if child.typ.static_size == 1:  # SPECIAL (zerocopy)
                if path.root.typname == "stat":  # SPECIAL (stat)
                    ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
                else:
                    if (
                        c9util.typename(child.typ, child)
                        == f"struct {c9util.ident("_iovec")}"
                    ):
                        ret += f"{'\t'*indent_lvl()}for (int iov_i = 0; iov_i < {path.c_str('val->')[:-3]}->iovcnt; iov_i++)\n"
                        ret += f"{'\t'*(indent_lvl()+1)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}->iov[iov_i].iov_base, {path.c_str('val->')[:-3]}->iov[iov_i].iov_len);\n"
                    else:
                        ret += f"{'\t'*indent_lvl()}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
                return idlutil.WalkCmd.KEEP_GOING, pop
            loopvar = chr(ord("i") + loopdepth - 1)
            ret += f"{'\t'*indent_lvl()}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
            indent_stack.append(IndentLevel(ifdef=False))
        if not isinstance(child.typ, idl.Struct):
            if child.val:

                def lookup_sym(sym: str) -> str:
                    nonlocal path
                    if sym.startswith("&"):
                        sym = sym[1:]
                    return f"offsetof{''.join('_'+m.membname for m in path.elems[:-1])}_{sym}"

                val = c9util.idl_expr(child.val, lookup_sym)
            else:
                val = path.c_str("val->")
            if isinstance(child.typ, idl.Bitfield):
                val += f" & {child.typ.typname}_masks[ctx->version]"
            ret += f"{'\t'*indent_lvl()}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
        return idlutil.WalkCmd.KEEP_GOING, pop

    for typ in typs:
        if not (
            isinstance(typ, idl.Message) or typ.typname == "stat"
        ):  # SPECIAL (include stat)
            continue
        assert isinstance(typ, idl.Struct)
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"static bool marshal_{typ.typname}(struct lib9p_ctx *ctx, {c9util.typename(typ)} *val, struct _marshal_ret *ret) {{\n"

        # Pass 1 - check size
        max_size = max(typ.max_size(v) for v in typ.in_versions)

        if max_size > cutil.UINT32_MAX:  # SPECIAL (9P2000.e)
            ret += get_offset_expr(typ, go_to_end).gen_c(
                "uint64_t", "needed_size", "val->", 1, 0
            )
            ret += "\tif (needed_size > (uint64_t)(ctx->max_msg_size)) {\n"
        else:
            ret += get_offset_expr(typ, go_to_end).gen_c(
                "uint32_t", "needed_size", "val->", 1, 0
            )
            ret += "\tif (needed_size > ctx->max_msg_size) {\n"
        if isinstance(typ, idl.Message):  # SPECIAL (disable for stat)
            ret += '\t\tlib9p_errorf(ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
            ret += f'\t\t\t"{typ.typname}",\n'
            ret += f'\t\t\tctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
            ret += "\t\t\tctx->max_msg_size);\n"
        ret += "\t\treturn true;\n"
        ret += "\t}\n"

        # Pass 2 - write data
        indent_stack = [IndentLevel(ifdef=True)]
        idlutil.walk(typ, handle)
        while len(indent_stack) > 1:
            ret += f"{'\t'*(indent_lvl()-1)}}}\n"
            if indent_stack.pop().ifdef:
                ret += cutil.ifdef_pop(ifdef_lvl())

        # Return
        ret += "\treturn false;\n"
        ret += "}\n"
    ret += cutil.ifdef_pop(0)
    return ret