diff options
Diffstat (limited to 'lib9p/protogen/h.py')
-rw-r--r-- | lib9p/protogen/h.py | 447 |
1 files changed, 447 insertions, 0 deletions
diff --git a/lib9p/protogen/h.py b/lib9p/protogen/h.py new file mode 100644 index 0000000..7785ca1 --- /dev/null +++ b/lib9p/protogen/h.py @@ -0,0 +1,447 @@ +# lib9p/protogen/h.py - Generate 9p.generated.h +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import sys +import typing + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + +# pylint: disable=unused-variable +__all__ = ["gen_h"] + +# get_buffer_size() ############################################################ + + +class BufferSize(typing.NamedTuple): + min_size: int # really just here to sanity-check against typ.min_size(version) + exp_size: int # "expected" or max-reasonable size + max_size: int # really just here to sanity-check against typ.max_size(version) + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + +class TmpBufferSize: + min_size: int + exp_size: int + max_size: int + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + tmp_starts_with_copy: bool + tmp_ends_with_copy: bool + + def __init__(self) -> None: + self.min_size = 0 + self.exp_size = 0 + self.max_size = 0 + self.max_copy = 0 + self.max_copy_extra = "" + self.max_iov = 0 + self.max_iov_extra = "" + self.tmp_starts_with_copy = False + self.tmp_ends_with_copy = False + + +def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: + assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) + + ret = TmpBufferSize() + + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret.min_size = typ.static_size + ret.exp_size = typ.static_size + ret.max_size = typ.static_size + ret.max_copy = typ.static_size + ret.max_iov = 1 + ret.tmp_starts_with_copy = True + ret.tmp_ends_with_copy = True + return ret + + def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]: + nonlocal ret + if path.elems: + child = path.elems[-1] + if version not in child.in_versions: + return idlutil.WalkCmd.DONT_RECURSE, None + if child.cnt: + if child.typ.static_size == 1: # SPECIAL (zerocopy) + ret.max_iov += 1 + # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data + ret.exp_size += 27 if child.membname == "utf8" else 8192 + ret.max_size += child.max_cnt + ret.tmp_ends_with_copy = False + return idlutil.WalkCmd.DONT_RECURSE, None + sub = _get_buffer_size(child.typ, version) + ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM + ret.max_size += sub.max_size * child.max_cnt + if child.membname == "wname" and path.root.typname in ( + "Tsread", + "Tswrite", + ): # SPECIAL (9P2000.e) + assert ret.tmp_ends_with_copy + assert sub.tmp_starts_with_copy + assert not sub.tmp_ends_with_copy + ret.max_copy_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" + ) + ret.max_iov_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" + ) + ret.max_iov -= 1 + else: + ret.max_copy += sub.max_copy * child.max_cnt + if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy + ret.max_iov += 1 + else: # contains zero-copy segments + ret.max_iov += sub.max_iov * child.max_cnt + if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: + # we can merge this one + ret.max_iov -= 1 + if ( + sub.tmp_ends_with_copy + and sub.tmp_starts_with_copy + and sub.max_iov > 1 + ): + # we can merge these + ret.max_iov -= child.max_cnt - 1 + ret.tmp_ends_with_copy = sub.tmp_ends_with_copy + return idlutil.WalkCmd.DONT_RECURSE, None + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + if not ret.tmp_ends_with_copy: + if ret.max_size == 0: + ret.tmp_starts_with_copy = True + ret.max_iov += 1 + ret.tmp_ends_with_copy = True + ret.min_size += child.typ.static_size + ret.exp_size += child.typ.static_size + ret.max_size += child.typ.static_size + ret.max_copy += child.typ.static_size + return idlutil.WalkCmd.KEEP_GOING, None + + idlutil.walk(typ, handle) + assert ret.min_size == typ.min_size(version) + assert ret.max_size == typ.max_size(version) + return ret + + +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + tmp = _get_buffer_size(typ, version) + return BufferSize( + min_size=tmp.min_size, + exp_size=tmp.exp_size, + max_size=tmp.max_size, + max_copy=tmp.max_copy, + max_copy_extra=tmp.max_copy_extra, + max_iov=tmp.max_iov, + max_iov_extra=tmp.max_iov_extra, + ) + + +# Generate .h ################################################################## + + +def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: + cutil.ifdef_init() + + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#ifndef _LIB9P_9P_H_ +\t#error Do not include <lib9p/9p.generated.h> directly; include <lib9p/9p.h> instead +#endif + +#include <stdint.h> /* for uint{{n}}_t types */ + +#include <libhw/generic/net.h> /* for struct iovec */ +""" + + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + ret += """ +/* config *********************************************************************/ + +#include "config.h" +""" + for ver in sorted(versions): + ret += "\n" + ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" + ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" + if ver == "9P2000.e": # SPECIAL (9P2000.e) + ret += "#else\n" + ret += f"\t#if {c9util.ver_ifdef({ver})}\n" + ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += "\t\t#endif\n" + ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" + ret += "\t#endif\n" + ret += "#endif\n" + + ret += f""" +/* enum version ***************************************************************/ + +enum {c9util.ident('version')} {{ +""" + fullversions = ["unknown = 0", *sorted(versions)] + verwidth = max(len(v) for v in fullversions) + for ver in fullversions: + if ver in versions: + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t{c9util.ver_enum(ver)}," + ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' + ret += cutil.ifdef_pop(0) + ret += f"\t{c9util.ver_enum('NUM')},\n" + ret += "};\n" + + ret += """ +/* enum msg_type **************************************************************/ + +""" + ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" + namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) + for n in range(0x100): + if n not in id2typ: + continue + msg = id2typ[n] + ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) + ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" + ret += cutil.ifdef_pop(0) + ret += "};\n" + + ret += """ +/* payload types **************************************************************/ +""" + + def per_version_comment( + typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] + ) -> str: + lines: dict[str, str] = {} + for version in sorted(typ.in_versions): + lines[version] = fn(typ, version) + if len(set(lines.values())) == 1: + for _, line in lines.items(): + return f"/* {line} */\n" + assert False + else: + ret = "" + v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) + for version, line in lines.items(): + ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" + return ret + + for typ in idlutil.topo_sorted(typs): + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + + def sum_size(typ: idl.UserType, version: str) -> str: + sz = get_buffer_size(typ, version) + assert ( + sz.min_size <= sz.exp_size + and sz.exp_size <= sz.max_size + and sz.max_size < cutil.UINT64_MAX + ) + ret = "" + if sz.min_size == sz.max_size: + ret += f"size = {sz.min_size:,}" + else: + ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" + if sz.max_size > cutil.UINT32_MAX: + ret += " (warning: >UINT32_MAX)" + ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" + return ret + + ret += per_version_comment(typ, sum_size) + + match typ: + case idl.Number(): + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" + prefix = f"{c9util.IDENT(typ.typname)}_" + namewidth = max(len(name) for name in typ.vals) + for name, val in typ.vals.items(): + ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" + case idl.Bitfield(): + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" + + def bitname(val: idl.Bit | idl.BitAlias) -> str: + s = val.bitname + match val: + case idl.Bit(cat=idl.BitCat.RESERVED): + s = "_RESERVED_" + s + case idl.Bit(cat=idl.BitCat.SUBFIELD): + assert isinstance(typ, idl.Bitfield) + n = sum( + 1 + for b in typ.bits[: val.num] + if b.cat == idl.BitCat.SUBFIELD + and b.bitname == val.bitname + ) + s = f"_{s}_{n}" + case idl.Bit(cat=idl.BitCat.UNUSED): + return "" + return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) + + namewidth = max( + len(bitname(val)) for val in [*typ.bits, *typ.names.values()] + ) + + ret += "\n" + for bit in reversed(typ.bits): + vers = bit.in_versions + if bit.cat == idl.BitCat.UNUSED: + vers = typ.in_versions + ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) + + # It is important all of the `beg` strings have + # the same length. + end = "" + match bit.cat: + case ( + idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD + ): + if cutil.ifdef_leaf_is_noop(): + beg = "#define " + else: + beg = "# define" + case idl.BitCat.UNUSED: + beg = "/* unused" + end = " */" + + c_name = bitname(bit) + c_val = f"1<<{bit.num}" + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" + if aliases := [ + alias + for alias in typ.names.values() + if isinstance(alias, idl.BitAlias) + ]: + ret += "\n" + + for alias in aliases: + ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) + + end = "" + if cutil.ifdef_leaf_is_noop(): + beg = "#define " + else: + beg = "# define" + + c_name = bitname(alias) + c_val = alias.val + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" + ret += cutil.ifdef_pop(1) + del bitname + case idl.Struct(): # and idl.Message(): + ret += c9util.typename(typ) + " {" + if not typ.members: + ret += "};\n" + continue + ret += "\n" + + typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) + + for member in typ.members: + if member.val: + continue + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" + ret += cutil.ifdef_pop(1) + ret += "};\n" + del typ + ret += cutil.ifdef_pop(0) + + ret += """ +/* containers *****************************************************************/ +""" + ret += "\n" + ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" + + tmsg_max_iov: dict[str, int] = {} + tmsg_max_copy: dict[str, int] = {} + rmsg_max_iov: dict[str, int] = {} + rmsg_max_copy: dict[str, int] = {} + for typ in typs: + if not isinstance(typ, idl.Message): + continue + if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) + continue + max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov + max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy + for version in typ.in_versions: + if version not in max_iov: + max_iov[version] = 0 + max_copy[version] = 0 + sz = get_buffer_size(typ, version) + if sz.max_iov > max_iov[version]: + max_iov[version] = sz.max_iov + if sz.max_copy > max_copy[version]: + max_copy[version] = sz.max_copy + + for name, table in [ + ("tmsg_max_iov", tmsg_max_iov), + ("tmsg_max_copy", tmsg_max_copy), + ("rmsg_max_iov", rmsg_max_iov), + ("rmsg_max_copy", rmsg_max_copy), + ]: + inv: dict[int, set[str]] = {} + for version, maxval in table.items(): + if maxval not in inv: + inv[maxval] = set() + inv[maxval].add(version) + + ret += "\n" + directive = "if" + seen_e = False # SPECIAL (9P2000.e) + for maxval in sorted(inv, reverse=True): + ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" + indent = 1 + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + typ = next(typ for typ in typs if typ.typname == "Tswrite") + sz = get_buffer_size(typ, "9P2000.e") + match name: + case "tmsg_max_iov": + maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" + case "tmsg_max_copy": + maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" + case _: + assert False + ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" + ret += "\t#else\n" + indent += 1 + ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + ret += "\t#endif\n" + if "9P2000.e" in inv[maxval]: + seen_e = True + directive = "elif" + ret += "#endif\n" + + ret += "\n" + ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" + ret += "\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" + ret += "};\n" + + ret += "\n" + ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" + ret += "\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" + ret += "};\n" + + return ret |