# lib9p/protogen/c_marshal.py - Generate C marshal functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import typing

import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c_marshal"]


def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* marshal_* ******************************************************************/

"""
    ret += cutil.macro(
        "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n"
        "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n"
        "\t\tctx->net_iov_cnt++;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n"
        "\tctx->net_iov_cnt++;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_BYTES(ctx, data, len)\n"
        "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
        "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
        "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n"
        "\tctx->net_copied_size += len;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U8LE(ctx, val)\n"
        "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
        "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
        "\tctx->net_copied[ctx->net_copied_size] = val;\n"
        "\tctx->net_copied_size += 1;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U16LE(ctx, val)\n"
        "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
        "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
        "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
        "\tctx->net_copied_size += 2;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U32LE(ctx, val)\n"
        "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
        "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
        "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
        "\tctx->net_copied_size += 4;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n"
    )
    ret += cutil.macro(
        "#define MARSHAL_U64LE(ctx, val)\n"
        "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n"
        "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n"
        "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n"
        "\tctx->net_copied_size += 8;\n"
        "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n"
    )

    class OffsetExpr:
        static: int
        cond: dict[frozenset[str], "OffsetExpr"]
        rep: list[tuple[idlutil.Path, "OffsetExpr"]]

        def __init__(self) -> None:
            self.static = 0
            self.rep = []
            self.cond = {}

        def add(self, other: "OffsetExpr") -> None:
            self.static += other.static
            self.rep += other.rep
            for k, v in other.cond.items():
                if k in self.cond:
                    self.cond[k].add(v)
                else:
                    self.cond[k] = v

        def gen_c(
            self,
            dsttyp: str,
            dstvar: str,
            root: str,
            indent_depth: int,
            loop_depth: int,
        ) -> str:
            oneline: list[str] = []
            multiline = ""
            if self.static:
                oneline.append(str(self.static))
            for cnt, sub in self.rep:
                if not sub.cond and not sub.rep:
                    if sub.static == 1:
                        oneline.append(cnt.c_str(root))
                    else:
                        oneline.append(f"({cnt.c_str(root)})*{sub.static}")
                    continue
                loopvar = chr(ord("i") + loop_depth)
                multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n"
                multiline += sub.gen_c(
                    "", dstvar, root, indent_depth + 1, loop_depth + 1
                )
                multiline += f"{'\t'*indent_depth}}}\n"
            for vers, sub in self.cond.items():
                multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers))
                multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n"
                multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth)
                multiline += f"{'\t'*indent_depth}}}\n"
                multiline += cutil.ifdef_pop(indent_depth)
            if dsttyp:
                if not oneline:
                    oneline.append("0")
                ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n"
            elif oneline:
                ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n"
            ret += multiline
            return ret

    type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd]

    def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr:
        if not isinstance(typ, idl.Struct):
            assert typ.static_size
            ret = OffsetExpr()
            ret.static = typ.static_size
            return ret

        stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]]

        def pop_root() -> None:
            assert False

        def pop_cond() -> None:
            nonlocal stack
            key = frozenset(stack[-1][0].elems[-1].in_versions)
            if key in stack[-2][1].cond:
                stack[-2][1].cond[key].add(stack[-1][1])
            else:
                stack[-2][1].cond[key] = stack[-1][1]
            stack = stack[:-1]

        def pop_rep() -> None:
            nonlocal stack
            member_path = stack[-1][0]
            member = member_path.elems[-1]
            assert member.cnt
            cnt_path = member_path.parent().add(member.cnt)
            stack[-2][1].rep.append((cnt_path, stack[-1][1]))
            stack = stack[:-1]

        def handle(
            path: idlutil.Path,
        ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]:
            nonlocal recurse

            ret = recurse(path)
            if ret != idlutil.WalkCmd.KEEP_GOING:
                return ret, None

            nonlocal stack
            stack_len = len(stack)

            def pop() -> None:
                nonlocal stack
                nonlocal stack_len
                while len(stack) > stack_len:
                    stack[-1][2]()

            if path.elems:
                child = path.elems[-1]
                parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
                if child.in_versions < parent.in_versions:
                    stack.append((path, OffsetExpr(), pop_cond))
                if child.cnt:
                    stack.append((path, OffsetExpr(), pop_rep))
                if not isinstance(child.typ, idl.Struct):
                    assert child.typ.static_size
                    stack[-1][1].static += child.typ.static_size
            return ret, pop

        stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)]
        idlutil.walk(typ, handle)
        return stack[0][1]

    def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd:
        return idlutil.WalkCmd.KEEP_GOING

    def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]:
        def ret(path: idlutil.Path) -> idlutil.WalkCmd:
            if len(path.elems) == 1 and path.elems[0].membname == name:
                return idlutil.WalkCmd.ABORT
            return idlutil.WalkCmd.KEEP_GOING

        return ret

    for typ in typs:
        if not (
            isinstance(typ, idl.Message) or typ.typname == "stat"
        ):  # SPECIAL (include stat)
            continue
        assert isinstance(typ, idl.Struct)
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n"

        # Pass 1 - check size
        max_size = max(typ.max_size(v) for v in typ.in_versions)

        if max_size > cutil.UINT32_MAX:  # SPECIAL (9P2000.e)
            ret += get_offset_expr(typ, go_to_end).gen_c(
                "uint64_t", "needed_size", "val->", 1, 0
            )
            ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n"
        else:
            ret += get_offset_expr(typ, go_to_end).gen_c(
                "uint32_t", "needed_size", "val->", 1, 0
            )
            ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n"
        if isinstance(typ, idl.Message):  # SPECIAL (disable for stat)
            ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n'
            ret += f'\t\t\t"{typ.typname}",\n'
            ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n'
            ret += "\t\t\tctx->ctx->max_msg_size);\n"
        ret += "\t\treturn true;\n"
        ret += "\t}\n"

        # Pass 2 - write data
        ifdef_depth = 1
        stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)]

        def handle(
            path: idlutil.Path,
        ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
            nonlocal ret
            nonlocal ifdef_depth
            nonlocal stack
            stack_len = len(stack)

            def pop() -> None:
                nonlocal ret
                nonlocal ifdef_depth
                nonlocal stack
                nonlocal stack_len
                while len(stack) > stack_len:
                    ret += f"{'\t'*(len(stack)-1)}}}\n"
                    if stack[-1][1]:
                        ifdef_depth -= 1
                        ret += cutil.ifdef_pop(ifdef_depth)
                    stack = stack[:-1]

            loopdepth = sum(1 for elem in path.elems if elem.cnt)
            struct = path.elems[-1].typ if path.elems else path.root
            if isinstance(struct, idl.Struct):
                offsets: list[str] = []
                for member in struct.members:
                    if not member.val:
                        continue
                    for tok in member.val.tokens:
                        if not isinstance(tok, idl.ExprSym):
                            continue
                        if tok.symname == "end" or tok.symname.startswith("&"):
                            if tok.symname not in offsets:
                                offsets.append(tok.symname)
                for name in offsets:
                    name_prefix = "offsetof_" + "".join(
                        m.membname + "_" for m in path.elems
                    )
                    if name == "end":
                        if not path.elems:
                            nonlocal max_size
                            if max_size > cutil.UINT32_MAX:
                                ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n"
                            else:
                                ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n"
                            continue
                        recurse: OffsetExprRecursion = go_to_end
                    else:
                        assert name.startswith("&")
                        name = name[1:]
                        recurse = go_to_tok(name)
                    expr = get_offset_expr(struct, recurse)
                    expr_prefix = path.c_str("val->", loopdepth)
                    if not expr_prefix.endswith(">"):
                        expr_prefix += "."
                    ret += expr.gen_c(
                        "uint32_t",
                        name_prefix + name,
                        expr_prefix,
                        len(stack),
                        loopdepth,
                    )
            if path.elems:
                child = path.elems[-1]
                parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
                if child.in_versions < parent.in_versions:
                    ret += cutil.ifdef_push(
                        ifdef_depth + 1, c9util.ver_ifdef(child.in_versions)
                    )
                    ifdef_depth += 1
                    ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n"
                    stack.append((path, True))
                if child.cnt:
                    cnt_path = path.parent().add(child.cnt)
                    if child.typ.static_size == 1:  # SPECIAL (zerocopy)
                        if path.root.typname == "stat":  # SPECIAL (stat)
                            ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
                        else:
                            ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n"
                        return idlutil.WalkCmd.KEEP_GOING, pop
                    loopvar = chr(ord("i") + loopdepth - 1)
                    ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n"
                    stack.append((path, False))
                if not isinstance(child.typ, idl.Struct):
                    if child.val:

                        def lookup_sym(sym: str) -> str:
                            nonlocal path
                            if sym.startswith("&"):
                                sym = sym[1:]
                            return (
                                "offsetof_"
                                + "".join(m.membname + "_" for m in path.elems[:-1])
                                + sym
                            )

                        val = c9util.idl_expr(child.val, lookup_sym)
                    else:
                        val = path.c_str("val->")
                    if isinstance(child.typ, idl.Bitfield):
                        val += f" & {child.typ.typname}_masks[ctx->ctx->version]"
                    ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n"
            return idlutil.WalkCmd.KEEP_GOING, pop

        idlutil.walk(typ, handle)
        del handle
        del stack
        del max_size

        ret += "\treturn false;\n"
        ret += "}\n"
    ret += cutil.ifdef_pop(0)
    return ret