From 2a70a611558daa248e4fc1a11a9aa0ceb3ed397a Mon Sep 17 00:00:00 2001
From: "Luke T. Shumaker" <lukeshu@lukeshu.com>
Date: Sun, 23 Mar 2025 02:09:30 -0600
Subject: lib9p: protogen: pull h.py out of __init__.py

---
 lib9p/protogen/h.py | 447 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 447 insertions(+)
 create mode 100644 lib9p/protogen/h.py

(limited to 'lib9p/protogen/h.py')

diff --git a/lib9p/protogen/h.py b/lib9p/protogen/h.py
new file mode 100644
index 0000000..7785ca1
--- /dev/null
+++ b/lib9p/protogen/h.py
@@ -0,0 +1,447 @@
+# lib9p/protogen/h.py - Generate 9p.generated.h
+#
+# Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
+# SPDX-License-Identifier: AGPL-3.0-or-later
+
+import sys
+import typing
+
+import idl
+
+from . import c9util, cutil, idlutil
+
+# This strives to be "general-purpose" in that it just acts on the
+# *.9p inputs; but (unfortunately?) there are a few special-cases in
+# this script, marked with "SPECIAL".
+
+# pylint: disable=unused-variable
+__all__ = ["gen_h"]
+
+# get_buffer_size() ############################################################
+
+
+class BufferSize(typing.NamedTuple):
+    min_size: int  # really just here to sanity-check against typ.min_size(version)
+    exp_size: int  # "expected" or max-reasonable size
+    max_size: int  # really just here to sanity-check against typ.max_size(version)
+    max_copy: int
+    max_copy_extra: str
+    max_iov: int
+    max_iov_extra: str
+
+
+class TmpBufferSize:
+    min_size: int
+    exp_size: int
+    max_size: int
+    max_copy: int
+    max_copy_extra: str
+    max_iov: int
+    max_iov_extra: str
+
+    tmp_starts_with_copy: bool
+    tmp_ends_with_copy: bool
+
+    def __init__(self) -> None:
+        self.min_size = 0
+        self.exp_size = 0
+        self.max_size = 0
+        self.max_copy = 0
+        self.max_copy_extra = ""
+        self.max_iov = 0
+        self.max_iov_extra = ""
+        self.tmp_starts_with_copy = False
+        self.tmp_ends_with_copy = False
+
+
+def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize:
+    assert isinstance(typ, idl.Primitive) or (version in typ.in_versions)
+
+    ret = TmpBufferSize()
+
+    if not isinstance(typ, idl.Struct):
+        assert typ.static_size
+        ret.min_size = typ.static_size
+        ret.exp_size = typ.static_size
+        ret.max_size = typ.static_size
+        ret.max_copy = typ.static_size
+        ret.max_iov = 1
+        ret.tmp_starts_with_copy = True
+        ret.tmp_ends_with_copy = True
+        return ret
+
+    def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]:
+        nonlocal ret
+        if path.elems:
+            child = path.elems[-1]
+            if version not in child.in_versions:
+                return idlutil.WalkCmd.DONT_RECURSE, None
+            if child.cnt:
+                if child.typ.static_size == 1:  # SPECIAL (zerocopy)
+                    ret.max_iov += 1
+                    # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data
+                    ret.exp_size += 27 if child.membname == "utf8" else 8192
+                    ret.max_size += child.max_cnt
+                    ret.tmp_ends_with_copy = False
+                    return idlutil.WalkCmd.DONT_RECURSE, None
+                sub = _get_buffer_size(child.typ, version)
+                ret.exp_size += sub.exp_size * 16  # HEURISTIC: MAXWELEM
+                ret.max_size += sub.max_size * child.max_cnt
+                if child.membname == "wname" and path.root.typname in (
+                    "Tsread",
+                    "Tswrite",
+                ):  # SPECIAL (9P2000.e)
+                    assert ret.tmp_ends_with_copy
+                    assert sub.tmp_starts_with_copy
+                    assert not sub.tmp_ends_with_copy
+                    ret.max_copy_extra = (
+                        f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})"
+                    )
+                    ret.max_iov_extra = (
+                        f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})"
+                    )
+                    ret.max_iov -= 1
+                else:
+                    ret.max_copy += sub.max_copy * child.max_cnt
+                    if sub.max_iov == 1 and sub.tmp_starts_with_copy:  # is purely copy
+                        ret.max_iov += 1
+                    else:  # contains zero-copy segments
+                        ret.max_iov += sub.max_iov * child.max_cnt
+                    if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy:
+                        # we can merge this one
+                        ret.max_iov -= 1
+                    if (
+                        sub.tmp_ends_with_copy
+                        and sub.tmp_starts_with_copy
+                        and sub.max_iov > 1
+                    ):
+                        # we can merge these
+                        ret.max_iov -= child.max_cnt - 1
+                ret.tmp_ends_with_copy = sub.tmp_ends_with_copy
+                return idlutil.WalkCmd.DONT_RECURSE, None
+            if not isinstance(child.typ, idl.Struct):
+                assert child.typ.static_size
+                if not ret.tmp_ends_with_copy:
+                    if ret.max_size == 0:
+                        ret.tmp_starts_with_copy = True
+                    ret.max_iov += 1
+                    ret.tmp_ends_with_copy = True
+                ret.min_size += child.typ.static_size
+                ret.exp_size += child.typ.static_size
+                ret.max_size += child.typ.static_size
+                ret.max_copy += child.typ.static_size
+        return idlutil.WalkCmd.KEEP_GOING, None
+
+    idlutil.walk(typ, handle)
+    assert ret.min_size == typ.min_size(version)
+    assert ret.max_size == typ.max_size(version)
+    return ret
+
+
+def get_buffer_size(typ: idl.Type, version: str) -> BufferSize:
+    tmp = _get_buffer_size(typ, version)
+    return BufferSize(
+        min_size=tmp.min_size,
+        exp_size=tmp.exp_size,
+        max_size=tmp.max_size,
+        max_copy=tmp.max_copy,
+        max_copy_extra=tmp.max_copy_extra,
+        max_iov=tmp.max_iov,
+        max_iov_extra=tmp.max_iov_extra,
+    )
+
+
+# Generate .h ##################################################################
+
+
+def gen_h(versions: set[str], typs: list[idl.UserType]) -> str:
+    cutil.ifdef_init()
+
+    ret = f"""/* Generated by `{' '.join(sys.argv)}`.  DO NOT EDIT!  */
+
+#ifndef _LIB9P_9P_H_
+\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead
+#endif
+
+#include <stdint.h> /* for uint{{n}}_t types */
+
+#include <libhw/generic/net.h> /* for struct iovec */
+"""
+
+    id2typ: dict[int, idl.Message] = {}
+    for msg in [msg for msg in typs if isinstance(msg, idl.Message)]:
+        id2typ[msg.msgid] = msg
+
+    ret += """
+/* config *********************************************************************/
+
+#include "config.h"
+"""
+    for ver in sorted(versions):
+        ret += "\n"
+        ret += f"#ifndef {c9util.ver_ifdef({ver})}\n"
+        ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n"
+        if ver == "9P2000.e":  # SPECIAL (9P2000.e)
+            ret += "#else\n"
+            ret += f"\t#if {c9util.ver_ifdef({ver})}\n"
+            ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n"
+            ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n"
+            ret += "\t\t#endif\n"
+            ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n"
+            ret += "\t#endif\n"
+        ret += "#endif\n"
+
+    ret += f"""
+/* enum version ***************************************************************/
+
+enum {c9util.ident('version')} {{
+"""
+    fullversions = ["unknown = 0", *sorted(versions)]
+    verwidth = max(len(v) for v in fullversions)
+    for ver in fullversions:
+        if ver in versions:
+            ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver}))
+        ret += f"\t{c9util.ver_enum(ver)},"
+        ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n'
+    ret += cutil.ifdef_pop(0)
+    ret += f"\t{c9util.ver_enum('NUM')},\n"
+    ret += "};\n"
+
+    ret += """
+/* enum msg_type **************************************************************/
+
+"""
+    ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n"
+    namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message))
+    for n in range(0x100):
+        if n not in id2typ:
+            continue
+        msg = id2typ[n]
+        ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions))
+        ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n"
+    ret += cutil.ifdef_pop(0)
+    ret += "};\n"
+
+    ret += """
+/* payload types **************************************************************/
+"""
+
+    def per_version_comment(
+        typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str]
+    ) -> str:
+        lines: dict[str, str] = {}
+        for version in sorted(typ.in_versions):
+            lines[version] = fn(typ, version)
+        if len(set(lines.values())) == 1:
+            for _, line in lines.items():
+                return f"/* {line} */\n"
+            assert False
+        else:
+            ret = ""
+            v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions)
+            for version, line in lines.items():
+                ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n"
+            return ret
+
+    for typ in idlutil.topo_sorted(typs):
+        ret += "\n"
+        ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions))
+
+        def sum_size(typ: idl.UserType, version: str) -> str:
+            sz = get_buffer_size(typ, version)
+            assert (
+                sz.min_size <= sz.exp_size
+                and sz.exp_size <= sz.max_size
+                and sz.max_size < cutil.UINT64_MAX
+            )
+            ret = ""
+            if sz.min_size == sz.max_size:
+                ret += f"size = {sz.min_size:,}"
+            else:
+                ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}"
+            if sz.max_size > cutil.UINT32_MAX:
+                ret += " (warning: >UINT32_MAX)"
+            ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}"
+            return ret
+
+        ret += per_version_comment(typ, sum_size)
+
+        match typ:
+            case idl.Number():
+                ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
+                prefix = f"{c9util.IDENT(typ.typname)}_"
+                namewidth = max(len(name) for name in typ.vals)
+                for name, val in typ.vals.items():
+                    ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n"
+            case idl.Bitfield():
+                ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n"
+
+                def bitname(val: idl.Bit | idl.BitAlias) -> str:
+                    s = val.bitname
+                    match val:
+                        case idl.Bit(cat=idl.BitCat.RESERVED):
+                            s = "_RESERVED_" + s
+                        case idl.Bit(cat=idl.BitCat.SUBFIELD):
+                            assert isinstance(typ, idl.Bitfield)
+                            n = sum(
+                                1
+                                for b in typ.bits[: val.num]
+                                if b.cat == idl.BitCat.SUBFIELD
+                                and b.bitname == val.bitname
+                            )
+                            s = f"_{s}_{n}"
+                        case idl.Bit(cat=idl.BitCat.UNUSED):
+                            return ""
+                    return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s))
+
+                namewidth = max(
+                    len(bitname(val)) for val in [*typ.bits, *typ.names.values()]
+                )
+
+                ret += "\n"
+                for bit in reversed(typ.bits):
+                    vers = bit.in_versions
+                    if bit.cat == idl.BitCat.UNUSED:
+                        vers = typ.in_versions
+                    ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers))
+
+                    # It is important all of the `beg` strings have
+                    # the same length.
+                    end = ""
+                    match bit.cat:
+                        case (
+                            idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD
+                        ):
+                            if cutil.ifdef_leaf_is_noop():
+                                beg = "#define  "
+                            else:
+                                beg = "#  define"
+                        case idl.BitCat.UNUSED:
+                            beg = "/* unused"
+                            end = " */"
+
+                    c_name = bitname(bit)
+                    c_val = f"1<<{bit.num}"
+                    ret += f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"
+                if aliases := [
+                    alias
+                    for alias in typ.names.values()
+                    if isinstance(alias, idl.BitAlias)
+                ]:
+                    ret += "\n"
+
+                    for alias in aliases:
+                        ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions))
+
+                        end = ""
+                        if cutil.ifdef_leaf_is_noop():
+                            beg = "#define  "
+                        else:
+                            beg = "#  define"
+
+                        c_name = bitname(alias)
+                        c_val = alias.val
+                        ret += f"{beg} {c_name:<{namewidth}}  (({c9util.typename(typ)})({c_val})){end}\n"
+                ret += cutil.ifdef_pop(1)
+                del bitname
+            case idl.Struct():  # and idl.Message():
+                ret += c9util.typename(typ) + " {"
+                if not typ.members:
+                    ret += "};\n"
+                    continue
+                ret += "\n"
+
+                typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members)
+
+                for member in typ.members:
+                    if member.val:
+                        continue
+                    ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions))
+                    ret += f"\t{c9util.typename(member.typ, member):<{typewidth}}  {'*' if member.cnt else ' '}{member.membname};\n"
+                ret += cutil.ifdef_pop(1)
+                ret += "};\n"
+        del typ
+    ret += cutil.ifdef_pop(0)
+
+    ret += """
+/* containers *****************************************************************/
+"""
+    ret += "\n"
+    ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n"
+
+    tmsg_max_iov: dict[str, int] = {}
+    tmsg_max_copy: dict[str, int] = {}
+    rmsg_max_iov: dict[str, int] = {}
+    rmsg_max_copy: dict[str, int] = {}
+    for typ in typs:
+        if not isinstance(typ, idl.Message):
+            continue
+        if typ.typname in ("Tsread", "Tswrite"):  # SPECIAL (9P2000.e)
+            continue
+        max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov
+        max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy
+        for version in typ.in_versions:
+            if version not in max_iov:
+                max_iov[version] = 0
+                max_copy[version] = 0
+            sz = get_buffer_size(typ, version)
+            if sz.max_iov > max_iov[version]:
+                max_iov[version] = sz.max_iov
+            if sz.max_copy > max_copy[version]:
+                max_copy[version] = sz.max_copy
+
+    for name, table in [
+        ("tmsg_max_iov", tmsg_max_iov),
+        ("tmsg_max_copy", tmsg_max_copy),
+        ("rmsg_max_iov", rmsg_max_iov),
+        ("rmsg_max_copy", rmsg_max_copy),
+    ]:
+        inv: dict[int, set[str]] = {}
+        for version, maxval in table.items():
+            if maxval not in inv:
+                inv[maxval] = set()
+            inv[maxval].add(version)
+
+        ret += "\n"
+        directive = "if"
+        seen_e = False  # SPECIAL (9P2000.e)
+        for maxval in sorted(inv, reverse=True):
+            ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n"
+            indent = 1
+            if name.startswith("tmsg") and not seen_e:  # SPECIAL (9P2000.e)
+                typ = next(typ for typ in typs if typ.typname == "Tswrite")
+                sz = get_buffer_size(typ, "9P2000.e")
+                match name:
+                    case "tmsg_max_iov":
+                        maxexpr = f"{sz.max_iov}{sz.max_iov_extra}"
+                    case "tmsg_max_copy":
+                        maxexpr = f"{sz.max_copy}{sz.max_copy_extra}"
+                    case _:
+                        assert False
+                ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n"
+                ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n"
+                ret += "\t#else\n"
+                indent += 1
+            ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n"
+            if name.startswith("tmsg") and not seen_e:  # SPECIAL (9P2000.e)
+                ret += "\t#endif\n"
+                if "9P2000.e" in inv[maxval]:
+                    seen_e = True
+            directive = "elif"
+        ret += "#endif\n"
+
+    ret += "\n"
+    ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n"
+    ret += "\tsize_t          iov_cnt;\n"
+    ret += f"\tstruct iovec    iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n"
+    ret += f"\tuint8_t         copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n"
+    ret += "};\n"
+
+    ret += "\n"
+    ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n"
+    ret += "\tsize_t          iov_cnt;\n"
+    ret += f"\tstruct iovec    iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n"
+    ret += f"\tuint8_t         copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n"
+    ret += "};\n"
+
+    return ret
-- 
cgit v1.2.3-2-g168b