# lib9p/protogen/h.py - Generate 9p.generated.h
#
# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
# SPDX-License-Identifier: AGPL-3.0-or-later

import sys
import typing

import idl

from . import c9util, cutil, idlutil

# This strives to be "general-purpose" in that it just acts on the
# *.9p inputs; but (unfortunately?) there are a few special-cases in
# this script, marked with "SPECIAL".

# pylint: disable=unused-variable
__all__ = ["gen_h"]

# get_buffer_size() ############################################################


class BufferSize(typing.NamedTuple):
    min_size: int  # really just here to sanity-check against typ.min_size(version)
    exp_size: int  # "expected" or max-reasonable size
    max_size: int  # really just here to sanity-check against typ.max_size(version)
    max_copy: int
    max_copy_extra: str
    max_iov: int
    max_iov_extra: str


class TmpBufferSize:
    min_size: int
    exp_size: int
    max_size: int
    max_copy: int
    max_copy_extra: str
    max_iov: int
    max_iov_extra: str

    tmp_starts_with_copy: bool
    tmp_ends_with_copy: bool

    def __init__(self) -> None:
        self.min_size = 0
        self.exp_size = 0
        self.max_size = 0
        self.max_copy = 0
        self.max_copy_extra = ""
        self.max_iov = 0
        self.max_iov_extra = ""
        self.tmp_starts_with_copy = False
        self.tmp_ends_with_copy = False


def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize:
    assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)

    ret = TmpBufferSize()

    if not isinstance(typ, idl.Struct):
        assert typ.static_size
        ret.min_size = typ.static_size
        ret.exp_size = typ.static_size
        ret.max_size = typ.static_size
        ret.max_copy = typ.static_size
        ret.max_iov = 1
        ret.tmp_starts_with_copy = True
        ret.tmp_ends_with_copy = True
        return ret

    def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]:
        nonlocal ret
        if path.elems:
            child = path.elems[-1]
            if version not in child.in_versions:
                return idlutil.WalkCmd.DONT_RECURSE, None
            if child.cnt:
                if child.typ.static_size == 1:  # SPECIAL (zerocopy)
                    ret.max_iov += 1
                    # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
                    ret.exp_size += 27 if child.membname == "utf8" else 8192
                    ret.max_size += child.max_cnt
                    ret.tmp_ends_with_copy = False
                    return idlutil.WalkCmd.DONT_RECURSE, None
                sub = _get_buffer_size(child.typ, version)
                ret.exp_size += sub.exp_size * 16  # HEURISTIC: MAXWELEM
                ret.max_size += sub.max_size * child.max_cnt
                if child.membname == "wname" and path.root.typname in (
                    "Tsread",
                    "Tswrite",
                ):  # SPECIAL (9P2000.e)
                    assert ret.tmp_ends_with_copy
                    assert sub.tmp_starts_with_copy
                    assert not sub.tmp_ends_with_copy
                    ret.max_copy_extra = (
                        f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
                    )
                    ret.max_iov_extra = (
                        f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})"
                    )
                    ret.max_iov -= 1
                else:
                    ret.max_copy += sub.max_copy * child.max_cnt
                    if sub.max_iov == 1 and sub.tmp_starts_with_copy:  # is purely copy
                        ret.max_iov += 1
                    else:  # contains zero-copy segments
                        ret.max_iov += sub.max_iov * child.max_cnt
                    if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy:
                        # we can merge this one
                        ret.max_iov -= 1
                    if (
                        sub.tmp_ends_with_copy
                        and sub.tmp_starts_with_copy
                        and sub.max_iov > 1
                    ):
                        # we can merge these
                        ret.max_iov -= child.max_cnt - 1
                ret.tmp_ends_with_copy = sub.tmp_ends_with_copy
                return idlutil.WalkCmd.DONT_RECURSE, None
            if not isinstance(child.typ, idl.Struct):
                assert child.typ.static_size
                if not ret.tmp_ends_with_copy:
                    if ret.max_size == 0:
                        ret.tmp_starts_with_copy = True
                    ret.max_iov += 1
                    ret.tmp_ends_with_copy = True
                ret.min_size += child.typ.static_size
                ret.exp_size += child.typ.static_size
                ret.max_size += child.typ.static_size
                ret.max_copy += child.typ.static_size
        return idlutil.WalkCmd.KEEP_GOING, None

    idlutil.walk(typ, handle)
    assert ret.min_size == typ.min_size(version)
    assert ret.max_size == typ.max_size(version)
    return ret


def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
    tmp = _get_buffer_size(typ, version)
    return BufferSize(
        min_size=tmp.min_size,
        exp_size=tmp.exp_size,
        max_size=tmp.max_size,
        max_copy=tmp.max_copy,
        max_copy_extra=tmp.max_copy_extra,
        max_iov=tmp.max_iov,
        max_iov_extra=tmp.max_iov_extra,
    )


# Generate .h ##################################################################


def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
    cutil.ifdef_init()

    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */

#ifndef _LIB9P_9P_H_
\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
#endif

#include <stdint.h> /* for uint{{n}}_t types */

#include <libhw/generic/net.h> /* for struct iovec */
"""

    id2typ: dict[int, idl.Message] = {}
    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
        id2typ[msg.msgid] = msg

    ret += """
/* config *********************************************************************/

#include "config.h"
"""
    for ver in sorted(versions):
        ret += "\n"
        ret += f"#ifndef {c9util.ver_ifdef({ver})}\n"
        ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n"
        if ver == "9P2000.e":  # SPECIAL (9P2000.e)
            ret += "#else\n"
            ret += f"\t#if {c9util.ver_ifdef({ver})}\n"
            ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
            ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
            ret += "\t\t#endif\n"
            ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
            ret += "\t#endif\n"
        ret += "#endif\n"

    ret += f"""
/* enum version ***************************************************************/

enum {c9util.ident('version')} {{
"""
    fullversions = ["unknown = 0", *sorted(versions)]
    verwidth = max(len(v) for v in fullversions)
    for ver in fullversions:
        if ver in versions:
            ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
        ret += f"\t{c9util.ver_enum(ver)},"
        ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
    ret += cutil.ifdef_pop(0)
    ret += f"\t{c9util.ver_enum('NUM')},\n"
    ret += "};\n"

    ret += """
/* enum msg_type **************************************************************/

"""
    ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n"
    namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
    for n in range(0x100):
        if n not in id2typ:
            continue
        msg = id2typ[n]
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions))
        ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n"
    ret += cutil.ifdef_pop(0)
    ret += "};\n"

    ret += """
/* payload types **************************************************************/
"""

    def per_version_comment(
        typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str]
    ) -> str:
        lines: dict[str, str] = {}
        for version in sorted(typ.in_versions):
            lines[version] = fn(typ, version)
        if len(set(lines.values())) == 1:
            for _, line in lines.items():
                return f"/* {line} */\n"
            assert False
        else:
            ret = ""
            v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions)
            for version, line in lines.items():
                ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n"
            return ret

    for typ in idlutil.topo_sorted(typs):
        ret += "\n"
        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))

        def sum_size(typ: idl.UserType, version: str) -> str:
            sz = get_buffer_size(typ, version)
            assert (
                sz.min_size <= sz.exp_size
                and sz.exp_size <= sz.max_size
                and sz.max_size < cutil.UINT64_MAX
            )
            ret = ""
            if sz.min_size == sz.max_size:
                ret += f"size = {sz.min_size:,}"
            else:
                ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}"
            if sz.max_size > cutil.UINT32_MAX:
                ret += " (warning: >UINT32_MAX)"
            ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}"
            return ret

        ret += per_version_comment(typ, sum_size)

        match typ:
            case idl.Number():
                ret += gen_number(typ)
            case idl.Bitfield():
                ret += gen_bitfield(typ)
            case idl.Struct():  # and idl.Message():
                ret += gen_struct(typ)
    ret += cutil.ifdef_pop(0)

    ret += """
/* containers *****************************************************************/
"""
    ret += "\n"
    ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n"

    tmsg_max_iov: dict[str, int] = {}
    tmsg_max_copy: dict[str, int] = {}
    rmsg_max_iov: dict[str, int] = {}
    rmsg_max_copy: dict[str, int] = {}
    for typ in typs:
        if not isinstance(typ, idl.Message):
            continue
        if typ.typname in ("Tsread", "Tswrite"):  # SPECIAL (9P2000.e)
            continue
        max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
        max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
        for version in typ.in_versions:
            if version not in max_iov:
                max_iov[version] = 0
                max_copy[version] = 0
            sz = get_buffer_size(typ, version)
            if sz.max_iov > max_iov[version]:
                max_iov[version] = sz.max_iov
            if sz.max_copy > max_copy[version]:
                max_copy[version] = sz.max_copy

    for name, table in [
        ("tmsg_max_iov", tmsg_max_iov),
        ("tmsg_max_copy", tmsg_max_copy),
        ("rmsg_max_iov", rmsg_max_iov),
        ("rmsg_max_copy", rmsg_max_copy),
    ]:
        inv: dict[int, set[str]] = {}
        for version, maxval in table.items():
            if maxval not in inv:
                inv[maxval] = set()
            inv[maxval].add(version)

        ret += "\n"
        directive = "if"
        seen_e = False  # SPECIAL (9P2000.e)
        for maxval in sorted(inv, reverse=True):
            ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n"
            indent = 1
            if name.startswith("tmsg") and not seen_e:  # SPECIAL (9P2000.e)
                typ = next(typ for typ in typs if typ.typname == "Tswrite")
                sz = get_buffer_size(typ, "9P2000.e")
                match name:
                    case "tmsg_max_iov":
                        maxexpr = f"{sz.max_iov}{sz.max_iov_extra}"
                    case "tmsg_max_copy":
                        maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
                    case _:
                        assert False
                ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n"
                ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n"
                ret += "\t#else\n"
                indent += 1
            ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n"
            if name.startswith("tmsg") and not seen_e:  # SPECIAL (9P2000.e)
                ret += "\t#endif\n"
                if "9P2000.e" in inv[maxval]:
                    seen_e = True
            directive = "elif"
        ret += "#endif\n"

    ret += "\n"
    ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n"
    ret += "\tsize_t          iov_cnt;\n"
    ret += f"\tstruct iovec    iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n"
    ret += f"\tuint8_t         copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n"
    ret += "};\n"

    ret += "\n"
    ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n"
    ret += "\tsize_t          iov_cnt;\n"
    ret += f"\tstruct iovec    iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n"
    ret += f"\tuint8_t         copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n"
    ret += "};\n"

    return ret


def gen_number(typ: idl.Number) -> str:
    ret = f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"

    def lookup_sym(sym: str) -> str:
        assert False

    def cname(base: str) -> str:
        prefix = f"{typ.typname}_".upper()
        return c9util.Ident(c9util.add_prefix(prefix, base))

    namewidth = max(len(cname(name)) for name in typ.vals)
    for name, val in typ.vals.items():
        c_name = cname(name)
        c_val = c9util.idl_expr(val, lookup_sym)
        ret += f"#define {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val}))\n"
    return ret


def gen_bitfield(typ: idl.Bitfield) -> str:
    ret = f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"

    def lookup_sym(sym: str) -> str:
        assert False

    # There are 4 parts here: bits, aliases, masks, and numbers.

    # 1. bits

    def bitname(bit: idl.Bit) -> str:
        prefix = f"{typ.typname}_".upper()
        base = bit.bitname
        match bit:
            case idl.Bit(cat="RESERVED"):
                base = "_RESERVED_" + base
            case idl.Bit(cat=idl.BitNum()):
                base += "_*"
            case idl.Bit(cat="UNUSED"):
                base = f"_UNUSED_{bit.num}"
        return c9util.Ident(c9util.add_prefix(prefix, base))

    namewidth = max(len(bitname(bit)) for bit in typ.bits)

    ret += "/* bits */\n"
    for bit in reversed(typ.bits):
        vers = bit.in_versions
        if bit.cat == "UNUSED":
            vers = typ.in_versions
        ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers))

        # It is important all of the `beg` strings have
        # the same length.
        end = ""
        match bit.cat:
            case "USED" | "RESERVED" | "UNUSED":
                if cutil.ifdef_leaf_is_noop():
                    beg = "#define  "
                else:
                    beg = "#  define"
            case idl.BitNum():
                beg = "/* number"
                end = " */"

        c_name = bitname(bit)
        c_val = f"UINT{typ.static_size*8}_C(1)<<{bit.num}"
        ret += (
            f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"
        )
    ret += cutil.ifdef_pop(1)

    # 2. aliases
    if typ.aliases:

        def aliasname(alias: idl.BitAlias) -> str:
            prefix = f"{typ.typname}_".upper()
            base = alias.bitname
            return c9util.Ident(c9util.add_prefix(prefix, base))

        ret += "/* aliases */\n"
        for alias in typ.aliases.values():
            ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions))

            end = ""
            if cutil.ifdef_leaf_is_noop():
                beg = "#define  "
            else:
                beg = "#  define"

            c_name = aliasname(alias)
            c_val = c9util.idl_expr(alias.val, lookup_sym)
            ret += f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"

        ret += cutil.ifdef_pop(1)

    # 3. masks
    if typ.masks:

        def maskname(mask: idl.BitAlias) -> str:
            prefix = f"{typ.typname}_".upper()
            base = mask.bitname
            return c9util.Ident(c9util.add_prefix(prefix, base) + "_MASK")

        ret += "/* masks */\n"
        for mask in typ.masks.values():
            ret += cutil.ifdef_push(2, c9util.ver_ifdef(mask.in_versions))

            end = ""
            if cutil.ifdef_leaf_is_noop():
                beg = "#define  "
            else:
                beg = "#  define"

            c_name = maskname(mask)
            c_val = c9util.idl_expr(mask.val, lookup_sym, bitwidth=typ.static_size * 8)
            ret += f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"

        ret += cutil.ifdef_pop(1)

    # 4. numbers
    def numname(num: idl.BitNum, base: str) -> str:
        prefix = f"{typ.typname}_{num.numname}_".upper()
        return c9util.Ident(c9util.add_prefix(prefix, base))

    for num in typ.nums.values():
        namewidth = max(
            len(numname(num, base))
            for base in [
                *[alias.bitname for alias in num.vals.values()],
                "MASK",
            ]
        )
        ret += f"/* number: {num.numname} */\n"
        for alias in num.vals.values():
            ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions))

            end = ""
            if cutil.ifdef_leaf_is_noop():
                beg = "#define  "
            else:
                beg = "#  define"

            c_name = numname(num, alias.bitname)
            c_val = c9util.idl_expr(alias.val, lookup_sym)
            ret += f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"
        ret += cutil.ifdef_pop(1)
        c_name = numname(num, "MASK")
        c_val = f"{num.mask:#0{typ.static_size*8}b}"
        ret += (
            f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"
        )

    return ret


def gen_struct(typ: idl.Struct) -> str:  # and idl.Message
    ret = c9util.typename(typ) + " {"
    if typ.members:
        ret += "\n"

        typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members)

        for member in typ.members:
            if member.val:
                continue
            ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
            ret += f"\t{c9util.typename(member.typ, member):<{typewidth}}  {'*' if member.cnt else ' '}{member.membname};\n"
        ret += cutil.ifdef_pop(1)
    ret += "};\n"
    return ret