From 057f17b3f22b8b0112f87b6c3128df6925b8f27a Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 00:19:29 -0600 Subject: lib9p: start to split idl.gen apart as proto.gen --- lib9p/protogen/__init__.py | 1437 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1437 insertions(+) create mode 100644 lib9p/protogen/__init__.py (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py new file mode 100644 index 0000000..21b5161 --- /dev/null +++ b/lib9p/protogen/__init__.py @@ -0,0 +1,1437 @@ +# lib9p/protogen/__init__.py - Generate C marshalers/unmarshalers for +# .9p files defining 9P protocol +# variants. +# +# Copyright (C) 2024-2025 Luke T. Shumaker +# SPDX-License-Identifier: AGPL-3.0-or-later + +import enum +import graphlib +import os.path +import sys +import typing + +import idl + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["main"] + +# Utilities #################################################################### + +idprefix = "lib9p_" + +u32max = (1 << 32) - 1 +u64max = (1 << 64) - 1 + + +def tab_ljust(s: str, width: int) -> str: + cur = len(s.expandtabs(tabsize=8)) + if cur >= width: + return s + return s + " " * (width - cur) + + +def add_prefix(p: str, s: str) -> str: + if s.startswith("_"): + return "_" + p + s[1:] + return p + s + + +def c_macro(full: str) -> str: + full = full.rstrip() + assert "\n" in full + lines = [l.rstrip() for l in full.split("\n")] + width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1]) + lines = [tab_ljust(l, width) for l in lines] + return " \\\n".join(lines).rstrip() + "\n" + + +def c_ver_enum(ver: str) -> str: + return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" + + +def c_ver_ifdef(versions: typing.Collection[str]) -> str: + return " || ".join( + f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) + ) + + +def c_ver_cond(versions: typing.Collection[str]) -> str: + if len(versions) == 1: + v = next(v for v in versions) + return f"is_ver(ctx, {v.replace('.', '_')})" + return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" + + +def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: + match typ: + case idl.Primitive(): + if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) + return "[[gnu::nonstring]] char" + return f"uint{typ.value*8}_t" + case idl.Number(): + return f"{idprefix}{typ.typname}_t" + case idl.Bitfield(): + return f"{idprefix}{typ.typname}_t" + case idl.Message(): + return f"struct {idprefix}msg_{typ.typname}" + case idl.Struct(): + return f"struct {idprefix}{typ.typname}" + case _: + raise ValueError(f"not a type: {typ.__class__.__name__}") + + +def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: + ret: list[str] = [] + for tok in expr.tokens: + match tok: + case idl.ExprOp(): + ret.append(tok.op) + case idl.ExprLit(): + ret.append(str(tok.val)) + case idl.ExprSym(symname="s32_max"): + ret.append("INT32_MAX") + case idl.ExprSym(symname="s64_max"): + ret.append("INT64_MAX") + case idl.ExprSym(): + ret.append(lookup_sym(tok.symname)) + case _: + assert False + return " ".join(ret) + + +_ifdef_stack: list[str | None] = [] + + +def ifdef_push(n: int, _newval: str) -> str: + # Grow the stack as needed + while len(_ifdef_stack) < n: + _ifdef_stack.append(None) + + # Set some variables + parentval: str | None = None + for x in _ifdef_stack[:-1]: + if x is not None: + parentval = x + oldval = _ifdef_stack[-1] + newval: str | None = _newval + if newval == parentval: + newval = None + + # Put newval on the stack. + _ifdef_stack[-1] = newval + + # Build output. + ret = "" + if newval != oldval: + if oldval is not None: + ret += f"#endif /* {oldval} */\n" + if newval is not None: + ret += f"#if {newval}\n" + return ret + + +def ifdef_pop(n: int) -> str: + global _ifdef_stack + ret = "" + while len(_ifdef_stack) > n: + if _ifdef_stack[-1] is not None: + ret += f"#endif /* {_ifdef_stack[-1]} */\n" + _ifdef_stack = _ifdef_stack[:-1] + return ret + + +# topo_sorted() ################################################################ + + +def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]: + ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter() + for typ in typs: + match typ: + case idl.Number(): + ts.add(typ) + case idl.Bitfield(): + ts.add(typ) + case idl.Struct(): # and idl.Message(): + deps = [ + member.typ + for member in typ.members + if not isinstance(member.typ, idl.Primitive) + ] + ts.add(typ, *deps) + return ts.static_order() + + +# walk() ####################################################################### + + +class Path: + root: idl.Type + elems: list[idl.StructMember] + + def __init__( + self, root: idl.Type, elems: list[idl.StructMember] | None = None + ) -> None: + self.root = root + self.elems = elems if elems is not None else [] + + def add(self, elem: idl.StructMember) -> "Path": + return Path(self.root, self.elems + [elem]) + + def parent(self) -> "Path": + return Path(self.root, self.elems[:-1]) + + def c_str(self, base: str, loopdepth: int = 0) -> str: + ret = base + for i, elem in enumerate(self.elems): + if i > 0: + ret += "." + ret += elem.membname + if elem.cnt: + ret += f"[{chr(ord('i')+loopdepth)}]" + loopdepth += 1 + return ret + + def __str__(self) -> str: + return self.c_str(self.root.typname + "->") + + +class WalkCmd(enum.Enum): + KEEP_GOING = 1 + DONT_RECURSE = 2 + ABORT = 3 + + +type WalkHandler = typing.Callable[ + [Path], tuple[WalkCmd, typing.Callable[[], None] | None] +] + + +def _walk(path: Path, handle: WalkHandler) -> WalkCmd: + typ = path.elems[-1].typ if path.elems else path.root + + ret, atexit = handle(path) + + if isinstance(typ, idl.Struct): + match ret: + case WalkCmd.KEEP_GOING: + for member in typ.members: + if _walk(path.add(member), handle) == WalkCmd.ABORT: + ret = WalkCmd.ABORT + break + case WalkCmd.DONT_RECURSE: + ret = WalkCmd.KEEP_GOING + case WalkCmd.ABORT: + ret = WalkCmd.ABORT + case _: + assert False, f"invalid cmd: {ret}" + + if atexit: + atexit() + return ret + + +def walk(typ: idl.Type, handle: WalkHandler) -> None: + _walk(Path(typ), handle) + + +# get_buffer_size() ############################################################ + + +class BufferSize(typing.NamedTuple): + min_size: int # really just here to sanity-check against typ.min_size(version) + exp_size: int # "expected" or max-reasonable size + max_size: int # really just here to sanity-check against typ.max_size(version) + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + +class TmpBufferSize: + min_size: int + exp_size: int + max_size: int + max_copy: int + max_copy_extra: str + max_iov: int + max_iov_extra: str + + tmp_starts_with_copy: bool + tmp_ends_with_copy: bool + + def __init__(self) -> None: + self.min_size = 0 + self.exp_size = 0 + self.max_size = 0 + self.max_copy = 0 + self.max_copy_extra = "" + self.max_iov = 0 + self.max_iov_extra = "" + self.tmp_starts_with_copy = False + self.tmp_ends_with_copy = False + + +def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: + assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) + + ret = TmpBufferSize() + + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret.min_size = typ.static_size + ret.exp_size = typ.static_size + ret.max_size = typ.static_size + ret.max_copy = typ.static_size + ret.max_iov = 1 + ret.tmp_starts_with_copy = True + ret.tmp_ends_with_copy = True + return ret + + def handle(path: Path) -> tuple[WalkCmd, None]: + nonlocal ret + if path.elems: + child = path.elems[-1] + if version not in child.in_versions: + return WalkCmd.DONT_RECURSE, None + if child.cnt: + if child.typ.static_size == 1: # SPECIAL (zerocopy) + ret.max_iov += 1 + # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data + ret.exp_size += 27 if child.membname == "utf8" else 8192 + ret.max_size += child.max_cnt + ret.tmp_ends_with_copy = False + return WalkCmd.DONT_RECURSE, None + sub = _get_buffer_size(child.typ, version) + ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM + ret.max_size += sub.max_size * child.max_cnt + if child.membname == "wname" and path.root.typname in ( + "Tsread", + "Tswrite", + ): # SPECIAL (9P2000.e) + assert ret.tmp_ends_with_copy + assert sub.tmp_starts_with_copy + assert not sub.tmp_ends_with_copy + ret.max_copy_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" + ) + ret.max_iov_extra = ( + f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" + ) + ret.max_iov -= 1 + else: + ret.max_copy += sub.max_copy * child.max_cnt + if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy + ret.max_iov += 1 + else: # contains zero-copy segments + ret.max_iov += sub.max_iov * child.max_cnt + if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: + # we can merge this one + ret.max_iov -= 1 + if ( + sub.tmp_ends_with_copy + and sub.tmp_starts_with_copy + and sub.max_iov > 1 + ): + # we can merge these + ret.max_iov -= child.max_cnt - 1 + ret.tmp_ends_with_copy = sub.tmp_ends_with_copy + return WalkCmd.DONT_RECURSE, None + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + if not ret.tmp_ends_with_copy: + if ret.max_size == 0: + ret.tmp_starts_with_copy = True + ret.max_iov += 1 + ret.tmp_ends_with_copy = True + ret.min_size += child.typ.static_size + ret.exp_size += child.typ.static_size + ret.max_size += child.typ.static_size + ret.max_copy += child.typ.static_size + return WalkCmd.KEEP_GOING, None + + walk(typ, handle) + assert ret.min_size == typ.min_size(version) + assert ret.max_size == typ.max_size(version) + return ret + + +def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: + tmp = _get_buffer_size(typ, version) + return BufferSize( + min_size=tmp.min_size, + exp_size=tmp.exp_size, + max_size=tmp.max_size, + max_copy=tmp.max_copy, + max_copy_extra=tmp.max_copy_extra, + max_iov=tmp.max_iov, + max_iov_extra=tmp.max_iov_extra, + ) + + +# Generate .h ################################################################## + + +def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: + global _ifdef_stack + _ifdef_stack = [] + + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#ifndef _LIB9P_9P_H_ +\t#error Do not include directly; include instead +#endif + +#include /* for uint{{n}}_t types */ + +#include /* for struct iovec */ +""" + + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + ret += """ +/* config *********************************************************************/ + +#include "config.h" +""" + for ver in sorted(versions): + ret += "\n" + ret += f"#ifndef {c_ver_ifdef({ver})}\n" + ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" + if ver == "9P2000.e": # SPECIAL (9P2000.e) + ret += "#else\n" + ret += f"\t#if {c_ver_ifdef({ver})}\n" + ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += "\t\t#endif\n" + ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" + ret += "\t#endif\n" + ret += "#endif\n" + + ret += f""" +/* enum version ***************************************************************/ + +enum {idprefix}version {{ +""" + fullversions = ["unknown = 0", *sorted(versions)] + verwidth = max(len(v) for v in fullversions) + for ver in fullversions: + if ver in versions: + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t{c_ver_enum(ver)}," + ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' + ret += ifdef_pop(0) + ret += f"\t{c_ver_enum('NUM')},\n" + ret += "};\n" + + ret += """ +/* enum msg_type **************************************************************/ + +""" + ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) + for n in range(0x100): + if n not in id2typ: + continue + msg = id2typ[n] + ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) + ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n" + ret += ifdef_pop(0) + ret += "};\n" + + ret += """ +/* payload types **************************************************************/ +""" + + def per_version_comment( + typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] + ) -> str: + lines: dict[str, str] = {} + for version in sorted(typ.in_versions): + lines[version] = fn(typ, version) + if len(set(lines.values())) == 1: + for _, line in lines.items(): + return f"/* {line} */\n" + assert False + else: + ret = "" + v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) + for version, line in lines.items(): + ret += f"/* {c_ver_enum(version):<{v_width}}: {line} */\n" + return ret + + for typ in topo_sorted(typs): + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + + def sum_size(typ: idl.UserType, version: str) -> str: + sz = get_buffer_size(typ, version) + assert ( + sz.min_size <= sz.exp_size + and sz.exp_size <= sz.max_size + and sz.max_size < u64max + ) + ret = "" + if sz.min_size == sz.max_size: + ret += f"size = {sz.min_size:,}" + else: + ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" + if sz.max_size > u32max: + ret += " (warning: >UINT32_MAX)" + ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" + return ret + + ret += per_version_comment(typ, sum_size) + + match typ: + case idl.Number(): + ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" + prefix = f"{idprefix.upper()}{typ.typname.upper()}_" + namewidth = max(len(name) for name in typ.vals) + for name, val in typ.vals.items(): + ret += f"#define {prefix}{name:<{namewidth}} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" + case idl.Bitfield(): + ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" + + def bitname(val: idl.Bit | idl.BitAlias) -> str: + s = val.bitname + match val: + case idl.Bit(cat=idl.BitCat.RESERVED): + s = "_RESERVED_" + s + case idl.Bit(cat=idl.BitCat.SUBFIELD): + assert isinstance(typ, idl.Bitfield) + n = sum( + 1 + for b in typ.bits[: val.num] + if b.cat == idl.BitCat.SUBFIELD + and b.bitname == val.bitname + ) + s = f"_{s}_{n}" + case idl.Bit(cat=idl.BitCat.UNUSED): + return "" + return add_prefix(f"{idprefix.upper()}{typ.typname.upper()}_", s) + + namewidth = max( + len(bitname(val)) for val in [*typ.bits, *typ.names.values()] + ) + + ret += "\n" + for bit in reversed(typ.bits): + vers = bit.in_versions + if bit.cat == idl.BitCat.UNUSED: + vers = typ.in_versions + ret += ifdef_push(2, c_ver_ifdef(vers)) + + # It is important all of the `beg` strings have + # the same length. + end = "" + match bit.cat: + case ( + idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD + ): + if _ifdef_stack[-1]: + beg = "# define" + else: + beg = "#define " + case idl.BitCat.UNUSED: + beg = "/* unused" + end = " */" + + c_name = bitname(bit) + c_val = f"1<<{bit.num}" + ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + if aliases := [ + alias + for alias in typ.names.values() + if isinstance(alias, idl.BitAlias) + ]: + ret += "\n" + + for alias in aliases: + ret += ifdef_push(2, c_ver_ifdef(alias.in_versions)) + + end = "" + if _ifdef_stack[-1]: + beg = "# define" + else: + beg = "#define " + + c_name = bitname(alias) + c_val = alias.val + ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + ret += ifdef_pop(1) + del bitname + case idl.Struct(): # and idl.Message(): + ret += c_typename(typ) + " {" + if not typ.members: + ret += "};\n" + continue + ret += "\n" + + typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) + + for member in typ.members: + if member.val: + continue + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" + ret += ifdef_pop(1) + ret += "};\n" + del typ + ret += ifdef_pop(0) + + ret += """ +/* containers *****************************************************************/ +""" + ret += "\n" + ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n" + + tmsg_max_iov: dict[str, int] = {} + tmsg_max_copy: dict[str, int] = {} + rmsg_max_iov: dict[str, int] = {} + rmsg_max_copy: dict[str, int] = {} + for typ in typs: + if not isinstance(typ, idl.Message): + continue + if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) + continue + max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov + max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy + for version in typ.in_versions: + if version not in max_iov: + max_iov[version] = 0 + max_copy[version] = 0 + sz = get_buffer_size(typ, version) + if sz.max_iov > max_iov[version]: + max_iov[version] = sz.max_iov + if sz.max_copy > max_copy[version]: + max_copy[version] = sz.max_copy + + for name, table in [ + ("tmsg_max_iov", tmsg_max_iov), + ("tmsg_max_copy", tmsg_max_copy), + ("rmsg_max_iov", rmsg_max_iov), + ("rmsg_max_copy", rmsg_max_copy), + ]: + inv: dict[int, set[str]] = {} + for version, maxval in table.items(): + if maxval not in inv: + inv[maxval] = set() + inv[maxval].add(version) + + ret += "\n" + directive = "if" + seen_e = False # SPECIAL (9P2000.e) + for maxval in sorted(inv, reverse=True): + ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" + indent = 1 + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + typ = next(typ for typ in typs if typ.typname == "Tswrite") + sz = get_buffer_size(typ, "9P2000.e") + match name: + case "tmsg_max_iov": + maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" + case "tmsg_max_copy": + maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" + case _: + assert False + ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" + ret += "\t#else\n" + indent += 1 + ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" + if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) + ret += "\t#endif\n" + if "9P2000.e" in inv[maxval]: + seen_e = True + directive = "elif" + ret += "#endif\n" + + ret += "\n" + ret += f"struct {idprefix}Tmsg_send_buf {{\n" + ret += "\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" + ret += "};\n" + + ret += "\n" + ret += f"struct {idprefix}Rmsg_send_buf {{\n" + ret += "\tsize_t iov_cnt;\n" + ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" + ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" + ret += "};\n" + + return ret + + +# Generate .c ################################################################## + + +def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: + global _ifdef_stack + _ifdef_stack = [] + + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#include +#include /* for size_t */ +#include /* for PRI* macros */ +#include /* for memset() */ + +#include + +#include + +#include "internal.h" +""" + + # utilities ################################################################ + ret += """ +/* utilities ******************************************************************/ +""" + + def used(arg: str) -> str: + return arg + + def unused(arg: str) -> str: + return f"LM_UNUSED({arg})" + + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: + ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + for ver in ["unknown", *sorted(versions)]: + if ver != "unknown": + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {{\n" + for n in range(*rng): + xmsg: idl.Message | None = id2typ.get(n, None) + if xmsg: + if ver == "unknown": # SPECIAL (initialization) + if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: + xmsg = None + else: + if ver not in xmsg.in_versions: + xmsg = None + if xmsg: + ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" + ret += "\t},\n" + ret += ifdef_pop(0) + ret += "};\n" + return ret + + for v in sorted(versions): + ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" + ret += "#else\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" + ret += "#endif\n" + ret += "\n" + ret += "/**\n" + ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" + ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n" + ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" + ret += " * several version checks together.\n" + ret += " */\n" + ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" + + # strings ################################################################## + ret += f""" +/* strings ********************************************************************/ + +const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ +""" + for ver in ["unknown", *sorted(versions)]: + if ver in versions: + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' + ret += ifdef_pop(0) + ret += "};\n" + + ret += "\n" + ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" + ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) + + # bitmasks ################################################################# + ret += """ +/* bitmasks *******************************************************************/ +""" + for typ in typs: + if not isinstance(typ, idl.Bitfield): + continue + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + ret += ifdef_push(2, c_ver_ifdef({ver})) + ret += ( + f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + ( + "1" + if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) + and ver in bit.in_versions + else "0" + ) + for bit in reversed(typ.bits) + ) + + ",\n" + ) + ret += ifdef_pop(1) + ret += "};\n" + ret += ifdef_pop(0) + + # validate_* ############################################################### + ret += """ +/* validate_* *****************************************************************/ + +LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { +\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) +\t\t/* If needed-net-size overflowed uint32_t, then +\t\t * there's no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\tif (ctx->net_offset > ctx->net_size) +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { +\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) +\t\t/* If needed-host-size overflowed size_t, then there's +\t\t * no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, + size_t cnt, + _validate_fn_t item_fn, size_t item_host_size) { +\tfor (size_t i = 0; i < cnt; i++) +\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) +\t\t\treturn true; +\treturn false; +} + +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } +""" + + def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: + return bool( + member.max or member.val or any(m.cnt == member for m in typ.members) + ) + + for typ in topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" + + match typ: + case idl.Number(): + ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" + case idl.Bitfield(): + ret += f"\t if (validate_{typ.static_size}(ctx))\n" + ret += "\t\treturn true;\n" + ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" + if typ.static_size == 1: + ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + else: + ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += "\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' + ret += "\treturn false;\n" + case idl.Struct(): # and idl.Message() + if len(typ.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + + # Pass 1 - declare value variables + for member in typ.members: + if should_save_value(typ, member): + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += f"\t{c_typename(member.typ)} {member.membname};\n" + ret += ifdef_pop(1) + + # Pass 2 - declare offset variables + mark_offset: set[str] = set() + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): + if tok.symname[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" + mark_offset.add(tok.symname[1:]) + + # Pass 3 - main pass + ret += "\treturn false\n" + for member in typ.members: + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += "\t || " + if member.in_versions != typ.in_versions: + ret += "( " + c_ver_cond(member.in_versions) + " && " + if member.cnt is not None: + if member.typ.static_size == 1: # SPECIAL (zerocopy) + ret += f"_validate_size_net(ctx, {member.cnt.membname})" + else: + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))" + if typ.typname == "s": # SPECIAL (string) + ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' + else: + if should_save_value(typ, member): + ret += "(" + if member.membname in mark_offset: + ret += f"({{ _{member.membname}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.typname}(ctx)" + if member.membname in mark_offset: + ret += "; })" + if should_save_value(typ, member): + nbytes = member.static_size + assert nbytes + if nbytes == 1: + ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + else: + ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" + if member.in_versions != typ.in_versions: + ret += " )" + ret += "\n" + + # Pass 4 - validate ,max= and ,val= constraints + for member in typ.members: + + def lookup_sym(sym: str) -> str: + match sym: + case "end": + return "ctx->net_offset" + case _: + assert sym.startswith("&") + return f"_{sym[1:]}_offset" + + if member.max: + assert member.static_size + nbits = member.static_size * 8 + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' + if member.val: + assert member.static_size + nbits = member.static_size * 8 + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' + + ret += ifdef_pop(1) + ret += "\t ;\n" + ret += "}\n" + ret += ifdef_pop(0) + + # unmarshal_* ############################################################## + ret += """ +/* unmarshal_* ****************************************************************/ + +LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { +\t*out = ctx->net_bytes[ctx->net_offset]; +\tctx->net_offset += 1; +} + +LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { +\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 2; +} + +LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { +\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 4; +} + +LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { +\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 8; +} +""" + for typ in topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + match typ: + case idl.Number(): + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + case idl.Bitfield(): + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + case idl.Struct(): + ret += "\tmemset(out, 0, sizeof(*out));\n" + + for member in typ.members: + ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + if member.val: + ret += f"\tctx->net_offset += {member.static_size};\n" + continue + ret += "\t" + + prefix = "\t" + if member.in_versions != typ.in_versions: + ret += "if ( " + c_ver_cond(member.in_versions) + " ) " + prefix = "\t\t" + if member.cnt: + if member.in_versions != typ.in_versions: + ret += "{\n" + ret += prefix + if member.typ.static_size == 1: # SPECIAL (string, zerocopy) + ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" + else: + ret += f"out->{member.membname} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" + if member.in_versions != typ.in_versions: + ret += "\t}\n" + else: + ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" + ret += ifdef_pop(1) + ret += "}\n" + ret += ifdef_pop(0) + + # marshal_* ################################################################ + ret += """ +/* marshal_* ******************************************************************/ + +""" + ret += c_macro( + "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" + "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" + "\t\tctx->net_iov_cnt++;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" + "\tctx->net_iov_cnt++;\n" + ) + ret += c_macro( + "#define MARSHAL_BYTES(ctx, data, len)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" + "\tctx->net_copied_size += len;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" + ) + ret += c_macro( + "#define MARSHAL_U8LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tctx->net_copied[ctx->net_copied_size] = val;\n" + "\tctx->net_copied_size += 1;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" + ) + ret += c_macro( + "#define MARSHAL_U16LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 2;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" + ) + ret += c_macro( + "#define MARSHAL_U32LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 4;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" + ) + ret += c_macro( + "#define MARSHAL_U64LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 8;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" + ) + + class OffsetExpr: + static: int + cond: dict[frozenset[str], "OffsetExpr"] + rep: list[tuple[Path, "OffsetExpr"]] + + def __init__(self) -> None: + self.static = 0 + self.rep = [] + self.cond = {} + + def add(self, other: "OffsetExpr") -> None: + self.static += other.static + self.rep += other.rep + for k, v in other.cond.items(): + if k in self.cond: + self.cond[k].add(v) + else: + self.cond[k] = v + + def gen_c( + self, + dsttyp: str, + dstvar: str, + root: str, + indent_depth: int, + loop_depth: int, + ) -> str: + oneline: list[str] = [] + multiline = "" + if self.static: + oneline.append(str(self.static)) + for cnt, sub in self.rep: + if not sub.cond and not sub.rep: + if sub.static == 1: + oneline.append(cnt.c_str(root)) + else: + oneline.append(f"({cnt.c_str(root)})*{sub.static}") + continue + loopvar = chr(ord("i") + loop_depth) + multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += sub.gen_c( + "", dstvar, root, indent_depth + 1, loop_depth + 1 + ) + multiline += f"{'\t'*indent_depth}}}\n" + for vers, sub in self.cond.items(): + multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" + multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) + multiline += f"{'\t'*indent_depth}}}\n" + multiline += ifdef_pop(indent_depth) + if dsttyp: + if not oneline: + oneline.append("0") + ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" + elif oneline: + ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" + ret += multiline + return ret + + type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] + + def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret = OffsetExpr() + ret.static = typ.static_size + return ret + + stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]] + + def pop_root() -> None: + assert False + + def pop_cond() -> None: + nonlocal stack + key = frozenset(stack[-1][0].elems[-1].in_versions) + if key in stack[-2][1].cond: + stack[-2][1].cond[key].add(stack[-1][1]) + else: + stack[-2][1].cond[key] = stack[-1][1] + stack = stack[:-1] + + def pop_rep() -> None: + nonlocal stack + member_path = stack[-1][0] + member = member_path.elems[-1] + assert member.cnt + cnt_path = member_path.parent().add(member.cnt) + stack[-2][1].rep.append((cnt_path, stack[-1][1])) + stack = stack[:-1] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]: + nonlocal recurse + + ret = recurse(path) + if ret != WalkCmd.KEEP_GOING: + return ret, None + + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + stack[-1][2]() + + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + stack.append((path, OffsetExpr(), pop_cond)) + if child.cnt: + stack.append((path, OffsetExpr(), pop_rep)) + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + stack[-1][1].static += child.typ.static_size + return ret, pop + + stack = [(Path(typ), OffsetExpr(), pop_root)] + walk(typ, handle) + return stack[0][1] + + def go_to_end(path: Path) -> WalkCmd: + return WalkCmd.KEEP_GOING + + def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: + def ret(path: Path) -> WalkCmd: + if len(path.elems) == 1 and path.elems[0].membname == name: + return WalkCmd.ABORT + return WalkCmd.KEEP_GOING + + return ret + + for typ in typs: + if not ( + isinstance(typ, idl.Message) or typ.typname == "stat" + ): # SPECIAL (include stat) + continue + assert isinstance(typ, idl.Struct) + ret += "\n" + ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" + + # Pass 1 - check size + max_size = max(typ.max_size(v) for v in typ.in_versions) + + if max_size > u32max: # SPECIAL (9P2000.e) + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint64_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" + else: + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint32_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" + if isinstance(typ, idl.Message): # SPECIAL (disable for stat) + ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' + ret += f'\t\t\t"{typ.typname}",\n' + ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' + ret += "\t\t\tctx->ctx->max_msg_size);\n" + ret += "\t\treturn true;\n" + ret += "\t}\n" + + # Pass 2 - write data + ifdef_depth = 1 + stack: list[tuple[Path, bool]] = [(Path(typ), False)] + + def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + ret += f"{'\t'*(len(stack)-1)}}}\n" + if stack[-1][1]: + ifdef_depth -= 1 + ret += ifdef_pop(ifdef_depth) + stack = stack[:-1] + + loopdepth = sum(1 for elem in path.elems if elem.cnt) + struct = path.elems[-1].typ if path.elems else path.root + if isinstance(struct, idl.Struct): + offsets: list[str] = [] + for member in struct.members: + if not member.val: + continue + for tok in member.val.tokens: + if not isinstance(tok, idl.ExprSym): + continue + if tok.symname == "end" or tok.symname.startswith("&"): + if tok.symname not in offsets: + offsets.append(tok.symname) + for name in offsets: + name_prefix = "offsetof_" + "".join( + m.membname + "_" for m in path.elems + ) + if name == "end": + if not path.elems: + nonlocal max_size + if max_size > u32max: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" + else: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" + continue + recurse: OffsetExprRecursion = go_to_end + else: + assert name.startswith("&") + name = name[1:] + recurse = go_to_tok(name) + expr = get_offset_expr(struct, recurse) + expr_prefix = path.c_str("val->", loopdepth) + if not expr_prefix.endswith(">"): + expr_prefix += "." + ret += expr.gen_c( + "uint32_t", + name_prefix + name, + expr_prefix, + len(stack), + loopdepth, + ) + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions)) + ifdef_depth += 1 + ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" + stack.append((path, True)) + if child.cnt: + cnt_path = path.parent().add(child.cnt) + if child.typ.static_size == 1: # SPECIAL (zerocopy) + if path.root.typname == "stat": # SPECIAL (stat) + ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + else: + ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + return WalkCmd.KEEP_GOING, pop + loopvar = chr(ord("i") + loopdepth - 1) + ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + stack.append((path, False)) + if not isinstance(child.typ, idl.Struct): + if child.val: + + def lookup_sym(sym: str) -> str: + nonlocal path + if sym.startswith("&"): + sym = sym[1:] + return ( + "offsetof_" + + "".join(m.membname + "_" for m in path.elems[:-1]) + + sym + ) + + val = c_expr(child.val, lookup_sym) + else: + val = path.c_str("val->") + if isinstance(child.typ, idl.Bitfield): + val += f" & {child.typ.typname}_masks[ctx->ctx->version]" + ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" + return WalkCmd.KEEP_GOING, pop + + walk(typ, handle) + del handle + del stack + del max_size + + ret += "\treturn false;\n" + ret += "}\n" + ret += ifdef_pop(0) + + # function tables ########################################################## + ret += """ +/* function tables ************************************************************/ +""" + + ret += "\n" + ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" + rerror = next(typ for typ in typs if typ.typname == "Rerror") + ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + for ver in sorted(versions): + ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" + ret += ifdef_pop(0) + ret += "};\n" + + ret += "\n" + ret += c_macro( + f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" + f"\t\t.validate = validate_##typ,\n" + f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" + f"\t}}\n" + ) + ret += c_macro( + f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" + f"\t}}\n" + ) + ret += "\n" + ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) + ret += "\n" + ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) + ret += "\n" + ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) + + ret += f""" +LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ +\treturn validate_stat(ctx); +}} +LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{ +\tunmarshal_stat(ctx, out); +}} +LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{ +\treturn marshal_stat(ctx, val); +}} +""" + + ############################################################################ + return ret + + +# Main ######################################################################### + + +def main() -> None: + if typing.TYPE_CHECKING: + + class ANSIColors: + MAGENTA = "\x1b[35m" + RED = "\x1b[31m" + RESET = "\x1b[0m" + + else: + from _colorize import ANSIColors # Present in Python 3.13+ + + if len(sys.argv) < 2: + raise ValueError("requires at least 1 .9p filename") + parser = idl.Parser() + for txtname in sys.argv[1:]: + try: + parser.parse_file(txtname) + except SyntaxError as e: + print( + f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}", + file=sys.stderr, + ) + assert e.text + print(f"\t{e.text}", file=sys.stderr) + text_suffix = e.text.lstrip() + text_prefix = e.text[: -len(text_suffix)] + print( + f"\t{text_prefix}{ANSIColors.RED}{'~'*len(text_suffix)}{ANSIColors.RESET}", + file=sys.stderr, + ) + sys.exit(2) + versions, typs = parser.all() + outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) + with open( + os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8" + ) as fh: + fh.write(gen_h(versions, typs)) + with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: + fh.write(gen_c(versions, typs)) -- cgit v1.2.3-2-g168b From c1a1f287ed883bed049627da0fd8395197ebf876 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 00:53:44 -0600 Subject: lib9p: protogen: pull cutil.py out of __init__.py --- lib9p/protogen/__init__.py | 187 ++++++++++++++++----------------------------- 1 file changed, 64 insertions(+), 123 deletions(-) (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 21b5161..8a9a371 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -1,6 +1,5 @@ # lib9p/protogen/__init__.py - Generate C marshalers/unmarshalers for -# .9p files defining 9P protocol -# variants. +# .9p files defining 9P protocol variants # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later @@ -13,6 +12,8 @@ import typing import idl +from . import cutil + # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". @@ -25,16 +26,6 @@ __all__ = ["main"] idprefix = "lib9p_" -u32max = (1 << 32) - 1 -u64max = (1 << 64) - 1 - - -def tab_ljust(s: str, width: int) -> str: - cur = len(s.expandtabs(tabsize=8)) - if cur >= width: - return s - return s + " " * (width - cur) - def add_prefix(p: str, s: str) -> str: if s.startswith("_"): @@ -42,15 +33,6 @@ def add_prefix(p: str, s: str) -> str: return p + s -def c_macro(full: str) -> str: - full = full.rstrip() - assert "\n" in full - lines = [l.rstrip() for l in full.split("\n")] - width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1]) - lines = [tab_ljust(l, width) for l in lines] - return " \\\n".join(lines).rstrip() + "\n" - - def c_ver_enum(ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" @@ -105,47 +87,6 @@ def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: return " ".join(ret) -_ifdef_stack: list[str | None] = [] - - -def ifdef_push(n: int, _newval: str) -> str: - # Grow the stack as needed - while len(_ifdef_stack) < n: - _ifdef_stack.append(None) - - # Set some variables - parentval: str | None = None - for x in _ifdef_stack[:-1]: - if x is not None: - parentval = x - oldval = _ifdef_stack[-1] - newval: str | None = _newval - if newval == parentval: - newval = None - - # Put newval on the stack. - _ifdef_stack[-1] = newval - - # Build output. - ret = "" - if newval != oldval: - if oldval is not None: - ret += f"#endif /* {oldval} */\n" - if newval is not None: - ret += f"#if {newval}\n" - return ret - - -def ifdef_pop(n: int) -> str: - global _ifdef_stack - ret = "" - while len(_ifdef_stack) > n: - if _ifdef_stack[-1] is not None: - ret += f"#endif /* {_ifdef_stack[-1]} */\n" - _ifdef_stack = _ifdef_stack[:-1] - return ret - - # topo_sorted() ################################################################ @@ -378,8 +319,7 @@ def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: - global _ifdef_stack - _ifdef_stack = [] + cutil.ifdef_init() ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ @@ -424,10 +364,10 @@ enum {idprefix}version {{ verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: - ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t{c_ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += f"\t{c_ver_enum('NUM')},\n" ret += "};\n" @@ -441,9 +381,9 @@ enum {idprefix}version {{ if n not in id2typ: continue msg = id2typ[n] - ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(msg.in_versions)) ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += "};\n" ret += """ @@ -469,21 +409,21 @@ enum {idprefix}version {{ for typ in topo_sorted(typs): ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) assert ( sz.min_size <= sz.exp_size and sz.exp_size <= sz.max_size - and sz.max_size < u64max + and sz.max_size < cutil.UINT64_MAX ) ret = "" if sz.min_size == sz.max_size: ret += f"size = {sz.min_size:,}" else: ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" - if sz.max_size > u32max: + if sz.max_size > cutil.UINT32_MAX: ret += " (warning: >UINT32_MAX)" ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" return ret @@ -527,7 +467,7 @@ enum {idprefix}version {{ vers = bit.in_versions if bit.cat == idl.BitCat.UNUSED: vers = typ.in_versions - ret += ifdef_push(2, c_ver_ifdef(vers)) + ret += cutil.ifdef_push(2, c_ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. @@ -536,10 +476,10 @@ enum {idprefix}version {{ case ( idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD ): - if _ifdef_stack[-1]: - beg = "# define" - else: + if cutil.ifdef_leaf_is_noop(): beg = "#define " + else: + beg = "# define" case idl.BitCat.UNUSED: beg = "/* unused" end = " */" @@ -555,18 +495,18 @@ enum {idprefix}version {{ ret += "\n" for alias in aliases: - ret += ifdef_push(2, c_ver_ifdef(alias.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(alias.in_versions)) end = "" - if _ifdef_stack[-1]: - beg = "# define" - else: + if cutil.ifdef_leaf_is_noop(): beg = "#define " + else: + beg = "# define" c_name = bitname(alias) c_val = alias.val ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) del bitname case idl.Struct(): # and idl.Message(): ret += c_typename(typ) + " {" @@ -580,12 +520,12 @@ enum {idprefix}version {{ for member in typ.members: if member.val: continue - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) ret += "};\n" del typ - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += """ /* containers *****************************************************************/ @@ -675,8 +615,7 @@ enum {idprefix}version {{ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: - global _ifdef_stack - _ifdef_stack = [] + cutil.ifdef_init() ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ @@ -711,7 +650,7 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" for ver in ["unknown", *sorted(versions)]: if ver != "unknown": - ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t[{c_ver_enum(ver)}] = {{\n" for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) @@ -725,7 +664,7 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: if xmsg: ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" ret += "\t},\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += "};\n" return ret @@ -752,9 +691,9 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: - ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" @@ -769,11 +708,11 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ if not isinstance(typ, idl.Bitfield): continue ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): - ret += ifdef_push(2, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(2, c_ver_ifdef({ver})) ret += ( f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( @@ -787,9 +726,9 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ ) + ",\n" ) - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) ret += "};\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) # validate_* ############################################################### ret += """ @@ -839,7 +778,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: @@ -865,9 +804,9 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 1 - declare value variables for member in typ.members: if should_save_value(typ, member): - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ)} {member.membname};\n" - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() @@ -881,7 +820,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " @@ -925,20 +864,20 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val if member.max: assert member.static_size nbits = member.static_size * 8 - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) ret += "\t ;\n" ret += "}\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) # unmarshal_* ############################################################## ret += """ @@ -968,7 +907,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" match typ: case idl.Number(): @@ -979,7 +918,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: - ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) if member.val: ret += f"\tctx->net_offset += {member.static_size};\n" continue @@ -1005,16 +944,16 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "\t}\n" else: ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" - ret += ifdef_pop(1) + ret += cutil.ifdef_pop(1) ret += "}\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ """ - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" "\t\tctx->net_iov_cnt++;\n" @@ -1022,7 +961,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" "\tctx->net_iov_cnt++;\n" ) - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_BYTES(ctx, data, len)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" @@ -1030,7 +969,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o "\tctx->net_copied_size += len;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" ) - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_U8LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" @@ -1038,7 +977,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o "\tctx->net_copied_size += 1;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" ) - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_U16LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" @@ -1046,7 +985,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o "\tctx->net_copied_size += 2;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" ) - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_U32LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" @@ -1054,7 +993,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o "\tctx->net_copied_size += 4;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" ) - ret += c_macro( + ret += cutil.macro( "#define MARSHAL_U64LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" @@ -1108,11 +1047,11 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ) multiline += f"{'\t'*indent_depth}}}\n" for vers, sub in self.cond.items(): - multiline += ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) + multiline += cutil.ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) multiline += f"{'\t'*indent_depth}}}\n" - multiline += ifdef_pop(indent_depth) + multiline += cutil.ifdef_pop(indent_depth) if dsttyp: if not oneline: oneline.append("0") @@ -1204,13 +1143,13 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o continue assert isinstance(typ, idl.Struct) ret += "\n" - ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) - if max_size > u32max: # SPECIAL (9P2000.e) + if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) ret += get_offset_expr(typ, go_to_end).gen_c( "uint64_t", "needed_size", "val->", 1, 0 ) @@ -1247,7 +1186,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"{'\t'*(len(stack)-1)}}}\n" if stack[-1][1]: ifdef_depth -= 1 - ret += ifdef_pop(ifdef_depth) + ret += cutil.ifdef_pop(ifdef_depth) stack = stack[:-1] loopdepth = sum(1 for elem in path.elems if elem.cnt) @@ -1270,7 +1209,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o if name == "end": if not path.elems: nonlocal max_size - if max_size > u32max: + if max_size > cutil.UINT32_MAX: ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" else: ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" @@ -1295,7 +1234,9 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: - ret += ifdef_push(ifdef_depth + 1, c_ver_ifdef(child.in_versions)) + ret += cutil.ifdef_push( + ifdef_depth + 1, c_ver_ifdef(child.in_versions) + ) ifdef_depth += 1 ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" stack.append((path, True)) @@ -1338,7 +1279,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += "\treturn false;\n" ret += "}\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) # function tables ########################################################## ret += """ @@ -1350,20 +1291,20 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o rerror = next(typ for typ in typs if typ.typname == "Rerror") ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): - ret += ifdef_push(1, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" - ret += ifdef_pop(0) + ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" - ret += c_macro( + ret += cutil.macro( f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" f"\t\t.validate = validate_##typ,\n" f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" f"\t}}\n" ) - ret += c_macro( + ret += cutil.macro( f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" f"\t}}\n" -- cgit v1.2.3-2-g168b From 77249bb45c44ec88c96cd00da0805e1a58a1bfd6 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 01:22:27 -0600 Subject: lib9p: protogen: pull c9util.py out of __init__.py --- lib9p/protogen/__init__.py | 264 ++++++++++++++++++--------------------------- 1 file changed, 105 insertions(+), 159 deletions(-) (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 8a9a371..73542a2 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -12,7 +12,7 @@ import typing import idl -from . import cutil +from . import c9util, cutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -22,70 +22,6 @@ from . import cutil # pylint: disable=unused-variable __all__ = ["main"] -# Utilities #################################################################### - -idprefix = "lib9p_" - - -def add_prefix(p: str, s: str) -> str: - if s.startswith("_"): - return "_" + p + s[1:] - return p + s - - -def c_ver_enum(ver: str) -> str: - return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" - - -def c_ver_ifdef(versions: typing.Collection[str]) -> str: - return " || ".join( - f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) - ) - - -def c_ver_cond(versions: typing.Collection[str]) -> str: - if len(versions) == 1: - v = next(v for v in versions) - return f"is_ver(ctx, {v.replace('.', '_')})" - return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" - - -def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: - match typ: - case idl.Primitive(): - if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) - return "[[gnu::nonstring]] char" - return f"uint{typ.value*8}_t" - case idl.Number(): - return f"{idprefix}{typ.typname}_t" - case idl.Bitfield(): - return f"{idprefix}{typ.typname}_t" - case idl.Message(): - return f"struct {idprefix}msg_{typ.typname}" - case idl.Struct(): - return f"struct {idprefix}{typ.typname}" - case _: - raise ValueError(f"not a type: {typ.__class__.__name__}") - - -def c_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: - ret: list[str] = [] - for tok in expr.tokens: - match tok: - case idl.ExprOp(): - ret.append(tok.op) - case idl.ExprLit(): - ret.append(str(tok.val)) - case idl.ExprSym(symname="s32_max"): - ret.append("INT32_MAX") - case idl.ExprSym(symname="s64_max"): - ret.append("INT64_MAX") - case idl.ExprSym(): - ret.append(lookup_sym(tok.symname)) - case _: - assert False - return " ".join(ret) - # topo_sorted() ################################################################ @@ -343,13 +279,13 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: """ for ver in sorted(versions): ret += "\n" - ret += f"#ifndef {c_ver_ifdef({ver})}\n" - ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" + ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" + ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" if ver == "9P2000.e": # SPECIAL (9P2000.e) ret += "#else\n" - ret += f"\t#if {c_ver_ifdef({ver})}\n" + ret += f"\t#if {c9util.ver_ifdef({ver})}\n" ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" - ret += f"\t\t\t#error if {c_ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" + ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += "\t\t#endif\n" ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" ret += "\t#endif\n" @@ -358,31 +294,31 @@ def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: ret += f""" /* enum version ***************************************************************/ -enum {idprefix}version {{ +enum {c9util.ident('version')} {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t{c_ver_enum(ver)}," + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t{c9util.ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += cutil.ifdef_pop(0) - ret += f"\t{c_ver_enum('NUM')},\n" + ret += f"\t{c9util.ver_enum('NUM')},\n" ret += "};\n" ret += """ /* enum msg_type **************************************************************/ """ - ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" + ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] - ret += cutil.ifdef_push(1, c_ver_ifdef(msg.in_versions)) - ret += f"\t{idprefix.upper()}TYP_{msg.typname:<{namewidth}} = {msg.msgid},\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) + ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" ret += cutil.ifdef_pop(0) ret += "};\n" @@ -402,14 +338,14 @@ enum {idprefix}version {{ assert False else: ret = "" - v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) + v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) for version, line in lines.items(): - ret += f"/* {c_ver_enum(version):<{v_width}}: {line} */\n" + ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" return ret for typ in topo_sorted(typs): ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) @@ -432,13 +368,13 @@ enum {idprefix}version {{ match typ: case idl.Number(): - ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" - prefix = f"{idprefix.upper()}{typ.typname.upper()}_" + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" + prefix = f"{c9util.IDENT(typ.typname)}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): - ret += f"#define {prefix}{name:<{namewidth}} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" + ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" case idl.Bitfield(): - ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" + ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" def bitname(val: idl.Bit | idl.BitAlias) -> str: s = val.bitname @@ -456,7 +392,7 @@ enum {idprefix}version {{ s = f"_{s}_{n}" case idl.Bit(cat=idl.BitCat.UNUSED): return "" - return add_prefix(f"{idprefix.upper()}{typ.typname.upper()}_", s) + return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) namewidth = max( len(bitname(val)) for val in [*typ.bits, *typ.names.values()] @@ -467,7 +403,7 @@ enum {idprefix}version {{ vers = bit.in_versions if bit.cat == idl.BitCat.UNUSED: vers = typ.in_versions - ret += cutil.ifdef_push(2, c_ver_ifdef(vers)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. @@ -486,7 +422,7 @@ enum {idprefix}version {{ c_name = bitname(bit) c_val = f"1<<{bit.num}" - ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" if aliases := [ alias for alias in typ.names.values() @@ -495,7 +431,7 @@ enum {idprefix}version {{ ret += "\n" for alias in aliases: - ret += cutil.ifdef_push(2, c_ver_ifdef(alias.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) end = "" if cutil.ifdef_leaf_is_noop(): @@ -505,23 +441,23 @@ enum {idprefix}version {{ c_name = bitname(alias) c_val = alias.val - ret += f"{beg} {c_name:<{namewidth}} (({c_typename(typ)})({c_val})){end}\n" + ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" ret += cutil.ifdef_pop(1) del bitname case idl.Struct(): # and idl.Message(): - ret += c_typename(typ) + " {" + ret += c9util.typename(typ) + " {" if not typ.members: ret += "};\n" continue ret += "\n" - typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) + typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" ret += cutil.ifdef_pop(1) ret += "};\n" del typ @@ -531,7 +467,7 @@ enum {idprefix}version {{ /* containers *****************************************************************/ """ ret += "\n" - ret += f"#define _{idprefix.upper()}MAX(a, b) ((a) > (b)) ? (a) : (b)\n" + ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" tmsg_max_iov: dict[str, int] = {} tmsg_max_copy: dict[str, int] = {} @@ -570,7 +506,7 @@ enum {idprefix}version {{ directive = "if" seen_e = False # SPECIAL (9P2000.e) for maxval in sorted(inv, reverse=True): - ret += f"#{directive} {c_ver_ifdef(inv[maxval])}\n" + ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) typ = next(typ for typ in typs if typ.typname == "Tswrite") @@ -582,11 +518,11 @@ enum {idprefix}version {{ maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" case _: assert False - ret += f"\t#if {c_ver_ifdef({"9P2000.e"})}\n" - ret += f"\t\t#define {idprefix.upper()}{name.upper()} _{idprefix.upper()}MAX({maxval}, {maxexpr})\n" + ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" + ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" ret += "\t#else\n" indent += 1 - ret += f"{'\t'*indent}#define {idprefix.upper()}{name.upper()} {maxval}\n" + ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) ret += "\t#endif\n" if "9P2000.e" in inv[maxval]: @@ -595,17 +531,17 @@ enum {idprefix}version {{ ret += "#endif\n" ret += "\n" - ret += f"struct {idprefix}Tmsg_send_buf {{\n" + ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{idprefix.upper()}TMSG_MAX_IOV];\n" - ret += f"\tuint8_t copied[{idprefix.upper()}TMSG_MAX_COPY];\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" ret += "};\n" ret += "\n" - ret += f"struct {idprefix}Rmsg_send_buf {{\n" + ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{idprefix.upper()}RMSG_MAX_IOV];\n" - ret += f"\tuint8_t copied[{idprefix.upper()}RMSG_MAX_COPY];\n" + ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" + ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" ret += "};\n" return ret @@ -647,11 +583,11 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: id2typ[msg.msgid] = msg def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: - ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" for ver in ["unknown", *sorted(versions)]: if ver != "unknown": - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: @@ -670,14 +606,16 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: for v in sorted(versions): ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" + ret += ( + f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" + ) ret += "#else\n" ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" ret += "#endif\n" ret += "\n" ret += "/**\n" - ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" - ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n" + ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" + ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" ret += " * several version checks together.\n" ret += " */\n" @@ -687,17 +625,17 @@ def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: ret += f""" /* strings ********************************************************************/ -const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ +const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" - ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" + ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) # bitmasks ################################################################# @@ -708,13 +646,13 @@ const char *const _{idprefix}table_ver_name[{c_ver_enum('NUM')}] = {{ if not isinstance(typ, idl.Bitfield): continue ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static const {c_typename(typ)} {typ.typname}_masks[{c_ver_enum('NUM')}] = {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): - ret += cutil.ifdef_push(2, c_ver_ifdef({ver})) + ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) ret += ( - f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( ( "1" @@ -778,7 +716,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: @@ -787,11 +725,11 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" - ret += f"\t{c_typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" + ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: - ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: - ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += "\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" @@ -804,8 +742,8 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 1 - declare value variables for member in typ.members: if should_save_value(typ, member): - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t{c_typename(member.typ)} {member.membname};\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" ret += cutil.ifdef_pop(1) # Pass 2 - declare offset variables @@ -820,15 +758,15 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: - ret += "( " + c_ver_cond(member.in_versions) + " && " + ret += "( " + c9util.ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: - ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c_typename(member.typ)}))" + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" if typ.typname == "s": # SPECIAL (string) ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: @@ -864,14 +802,14 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val if member.max: assert member.static_size nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += cutil.ifdef_pop(1) @@ -907,18 +845,18 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" match typ: case idl.Number(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c_typename(typ.prim)} *)out);\n" + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: - ret += cutil.ifdef_push(2, c_ver_ifdef(member.in_versions)) + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) if member.val: ret += f"\tctx->net_offset += {member.static_size};\n" continue @@ -926,7 +864,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o prefix = "\t" if member.in_versions != typ.in_versions: - ret += "if ( " + c_ver_cond(member.in_versions) + " ) " + ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " prefix = "\t\t" if member.cnt: if member.in_versions != typ.in_versions: @@ -1041,14 +979,14 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o oneline.append(f"({cnt.c_str(root)})*{sub.static}") continue loopvar = chr(ord("i") + loop_depth) - multiline += f"{'\t'*indent_depth}for ({c_typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" multiline += sub.gen_c( "", dstvar, root, indent_depth + 1, loop_depth + 1 ) multiline += f"{'\t'*indent_depth}}}\n" for vers, sub in self.cond.items(): - multiline += cutil.ifdef_push(indent_depth + 1, c_ver_ifdef(vers)) - multiline += f"{'\t'*indent_depth}if {c_ver_cond(vers)} {{\n" + multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) multiline += f"{'\t'*indent_depth}}}\n" multiline += cutil.ifdef_pop(indent_depth) @@ -1143,8 +1081,8 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o continue assert isinstance(typ, idl.Struct) ret += "\n" - ret += cutil.ifdef_push(1, c_ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c_typename(typ)} *val) {{\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) @@ -1235,10 +1173,10 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: ret += cutil.ifdef_push( - ifdef_depth + 1, c_ver_ifdef(child.in_versions) + ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) ) ifdef_depth += 1 - ret += f"{'\t'*len(stack)}if ({c_ver_cond(child.in_versions)}) {{\n" + ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" stack.append((path, True)) if child.cnt: cnt_path = path.parent().add(child.cnt) @@ -1249,7 +1187,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" return WalkCmd.KEEP_GOING, pop loopvar = chr(ord("i") + loopdepth - 1) - ret += f"{'\t'*len(stack)}for ({c_typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" stack.append((path, False)) if not isinstance(child.typ, idl.Struct): if child.val: @@ -1264,7 +1202,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o + sym ) - val = c_expr(child.val, lookup_sym) + val = c9util.idl_expr(child.val, lookup_sym) else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): @@ -1287,45 +1225,53 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o """ ret += "\n" - ret += f"const uint32_t _{idprefix}table_msg_min_size[{c_ver_enum('NUM')}] = {{\n" + ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" rerror = next(typ for typ in typs if typ.typname == "Rerror") - ret += f"\t[{c_ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): - ret += cutil.ifdef_push(1, c_ver_ifdef({ver})) - ret += f"\t[{c_ver_enum(ver)}] = {rerror.min_size(ver)},\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" ret += cutil.macro( - f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" - f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" + f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" + f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" f"\t\t.validate = validate_##typ,\n" f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" f"\t}}\n" ) ret += cutil.macro( - f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" + f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" f"\t}}\n" ) ret += "\n" - ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) + ret += msg_table( + "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) + ) ret += "\n" - ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) + ret += msg_table( + "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) + ) ret += "\n" - ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) + ret += msg_table( + "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) + ) ret += "\n" - ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) + ret += msg_table( + "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) + ) ret += f""" -LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ +LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ \treturn validate_stat(ctx); }} -LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct {idprefix}stat *out) {{ +LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ \tunmarshal_stat(ctx, out); }} -LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct {idprefix}stat *val) {{ +LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ \treturn marshal_stat(ctx, val); }} """ -- cgit v1.2.3-2-g168b From ee356d885a984d5e79da0da20ce608787f1426f3 Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 02:05:13 -0600 Subject: lib9p: protogen: pull idlutil.py out of __init__.py --- lib9p/protogen/__init__.py | 156 ++++++++++----------------------------------- 1 file changed, 32 insertions(+), 124 deletions(-) (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 73542a2..76e80f3 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -4,15 +4,13 @@ # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later -import enum -import graphlib import os.path import sys import typing import idl -from . import c9util, cutil +from . import c9util, cutil, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -23,100 +21,6 @@ from . import c9util, cutil __all__ = ["main"] -# topo_sorted() ################################################################ - - -def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]: - ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter() - for typ in typs: - match typ: - case idl.Number(): - ts.add(typ) - case idl.Bitfield(): - ts.add(typ) - case idl.Struct(): # and idl.Message(): - deps = [ - member.typ - for member in typ.members - if not isinstance(member.typ, idl.Primitive) - ] - ts.add(typ, *deps) - return ts.static_order() - - -# walk() ####################################################################### - - -class Path: - root: idl.Type - elems: list[idl.StructMember] - - def __init__( - self, root: idl.Type, elems: list[idl.StructMember] | None = None - ) -> None: - self.root = root - self.elems = elems if elems is not None else [] - - def add(self, elem: idl.StructMember) -> "Path": - return Path(self.root, self.elems + [elem]) - - def parent(self) -> "Path": - return Path(self.root, self.elems[:-1]) - - def c_str(self, base: str, loopdepth: int = 0) -> str: - ret = base - for i, elem in enumerate(self.elems): - if i > 0: - ret += "." - ret += elem.membname - if elem.cnt: - ret += f"[{chr(ord('i')+loopdepth)}]" - loopdepth += 1 - return ret - - def __str__(self) -> str: - return self.c_str(self.root.typname + "->") - - -class WalkCmd(enum.Enum): - KEEP_GOING = 1 - DONT_RECURSE = 2 - ABORT = 3 - - -type WalkHandler = typing.Callable[ - [Path], tuple[WalkCmd, typing.Callable[[], None] | None] -] - - -def _walk(path: Path, handle: WalkHandler) -> WalkCmd: - typ = path.elems[-1].typ if path.elems else path.root - - ret, atexit = handle(path) - - if isinstance(typ, idl.Struct): - match ret: - case WalkCmd.KEEP_GOING: - for member in typ.members: - if _walk(path.add(member), handle) == WalkCmd.ABORT: - ret = WalkCmd.ABORT - break - case WalkCmd.DONT_RECURSE: - ret = WalkCmd.KEEP_GOING - case WalkCmd.ABORT: - ret = WalkCmd.ABORT - case _: - assert False, f"invalid cmd: {ret}" - - if atexit: - atexit() - return ret - - -def walk(typ: idl.Type, handle: WalkHandler) -> None: - _walk(Path(typ), handle) - - # get_buffer_size() ############################################################ @@ -170,12 +74,12 @@ def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: ret.tmp_ends_with_copy = True return ret - def handle(path: Path) -> tuple[WalkCmd, None]: + def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]: nonlocal ret if path.elems: child = path.elems[-1] if version not in child.in_versions: - return WalkCmd.DONT_RECURSE, None + return idlutil.WalkCmd.DONT_RECURSE, None if child.cnt: if child.typ.static_size == 1: # SPECIAL (zerocopy) ret.max_iov += 1 @@ -183,7 +87,7 @@ def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: ret.exp_size += 27 if child.membname == "utf8" else 8192 ret.max_size += child.max_cnt ret.tmp_ends_with_copy = False - return WalkCmd.DONT_RECURSE, None + return idlutil.WalkCmd.DONT_RECURSE, None sub = _get_buffer_size(child.typ, version) ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM ret.max_size += sub.max_size * child.max_cnt @@ -218,7 +122,7 @@ def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: # we can merge these ret.max_iov -= child.max_cnt - 1 ret.tmp_ends_with_copy = sub.tmp_ends_with_copy - return WalkCmd.DONT_RECURSE, None + return idlutil.WalkCmd.DONT_RECURSE, None if not isinstance(child.typ, idl.Struct): assert child.typ.static_size if not ret.tmp_ends_with_copy: @@ -230,9 +134,9 @@ def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: ret.exp_size += child.typ.static_size ret.max_size += child.typ.static_size ret.max_copy += child.typ.static_size - return WalkCmd.KEEP_GOING, None + return idlutil.WalkCmd.KEEP_GOING, None - walk(typ, handle) + idlutil.walk(typ, handle) assert ret.min_size == typ.min_size(version) assert ret.max_size == typ.max_size(version) return ret @@ -343,7 +247,7 @@ enum {c9util.ident('version')} {{ ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" return ret - for typ in topo_sorted(typs): + for typ in idlutil.topo_sorted(typs): ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) @@ -712,7 +616,7 @@ LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _val member.max or member.val or any(m.cnt == member for m in typ.members) ) - for typ in topo_sorted(typs): + for typ in idlutil.topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" @@ -841,7 +745,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o \tctx->net_offset += 8; } """ - for typ in topo_sorted(typs): + for typ in idlutil.topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" @@ -943,7 +847,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o class OffsetExpr: static: int cond: dict[frozenset[str], "OffsetExpr"] - rep: list[tuple[Path, "OffsetExpr"]] + rep: list[tuple[idlutil.Path, "OffsetExpr"]] def __init__(self) -> None: self.static = 0 @@ -999,7 +903,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += multiline return ret - type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] + type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: if not isinstance(typ, idl.Struct): @@ -1008,7 +912,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret.static = typ.static_size return ret - stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]] + stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]] def pop_root() -> None: assert False @@ -1031,11 +935,13 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o stack[-2][1].rep.append((cnt_path, stack[-1][1])) stack = stack[:-1] - def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]: + def handle( + path: idlutil.Path, + ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: nonlocal recurse ret = recurse(path) - if ret != WalkCmd.KEEP_GOING: + if ret != idlutil.WalkCmd.KEEP_GOING: return ret, None nonlocal stack @@ -1059,18 +965,18 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o stack[-1][1].static += child.typ.static_size return ret, pop - stack = [(Path(typ), OffsetExpr(), pop_root)] - walk(typ, handle) + stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)] + idlutil.walk(typ, handle) return stack[0][1] - def go_to_end(path: Path) -> WalkCmd: - return WalkCmd.KEEP_GOING + def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: + return idlutil.WalkCmd.KEEP_GOING - def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: - def ret(path: Path) -> WalkCmd: + def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: + def ret(path: idlutil.Path) -> idlutil.WalkCmd: if len(path.elems) == 1 and path.elems[0].membname == name: - return WalkCmd.ABORT - return WalkCmd.KEEP_GOING + return idlutil.WalkCmd.ABORT + return idlutil.WalkCmd.KEEP_GOING return ret @@ -1107,9 +1013,11 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o # Pass 2 - write data ifdef_depth = 1 - stack: list[tuple[Path, bool]] = [(Path(typ), False)] + stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)] - def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]: + def handle( + path: idlutil.Path, + ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: nonlocal ret nonlocal ifdef_depth nonlocal stack @@ -1185,7 +1093,7 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - return WalkCmd.KEEP_GOING, pop + return idlutil.WalkCmd.KEEP_GOING, pop loopvar = chr(ord("i") + loopdepth - 1) ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" stack.append((path, False)) @@ -1208,9 +1116,9 @@ LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *o if isinstance(child.typ, idl.Bitfield): val += f" & {child.typ.typname}_masks[ctx->ctx->version]" ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" - return WalkCmd.KEEP_GOING, pop + return idlutil.WalkCmd.KEEP_GOING, pop - walk(typ, handle) + idlutil.walk(typ, handle) del handle del stack del max_size -- cgit v1.2.3-2-g168b From 2a70a611558daa248e4fc1a11a9aa0ceb3ed397a Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 02:09:30 -0600 Subject: lib9p: protogen: pull h.py out of __init__.py --- lib9p/protogen/__init__.py | 434 +-------------------------------------------- 1 file changed, 2 insertions(+), 432 deletions(-) (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 76e80f3..37cf6f5 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -10,7 +10,7 @@ import typing import idl -from . import c9util, cutil, idlutil +from . import c9util, cutil, h, idlutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in @@ -21,436 +21,6 @@ from . import c9util, cutil, idlutil __all__ = ["main"] -# get_buffer_size() ############################################################ - - -class BufferSize(typing.NamedTuple): - min_size: int # really just here to sanity-check against typ.min_size(version) - exp_size: int # "expected" or max-reasonable size - max_size: int # really just here to sanity-check against typ.max_size(version) - max_copy: int - max_copy_extra: str - max_iov: int - max_iov_extra: str - - -class TmpBufferSize: - min_size: int - exp_size: int - max_size: int - max_copy: int - max_copy_extra: str - max_iov: int - max_iov_extra: str - - tmp_starts_with_copy: bool - tmp_ends_with_copy: bool - - def __init__(self) -> None: - self.min_size = 0 - self.exp_size = 0 - self.max_size = 0 - self.max_copy = 0 - self.max_copy_extra = "" - self.max_iov = 0 - self.max_iov_extra = "" - self.tmp_starts_with_copy = False - self.tmp_ends_with_copy = False - - -def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: - assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) - - ret = TmpBufferSize() - - if not isinstance(typ, idl.Struct): - assert typ.static_size - ret.min_size = typ.static_size - ret.exp_size = typ.static_size - ret.max_size = typ.static_size - ret.max_copy = typ.static_size - ret.max_iov = 1 - ret.tmp_starts_with_copy = True - ret.tmp_ends_with_copy = True - return ret - - def handle(path: idlutil.Path) -> tuple[idlutil.WalkCmd, None]: - nonlocal ret - if path.elems: - child = path.elems[-1] - if version not in child.in_versions: - return idlutil.WalkCmd.DONT_RECURSE, None - if child.cnt: - if child.typ.static_size == 1: # SPECIAL (zerocopy) - ret.max_iov += 1 - # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data - ret.exp_size += 27 if child.membname == "utf8" else 8192 - ret.max_size += child.max_cnt - ret.tmp_ends_with_copy = False - return idlutil.WalkCmd.DONT_RECURSE, None - sub = _get_buffer_size(child.typ, version) - ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM - ret.max_size += sub.max_size * child.max_cnt - if child.membname == "wname" and path.root.typname in ( - "Tsread", - "Tswrite", - ): # SPECIAL (9P2000.e) - assert ret.tmp_ends_with_copy - assert sub.tmp_starts_with_copy - assert not sub.tmp_ends_with_copy - ret.max_copy_extra = ( - f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" - ) - ret.max_iov_extra = ( - f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" - ) - ret.max_iov -= 1 - else: - ret.max_copy += sub.max_copy * child.max_cnt - if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy - ret.max_iov += 1 - else: # contains zero-copy segments - ret.max_iov += sub.max_iov * child.max_cnt - if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: - # we can merge this one - ret.max_iov -= 1 - if ( - sub.tmp_ends_with_copy - and sub.tmp_starts_with_copy - and sub.max_iov > 1 - ): - # we can merge these - ret.max_iov -= child.max_cnt - 1 - ret.tmp_ends_with_copy = sub.tmp_ends_with_copy - return idlutil.WalkCmd.DONT_RECURSE, None - if not isinstance(child.typ, idl.Struct): - assert child.typ.static_size - if not ret.tmp_ends_with_copy: - if ret.max_size == 0: - ret.tmp_starts_with_copy = True - ret.max_iov += 1 - ret.tmp_ends_with_copy = True - ret.min_size += child.typ.static_size - ret.exp_size += child.typ.static_size - ret.max_size += child.typ.static_size - ret.max_copy += child.typ.static_size - return idlutil.WalkCmd.KEEP_GOING, None - - idlutil.walk(typ, handle) - assert ret.min_size == typ.min_size(version) - assert ret.max_size == typ.max_size(version) - return ret - - -def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: - tmp = _get_buffer_size(typ, version) - return BufferSize( - min_size=tmp.min_size, - exp_size=tmp.exp_size, - max_size=tmp.max_size, - max_copy=tmp.max_copy, - max_copy_extra=tmp.max_copy_extra, - max_iov=tmp.max_iov, - max_iov_extra=tmp.max_iov_extra, - ) - - -# Generate .h ################################################################## - - -def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: - cutil.ifdef_init() - - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#ifndef _LIB9P_9P_H_ -\t#error Do not include directly; include instead -#endif - -#include /* for uint{{n}}_t types */ - -#include /* for struct iovec */ -""" - - id2typ: dict[int, idl.Message] = {} - for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: - id2typ[msg.msgid] = msg - - ret += """ -/* config *********************************************************************/ - -#include "config.h" -""" - for ver in sorted(versions): - ret += "\n" - ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" - ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" - if ver == "9P2000.e": # SPECIAL (9P2000.e) - ret += "#else\n" - ret += f"\t#if {c9util.ver_ifdef({ver})}\n" - ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" - ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" - ret += "\t\t#endif\n" - ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" - ret += "\t#endif\n" - ret += "#endif\n" - - ret += f""" -/* enum version ***************************************************************/ - -enum {c9util.ident('version')} {{ -""" - fullversions = ["unknown = 0", *sorted(versions)] - verwidth = max(len(v) for v in fullversions) - for ver in fullversions: - if ver in versions: - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t{c9util.ver_enum(ver)}," - ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' - ret += cutil.ifdef_pop(0) - ret += f"\t{c9util.ver_enum('NUM')},\n" - ret += "};\n" - - ret += """ -/* enum msg_type **************************************************************/ - -""" - ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" - namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) - for n in range(0x100): - if n not in id2typ: - continue - msg = id2typ[n] - ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) - ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += """ -/* payload types **************************************************************/ -""" - - def per_version_comment( - typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] - ) -> str: - lines: dict[str, str] = {} - for version in sorted(typ.in_versions): - lines[version] = fn(typ, version) - if len(set(lines.values())) == 1: - for _, line in lines.items(): - return f"/* {line} */\n" - assert False - else: - ret = "" - v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) - for version, line in lines.items(): - ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" - return ret - - for typ in idlutil.topo_sorted(typs): - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - - def sum_size(typ: idl.UserType, version: str) -> str: - sz = get_buffer_size(typ, version) - assert ( - sz.min_size <= sz.exp_size - and sz.exp_size <= sz.max_size - and sz.max_size < cutil.UINT64_MAX - ) - ret = "" - if sz.min_size == sz.max_size: - ret += f"size = {sz.min_size:,}" - else: - ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" - if sz.max_size > cutil.UINT32_MAX: - ret += " (warning: >UINT32_MAX)" - ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" - return ret - - ret += per_version_comment(typ, sum_size) - - match typ: - case idl.Number(): - ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" - prefix = f"{c9util.IDENT(typ.typname)}_" - namewidth = max(len(name) for name in typ.vals) - for name, val in typ.vals.items(): - ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" - case idl.Bitfield(): - ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" - - def bitname(val: idl.Bit | idl.BitAlias) -> str: - s = val.bitname - match val: - case idl.Bit(cat=idl.BitCat.RESERVED): - s = "_RESERVED_" + s - case idl.Bit(cat=idl.BitCat.SUBFIELD): - assert isinstance(typ, idl.Bitfield) - n = sum( - 1 - for b in typ.bits[: val.num] - if b.cat == idl.BitCat.SUBFIELD - and b.bitname == val.bitname - ) - s = f"_{s}_{n}" - case idl.Bit(cat=idl.BitCat.UNUSED): - return "" - return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) - - namewidth = max( - len(bitname(val)) for val in [*typ.bits, *typ.names.values()] - ) - - ret += "\n" - for bit in reversed(typ.bits): - vers = bit.in_versions - if bit.cat == idl.BitCat.UNUSED: - vers = typ.in_versions - ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) - - # It is important all of the `beg` strings have - # the same length. - end = "" - match bit.cat: - case ( - idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD - ): - if cutil.ifdef_leaf_is_noop(): - beg = "#define " - else: - beg = "# define" - case idl.BitCat.UNUSED: - beg = "/* unused" - end = " */" - - c_name = bitname(bit) - c_val = f"1<<{bit.num}" - ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" - if aliases := [ - alias - for alias in typ.names.values() - if isinstance(alias, idl.BitAlias) - ]: - ret += "\n" - - for alias in aliases: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) - - end = "" - if cutil.ifdef_leaf_is_noop(): - beg = "#define " - else: - beg = "# define" - - c_name = bitname(alias) - c_val = alias.val - ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" - ret += cutil.ifdef_pop(1) - del bitname - case idl.Struct(): # and idl.Message(): - ret += c9util.typename(typ) + " {" - if not typ.members: - ret += "};\n" - continue - ret += "\n" - - typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) - - for member in typ.members: - if member.val: - continue - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" - ret += cutil.ifdef_pop(1) - ret += "};\n" - del typ - ret += cutil.ifdef_pop(0) - - ret += """ -/* containers *****************************************************************/ -""" - ret += "\n" - ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" - - tmsg_max_iov: dict[str, int] = {} - tmsg_max_copy: dict[str, int] = {} - rmsg_max_iov: dict[str, int] = {} - rmsg_max_copy: dict[str, int] = {} - for typ in typs: - if not isinstance(typ, idl.Message): - continue - if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) - continue - max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov - max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy - for version in typ.in_versions: - if version not in max_iov: - max_iov[version] = 0 - max_copy[version] = 0 - sz = get_buffer_size(typ, version) - if sz.max_iov > max_iov[version]: - max_iov[version] = sz.max_iov - if sz.max_copy > max_copy[version]: - max_copy[version] = sz.max_copy - - for name, table in [ - ("tmsg_max_iov", tmsg_max_iov), - ("tmsg_max_copy", tmsg_max_copy), - ("rmsg_max_iov", rmsg_max_iov), - ("rmsg_max_copy", rmsg_max_copy), - ]: - inv: dict[int, set[str]] = {} - for version, maxval in table.items(): - if maxval not in inv: - inv[maxval] = set() - inv[maxval].add(version) - - ret += "\n" - directive = "if" - seen_e = False # SPECIAL (9P2000.e) - for maxval in sorted(inv, reverse=True): - ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" - indent = 1 - if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) - typ = next(typ for typ in typs if typ.typname == "Tswrite") - sz = get_buffer_size(typ, "9P2000.e") - match name: - case "tmsg_max_iov": - maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" - case "tmsg_max_copy": - maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" - case _: - assert False - ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" - ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" - ret += "\t#else\n" - indent += 1 - ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" - if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) - ret += "\t#endif\n" - if "9P2000.e" in inv[maxval]: - seen_e = True - directive = "elif" - ret += "#endif\n" - - ret += "\n" - ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" - ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" - ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" - ret += "};\n" - - ret += "\n" - ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" - ret += "\tsize_t iov_cnt;\n" - ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" - ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" - ret += "};\n" - - return ret - - # Generate .c ################################################################## @@ -1227,6 +797,6 @@ def main() -> None: with open( os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8" ) as fh: - fh.write(gen_h(versions, typs)) + fh.write(h.gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: fh.write(gen_c(versions, typs)) -- cgit v1.2.3-2-g168b From 82b733e4f8b3febc3b51c133a52fb62b54180b4b Mon Sep 17 00:00:00 2001 From: "Luke T. Shumaker" Date: Sun, 23 Mar 2025 02:26:08 -0600 Subject: lib9p: protogen: pull c.py and c_*.py out of __init__.py --- lib9p/protogen/__init__.py | 749 +-------------------------------------------- 1 file changed, 2 insertions(+), 747 deletions(-) (limited to 'lib9p/protogen/__init__.py') diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 37cf6f5..c2c6173 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -10,757 +10,12 @@ import typing import idl -from . import c9util, cutil, h, idlutil - -# This strives to be "general-purpose" in that it just acts on the -# *.9p inputs; but (unfortunately?) there are a few special-cases in -# this script, marked with "SPECIAL". - +from . import c, h # pylint: disable=unused-variable __all__ = ["main"] -# Generate .c ################################################################## - - -def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: - cutil.ifdef_init() - - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#include -#include /* for size_t */ -#include /* for PRI* macros */ -#include /* for memset() */ - -#include - -#include - -#include "internal.h" -""" - - # utilities ################################################################ - ret += """ -/* utilities ******************************************************************/ -""" - - def used(arg: str) -> str: - return arg - - def unused(arg: str) -> str: - return f"LM_UNUSED({arg})" - - id2typ: dict[int, idl.Message] = {} - for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: - id2typ[msg.msgid] = msg - - def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: - ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" - for ver in ["unknown", *sorted(versions)]: - if ver != "unknown": - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" - for n in range(*rng): - xmsg: idl.Message | None = id2typ.get(n, None) - if xmsg: - if ver == "unknown": # SPECIAL (initialization) - if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: - xmsg = None - else: - if ver not in xmsg.in_versions: - xmsg = None - if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" - ret += "\t},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - return ret - - for v in sorted(versions): - ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += ( - f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" - ) - ret += "#else\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" - ret += "#endif\n" - ret += "\n" - ret += "/**\n" - ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" - ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" - ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" - ret += " * several version checks together.\n" - ret += " */\n" - ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" - - # strings ################################################################## - ret += f""" -/* strings ********************************************************************/ - -const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ -""" - for ver in ["unknown", *sorted(versions)]: - if ver in versions: - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" - ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) - - # bitmasks ################################################################# - ret += """ -/* bitmasks *******************************************************************/ -""" - for typ in typs: - if not isinstance(typ, idl.Bitfield): - continue - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) - ret += ( - f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - ( - "1" - if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) - and ver in bit.in_versions - else "0" - ) - for bit in reversed(typ.bits) - ) - + ",\n" - ) - ret += cutil.ifdef_pop(1) - ret += "};\n" - ret += cutil.ifdef_pop(0) - - # validate_* ############################################################### - ret += """ -/* validate_* *****************************************************************/ - -LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { -\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) -\t\t/* If needed-net-size overflowed uint32_t, then -\t\t * there's no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\tif (ctx->net_offset > ctx->net_size) -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { -\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) -\t\t/* If needed-host-size overflowed size_t, then there's -\t\t * no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, - size_t cnt, - _validate_fn_t item_fn, size_t item_host_size) { -\tfor (size_t i = 0; i < cnt; i++) -\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) -\t\t\treturn true; -\treturn false; -} - -LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } -LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } -LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } -LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } -""" - - def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: - return bool( - member.max or member.val or any(m.cnt == member for m in typ.members) - ) - - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" - - match typ: - case idl.Number(): - ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" - case idl.Bitfield(): - ret += f"\t if (validate_{typ.static_size}(ctx))\n" - ret += "\t\treturn true;\n" - ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" - if typ.static_size == 1: - ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" - else: - ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += "\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' - ret += "\treturn false;\n" - case idl.Struct(): # and idl.Message() - if len(typ.members) == 0: - ret += "\treturn false;\n" - ret += "}\n" - continue - - # Pass 1 - declare value variables - for member in typ.members: - if should_save_value(typ, member): - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" - ret += cutil.ifdef_pop(1) - - # Pass 2 - declare offset variables - mark_offset: set[str] = set() - for member in typ.members: - for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): - if tok.symname[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" - mark_offset.add(tok.symname[1:]) - - # Pass 3 - main pass - ret += "\treturn false\n" - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += "\t || " - if member.in_versions != typ.in_versions: - ret += "( " + c9util.ver_cond(member.in_versions) + " && " - if member.cnt is not None: - if member.typ.static_size == 1: # SPECIAL (zerocopy) - ret += f"_validate_size_net(ctx, {member.cnt.membname})" - else: - ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" - if typ.typname == "s": # SPECIAL (string) - ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' - else: - if should_save_value(typ, member): - ret += "(" - if member.membname in mark_offset: - ret += f"({{ _{member.membname}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.typname}(ctx)" - if member.membname in mark_offset: - ret += "; })" - if should_save_value(typ, member): - nbytes = member.static_size - assert nbytes - if nbytes == 1: - ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" - else: - ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" - if member.in_versions != typ.in_versions: - ret += " )" - ret += "\n" - - # Pass 4 - validate ,max= and ,val= constraints - for member in typ.members: - - def lookup_sym(sym: str) -> str: - match sym: - case "end": - return "ctx->net_offset" - case _: - assert sym.startswith("&") - return f"_{sym[1:]}_offset" - - if member.max: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' - if member.val: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' - - ret += cutil.ifdef_pop(1) - ret += "\t ;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # unmarshal_* ############################################################## - ret += """ -/* unmarshal_* ****************************************************************/ - -LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { -\t*out = ctx->net_bytes[ctx->net_offset]; -\tctx->net_offset += 1; -} - -LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { -\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 2; -} - -LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { -\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 4; -} - -LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { -\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 8; -} -""" - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" - match typ: - case idl.Number(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Struct(): - ret += "\tmemset(out, 0, sizeof(*out));\n" - - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - if member.val: - ret += f"\tctx->net_offset += {member.static_size};\n" - continue - ret += "\t" - - prefix = "\t" - if member.in_versions != typ.in_versions: - ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " - prefix = "\t\t" - if member.cnt: - if member.in_versions != typ.in_versions: - ret += "{\n" - ret += prefix - if member.typ.static_size == 1: # SPECIAL (string, zerocopy) - ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" - ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" - else: - ret += f"out->{member.membname} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" - if member.in_versions != typ.in_versions: - ret += "\t}\n" - else: - ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" - ret += cutil.ifdef_pop(1) - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # marshal_* ################################################################ - ret += """ -/* marshal_* ******************************************************************/ - -""" - ret += cutil.macro( - "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" - "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" - "\t\tctx->net_iov_cnt++;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" - "\tctx->net_iov_cnt++;\n" - ) - ret += cutil.macro( - "#define MARSHAL_BYTES(ctx, data, len)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" - "\tctx->net_copied_size += len;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U8LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tctx->net_copied[ctx->net_copied_size] = val;\n" - "\tctx->net_copied_size += 1;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U16LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 2;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U32LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 4;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U64LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 8;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" - ) - - class OffsetExpr: - static: int - cond: dict[frozenset[str], "OffsetExpr"] - rep: list[tuple[idlutil.Path, "OffsetExpr"]] - - def __init__(self) -> None: - self.static = 0 - self.rep = [] - self.cond = {} - - def add(self, other: "OffsetExpr") -> None: - self.static += other.static - self.rep += other.rep - for k, v in other.cond.items(): - if k in self.cond: - self.cond[k].add(v) - else: - self.cond[k] = v - - def gen_c( - self, - dsttyp: str, - dstvar: str, - root: str, - indent_depth: int, - loop_depth: int, - ) -> str: - oneline: list[str] = [] - multiline = "" - if self.static: - oneline.append(str(self.static)) - for cnt, sub in self.rep: - if not sub.cond and not sub.rep: - if sub.static == 1: - oneline.append(cnt.c_str(root)) - else: - oneline.append(f"({cnt.c_str(root)})*{sub.static}") - continue - loopvar = chr(ord("i") + loop_depth) - multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" - multiline += sub.gen_c( - "", dstvar, root, indent_depth + 1, loop_depth + 1 - ) - multiline += f"{'\t'*indent_depth}}}\n" - for vers, sub in self.cond.items(): - multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) - multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" - multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) - multiline += f"{'\t'*indent_depth}}}\n" - multiline += cutil.ifdef_pop(indent_depth) - if dsttyp: - if not oneline: - oneline.append("0") - ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" - elif oneline: - ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" - ret += multiline - return ret - - type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] - - def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: - if not isinstance(typ, idl.Struct): - assert typ.static_size - ret = OffsetExpr() - ret.static = typ.static_size - return ret - - stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]] - - def pop_root() -> None: - assert False - - def pop_cond() -> None: - nonlocal stack - key = frozenset(stack[-1][0].elems[-1].in_versions) - if key in stack[-2][1].cond: - stack[-2][1].cond[key].add(stack[-1][1]) - else: - stack[-2][1].cond[key] = stack[-1][1] - stack = stack[:-1] - - def pop_rep() -> None: - nonlocal stack - member_path = stack[-1][0] - member = member_path.elems[-1] - assert member.cnt - cnt_path = member_path.parent().add(member.cnt) - stack[-2][1].rep.append((cnt_path, stack[-1][1])) - stack = stack[:-1] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: - nonlocal recurse - - ret = recurse(path) - if ret != idlutil.WalkCmd.KEEP_GOING: - return ret, None - - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - stack[-1][2]() - - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - stack.append((path, OffsetExpr(), pop_cond)) - if child.cnt: - stack.append((path, OffsetExpr(), pop_rep)) - if not isinstance(child.typ, idl.Struct): - assert child.typ.static_size - stack[-1][1].static += child.typ.static_size - return ret, pop - - stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)] - idlutil.walk(typ, handle) - return stack[0][1] - - def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: - return idlutil.WalkCmd.KEEP_GOING - - def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: - def ret(path: idlutil.Path) -> idlutil.WalkCmd: - if len(path.elems) == 1 and path.elems[0].membname == name: - return idlutil.WalkCmd.ABORT - return idlutil.WalkCmd.KEEP_GOING - - return ret - - for typ in typs: - if not ( - isinstance(typ, idl.Message) or typ.typname == "stat" - ): # SPECIAL (include stat) - continue - assert isinstance(typ, idl.Struct) - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" - - # Pass 1 - check size - max_size = max(typ.max_size(v) for v in typ.in_versions) - - if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint64_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" - else: - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint32_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" - if isinstance(typ, idl.Message): # SPECIAL (disable for stat) - ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' - ret += f'\t\t\t"{typ.typname}",\n' - ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' - ret += "\t\t\tctx->ctx->max_msg_size);\n" - ret += "\t\treturn true;\n" - ret += "\t}\n" - - # Pass 2 - write data - ifdef_depth = 1 - stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - ret += f"{'\t'*(len(stack)-1)}}}\n" - if stack[-1][1]: - ifdef_depth -= 1 - ret += cutil.ifdef_pop(ifdef_depth) - stack = stack[:-1] - - loopdepth = sum(1 for elem in path.elems if elem.cnt) - struct = path.elems[-1].typ if path.elems else path.root - if isinstance(struct, idl.Struct): - offsets: list[str] = [] - for member in struct.members: - if not member.val: - continue - for tok in member.val.tokens: - if not isinstance(tok, idl.ExprSym): - continue - if tok.symname == "end" or tok.symname.startswith("&"): - if tok.symname not in offsets: - offsets.append(tok.symname) - for name in offsets: - name_prefix = "offsetof_" + "".join( - m.membname + "_" for m in path.elems - ) - if name == "end": - if not path.elems: - nonlocal max_size - if max_size > cutil.UINT32_MAX: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" - else: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" - continue - recurse: OffsetExprRecursion = go_to_end - else: - assert name.startswith("&") - name = name[1:] - recurse = go_to_tok(name) - expr = get_offset_expr(struct, recurse) - expr_prefix = path.c_str("val->", loopdepth) - if not expr_prefix.endswith(">"): - expr_prefix += "." - ret += expr.gen_c( - "uint32_t", - name_prefix + name, - expr_prefix, - len(stack), - loopdepth, - ) - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - ret += cutil.ifdef_push( - ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) - ) - ifdef_depth += 1 - ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" - stack.append((path, True)) - if child.cnt: - cnt_path = path.parent().add(child.cnt) - if child.typ.static_size == 1: # SPECIAL (zerocopy) - if path.root.typname == "stat": # SPECIAL (stat) - ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - else: - ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - loopvar = chr(ord("i") + loopdepth - 1) - ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" - stack.append((path, False)) - if not isinstance(child.typ, idl.Struct): - if child.val: - - def lookup_sym(sym: str) -> str: - nonlocal path - if sym.startswith("&"): - sym = sym[1:] - return ( - "offsetof_" - + "".join(m.membname + "_" for m in path.elems[:-1]) - + sym - ) - - val = c9util.idl_expr(child.val, lookup_sym) - else: - val = path.c_str("val->") - if isinstance(child.typ, idl.Bitfield): - val += f" & {child.typ.typname}_masks[ctx->ctx->version]" - ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - - idlutil.walk(typ, handle) - del handle - del stack - del max_size - - ret += "\treturn false;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # function tables ########################################################## - ret += """ -/* function tables ************************************************************/ -""" - - ret += "\n" - ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" - rerror = next(typ for typ in typs if typ.typname == "Rerror") - ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) - for ver in sorted(versions): - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += cutil.macro( - f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" - f"\t\t.validate = validate_##typ,\n" - f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" - f"\t}}\n" - ) - ret += cutil.macro( - f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" - f"\t}}\n" - ) - ret += "\n" - ret += msg_table( - "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) - ) - - ret += f""" -LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ -\treturn validate_stat(ctx); -}} -LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ -\tunmarshal_stat(ctx, out); -}} -LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ -\treturn marshal_stat(ctx, val); -}} -""" - - ############################################################################ - return ret - - -# Main ######################################################################### - - def main() -> None: if typing.TYPE_CHECKING: @@ -799,4 +54,4 @@ def main() -> None: ) as fh: fh.write(h.gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: - fh.write(gen_c(versions, typs)) + fh.write(c.gen_c(versions, typs)) -- cgit v1.2.3-2-g168b