# lib9p/protogen/h.py - Generate 9p.generated.h # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import sys import typing import idl from . import c9util, cutil, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # pylint: disable=unused-variable __all__ = ["gen_h"] # get_buffer_size() ############################################################ class BufferSize(typing.NamedTuple): min_size: int # really just here to sanity-check against typ.min_size(version) exp_size: int # "expected" or max-reasonable size max_size: int # really just here to sanity-check against typ.max_size(version) max_copy: int max_copy_extra: str max_iov: int max_iov_extra: str class TmpBufferSize: min_size: int exp_size: int max_size: int max_copy: int max_copy_extra: str max_iov: int max_iov_extra: str tmp_starts_with_copy: bool tmp_ends_with_copy: bool def __init__(self) -> None: self.min_size = 0 self.exp_size = 0 self.max_size = 0 self.max_copy = 0 self.max_copy_extra = "" self.max_iov = 0 self.max_iov_extra = "" self.tmp_starts_with_copy = False self.tmp_ends_with_copy = False def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) ret = TmpBufferSize() if not isinstance(typ, idl.Struct): assert typ.static_size ret.min_size = typ.static_size ret.exp_size = typ.static_size ret.max_size = typ.static_size ret.max_copy = typ.static_size ret.max_iov = 1 ret.tmp_starts_with_copy = True ret.tmp_ends_with_copy = True return ret def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]: nonlocal ret if path.elems: child = path.elems[-1] if version not in child.in_versions: return idlutil.WalkCmd.DONT_RECURSE, None if child.cnt: if child.typ.static_size == 1: # SPECIAL (zerocopy) ret.max_iov += 1 # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data ret.exp_size += 27 if child.membname == "utf8" else 8192 ret.max_size += child.max_cnt ret.tmp_ends_with_copy = False return idlutil.WalkCmd.DONT_RECURSE, None sub = _get_buffer_size(child.typ, version) ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM ret.max_size += sub.max_size * child.max_cnt if child.membname == "wname" and path.root.typname in ( "Tsread", "Tswrite", ): # SPECIAL (9P2000.e) assert ret.tmp_ends_with_copy assert sub.tmp_starts_with_copy assert not sub.tmp_ends_with_copy ret.max_copy_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" ) ret.max_iov_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" ) ret.max_iov -= 1 else: ret.max_copy += sub.max_copy * child.max_cnt if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy ret.max_iov += 1 else: # contains zero-copy segments ret.max_iov += sub.max_iov * child.max_cnt if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: # we can merge this one ret.max_iov -= 1 if ( sub.tmp_ends_with_copy and sub.tmp_starts_with_copy and sub.max_iov > 1 ): # we can merge these ret.max_iov -= child.max_cnt - 1 ret.tmp_ends_with_copy = sub.tmp_ends_with_copy return idlutil.WalkCmd.DONT_RECURSE, None if not isinstance(child.typ, idl.Struct): assert child.typ.static_size if not ret.tmp_ends_with_copy: if ret.max_size == 0: ret.tmp_starts_with_copy = True ret.max_iov += 1 ret.tmp_ends_with_copy = True ret.min_size += child.typ.static_size ret.exp_size += child.typ.static_size ret.max_size += child.typ.static_size ret.max_copy += child.typ.static_size return idlutil.WalkCmd.KEEP_GOING, None idlutil.walk(typ, handle) assert ret.min_size == typ.min_size(version) assert ret.max_size == typ.max_size(version) return ret def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: tmp = _get_buffer_size(typ, version) return BufferSize( min_size=tmp.min_size, exp_size=tmp.exp_size, max_size=tmp.max_size, max_copy=tmp.max_copy, max_copy_extra=tmp.max_copy_extra, max_iov=tmp.max_iov, max_iov_extra=tmp.max_iov_extra, ) # Generate .h ################################################################## def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: cutil.ifdef_init() ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ \t#error Do not include directly; include instead #endif #include /* for uint{{n}}_t types */ #include /* for struct iovec */ """ id2typ: dict[int, idl.Message] = {} for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg ret += """ /* config *********************************************************************/ #include "config.h" """ for ver in sorted(versions): ret += "\n" ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" if ver == "9P2000.e": # SPECIAL (9P2000.e) ret += "#else\n" ret += f"\t#if {c9util.ver_ifdef({ver})}\n" ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += "\t\t#endif\n" ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" ret += "\t#endif\n" ret += "#endif\n" ret += f""" /* enum version ***************************************************************/ enum {c9util.ident('version')} {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) ret += f"\t{c9util.ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += cutil.ifdef_pop(0) ret += f"\t{c9util.ver_enum('NUM')},\n" ret += "};\n" ret += """ /* enum msg_type **************************************************************/ """ ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" ret += cutil.ifdef_pop(0) ret += "};\n" ret += """ /* payload types **************************************************************/ """ def per_version_comment( typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] ) -> str: lines: dict[str, str] = {} for version in sorted(typ.in_versions): lines[version] = fn(typ, version) if len(set(lines.values())) == 1: for _, line in lines.items(): return f"/* {line} */\n" assert False else: ret = "" v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) for version, line in lines.items(): ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" return ret for typ in idlutil.topo_sorted(typs): ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) assert ( sz.min_size <= sz.exp_size and sz.exp_size <= sz.max_size and sz.max_size < cutil.UINT64_MAX ) ret = "" if sz.min_size == sz.max_size: ret += f"size = {sz.min_size:,}" else: ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" if sz.max_size > cutil.UINT32_MAX: ret += " (warning: >UINT32_MAX)" ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" return ret ret += per_version_comment(typ, sum_size) match typ: case idl.Number(): ret += gen_number(typ) case idl.Bitfield(): ret += gen_bitfield(typ) case idl.Struct(): # and idl.Message(): ret += gen_struct(typ) ret += cutil.ifdef_pop(0) ret += """ /* containers *****************************************************************/ """ ret += "\n" ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" tmsg_max_iov: dict[str, int] = {} tmsg_max_copy: dict[str, int] = {} rmsg_max_iov: dict[str, int] = {} rmsg_max_copy: dict[str, int] = {} for typ in typs: if not isinstance(typ, idl.Message): continue if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) continue max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy for version in typ.in_versions: if version not in max_iov: max_iov[version] = 0 max_copy[version] = 0 sz = get_buffer_size(typ, version) if sz.max_iov > max_iov[version]: max_iov[version] = sz.max_iov if sz.max_copy > max_copy[version]: max_copy[version] = sz.max_copy for name, table in [ ("tmsg_max_iov", tmsg_max_iov), ("tmsg_max_copy", tmsg_max_copy), ("rmsg_max_iov", rmsg_max_iov), ("rmsg_max_copy", rmsg_max_copy), ]: inv: dict[int, set[str]] = {} for version, maxval in table.items(): if maxval not in inv: inv[maxval] = set() inv[maxval].add(version) ret += "\n" directive = "if" seen_e = False # SPECIAL (9P2000.e) for maxval in sorted(inv, reverse=True): ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) typ = next(typ for typ in typs if typ.typname == "Tswrite") sz = get_buffer_size(typ, "9P2000.e") match name: case "tmsg_max_iov": maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" case "tmsg_max_copy": maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" case _: assert False ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" ret += "\t#else\n" indent += 1 ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) ret += "\t#endif\n" if "9P2000.e" in inv[maxval]: seen_e = True directive = "elif" ret += "#endif\n" ret += "\n" ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" ret += "};\n" ret += "\n" ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" ret += "};\n" return ret def gen_number(typ: idl.Number) -> str: ret = f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" prefix = f"{c9util.IDENT(typ.typname)}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" return ret def gen_bitfield(typ: idl.Bitfield) -> str: ret = f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" def bitname(val: idl.Bit | idl.BitAlias) -> str: s = val.bitname match val: case idl.Bit(cat=idl.BitCat.RESERVED): s = "_RESERVED_" + s case idl.Bit(cat=idl.BitCat.SUBFIELD): assert isinstance(typ, idl.Bitfield) n = sum( 1 for b in typ.bits[: val.num] if b.cat == idl.BitCat.SUBFIELD and b.bitname == val.bitname ) s = f"_{s}_{n}" case idl.Bit(cat=idl.BitCat.UNUSED): return "" return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) namewidth = max(len(bitname(val)) for val in [*typ.bits, *typ.names.values()]) ret += "\n" for bit in reversed(typ.bits): vers = bit.in_versions if bit.cat == idl.BitCat.UNUSED: vers = typ.in_versions ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. end = "" match bit.cat: case idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD: if cutil.ifdef_leaf_is_noop(): beg = "#define " else: beg = "# define" case idl.BitCat.UNUSED: beg = "/* unused" end = " */" c_name = bitname(bit) c_val = f"1<<{bit.num}" ret += ( f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" ) if aliases := [ alias for alias in typ.names.values() if isinstance(alias, idl.BitAlias) ]: ret += "\n" for alias in aliases: ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) end = "" if cutil.ifdef_leaf_is_noop(): beg = "#define " else: beg = "# define" c_name = bitname(alias) c_val = alias.val ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" ret += cutil.ifdef_pop(1) return ret def gen_struct(typ: idl.Struct) -> str: # and idl.Message ret = c9util.typename(typ) + " {" if not typ.members: ret += "};\n" return ret ret += "\n" typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" ret += cutil.ifdef_pop(1) ret += "};\n" return ret