# lib9p/protogen/__init__.py - Generate C marshalers/unmarshalers for # .9p files defining 9P protocol variants # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import enum import graphlib import os.path import sys import typing import idl from . import c9util, cutil # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # pylint: disable=unused-variable __all__ = ["main"] # topo_sorted() ################################################################ def topo_sorted(typs: list[idl.UserType]) -> typing.Iterable[idl.UserType]: ts: graphlib.TopologicalSorter[idl.UserType] = graphlib.TopologicalSorter() for typ in typs: match typ: case idl.Number(): ts.add(typ) case idl.Bitfield(): ts.add(typ) case idl.Struct(): # and idl.Message(): deps = [ member.typ for member in typ.members if not isinstance(member.typ, idl.Primitive) ] ts.add(typ, *deps) return ts.static_order() # walk() ####################################################################### class Path: root: idl.Type elems: list[idl.StructMember] def __init__( self, root: idl.Type, elems: list[idl.StructMember] | None = None ) -> None: self.root = root self.elems = elems if elems is not None else [] def add(self, elem: idl.StructMember) -> "Path": return Path(self.root, self.elems + [elem]) def parent(self) -> "Path": return Path(self.root, self.elems[:-1]) def c_str(self, base: str, loopdepth: int = 0) -> str: ret = base for i, elem in enumerate(self.elems): if i > 0: ret += "." ret += elem.membname if elem.cnt: ret += f"[{chr(ord('i')+loopdepth)}]" loopdepth += 1 return ret def __str__(self) -> str: return self.c_str(self.root.typname + "->") class WalkCmd(enum.Enum): KEEP_GOING = 1 DONT_RECURSE = 2 ABORT = 3 type WalkHandler = typing.Callable[ [Path], tuple[WalkCmd, typing.Callable[[], None] | None] ] def _walk(path: Path, handle: WalkHandler) -> WalkCmd: typ = path.elems[-1].typ if path.elems else path.root ret, atexit = handle(path) if isinstance(typ, idl.Struct): match ret: case WalkCmd.KEEP_GOING: for member in typ.members: if _walk(path.add(member), handle) == WalkCmd.ABORT: ret = WalkCmd.ABORT break case WalkCmd.DONT_RECURSE: ret = WalkCmd.KEEP_GOING case WalkCmd.ABORT: ret = WalkCmd.ABORT case _: assert False, f"invalid cmd: {ret}" if atexit: atexit() return ret def walk(typ: idl.Type, handle: WalkHandler) -> None: _walk(Path(typ), handle) # get_buffer_size() ############################################################ class BufferSize(typing.NamedTuple): min_size: int # really just here to sanity-check against typ.min_size(version) exp_size: int # "expected" or max-reasonable size max_size: int # really just here to sanity-check against typ.max_size(version) max_copy: int max_copy_extra: str max_iov: int max_iov_extra: str class TmpBufferSize: min_size: int exp_size: int max_size: int max_copy: int max_copy_extra: str max_iov: int max_iov_extra: str tmp_starts_with_copy: bool tmp_ends_with_copy: bool def __init__(self) -> None: self.min_size = 0 self.exp_size = 0 self.max_size = 0 self.max_copy = 0 self.max_copy_extra = "" self.max_iov = 0 self.max_iov_extra = "" self.tmp_starts_with_copy = False self.tmp_ends_with_copy = False def _get_buffer_size(typ: idl.Type, version: str) -> TmpBufferSize: assert isinstance(typ, idl.Primitive) or (version in typ.in_versions) ret = TmpBufferSize() if not isinstance(typ, idl.Struct): assert typ.static_size ret.min_size = typ.static_size ret.exp_size = typ.static_size ret.max_size = typ.static_size ret.max_copy = typ.static_size ret.max_iov = 1 ret.tmp_starts_with_copy = True ret.tmp_ends_with_copy = True return ret def handle(path: Path) -> tuple[WalkCmd, None]: nonlocal ret if path.elems: child = path.elems[-1] if version not in child.in_versions: return WalkCmd.DONT_RECURSE, None if child.cnt: if child.typ.static_size == 1: # SPECIAL (zerocopy) ret.max_iov += 1 # HEURISTIC: 27 for strings (max-strlen from 9P1), 8KiB for other data ret.exp_size += 27 if child.membname == "utf8" else 8192 ret.max_size += child.max_cnt ret.tmp_ends_with_copy = False return WalkCmd.DONT_RECURSE, None sub = _get_buffer_size(child.typ, version) ret.exp_size += sub.exp_size * 16 # HEURISTIC: MAXWELEM ret.max_size += sub.max_size * child.max_cnt if child.membname == "wname" and path.root.typname in ( "Tsread", "Tswrite", ): # SPECIAL (9P2000.e) assert ret.tmp_ends_with_copy assert sub.tmp_starts_with_copy assert not sub.tmp_ends_with_copy ret.max_copy_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_copy})" ) ret.max_iov_extra = ( f" + (CONFIG_9P_MAX_9P2000_e_WELEM * {sub.max_iov})" ) ret.max_iov -= 1 else: ret.max_copy += sub.max_copy * child.max_cnt if sub.max_iov == 1 and sub.tmp_starts_with_copy: # is purely copy ret.max_iov += 1 else: # contains zero-copy segments ret.max_iov += sub.max_iov * child.max_cnt if ret.tmp_ends_with_copy and sub.tmp_starts_with_copy: # we can merge this one ret.max_iov -= 1 if ( sub.tmp_ends_with_copy and sub.tmp_starts_with_copy and sub.max_iov > 1 ): # we can merge these ret.max_iov -= child.max_cnt - 1 ret.tmp_ends_with_copy = sub.tmp_ends_with_copy return WalkCmd.DONT_RECURSE, None if not isinstance(child.typ, idl.Struct): assert child.typ.static_size if not ret.tmp_ends_with_copy: if ret.max_size == 0: ret.tmp_starts_with_copy = True ret.max_iov += 1 ret.tmp_ends_with_copy = True ret.min_size += child.typ.static_size ret.exp_size += child.typ.static_size ret.max_size += child.typ.static_size ret.max_copy += child.typ.static_size return WalkCmd.KEEP_GOING, None walk(typ, handle) assert ret.min_size == typ.min_size(version) assert ret.max_size == typ.max_size(version) return ret def get_buffer_size(typ: idl.Type, version: str) -> BufferSize: tmp = _get_buffer_size(typ, version) return BufferSize( min_size=tmp.min_size, exp_size=tmp.exp_size, max_size=tmp.max_size, max_copy=tmp.max_copy, max_copy_extra=tmp.max_copy_extra, max_iov=tmp.max_iov, max_iov_extra=tmp.max_iov_extra, ) # Generate .h ################################################################## def gen_h(versions: set[str], typs: list[idl.UserType]) -> str: cutil.ifdef_init() ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ \t#error Do not include directly; include instead #endif #include /* for uint{{n}}_t types */ #include /* for struct iovec */ """ id2typ: dict[int, idl.Message] = {} for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg ret += """ /* config *********************************************************************/ #include "config.h" """ for ver in sorted(versions): ret += "\n" ret += f"#ifndef {c9util.ver_ifdef({ver})}\n" ret += f"\t#error config.h must define {c9util.ver_ifdef({ver})}\n" if ver == "9P2000.e": # SPECIAL (9P2000.e) ret += "#else\n" ret += f"\t#if {c9util.ver_ifdef({ver})}\n" ret += "\t\t#ifndef CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += f"\t\t\t#error if {c9util.ver_ifdef({ver})} then config.h must define CONFIG_9P_MAX_9P2000_e_WELEM\n" ret += "\t\t#endif\n" ret += "\t\tstatic_assert(CONFIG_9P_MAX_9P2000_e_WELEM > 0);\n" ret += "\t#endif\n" ret += "#endif\n" ret += f""" /* enum version ***************************************************************/ enum {c9util.ident('version')} {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) ret += f"\t{c9util.ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += cutil.ifdef_pop(0) ret += f"\t{c9util.ver_enum('NUM')},\n" ret += "};\n" ret += """ /* enum msg_type **************************************************************/ """ ret += f"enum {c9util.ident('msg_type')} {{ /* uint8_t */\n" namewidth = max(len(msg.typname) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] ret += cutil.ifdef_push(1, c9util.ver_ifdef(msg.in_versions)) ret += f"\t{c9util.Ident(f'TYP_{msg.typname:<{namewidth}}')} = {msg.msgid},\n" ret += cutil.ifdef_pop(0) ret += "};\n" ret += """ /* payload types **************************************************************/ """ def per_version_comment( typ: idl.UserType, fn: typing.Callable[[idl.UserType, str], str] ) -> str: lines: dict[str, str] = {} for version in sorted(typ.in_versions): lines[version] = fn(typ, version) if len(set(lines.values())) == 1: for _, line in lines.items(): return f"/* {line} */\n" assert False else: ret = "" v_width = max(len(c9util.ver_enum(v)) for v in typ.in_versions) for version, line in lines.items(): ret += f"/* {c9util.ver_enum(version):<{v_width}}: {line} */\n" return ret for typ in topo_sorted(typs): ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) def sum_size(typ: idl.UserType, version: str) -> str: sz = get_buffer_size(typ, version) assert ( sz.min_size <= sz.exp_size and sz.exp_size <= sz.max_size and sz.max_size < cutil.UINT64_MAX ) ret = "" if sz.min_size == sz.max_size: ret += f"size = {sz.min_size:,}" else: ret += f"min_size = {sz.min_size:,} ; exp_size = {sz.exp_size:,} ; max_size = {sz.max_size:,}" if sz.max_size > cutil.UINT32_MAX: ret += " (warning: >UINT32_MAX)" ret += f" ; max_iov = {sz.max_iov:,}{sz.max_iov_extra} ; max_copy = {sz.max_copy:,}{sz.max_copy_extra}" return ret ret += per_version_comment(typ, sum_size) match typ: case idl.Number(): ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" prefix = f"{c9util.IDENT(typ.typname)}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): ret += f"#define {prefix}{name:<{namewidth}} (({c9util.typename(typ)})UINT{typ.static_size*8}_C({val}))\n" case idl.Bitfield(): ret += f"typedef {c9util.typename(typ.prim)} {c9util.typename(typ)};\n" def bitname(val: idl.Bit | idl.BitAlias) -> str: s = val.bitname match val: case idl.Bit(cat=idl.BitCat.RESERVED): s = "_RESERVED_" + s case idl.Bit(cat=idl.BitCat.SUBFIELD): assert isinstance(typ, idl.Bitfield) n = sum( 1 for b in typ.bits[: val.num] if b.cat == idl.BitCat.SUBFIELD and b.bitname == val.bitname ) s = f"_{s}_{n}" case idl.Bit(cat=idl.BitCat.UNUSED): return "" return c9util.Ident(c9util.add_prefix(typ.typname.upper() + "_", s)) namewidth = max( len(bitname(val)) for val in [*typ.bits, *typ.names.values()] ) ret += "\n" for bit in reversed(typ.bits): vers = bit.in_versions if bit.cat == idl.BitCat.UNUSED: vers = typ.in_versions ret += cutil.ifdef_push(2, c9util.ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. end = "" match bit.cat: case ( idl.BitCat.USED | idl.BitCat.RESERVED | idl.BitCat.SUBFIELD ): if cutil.ifdef_leaf_is_noop(): beg = "#define " else: beg = "# define" case idl.BitCat.UNUSED: beg = "/* unused" end = " */" c_name = bitname(bit) c_val = f"1<<{bit.num}" ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" if aliases := [ alias for alias in typ.names.values() if isinstance(alias, idl.BitAlias) ]: ret += "\n" for alias in aliases: ret += cutil.ifdef_push(2, c9util.ver_ifdef(alias.in_versions)) end = "" if cutil.ifdef_leaf_is_noop(): beg = "#define " else: beg = "# define" c_name = bitname(alias) c_val = alias.val ret += f"{beg} {c_name:<{namewidth}} (({c9util.typename(typ)})({c_val})){end}\n" ret += cutil.ifdef_pop(1) del bitname case idl.Struct(): # and idl.Message(): ret += c9util.typename(typ) + " {" if not typ.members: ret += "};\n" continue ret += "\n" typewidth = max(len(c9util.typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t{c9util.typename(member.typ, member):<{typewidth}} {'*' if member.cnt else ' '}{member.membname};\n" ret += cutil.ifdef_pop(1) ret += "};\n" del typ ret += cutil.ifdef_pop(0) ret += """ /* containers *****************************************************************/ """ ret += "\n" ret += f"#define {c9util.IDENT('_MAX')}(a, b) ((a) > (b)) ? (a) : (b)\n" tmsg_max_iov: dict[str, int] = {} tmsg_max_copy: dict[str, int] = {} rmsg_max_iov: dict[str, int] = {} rmsg_max_copy: dict[str, int] = {} for typ in typs: if not isinstance(typ, idl.Message): continue if typ.typname in ("Tsread", "Tswrite"): # SPECIAL (9P2000.e) continue max_iov = tmsg_max_iov if typ.msgid % 2 == 0 else rmsg_max_iov max_copy = tmsg_max_copy if typ.msgid % 2 == 0 else rmsg_max_copy for version in typ.in_versions: if version not in max_iov: max_iov[version] = 0 max_copy[version] = 0 sz = get_buffer_size(typ, version) if sz.max_iov > max_iov[version]: max_iov[version] = sz.max_iov if sz.max_copy > max_copy[version]: max_copy[version] = sz.max_copy for name, table in [ ("tmsg_max_iov", tmsg_max_iov), ("tmsg_max_copy", tmsg_max_copy), ("rmsg_max_iov", rmsg_max_iov), ("rmsg_max_copy", rmsg_max_copy), ]: inv: dict[int, set[str]] = {} for version, maxval in table.items(): if maxval not in inv: inv[maxval] = set() inv[maxval].add(version) ret += "\n" directive = "if" seen_e = False # SPECIAL (9P2000.e) for maxval in sorted(inv, reverse=True): ret += f"#{directive} {c9util.ver_ifdef(inv[maxval])}\n" indent = 1 if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) typ = next(typ for typ in typs if typ.typname == "Tswrite") sz = get_buffer_size(typ, "9P2000.e") match name: case "tmsg_max_iov": maxexpr = f"{sz.max_iov}{sz.max_iov_extra}" case "tmsg_max_copy": maxexpr = f"{sz.max_copy}{sz.max_copy_extra}" case _: assert False ret += f"\t#if {c9util.ver_ifdef({"9P2000.e"})}\n" ret += f"\t\t#define {c9util.IDENT(name)} {c9util.IDENT('_MAX')}({maxval}, {maxexpr})\n" ret += "\t#else\n" indent += 1 ret += f"{'\t'*indent}#define {c9util.IDENT(name)} {maxval}\n" if name.startswith("tmsg") and not seen_e: # SPECIAL (9P2000.e) ret += "\t#endif\n" if "9P2000.e" in inv[maxval]: seen_e = True directive = "elif" ret += "#endif\n" ret += "\n" ret += f"struct {c9util.ident('Tmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{c9util.IDENT('TMSG_MAX_IOV')}];\n" ret += f"\tuint8_t copied[{c9util.IDENT('TMSG_MAX_COPY')}];\n" ret += "};\n" ret += "\n" ret += f"struct {c9util.ident('Rmsg_send_buf')} {{\n" ret += "\tsize_t iov_cnt;\n" ret += f"\tstruct iovec iov[{c9util.IDENT('RMSG_MAX_IOV')}];\n" ret += f"\tuint8_t copied[{c9util.IDENT('RMSG_MAX_COPY')}];\n" ret += "};\n" return ret # Generate .c ################################################################## def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: cutil.ifdef_init() ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include #include "internal.h" """ # utilities ################################################################ ret += """ /* utilities ******************************************************************/ """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"LM_UNUSED({arg})" id2typ: dict[int, idl.Message] = {} for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" for ver in ["unknown", *sorted(versions)]: if ver != "unknown": ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: if ver == "unknown": # SPECIAL (initialization) if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: if ver not in xmsg.in_versions: xmsg = None if xmsg: ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" ret += "\t},\n" ret += cutil.ifdef_pop(0) ret += "};\n" return ret for v in sorted(versions): ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" ret += ( f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" ) ret += "#else\n" ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" ret += "#endif\n" ret += "\n" ret += "/**\n" ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" ret += " * several version checks together.\n" ret += " */\n" ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" # strings ################################################################## ret += f""" /* strings ********************************************************************/ const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) # bitmasks ################################################################# ret += """ /* bitmasks *******************************************************************/ """ for typ in typs: if not isinstance(typ, idl.Bitfield): continue ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) ret += ( f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( ( "1" if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) and ver in bit.in_versions else "0" ) for bit in reversed(typ.bits) ) + ",\n" ) ret += cutil.ifdef_pop(1) ret += "};\n" ret += cutil.ifdef_pop(0) # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { \tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) \t\t/* If needed-net-size overflowed uint32_t, then \t\t * there's no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \tif (ctx->net_offset > ctx->net_size) \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { \tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) \t\t/* If needed-host-size overflowed size_t, then there's \t\t * no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { \tfor (size_t i = 0; i < cnt; i++) \t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) \t\t\treturn true; \treturn false; } LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: return bool( member.max or member.val or any(m.cnt == member for m in typ.members) ) for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: case idl.Number(): ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" if typ.static_size == 1: ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += "\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case idl.Struct(): # and idl.Message() if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue # Pass 1 - declare value variables for member in typ.members: if should_save_value(typ, member): ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" ret += cutil.ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): if tok.symname[1:] not in mark_offset: ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" mark_offset.add(tok.symname[1:]) # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c9util.ver_cond(member.in_versions) + " && " if member.cnt is not None: if member.typ.static_size == 1: # SPECIAL (zerocopy) ret += f"_validate_size_net(ctx, {member.cnt.membname})" else: ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" if typ.typname == "s": # SPECIAL (string) ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' else: if should_save_value(typ, member): ret += "(" if member.membname in mark_offset: ret += f"({{ _{member.membname}_offset = ctx->net_offset; " ret += f"validate_{member.typ.typname}(ctx)" if member.membname in mark_offset: ret += "; })" if should_save_value(typ, member): nbytes = member.static_size assert nbytes if nbytes == 1: ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" else: ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: def lookup_sym(sym: str) -> str: match sym: case "end": return "ctx->net_offset" case _: assert sym.startswith("&") return f"_{sym[1:]}_offset" if member.max: assert member.static_size nbits = member.static_size * 8 ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' ret += cutil.ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += cutil.ifdef_pop(0) # unmarshal_* ############################################################## ret += """ /* unmarshal_* ****************************************************************/ LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { \t*out = ctx->net_bytes[ctx->net_offset]; \tctx->net_offset += 1; } LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { \t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 2; } LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { \t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 4; } LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { \t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 8; } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" match typ: case idl.Number(): ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Bitfield(): ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) if member.val: ret += f"\tctx->net_offset += {member.static_size};\n" continue ret += "\t" prefix = "\t" if member.in_versions != typ.in_versions: ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " prefix = "\t\t" if member.cnt: if member.in_versions != typ.in_versions: ret += "{\n" ret += prefix if member.typ.static_size == 1: # SPECIAL (string, zerocopy) ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" else: ret += f"out->{member.membname} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" else: ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" ret += cutil.ifdef_pop(1) ret += "}\n" ret += cutil.ifdef_pop(0) # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ """ ret += cutil.macro( "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" "\t\tctx->net_iov_cnt++;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" "\tctx->net_iov_cnt++;\n" ) ret += cutil.macro( "#define MARSHAL_BYTES(ctx, data, len)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" "\tctx->net_copied_size += len;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" ) ret += cutil.macro( "#define MARSHAL_U8LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tctx->net_copied[ctx->net_copied_size] = val;\n" "\tctx->net_copied_size += 1;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" ) ret += cutil.macro( "#define MARSHAL_U16LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 2;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" ) ret += cutil.macro( "#define MARSHAL_U32LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 4;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" ) ret += cutil.macro( "#define MARSHAL_U64LE(ctx, val)\n" "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" "\tctx->net_copied_size += 8;\n" "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" ) class OffsetExpr: static: int cond: dict[frozenset[str], "OffsetExpr"] rep: list[tuple[Path, "OffsetExpr"]] def __init__(self) -> None: self.static = 0 self.rep = [] self.cond = {} def add(self, other: "OffsetExpr") -> None: self.static += other.static self.rep += other.rep for k, v in other.cond.items(): if k in self.cond: self.cond[k].add(v) else: self.cond[k] = v def gen_c( self, dsttyp: str, dstvar: str, root: str, indent_depth: int, loop_depth: int, ) -> str: oneline: list[str] = [] multiline = "" if self.static: oneline.append(str(self.static)) for cnt, sub in self.rep: if not sub.cond and not sub.rep: if sub.static == 1: oneline.append(cnt.c_str(root)) else: oneline.append(f"({cnt.c_str(root)})*{sub.static}") continue loopvar = chr(ord("i") + loop_depth) multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" multiline += sub.gen_c( "", dstvar, root, indent_depth + 1, loop_depth + 1 ) multiline += f"{'\t'*indent_depth}}}\n" for vers, sub in self.cond.items(): multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) multiline += f"{'\t'*indent_depth}}}\n" multiline += cutil.ifdef_pop(indent_depth) if dsttyp: if not oneline: oneline.append("0") ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" elif oneline: ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" ret += multiline return ret type OffsetExprRecursion = typing.Callable[[Path], WalkCmd] def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: if not isinstance(typ, idl.Struct): assert typ.static_size ret = OffsetExpr() ret.static = typ.static_size return ret stack: list[tuple[Path, OffsetExpr, typing.Callable[[], None]]] def pop_root() -> None: assert False def pop_cond() -> None: nonlocal stack key = frozenset(stack[-1][0].elems[-1].in_versions) if key in stack[-2][1].cond: stack[-2][1].cond[key].add(stack[-1][1]) else: stack[-2][1].cond[key] = stack[-1][1] stack = stack[:-1] def pop_rep() -> None: nonlocal stack member_path = stack[-1][0] member = member_path.elems[-1] assert member.cnt cnt_path = member_path.parent().add(member.cnt) stack[-2][1].rep.append((cnt_path, stack[-1][1])) stack = stack[:-1] def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None] | None]: nonlocal recurse ret = recurse(path) if ret != WalkCmd.KEEP_GOING: return ret, None nonlocal stack stack_len = len(stack) def pop() -> None: nonlocal stack nonlocal stack_len while len(stack) > stack_len: stack[-1][2]() if path.elems: child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: stack.append((path, OffsetExpr(), pop_cond)) if child.cnt: stack.append((path, OffsetExpr(), pop_rep)) if not isinstance(child.typ, idl.Struct): assert child.typ.static_size stack[-1][1].static += child.typ.static_size return ret, pop stack = [(Path(typ), OffsetExpr(), pop_root)] walk(typ, handle) return stack[0][1] def go_to_end(path: Path) -> WalkCmd: return WalkCmd.KEEP_GOING def go_to_tok(name: str) -> typing.Callable[[Path], WalkCmd]: def ret(path: Path) -> WalkCmd: if len(path.elems) == 1 and path.elems[0].membname == name: return WalkCmd.ABORT return WalkCmd.KEEP_GOING return ret for typ in typs: if not ( isinstance(typ, idl.Message) or typ.typname == "stat" ): # SPECIAL (include stat) continue assert isinstance(typ, idl.Struct) ret += "\n" ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" # Pass 1 - check size max_size = max(typ.max_size(v) for v in typ.in_versions) if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) ret += get_offset_expr(typ, go_to_end).gen_c( "uint64_t", "needed_size", "val->", 1, 0 ) ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" else: ret += get_offset_expr(typ, go_to_end).gen_c( "uint32_t", "needed_size", "val->", 1, 0 ) ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" if isinstance(typ, idl.Message): # SPECIAL (disable for stat) ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' ret += f'\t\t\t"{typ.typname}",\n' ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' ret += "\t\t\tctx->ctx->max_msg_size);\n" ret += "\t\treturn true;\n" ret += "\t}\n" # Pass 2 - write data ifdef_depth = 1 stack: list[tuple[Path, bool]] = [(Path(typ), False)] def handle(path: Path) -> tuple[WalkCmd, typing.Callable[[], None]]: nonlocal ret nonlocal ifdef_depth nonlocal stack stack_len = len(stack) def pop() -> None: nonlocal ret nonlocal ifdef_depth nonlocal stack nonlocal stack_len while len(stack) > stack_len: ret += f"{'\t'*(len(stack)-1)}}}\n" if stack[-1][1]: ifdef_depth -= 1 ret += cutil.ifdef_pop(ifdef_depth) stack = stack[:-1] loopdepth = sum(1 for elem in path.elems if elem.cnt) struct = path.elems[-1].typ if path.elems else path.root if isinstance(struct, idl.Struct): offsets: list[str] = [] for member in struct.members: if not member.val: continue for tok in member.val.tokens: if not isinstance(tok, idl.ExprSym): continue if tok.symname == "end" or tok.symname.startswith("&"): if tok.symname not in offsets: offsets.append(tok.symname) for name in offsets: name_prefix = "offsetof_" + "".join( m.membname + "_" for m in path.elems ) if name == "end": if not path.elems: nonlocal max_size if max_size > cutil.UINT32_MAX: ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" else: ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" continue recurse: OffsetExprRecursion = go_to_end else: assert name.startswith("&") name = name[1:] recurse = go_to_tok(name) expr = get_offset_expr(struct, recurse) expr_prefix = path.c_str("val->", loopdepth) if not expr_prefix.endswith(">"): expr_prefix += "." ret += expr.gen_c( "uint32_t", name_prefix + name, expr_prefix, len(stack), loopdepth, ) if path.elems: child = path.elems[-1] parent = path.elems[-2].typ if len(path.elems) > 1 else path.root if child.in_versions < parent.in_versions: ret += cutil.ifdef_push( ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) ) ifdef_depth += 1 ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" stack.append((path, True)) if child.cnt: cnt_path = path.parent().add(child.cnt) if child.typ.static_size == 1: # SPECIAL (zerocopy) if path.root.typname == "stat": # SPECIAL (stat) ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" else: ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" return WalkCmd.KEEP_GOING, pop loopvar = chr(ord("i") + loopdepth - 1) ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" stack.append((path, False)) if not isinstance(child.typ, idl.Struct): if child.val: def lookup_sym(sym: str) -> str: nonlocal path if sym.startswith("&"): sym = sym[1:] return ( "offsetof_" + "".join(m.membname + "_" for m in path.elems[:-1]) + sym ) val = c9util.idl_expr(child.val, lookup_sym) else: val = path.c_str("val->") if isinstance(child.typ, idl.Bitfield): val += f" & {child.typ.typname}_masks[ctx->ctx->version]" ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" return WalkCmd.KEEP_GOING, pop walk(typ, handle) del handle del stack del max_size ret += "\treturn false;\n" ret += "}\n" ret += cutil.ifdef_pop(0) # function tables ########################################################## ret += """ /* function tables ************************************************************/ """ ret += "\n" ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" rerror = next(typ for typ in typs if typ.typname == "Rerror") ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) for ver in sorted(versions): ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" ret += cutil.ifdef_pop(0) ret += "};\n" ret += "\n" ret += cutil.macro( f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" f"\t\t.validate = validate_##typ,\n" f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" f"\t}}\n" ) ret += cutil.macro( f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" f"\t}}\n" ) ret += "\n" ret += msg_table( "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) ) ret += "\n" ret += msg_table( "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) ) ret += "\n" ret += msg_table( "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) ) ret += "\n" ret += msg_table( "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) ) ret += f""" LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ \treturn validate_stat(ctx); }} LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ \tunmarshal_stat(ctx, out); }} LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ \treturn marshal_stat(ctx, val); }} """ ############################################################################ return ret # Main ######################################################################### def main() -> None: if typing.TYPE_CHECKING: class ANSIColors: MAGENTA = "\x1b[35m" RED = "\x1b[31m" RESET = "\x1b[0m" else: from _colorize import ANSIColors # Present in Python 3.13+ if len(sys.argv) < 2: raise ValueError("requires at least 1 .9p filename") parser = idl.Parser() for txtname in sys.argv[1:]: try: parser.parse_file(txtname) except SyntaxError as e: print( f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}", file=sys.stderr, ) assert e.text print(f"\t{e.text}", file=sys.stderr) text_suffix = e.text.lstrip() text_prefix = e.text[: -len(text_suffix)] print( f"\t{text_prefix}{ANSIColors.RED}{'~'*len(text_suffix)}{ANSIColors.RESET}", file=sys.stderr, ) sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open( os.path.join(outdir, "include/lib9p/9p.generated.h"), "w", encoding="utf-8" ) as fh: fh.write(gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: fh.write(gen_c(versions, typs))