# lib9p/protogen/c_unmarshal.py - Generate C unmarshal functions
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import typing

import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c_unmarshal"]


def gen_c_unmarshal(versions: set[str], typs: list[idl.UserType]) -> str:
    ret = """
/* unmarshal_* ****************************************************************/

"""
    ret += cutil.macro(
        "#define UNMARSHAL_BYTES(ctx, data_lvalue, len)\n"
        "\tdata_lvalue = (char *)&net_bytes[net_offset];\n"
        "\tnet_offset += len;\n"
    )
    ret += cutil.macro(
        "#define UNMARSHAL_U8LE(ctx, val_lvalue)\n"
        "\tval_lvalue = net_bytes[net_offset];\n"
        "\tnet_offset += 1;\n"
    )
    ret += cutil.macro(
        "#define UNMARSHAL_U16LE(ctx, val_lvalue)\n"
        "\tval_lvalue = uint16le_decode(&net_bytes[net_offset]);\n"
        "\tnet_offset += 2;\n"
    )
    ret += cutil.macro(
        "#define UNMARSHAL_U32LE(ctx, val_lvalue)\n"
        "\tval_lvalue = uint32le_decode(&net_bytes[net_offset]);\n"
        "\tnet_offset += 4;\n"
    )
    ret += cutil.macro(
        "#define UNMARSHAL_U64LE(ctx, val_lvalue)\n"
        "\tval_lvalue = uint64le_decode(&net_bytes[net_offset]);\n"
        "\tnet_offset += 8;\n"
    )

    class IndentLevel(typing.NamedTuple):
        ifdef: bool  # whether this is both `{` and `#if`, or just `{`

    indent_stack: list[IndentLevel]

    def ifdef_lvl() -> int:
        return sum(1 if lvl.ifdef else 0 for lvl in indent_stack)

    def indent_lvl() -> int:
        return len(indent_stack)

    def handle(
        path: idlutil.Path,
    ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]:
        nonlocal ret
        nonlocal indent_stack
        indent_stack_len = len(indent_stack)

        def pop() -> None:
            nonlocal ret
            nonlocal indent_stack
            nonlocal indent_stack_len
            while len(indent_stack) > indent_stack_len:
                if len(indent_stack) == indent_stack_len + 1 and indent_stack[-1].ifdef:
                    break
                ret += f"{'\t'*(indent_lvl()-1)}}}\n"
                if indent_stack.pop().ifdef:
                    ret += cutil.ifdef_pop(ifdef_lvl())

        if not path.elems:
            return idlutil.WalkCmd.KEEP_GOING, pop

        child = path.elems[-1]
        parent = path.elems[-2].typ if len(path.elems) > 1 else path.root
        if child.in_versions < parent.in_versions:
            if line := cutil.ifdef_push(
                ifdef_lvl() + 1, c9util.ver_ifdef(child.in_versions)
            ):
                ret += line
                ret += (
                    f"{'\t'*indent_lvl()}if ({c9util.ver_cond(child.in_versions)}) {{\n"
                )
                indent_stack.append(IndentLevel(ifdef=True))
        if child.cnt:
            if isinstance(child.cnt, int):
                cnt_str = str(child.cnt)
                cnt_typ = "size_t"
            else:
                cnt_str = path.parent().add(child.cnt).c_str("out->")
                cnt_typ = c9util.typename(child.cnt.typ)
            if child.typ.static_size == 1:  # SPECIAL (zerocopy)
                ret += f"{'\t'*indent_lvl()}UNMARSHAL_BYTES(ctx, {path.c_str('out->')[:-3]}, {cnt_str});\n"
                return idlutil.WalkCmd.KEEP_GOING, pop
            ret += f"{'\t'*indent_lvl()}{path.c_str('out->')[:-3]} = extra;\n"
            ret += f"{'\t'*indent_lvl()}extra += sizeof({path.c_str('out->')[:-3]}[0]) * {cnt_str};\n"
            loopdepth = sum(1 for elem in path.elems if elem.cnt)
            loopvar = chr(ord("i") + loopdepth - 1)
            ret += f"{'\t'*indent_lvl()}for ({cnt_typ} {loopvar} = 0; {loopvar} < {cnt_str}; {loopvar}++) {{\n"
            indent_stack.append(IndentLevel(ifdef=False))
        if not isinstance(child.typ, idl.Struct):
            if child.val:
                ret += f"{'\t'*indent_lvl()}net_offset += {child.typ.static_size};\n"
            else:
                ret += f"{'\t'*indent_lvl()}UNMARSHAL_U{child.typ.static_size*8}LE(ctx, {path.c_str('out->')});\n"
        return idlutil.WalkCmd.KEEP_GOING, pop

    for typ in typs:
        if not (
            isinstance(typ, idl.Message) or typ.typname == "stat"
        ):  # SPECIAL (include stat)
            continue
        assert isinstance(typ, idl.Struct)
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"static void unmarshal_{typ.typname}([[gnu::unused]] struct lib9p_ctx *ctx, uint8_t *net_bytes, void *out_buf) {{\n"
        ret += f"\t{c9util.typename(typ)} *out = out_buf;\n"
        ret += "\t[[gnu::unused]] void *extra = &out[1];\n"
        ret += "\tuint32_t net_offset = 0;\n"

        indent_stack = [IndentLevel(ifdef=True)]
        idlutil.walk(typ, handle)
        while len(indent_stack) > 0:
            ret += f"{'\t'*(indent_lvl()-1)}}}\n"
            if indent_stack.pop().ifdef and indent_stack:
                ret += cutil.ifdef_pop(ifdef_lvl())
    ret += cutil.ifdef_pop(0)
    return ret