# lib9p/protogen/c.py - Generate 9p.generated.c
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import sys

import idl

from . import c9util, c_marshal, c_unmarshal, c_validate, cutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".


# pylint: disable=unused-variable
__all__ = ["gen_c"]


def gen_c(versions: set[str], typs: list[idl.UserType]) -> str:
    cutil.ifdef_init()

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#include <stdbool.h>
#include <stddef.h>   /* for size_t */
#include <inttypes.h> /* for PRI* macros */
#include <string.h>   /* for memset() */

#include <libmisc/assert.h>

#include <lib9p/9p.h>

#include "internal.h"
"""

    # utilities ################################################################
    ret += """
/* utilities ******************************************************************/
"""

    id2typ: dict[int, idl.Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        id2typ[msg.msgid] = msg

    def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str:
        ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n"
        for ver in ["unknown", *sorted(versions)]:
            if ver != "unknown":
                ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
            ret += f"\t[{c9util.ver_enum(ver)}] = {{\n"
            for n in range(*rng):
                xmsg: idl.Message | None = id2typ.get(n, None)
                if xmsg:
                    if ver == "unknown":  # SPECIAL (initialization)
                        if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]:
                            xmsg = None
                    else:
                        if ver not in xmsg.in_versions:
                            xmsg = None
                if xmsg:
                    ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n"
            ret += "\t},\n"
        ret += cutil.ifdef_pop(0)
        ret += "};\n"
        return ret

    for v in sorted(versions):
        ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n"
        ret += (
            f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n"
        )
        ret += "#else\n"
        ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n"
        ret += "#endif\n"
    ret += "\n"
    ret += "/**\n"
    ret += f" * is_ver(ctx, ver) is essentially `(ctx->version == {c9util.Ident('VER_')}##ver)`, but\n"
    ret += f" * compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n"
    ret += " * (because `!CONFIG_9P_ENABLE_##ver`).  This is useful when `||`ing\n"
    ret += " * several version checks together.\n"
    ret += " */\n"
    ret += "#define is_ver(ctx, ver) _is_ver_##ver((ctx)->version)\n"

    # strings ##################################################################
    ret += f"""
/* strings ********************************************************************/

const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{
"""
    for ver in ["unknown", *sorted(versions)]:
        if ver in versions:
            ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
        ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n'
    ret += cutil.ifdef_pop(0)
    ret += "};\n"

    ret += "\n"
    ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n"
    ret += msg_table("msg", "name", "char *const", (0, 0x100, 1))

    # bitmasks #################################################################
    ret += """
/* bitmasks *******************************************************************/
"""
    for typ in typs:
        if not isinstance(typ, idl.Bitfield):
            continue
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
        ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n"
        verwidth = max(len(ver) for ver in versions)
        for ver in sorted(versions):
            ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver}))
            ret += (
                f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b"
                + "".join(
                    (
                        "1"
                        if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD)
                        and ver in bit.in_versions
                        else "0"
                    )
                    for bit in reversed(typ.bits)
                )
                + ",\n"
            )
        ret += cutil.ifdef_pop(1)
        ret += "};\n"
    ret += cutil.ifdef_pop(0)

    # validate_* ###############################################################
    ret += c_validate.gen_c_validate(versions, typs)

    # unmarshal_* ##############################################################
    ret += c_unmarshal.gen_c_unmarshal(versions, typs)

    # marshal_* ################################################################
    ret += c_marshal.gen_c_marshal(versions, typs)

    # function tables ##########################################################
    ret += """
/* function tables ************************************************************/
"""

    ret += "\n"
    ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n"
    rerror = next(typ for typ in typs if typ.typname == "Rerror")
    ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n"  # SPECIAL (initialization)
    for ver in sorted(versions):
        ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
        ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n"
    ret += cutil.ifdef_pop(0)
    ret += "};\n"

    ret += "\n"
    ret += cutil.macro(
        f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
        f"\t\t.basesize  = sizeof(struct {c9util.ident('msg_')}##typ),\n"
        f"\t\t.validate  = validate_##typ,\n"
        f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n"
        f"\t}}\n"
    )
    ret += cutil.macro(
        f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n"
        f"\t\t.marshal   = (_marshal_fn_t)marshal_##typ,\n"
        f"\t}}\n"
    )
    ret += "\n"
    ret += msg_table(
        "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2)
    )
    ret += "\n"
    ret += msg_table(
        "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2)
    )
    ret += "\n"
    ret += msg_table(
        "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2)
    )
    ret += "\n"
    ret += msg_table(
        "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2)
    )

    ret += f"""
LM_FLATTEN ssize_t {c9util.ident('_stat_validate')}(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size) {{
\treturn validate_stat(ctx, net_size, net_bytes, ret_net_size);
}}
LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct lib9p_ctx *ctx, uint8_t *net_bytes, void *out) {{
\tunmarshal_stat(ctx, net_bytes, out);
}}
LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct lib9p_ctx *ctx, struct {c9util.ident('stat')} *val, struct _marshal_ret *ret) {{
\treturn marshal_stat(ctx, val, ret);
}}
"""

    ############################################################################
    return ret