#!/usr/bin/env python # lib9p/idl.gen - Generate C marshalers/unmarshalers for .9p files # defining 9P protocol variants. # # Copyright (C) 2024-2025 Luke T. Shumaker # SPDX-License-Identifier: AGPL-3.0-or-later import graphlib import os.path import sys import typing sys.path.insert(0, os.path.normpath(os.path.join(__file__, ".."))) import idl # This strives to be "general-purpose" in that it just acts on the # *.9p inputs; but (unfortunately?) there are a few special-cases in # this script, marked with "SPECIAL". # Utilities #################################################################### idprefix = "lib9p_" u32max = (1 << 32) - 1 u64max = (1 << 64) - 1 def tab_ljust(s: str, width: int) -> str: cur = len(s.expandtabs(tabsize=8)) if cur >= width: return s return s + " " * (width - cur) def add_prefix(p: str, s: str) -> str: if s.startswith("_"): return "_" + p + s[1:] return p + s def c_macro(full: str) -> str: full = full.rstrip() assert "\n" in full lines = [l.rstrip() for l in full.split("\n")] width = max(len(l.expandtabs(tabsize=8)) for l in lines[:-1]) lines = [tab_ljust(l, width) for l in lines] return " \\\n".join(lines).rstrip() + "\n" def c_ver_enum(ver: str) -> str: return f"{idprefix.upper()}VER_{ver.replace('.', '_')}" def c_ver_ifdef(versions: set[str]) -> str: return " || ".join( f"CONFIG_9P_ENABLE_{v.replace('.', '_')}" for v in sorted(versions) ) def c_ver_cond(versions: set[str]) -> str: if len(versions) == 1: v = next(v for v in versions) return f"is_ver(ctx, {v.replace('.', '_')})" return "( " + (" || ".join(c_ver_cond({v}) for v in sorted(versions))) + " )" def c_typename(typ: idl.Type, parent: idl.StructMember | None = None) -> str: match typ: case idl.Primitive(): if typ.value == 1 and parent and parent.cnt: # SPECIAL (string) return "[[gnu::nonstring]] char" return f"uint{typ.value*8}_t" case idl.Number(): return f"{idprefix}{typ.name}_t" case idl.Bitfield(): return f"{idprefix}{typ.name}_t" case idl.Message(): return f"struct {idprefix}msg_{typ.name}" case idl.Struct(): return f"struct {idprefix}{typ.name}" case _: raise ValueError(f"not a type: {typ.__class__.__name__}") def c_expr(expr: idl.Expr) -> str: ret: list[str] = [] for tok in expr.tokens: match tok: case idl.ExprOp(): ret.append(tok.op) case idl.ExprLit(): ret.append(str(tok.val)) case idl.ExprSym(name="end"): ret.append("ctx->net_offset") case idl.ExprSym(name="s32_max"): ret.append("INT32_MAX") case idl.ExprSym(name="s64_max"): ret.append("INT64_MAX") case idl.ExprSym(): ret.append(f"_{tok.name[1:]}_offset") return " ".join(ret) _ifdef_stack: list[str | None] = [] def ifdef_push(n: int, _newval: str) -> str: # Grow the stack as needed global _ifdef_stack while len(_ifdef_stack) < n: _ifdef_stack.append(None) # Set some variables parentval: str | None = None for x in _ifdef_stack[:-1]: if x is not None: parentval = x oldval = _ifdef_stack[-1] newval: str | None = _newval if newval == parentval: newval = None # Put newval on the stack. _ifdef_stack[-1] = newval # Build output. ret = "" if newval != oldval: if oldval is not None: ret += f"#endif /* {oldval} */\n" if newval is not None: ret += f"#if {newval}\n" return ret def ifdef_pop(n: int) -> str: global _ifdef_stack ret = "" while len(_ifdef_stack) > n: if _ifdef_stack[-1] is not None: ret += f"#endif /* {_ifdef_stack[-1]} */\n" _ifdef_stack = _ifdef_stack[:-1] return ret def topo_sorted(typs: list[idl.Type]) -> typing.Iterable[idl.Type]: ts: graphlib.TopologicalSorter[idl.Type] = graphlib.TopologicalSorter() for typ in typs: match typ: case idl.Number(): ts.add(typ) case idl.Bitfield(): ts.add(typ) case idl.Struct(): # and idl.Message(): deps = [ member.typ for member in typ.members if not isinstance(member.typ, idl.Primitive) ] ts.add(typ, *deps) return ts.static_order() # Generate .h ################################################################## def gen_h(versions: set[str], typs: list[idl.Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #ifndef _LIB9P_9P_H_ \t#error Do not include directly; include instead #endif #include /* for uint{{n}}_t types */ """ id2typ: dict[int, idl.Message] = {} for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg ret += f""" /* config *********************************************************************/ #include "config.h" """ for ver in sorted(versions): ret += "\n" ret += f"#ifndef {c_ver_ifdef({ver})}\n" ret += f"\t#error config.h must define {c_ver_ifdef({ver})}\n" ret += "#endif\n" ret += f""" /* enum version ***************************************************************/ enum {idprefix}version {{ """ fullversions = ["unknown = 0", *sorted(versions)] verwidth = max(len(v) for v in fullversions) for ver in fullversions: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t{c_ver_enum(ver)}," ret += (" " * (verwidth - len(ver))) + ' /* "' + ver.split()[0] + '" */\n' ret += ifdef_pop(0) ret += f"\t{c_ver_enum('NUM')},\n" ret += "};\n" ret += """ /* enum msg_type **************************************************************/ """ ret += f"enum {idprefix}msg_type {{ /* uint8_t */\n" namewidth = max(len(msg.name) for msg in typs if isinstance(msg, idl.Message)) for n in range(0x100): if n not in id2typ: continue msg = id2typ[n] ret += ifdef_push(1, c_ver_ifdef(msg.in_versions)) ret += f"\t{idprefix.upper()}TYP_{msg.name.ljust(namewidth)} = {msg.msgid},\n" ret += ifdef_pop(0) ret += "};\n" ret += """ /* payload types **************************************************************/ """ def per_version_comment( typ: idl.Type, fn: typing.Callable[[idl.Type, str], str] ) -> str: lines: dict[str, str] = {} for version in sorted(typ.in_versions): lines[version] = fn(typ, version) if len(set(lines.values())) == 1: for _, line in lines.items(): return f"/* {line} */\n" assert False else: ret = "" v_width = max(len(c_ver_enum(v)) for v in typ.in_versions) for version, line in lines.items(): ret += f"/* {c_ver_enum(version).ljust(v_width)}: {line} */\n" return ret for typ in topo_sorted(typs): ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) def sum_size(typ: idl.Type, version: str) -> str: min_size = typ.min_size(version) max_size = typ.max_size(version) assert min_size <= max_size and max_size < u64max ret = "" if min_size == max_size: ret += f"size = {min_size:,}" else: ret += f"min_size = {min_size:,} ; max_size = {max_size:,}" if max_size > u32max: ret += " (warning: >UINT32_MAX)" return ret ret += per_version_comment(typ, sum_size) match typ: case idl.Number(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" prefix = f"{idprefix.upper()}{typ.name.upper()}_" namewidth = max(len(name) for name in typ.vals) for name, val in typ.vals.items(): ret += f"#define {prefix}{name.ljust(namewidth)} (({c_typename(typ)})UINT{typ.static_size*8}_C({val}))\n" case idl.Bitfield(): ret += f"typedef {c_typename(typ.prim)} {c_typename(typ)};\n" names = [ typ.bits[n] or f" {n}" for n in reversed(range(0, len(typ.bits))) ] if aliases := [k for k in typ.names if k not in typ.bits]: names.append("") names.extend(aliases) prefix = f"{idprefix.upper()}{typ.name.upper()}_" namewidth = max(len(add_prefix(prefix, name)) for name in names) ret += "\n" for name in names: if name == "": ret += "\n" continue if name.startswith(" "): vers = typ.in_versions c_name = "" c_val = f"1<<{name[1:]}" else: vers = typ.names[name].in_versions c_name = add_prefix(prefix, name) c_val = f"{typ.names[name].val}" ret += ifdef_push(2, c_ver_ifdef(vers)) # It is important all of the `beg` strings have # the same length. end = "" if name.startswith(" "): beg = "/* unused" end = " */" elif _ifdef_stack[-1]: beg = "# define" else: beg = "#define " ret += f"{beg} {c_name.ljust(namewidth)} (({c_typename(typ)})({c_val})){end}\n" ret += ifdef_pop(1) case idl.Struct(): # and idl.Message(): ret += c_typename(typ) + " {" if not typ.members: ret += "};\n" continue ret += "\n" typewidth = max(len(c_typename(m.typ, m)) for m in typ.members) for member in typ.members: if member.val: continue ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ, member).ljust(typewidth)} {'*' if member.cnt else ' '}{member.name};\n" ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) return ret # Generate .c ################################################################## def gen_c(versions: set[str], typs: list[idl.Type]) -> str: global _ifdef_stack _ifdef_stack = [] ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ #include #include /* for size_t */ #include /* for PRI* macros */ #include /* for memset() */ #include #include #include "internal.h" """ # utilities ################################################################ ret += f""" /* utilities ******************************************************************/ """ def used(arg: str) -> str: return arg def unused(arg: str) -> str: return f"LM_UNUSED({arg})" id2typ: dict[int, idl.Message] = {} for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: id2typ[msg.msgid] = msg def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: ret = f"const {tentry} _{idprefix}table_{grp}_{meth}[{c_ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" for ver in ["unknown", *sorted(versions)]: if ver != "unknown": ret += ifdef_push(1, c_ver_ifdef({ver})) ret += f"\t[{c_ver_enum(ver)}] = {{\n" for n in range(*rng): xmsg: idl.Message | None = id2typ.get(n, None) if xmsg: if ver == "unknown": # SPECIAL (initialization) if xmsg.name not in ["Tversion", "Rversion", "Rerror"]: xmsg = None else: if ver not in xmsg.in_versions: xmsg = None if xmsg: ret += f"\t\t_MSG_{meth.upper()}({xmsg.name}),\n" ret += "\t},\n" ret += ifdef_pop(0) ret += "};\n" return ret for v in sorted(versions): ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c_ver_enum(v)})\n" ret += "#else\n" ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" ret += "#endif\n" ret += "\n" ret += "/**\n" ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {idprefix.upper()}VER_##ver)`,\n" ret += f" * but compiles correctly (to `false`) even if `{idprefix.upper()}VER_##ver` isn't defined\n" ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" ret += " * several version checks together.\n" ret += " */\n" ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" # strings ################################################################## ret += f""" /* strings ********************************************************************/ const char *_lib9p_table_ver_name[{c_ver_enum('NUM')}] = {{ """ for ver in ["unknown", *sorted(versions)]: if ver in versions: ret += ifdef_push(1, c_ver_ifdef({ver})) ret += f'\t[{c_ver_enum(ver)}] = "{ver}",\n' ret += ifdef_pop(0) ret += "};\n" ret += "\n" ret += f"#define _MSG_NAME(typ) [{idprefix.upper()}TYP_##typ] = #typ\n" ret += msg_table("msg", "name", "char *", (0, 0x100, 1)) # bitmasks ################################################################# ret += f""" /* bitmasks *******************************************************************/ """ for typ in typs: if not isinstance(typ, idl.Bitfield): continue ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"static const {c_typename(typ)} {typ.name}_masks[{c_ver_enum('NUM')}] = {{\n" verwidth = max(len(ver) for ver in versions) for ver in sorted(versions): ret += ifdef_push(2, c_ver_ifdef({ver})) ret += ( f"\t[{c_ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + "".join( "1" if typ.bit_is_valid(bitname, ver) else "0" for bitname in reversed(typ.bits) ) + ",\n" ) ret += ifdef_pop(1) ret += "};\n" ret += ifdef_pop(0) # validate_* ############################################################### ret += """ /* validate_* *****************************************************************/ LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { \tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) \t\t/* If needed-net-size overflowed uint32_t, then \t\t * there's no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \tif (ctx->net_offset > ctx->net_size) \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { \tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) \t\t/* If needed-host-size overflowed size_t, then there's \t\t * no way that actual-net-size will live up to \t\t * that. */ \t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); \treturn false; } LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, size_t cnt, _validate_fn_t item_fn, size_t item_host_size) { \tfor (size_t i = 0; i < cnt; i++) \t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) \t\t\treturn true; \treturn false; } LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static bool validate_{typ.name}(struct _validate_ctx *{argfn('ctx')}) {{\n" match typ: case idl.Number(): ret += f"\treturn validate_{typ.prim.name}(ctx);\n" case idl.Bitfield(): ret += f"\t if (validate_{typ.static_size}(ctx))\n" ret += "\t\treturn true;\n" ret += ( f"\t{c_typename(typ)} mask = {typ.name}_masks[ctx->ctx->version];\n" ) if typ.static_size == 1: ret += f"\t{c_typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" else: ret += f"\t{c_typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" ret += f"\tif (val & ~mask)\n" ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.name} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' ret += "\treturn false;\n" case idl.Struct(): # and idl.Message() if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue def should_save_value(member: idl.StructMember) -> bool: nonlocal typ assert isinstance(typ, idl.Struct) return bool( member.max or member.val or any(m.cnt == member for m in typ.members) ) # Pass 1 - declare value variables for member in typ.members: if should_save_value(member): ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t{c_typename(member.typ)} {member.name};\n" ret += ifdef_pop(1) # Pass 2 - declare offset variables mark_offset: set[str] = set() for member in typ.members: for tok in [*member.max.tokens, *member.val.tokens]: if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): if tok.name[1:] not in mark_offset: ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" mark_offset.add(tok.name[1:]) # Pass 3 - main pass ret += "\treturn false\n" for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.cnt is not None: ret += f"_validate_list(ctx, {member.cnt.name}, validate_{member.typ.name}, sizeof({c_typename(member.typ)}))" if typ.name == "s": # SPECIAL (string) ret += f'\n\t || ({{ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); }})' else: if should_save_value(member): ret += "(" if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " ret += f"validate_{member.typ.name}(ctx)" if member.name in mark_offset: ret += "; })" if should_save_value(member): nbytes = member.static_size assert nbytes if nbytes == 1: ret += f" || ({{ {member.name} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" else: ret += f" || ({{ {member.name} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" # Pass 4 - validate ,max= and ,val= constraints for member in typ.members: if member.max: assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t max = {c_expr(member.max)}; (((uint{nbits}_t){member.name}) > max) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.name}, max); }})\n' if member.val: assert member.static_size nbits = member.static_size * 8 ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += f"\t || ({{ uint{nbits}_t exp = {c_expr(member.val)}; (((uint{nbits}_t){member.name}) != exp) &&\n" ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.name} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.name}, exp); }})\n' ret += ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += ifdef_pop(0) # unmarshal_* ############################################################## ret += """ /* unmarshal_* ****************************************************************/ LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { \t*out = ctx->net_bytes[ctx->net_offset]; \tctx->net_offset += 1; } LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { \t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 2; } LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { \t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 4; } LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { \t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); \tctx->net_offset += 8; } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static void unmarshal_{typ.name}(struct _unmarshal_ctx *{argfn('ctx')}, {c_typename(typ)} *out) {{\n" match typ: case idl.Number(): ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Bitfield(): ret += f"\tunmarshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)out);\n" case idl.Struct(): ret += "\tmemset(out, 0, sizeof(*out));\n" for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) if member.val: ret += f"\tctx->net_offset += {member.static_size};\n" continue ret += "\t" prefix = "\t" if member.in_versions != typ.in_versions: ret += "if ( " + c_ver_cond(member.in_versions) + " ) " prefix = "\t\t" if member.cnt: if member.in_versions != typ.in_versions: ret += "{\n" ret += prefix ret += f"out->{member.name} = ctx->extra;\n" ret += f"{prefix}ctx->extra += sizeof(out->{member.name}[0]) * out->{member.cnt.name};\n" ret += f"{prefix}for (typeof(out->{member.cnt.name}) i = 0; i < out->{member.cnt.name}; i++)\n" if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, (uint8_t *)&out->{member.name}[i]);\n" else: ret += f"{prefix}\tunmarshal_{member.typ.name}(ctx, &out->{member.name}[i]);\n" if member.in_versions != typ.in_versions: ret += "\t}\n" else: ret += ( f"unmarshal_{member.typ.name}(ctx, &out->{member.name});\n" ) ret += ifdef_pop(1) ret += "}\n" ret += ifdef_pop(0) # marshal_* ################################################################ ret += """ /* marshal_* ******************************************************************/ LM_ALWAYS_INLINE static bool _marshal_too_large(struct _marshal_ctx *ctx) { \tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s too large to marshal into %s limit (limit=%"PRIu32")", \t\t(ctx->net_bytes[4] % 2 == 0) ? "T-message" : "R-message", \t\tctx->ctx->version ? "negotiated" : ((ctx->net_bytes[4] % 2 == 0) ? "client" : "server"), \t\tctx->ctx->max_msg_size); \treturn true; } LM_ALWAYS_INLINE static bool marshal_1(struct _marshal_ctx *ctx, uint8_t *val) { \tif (ctx->net_offset + 1 > ctx->ctx->max_msg_size) \t\treturn _marshal_too_large(ctx); \tctx->net_bytes[ctx->net_offset] = *val; \tctx->net_offset += 1; \treturn false; } LM_ALWAYS_INLINE static bool marshal_2(struct _marshal_ctx *ctx, uint16_t *val) { \tif (ctx->net_offset + 2 > ctx->ctx->max_msg_size) \t\treturn _marshal_too_large(ctx); \tuint16le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 2; \treturn false; } LM_ALWAYS_INLINE static bool marshal_4(struct _marshal_ctx *ctx, uint32_t *val) { \tif (ctx->net_offset + 4 > ctx->ctx->max_msg_size) \t\treturn true; \tuint32le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 4; \treturn false; } LM_ALWAYS_INLINE static bool marshal_8(struct _marshal_ctx *ctx, uint64_t *val) { \tif (ctx->net_offset + 8 > ctx->ctx->max_msg_size) \t\treturn true; \tuint64le_encode(&ctx->net_bytes[ctx->net_offset], *val); \tctx->net_offset += 8; \treturn false; } """ for typ in topo_sorted(typs): inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used ret += "\n" ret += ifdef_push(1, c_ver_ifdef(typ.in_versions)) ret += f"{inline} static bool marshal_{typ.name}(struct _marshal_ctx *{argfn('ctx')}, {c_typename(typ)} *{argfn('val')}) {{\n" match typ: case idl.Number(): ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)val);\n" case idl.Bitfield(): ret += f"\t{c_typename(typ)} masked_val = *val & {typ.name}_masks[ctx->ctx->version];\n" ret += f"\treturn marshal_{typ.prim.name}(ctx, ({c_typename(typ.prim)} *)&masked_val);\n" case idl.Struct(): if len(typ.members) == 0: ret += "\treturn false;\n" ret += "}\n" continue # Pass 1 - declare offset variables mark_offset = set() for member in typ.members: if member.val: if member.name not in mark_offset: ret += f"\tuint32_t _{member.name}_offset;\n" mark_offset.add(member.name) for tok in member.val.tokens: if isinstance(tok, idl.ExprSym) and tok.name.startswith("&"): if tok.name[1:] not in mark_offset: ret += f"\tuint32_t _{tok.name[1:]}_offset;\n" mark_offset.add(tok.name[1:]) # Pass 2 - main pass ret += "\treturn false\n" for member in typ.members: ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) ret += "\t || " if member.in_versions != typ.in_versions: ret += "( " + c_ver_cond(member.in_versions) + " && " if member.name in mark_offset: ret += f"({{ _{member.name}_offset = ctx->net_offset; " if member.cnt: ret += "({ bool err = false;\n" ret += f"\t for (typeof(val->{member.cnt.name}) i = 0; i < val->{member.cnt.name} && !err; i++)\n" ret += "\t \terr = " if member.typ.static_size == 1: # SPECIAL (string) # Special-case is that we cast from `char` to `uint8_t`. ret += f"marshal_{member.typ.name}(ctx, (uint8_t *)&val->{member.name}[i]);\n" else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name}[i]);\n" ret += f"\t err; }})" elif member.val: # Just increment net_offset, don't actually marshal anything (yet). assert member.static_size ret += ( f"({{ ctx->net_offset += {member.static_size}; false; }})" ) else: ret += f"marshal_{member.typ.name}(ctx, &val->{member.name})" if member.name in mark_offset: ret += "; })" if member.in_versions != typ.in_versions: ret += " )" ret += "\n" # Pass 3 - marshal ,val= members for member in typ.members: if member.val: assert member.static_size ret += ifdef_push(2, c_ver_ifdef(member.in_versions)) if member.static_size == 1: ret += f"\t || ({{ ctx->net_bytes[_{member.name}_offset] = {c_expr(member.val)}; false; }})\n" else: ret += f"\t || ({{ uint{member.static_size*8}le_encode(&ctx->net_bytes[_{member.name}_offset], {c_expr(member.val)}); false; }})\n" ret += ifdef_pop(1) ret += "\t ;\n" ret += "}\n" ret += ifdef_pop(0) # function tables ########################################################## ret += """ /* function tables ************************************************************/ """ ret += c_macro( f"#define _MSG_RECV(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" f"\t\t.basesize = sizeof(struct {idprefix}msg_##typ),\n" f"\t\t.validate = validate_##typ,\n" f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" f"\t}}\n" ) ret += c_macro( f"#define _MSG_SEND(typ) [{idprefix.upper()}TYP_##typ/2] = {{\n" f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" f"\t}}\n" ) ret += "\n" ret += msg_table("Tmsg", "recv", f"struct _{idprefix}recv_tentry", (0, 0x100, 2)) ret += "\n" ret += msg_table("Rmsg", "recv", f"struct _{idprefix}recv_tentry", (1, 0x100, 2)) ret += "\n" ret += msg_table("Tmsg", "send", f"struct _{idprefix}send_tentry", (0, 0x100, 2)) ret += "\n" ret += msg_table("Rmsg", "send", f"struct _{idprefix}send_tentry", (1, 0x100, 2)) ret += f""" LM_FLATTEN bool _{idprefix}stat_validate(struct _validate_ctx *ctx) {{ \treturn validate_stat(ctx); }} LM_FLATTEN void _{idprefix}stat_unmarshal(struct _unmarshal_ctx *ctx, struct lib9p_stat *out) {{ \tunmarshal_stat(ctx, out); }} LM_FLATTEN bool _{idprefix}stat_marshal(struct _marshal_ctx *ctx, struct lib9p_stat *val) {{ \treturn marshal_stat(ctx, val); }} """ ############################################################################ return ret # Main ######################################################################### if __name__ == "__main__": import sys if typing.TYPE_CHECKING: class ANSIColors: MAGENTA = "\x1b[35m" RED = "\x1b[31m" RESET = "\x1b[0m" else: from _colorize import ANSIColors # Present in Python 3.13+ if len(sys.argv) < 2: raise ValueError("requires at least 1 .9p filename") parser = idl.Parser() for txtname in sys.argv[1:]: try: parser.parse_file(txtname) except SyntaxError as e: print( f"{ANSIColors.RED}{e.filename}{ANSIColors.RESET}:{ANSIColors.MAGENTA}{e.lineno}{ANSIColors.RESET}: {e.msg}", file=sys.stderr, ) assert e.text print(f"\t{e.text}", file=sys.stderr) print( f"\t{ANSIColors.RED}{'~'*len(e.text)}{ANSIColors.RESET}", file=sys.stderr, ) sys.exit(2) versions, typs = parser.all() outdir = os.path.normpath(os.path.join(sys.argv[0], "..")) with open(os.path.join(outdir, "include/lib9p/9p.generated.h"), "w") as fh: fh.write(gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w") as fh: fh.write(gen_c(versions, typs))