diff options
author | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-23 02:26:08 -0600 |
---|---|---|
committer | Luke T. Shumaker <lukeshu@lukeshu.com> | 2025-03-23 03:05:06 -0600 |
commit | 82b733e4f8b3febc3b51c133a52fb62b54180b4b (patch) | |
tree | 0cb858fd7b55f19eda2d027e628b580aab155342 | |
parent | 2a70a611558daa248e4fc1a11a9aa0ceb3ed397a (diff) |
lib9p: protogen: pull c.py and c_*.py out of __init__.py
-rw-r--r-- | lib9p/protogen/__init__.py | 749 | ||||
-rw-r--r-- | lib9p/protogen/c.py | 200 | ||||
-rw-r--r-- | lib9p/protogen/c9util.py | 8 | ||||
-rw-r--r-- | lib9p/protogen/c_marshal.py | 357 | ||||
-rw-r--r-- | lib9p/protogen/c_unmarshal.py | 92 | ||||
-rw-r--r-- | lib9p/protogen/c_validate.py | 171 |
6 files changed, 830 insertions, 747 deletions
diff --git a/lib9p/protogen/__init__.py b/lib9p/protogen/__init__.py index 37cf6f5..c2c6173 100644 --- a/lib9p/protogen/__init__.py +++ b/lib9p/protogen/__init__.py @@ -10,757 +10,12 @@ import typing import idl -from . import c9util, cutil, h, idlutil - -# This strives to be "general-purpose" in that it just acts on the -# *.9p inputs; but (unfortunately?) there are a few special-cases in -# this script, marked with "SPECIAL". - +from . import c, h # pylint: disable=unused-variable __all__ = ["main"] -# Generate .c ################################################################## - - -def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: - cutil.ifdef_init() - - ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ - -#include <stdbool.h> -#include <stddef.h> /* for size_t */ -#include <inttypes.h> /* for PRI* macros */ -#include <string.h> /* for memset() */ - -#include <libmisc/assert.h> - -#include <lib9p/9p.h> - -#include "internal.h" -""" - - # utilities ################################################################ - ret += """ -/* utilities ******************************************************************/ -""" - - def used(arg: str) -> str: - return arg - - def unused(arg: str) -> str: - return f"LM_UNUSED({arg})" - - id2typ: dict[int, idl.Message] = {} - for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: - id2typ[msg.msgid] = msg - - def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: - ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" - for ver in ["unknown", *sorted(versions)]: - if ver != "unknown": - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" - for n in range(*rng): - xmsg: idl.Message | None = id2typ.get(n, None) - if xmsg: - if ver == "unknown": # SPECIAL (initialization) - if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: - xmsg = None - else: - if ver not in xmsg.in_versions: - xmsg = None - if xmsg: - ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" - ret += "\t},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - return ret - - for v in sorted(versions): - ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" - ret += ( - f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" - ) - ret += "#else\n" - ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" - ret += "#endif\n" - ret += "\n" - ret += "/**\n" - ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" - ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" - ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" - ret += " * several version checks together.\n" - ret += " */\n" - ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" - - # strings ################################################################## - ret += f""" -/* strings ********************************************************************/ - -const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ -""" - for ver in ["unknown", *sorted(versions)]: - if ver in versions: - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" - ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) - - # bitmasks ################################################################# - ret += """ -/* bitmasks *******************************************************************/ -""" - for typ in typs: - if not isinstance(typ, idl.Bitfield): - continue - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" - verwidth = max(len(ver) for ver in versions) - for ver in sorted(versions): - ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) - ret += ( - f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" - + "".join( - ( - "1" - if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) - and ver in bit.in_versions - else "0" - ) - for bit in reversed(typ.bits) - ) - + ",\n" - ) - ret += cutil.ifdef_pop(1) - ret += "};\n" - ret += cutil.ifdef_pop(0) - - # validate_* ############################################################### - ret += """ -/* validate_* *****************************************************************/ - -LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { -\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) -\t\t/* If needed-net-size overflowed uint32_t, then -\t\t * there's no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\tif (ctx->net_offset > ctx->net_size) -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { -\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) -\t\t/* If needed-host-size overflowed size_t, then there's -\t\t * no way that actual-net-size will live up to -\t\t * that. */ -\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); -\treturn false; -} - -LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, - size_t cnt, - _validate_fn_t item_fn, size_t item_host_size) { -\tfor (size_t i = 0; i < cnt; i++) -\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) -\t\t\treturn true; -\treturn false; -} - -LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } -LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } -LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } -LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } -""" - - def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: - return bool( - member.max or member.val or any(m.cnt == member for m in typ.members) - ) - - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" - - match typ: - case idl.Number(): - ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" - case idl.Bitfield(): - ret += f"\t if (validate_{typ.static_size}(ctx))\n" - ret += "\t\treturn true;\n" - ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" - if typ.static_size == 1: - ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" - else: - ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" - ret += "\tif (val & ~mask)\n" - ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' - ret += "\treturn false;\n" - case idl.Struct(): # and idl.Message() - if len(typ.members) == 0: - ret += "\treturn false;\n" - ret += "}\n" - continue - - # Pass 1 - declare value variables - for member in typ.members: - if should_save_value(typ, member): - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" - ret += cutil.ifdef_pop(1) - - # Pass 2 - declare offset variables - mark_offset: set[str] = set() - for member in typ.members: - for tok in [*member.max.tokens, *member.val.tokens]: - if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): - if tok.symname[1:] not in mark_offset: - ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" - mark_offset.add(tok.symname[1:]) - - # Pass 3 - main pass - ret += "\treturn false\n" - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += "\t || " - if member.in_versions != typ.in_versions: - ret += "( " + c9util.ver_cond(member.in_versions) + " && " - if member.cnt is not None: - if member.typ.static_size == 1: # SPECIAL (zerocopy) - ret += f"_validate_size_net(ctx, {member.cnt.membname})" - else: - ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" - if typ.typname == "s": # SPECIAL (string) - ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' - else: - if should_save_value(typ, member): - ret += "(" - if member.membname in mark_offset: - ret += f"({{ _{member.membname}_offset = ctx->net_offset; " - ret += f"validate_{member.typ.typname}(ctx)" - if member.membname in mark_offset: - ret += "; })" - if should_save_value(typ, member): - nbytes = member.static_size - assert nbytes - if nbytes == 1: - ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" - else: - ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" - if member.in_versions != typ.in_versions: - ret += " )" - ret += "\n" - - # Pass 4 - validate ,max= and ,val= constraints - for member in typ.members: - - def lookup_sym(sym: str) -> str: - match sym: - case "end": - return "ctx->net_offset" - case _: - assert sym.startswith("&") - return f"_{sym[1:]}_offset" - - if member.max: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' - if member.val: - assert member.static_size - nbits = member.static_size * 8 - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" - ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' - - ret += cutil.ifdef_pop(1) - ret += "\t ;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # unmarshal_* ############################################################## - ret += """ -/* unmarshal_* ****************************************************************/ - -LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { -\t*out = ctx->net_bytes[ctx->net_offset]; -\tctx->net_offset += 1; -} - -LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { -\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 2; -} - -LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { -\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 4; -} - -LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { -\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); -\tctx->net_offset += 8; -} -""" - for typ in idlutil.topo_sorted(typs): - inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" - argfn = unused if (isinstance(typ, idl.Struct) and not typ.members) else used - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" - match typ: - case idl.Number(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Bitfield(): - ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" - case idl.Struct(): - ret += "\tmemset(out, 0, sizeof(*out));\n" - - for member in typ.members: - ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) - if member.val: - ret += f"\tctx->net_offset += {member.static_size};\n" - continue - ret += "\t" - - prefix = "\t" - if member.in_versions != typ.in_versions: - ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " - prefix = "\t\t" - if member.cnt: - if member.in_versions != typ.in_versions: - ret += "{\n" - ret += prefix - if member.typ.static_size == 1: # SPECIAL (string, zerocopy) - ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" - ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" - else: - ret += f"out->{member.membname} = ctx->extra;\n" - ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" - ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" - ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" - if member.in_versions != typ.in_versions: - ret += "\t}\n" - else: - ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" - ret += cutil.ifdef_pop(1) - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # marshal_* ################################################################ - ret += """ -/* marshal_* ******************************************************************/ - -""" - ret += cutil.macro( - "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" - "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" - "\t\tctx->net_iov_cnt++;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" - "\tctx->net_iov_cnt++;\n" - ) - ret += cutil.macro( - "#define MARSHAL_BYTES(ctx, data, len)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" - "\tctx->net_copied_size += len;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U8LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tctx->net_copied[ctx->net_copied_size] = val;\n" - "\tctx->net_copied_size += 1;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U16LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 2;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U32LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 4;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" - ) - ret += cutil.macro( - "#define MARSHAL_U64LE(ctx, val)\n" - "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" - "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" - "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" - "\tctx->net_copied_size += 8;\n" - "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" - ) - - class OffsetExpr: - static: int - cond: dict[frozenset[str], "OffsetExpr"] - rep: list[tuple[idlutil.Path, "OffsetExpr"]] - - def __init__(self) -> None: - self.static = 0 - self.rep = [] - self.cond = {} - - def add(self, other: "OffsetExpr") -> None: - self.static += other.static - self.rep += other.rep - for k, v in other.cond.items(): - if k in self.cond: - self.cond[k].add(v) - else: - self.cond[k] = v - - def gen_c( - self, - dsttyp: str, - dstvar: str, - root: str, - indent_depth: int, - loop_depth: int, - ) -> str: - oneline: list[str] = [] - multiline = "" - if self.static: - oneline.append(str(self.static)) - for cnt, sub in self.rep: - if not sub.cond and not sub.rep: - if sub.static == 1: - oneline.append(cnt.c_str(root)) - else: - oneline.append(f"({cnt.c_str(root)})*{sub.static}") - continue - loopvar = chr(ord("i") + loop_depth) - multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" - multiline += sub.gen_c( - "", dstvar, root, indent_depth + 1, loop_depth + 1 - ) - multiline += f"{'\t'*indent_depth}}}\n" - for vers, sub in self.cond.items(): - multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) - multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" - multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) - multiline += f"{'\t'*indent_depth}}}\n" - multiline += cutil.ifdef_pop(indent_depth) - if dsttyp: - if not oneline: - oneline.append("0") - ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" - elif oneline: - ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" - ret += multiline - return ret - - type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] - - def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: - if not isinstance(typ, idl.Struct): - assert typ.static_size - ret = OffsetExpr() - ret.static = typ.static_size - return ret - - stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]] - - def pop_root() -> None: - assert False - - def pop_cond() -> None: - nonlocal stack - key = frozenset(stack[-1][0].elems[-1].in_versions) - if key in stack[-2][1].cond: - stack[-2][1].cond[key].add(stack[-1][1]) - else: - stack[-2][1].cond[key] = stack[-1][1] - stack = stack[:-1] - - def pop_rep() -> None: - nonlocal stack - member_path = stack[-1][0] - member = member_path.elems[-1] - assert member.cnt - cnt_path = member_path.parent().add(member.cnt) - stack[-2][1].rep.append((cnt_path, stack[-1][1])) - stack = stack[:-1] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: - nonlocal recurse - - ret = recurse(path) - if ret != idlutil.WalkCmd.KEEP_GOING: - return ret, None - - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - stack[-1][2]() - - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - stack.append((path, OffsetExpr(), pop_cond)) - if child.cnt: - stack.append((path, OffsetExpr(), pop_rep)) - if not isinstance(child.typ, idl.Struct): - assert child.typ.static_size - stack[-1][1].static += child.typ.static_size - return ret, pop - - stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)] - idlutil.walk(typ, handle) - return stack[0][1] - - def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: - return idlutil.WalkCmd.KEEP_GOING - - def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: - def ret(path: idlutil.Path) -> idlutil.WalkCmd: - if len(path.elems) == 1 and path.elems[0].membname == name: - return idlutil.WalkCmd.ABORT - return idlutil.WalkCmd.KEEP_GOING - - return ret - - for typ in typs: - if not ( - isinstance(typ, idl.Message) or typ.typname == "stat" - ): # SPECIAL (include stat) - continue - assert isinstance(typ, idl.Struct) - ret += "\n" - ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) - ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" - - # Pass 1 - check size - max_size = max(typ.max_size(v) for v in typ.in_versions) - - if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint64_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" - else: - ret += get_offset_expr(typ, go_to_end).gen_c( - "uint32_t", "needed_size", "val->", 1, 0 - ) - ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" - if isinstance(typ, idl.Message): # SPECIAL (disable for stat) - ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' - ret += f'\t\t\t"{typ.typname}",\n' - ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' - ret += "\t\t\tctx->ctx->max_msg_size);\n" - ret += "\t\treturn true;\n" - ret += "\t}\n" - - # Pass 2 - write data - ifdef_depth = 1 - stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)] - - def handle( - path: idlutil.Path, - ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - stack_len = len(stack) - - def pop() -> None: - nonlocal ret - nonlocal ifdef_depth - nonlocal stack - nonlocal stack_len - while len(stack) > stack_len: - ret += f"{'\t'*(len(stack)-1)}}}\n" - if stack[-1][1]: - ifdef_depth -= 1 - ret += cutil.ifdef_pop(ifdef_depth) - stack = stack[:-1] - - loopdepth = sum(1 for elem in path.elems if elem.cnt) - struct = path.elems[-1].typ if path.elems else path.root - if isinstance(struct, idl.Struct): - offsets: list[str] = [] - for member in struct.members: - if not member.val: - continue - for tok in member.val.tokens: - if not isinstance(tok, idl.ExprSym): - continue - if tok.symname == "end" or tok.symname.startswith("&"): - if tok.symname not in offsets: - offsets.append(tok.symname) - for name in offsets: - name_prefix = "offsetof_" + "".join( - m.membname + "_" for m in path.elems - ) - if name == "end": - if not path.elems: - nonlocal max_size - if max_size > cutil.UINT32_MAX: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" - else: - ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" - continue - recurse: OffsetExprRecursion = go_to_end - else: - assert name.startswith("&") - name = name[1:] - recurse = go_to_tok(name) - expr = get_offset_expr(struct, recurse) - expr_prefix = path.c_str("val->", loopdepth) - if not expr_prefix.endswith(">"): - expr_prefix += "." - ret += expr.gen_c( - "uint32_t", - name_prefix + name, - expr_prefix, - len(stack), - loopdepth, - ) - if path.elems: - child = path.elems[-1] - parent = path.elems[-2].typ if len(path.elems) > 1 else path.root - if child.in_versions < parent.in_versions: - ret += cutil.ifdef_push( - ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) - ) - ifdef_depth += 1 - ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" - stack.append((path, True)) - if child.cnt: - cnt_path = path.parent().add(child.cnt) - if child.typ.static_size == 1: # SPECIAL (zerocopy) - if path.root.typname == "stat": # SPECIAL (stat) - ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - else: - ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - loopvar = chr(ord("i") + loopdepth - 1) - ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" - stack.append((path, False)) - if not isinstance(child.typ, idl.Struct): - if child.val: - - def lookup_sym(sym: str) -> str: - nonlocal path - if sym.startswith("&"): - sym = sym[1:] - return ( - "offsetof_" - + "".join(m.membname + "_" for m in path.elems[:-1]) - + sym - ) - - val = c9util.idl_expr(child.val, lookup_sym) - else: - val = path.c_str("val->") - if isinstance(child.typ, idl.Bitfield): - val += f" & {child.typ.typname}_masks[ctx->ctx->version]" - ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" - return idlutil.WalkCmd.KEEP_GOING, pop - - idlutil.walk(typ, handle) - del handle - del stack - del max_size - - ret += "\treturn false;\n" - ret += "}\n" - ret += cutil.ifdef_pop(0) - - # function tables ########################################################## - ret += """ -/* function tables ************************************************************/ -""" - - ret += "\n" - ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" - rerror = next(typ for typ in typs if typ.typname == "Rerror") - ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) - for ver in sorted(versions): - ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) - ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" - ret += cutil.ifdef_pop(0) - ret += "};\n" - - ret += "\n" - ret += cutil.macro( - f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" - f"\t\t.validate = validate_##typ,\n" - f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" - f"\t}}\n" - ) - ret += cutil.macro( - f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" - f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" - f"\t}}\n" - ) - ret += "\n" - ret += msg_table( - "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) - ) - ret += "\n" - ret += msg_table( - "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) - ) - - ret += f""" -LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ -\treturn validate_stat(ctx); -}} -LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ -\tunmarshal_stat(ctx, out); -}} -LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ -\treturn marshal_stat(ctx, val); -}} -""" - - ############################################################################ - return ret - - -# Main ######################################################################### - - def main() -> None: if typing.TYPE_CHECKING: @@ -799,4 +54,4 @@ def main() -> None: ) as fh: fh.write(h.gen_h(versions, typs)) with open(os.path.join(outdir, "9p.generated.c"), "w", encoding="utf-8") as fh: - fh.write(gen_c(versions, typs)) + fh.write(c.gen_c(versions, typs)) diff --git a/lib9p/protogen/c.py b/lib9p/protogen/c.py new file mode 100644 index 0000000..a7e1773 --- /dev/null +++ b/lib9p/protogen/c.py @@ -0,0 +1,200 @@ +# lib9p/protogen/c.py - Generate 9p.generated.c +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import sys + +import idl + +from . import c9util, c_marshal, c_unmarshal, c_validate, cutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c"] + + +def gen_c(versions: set[str], typs: list[idl.UserType]) -> str: + cutil.ifdef_init() + + ret = f"""/* Generated by `{' '.join(sys.argv)}`. DO NOT EDIT! */ + +#include <stdbool.h> +#include <stddef.h> /* for size_t */ +#include <inttypes.h> /* for PRI* macros */ +#include <string.h> /* for memset() */ + +#include <libmisc/assert.h> + +#include <lib9p/9p.h> + +#include "internal.h" +""" + + # utilities ################################################################ + ret += """ +/* utilities ******************************************************************/ +""" + + id2typ: dict[int, idl.Message] = {} + for msg in [msg for msg in typs if isinstance(msg, idl.Message)]: + id2typ[msg.msgid] = msg + + def msg_table(grp: str, meth: str, tentry: str, rng: tuple[int, int, int]) -> str: + ret = f"const {tentry} {c9util.ident(f'_table_{grp}_{meth}')}[{c9util.ver_enum('NUM')}][{hex(len(range(*rng)))}] = {{\n" + for ver in ["unknown", *sorted(versions)]: + if ver != "unknown": + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {{\n" + for n in range(*rng): + xmsg: idl.Message | None = id2typ.get(n, None) + if xmsg: + if ver == "unknown": # SPECIAL (initialization) + if xmsg.typname not in ["Tversion", "Rversion", "Rerror"]: + xmsg = None + else: + if ver not in xmsg.in_versions: + xmsg = None + if xmsg: + ret += f"\t\t_MSG_{meth.upper()}({xmsg.typname}),\n" + ret += "\t},\n" + ret += cutil.ifdef_pop(0) + ret += "};\n" + return ret + + for v in sorted(versions): + ret += f"#if CONFIG_9P_ENABLE_{v.replace('.', '_')}\n" + ret += ( + f"\t#define _is_ver_{v.replace('.', '_')}(v) (v == {c9util.ver_enum(v)})\n" + ) + ret += "#else\n" + ret += f"\t#define _is_ver_{v.replace('.', '_')}(v) false\n" + ret += "#endif\n" + ret += "\n" + ret += "/**\n" + ret += f" * is_ver(ctx, ver) is essentially `(ctx->ctx->version == {c9util.Ident('VER_')}##ver)`,\n" + ret += f" * but compiles correctly (to `false`) even if `{c9util.Ident('VER_')}##ver` isn't defined\n" + ret += " * (because `!CONFIG_9P_ENABLE_##ver`). This is useful when `||`ing\n" + ret += " * several version checks together.\n" + ret += " */\n" + ret += "#define is_ver(CTX, ver) _is_ver_##ver(CTX->ctx->version)\n" + + # strings ################################################################## + ret += f""" +/* strings ********************************************************************/ + +const char *const {c9util.ident('_table_ver_name')}[{c9util.ver_enum('NUM')}] = {{ +""" + for ver in ["unknown", *sorted(versions)]: + if ver in versions: + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f'\t[{c9util.ver_enum(ver)}] = "{ver}",\n' + ret += cutil.ifdef_pop(0) + ret += "};\n" + + ret += "\n" + ret += f"#define _MSG_NAME(typ) [{c9util.Ident('TYP_')}##typ] = #typ\n" + ret += msg_table("msg", "name", "char *const", (0, 0x100, 1)) + + # bitmasks ################################################################# + ret += """ +/* bitmasks *******************************************************************/ +""" + for typ in typs: + if not isinstance(typ, idl.Bitfield): + continue + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static const {c9util.typename(typ)} {typ.typname}_masks[{c9util.ver_enum('NUM')}] = {{\n" + verwidth = max(len(ver) for ver in versions) + for ver in sorted(versions): + ret += cutil.ifdef_push(2, c9util.ver_ifdef({ver})) + ret += ( + f"\t[{c9util.ver_enum(ver)}]{' '*(verwidth-len(ver))} = 0b" + + "".join( + ( + "1" + if bit.cat in (idl.BitCat.USED, idl.BitCat.SUBFIELD) + and ver in bit.in_versions + else "0" + ) + for bit in reversed(typ.bits) + ) + + ",\n" + ) + ret += cutil.ifdef_pop(1) + ret += "};\n" + ret += cutil.ifdef_pop(0) + + # validate_* ############################################################### + ret += c_validate.gen_c_validate(versions, typs) + + # unmarshal_* ############################################################## + ret += c_unmarshal.gen_c_unmarshal(versions, typs) + + # marshal_* ################################################################ + ret += c_marshal.gen_c_marshal(versions, typs) + + # function tables ########################################################## + ret += """ +/* function tables ************************************************************/ +""" + + ret += "\n" + ret += f"const uint32_t {c9util.ident('_table_msg_min_size')}[{c9util.ver_enum('NUM')}] = {{\n" + rerror = next(typ for typ in typs if typ.typname == "Rerror") + ret += f"\t[{c9util.ver_enum('unknown')}] = {rerror.min_size('9P2000')},\n" # SPECIAL (initialization) + for ver in sorted(versions): + ret += cutil.ifdef_push(1, c9util.ver_ifdef({ver})) + ret += f"\t[{c9util.ver_enum(ver)}] = {rerror.min_size(ver)},\n" + ret += cutil.ifdef_pop(0) + ret += "};\n" + + ret += "\n" + ret += cutil.macro( + f"#define _MSG_RECV(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" + f"\t\t.basesize = sizeof(struct {c9util.ident('msg_')}##typ),\n" + f"\t\t.validate = validate_##typ,\n" + f"\t\t.unmarshal = (_unmarshal_fn_t)unmarshal_##typ,\n" + f"\t}}\n" + ) + ret += cutil.macro( + f"#define _MSG_SEND(typ) [{c9util.Ident('TYP_')}##typ/2] = {{\n" + f"\t\t.marshal = (_marshal_fn_t)marshal_##typ,\n" + f"\t}}\n" + ) + ret += "\n" + ret += msg_table( + "Tmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (0, 0x100, 2) + ) + ret += "\n" + ret += msg_table( + "Rmsg", "recv", f"struct {c9util.ident('_recv_tentry')}", (1, 0x100, 2) + ) + ret += "\n" + ret += msg_table( + "Tmsg", "send", f"struct {c9util.ident('_send_tentry')}", (0, 0x100, 2) + ) + ret += "\n" + ret += msg_table( + "Rmsg", "send", f"struct {c9util.ident('_send_tentry')}", (1, 0x100, 2) + ) + + ret += f""" +LM_FLATTEN bool {c9util.ident('_stat_validate')}(struct _validate_ctx *ctx) {{ +\treturn validate_stat(ctx); +}} +LM_FLATTEN void {c9util.ident('_stat_unmarshal')}(struct _unmarshal_ctx *ctx, struct {c9util.ident('stat')} *out) {{ +\tunmarshal_stat(ctx, out); +}} +LM_FLATTEN bool {c9util.ident('_stat_marshal')}(struct _marshal_ctx *ctx, struct {c9util.ident('stat')} *val) {{ +\treturn marshal_stat(ctx, val); +}} +""" + + ############################################################################ + return ret diff --git a/lib9p/protogen/c9util.py b/lib9p/protogen/c9util.py index e7ad999..f9c49fc 100644 --- a/lib9p/protogen/c9util.py +++ b/lib9p/protogen/c9util.py @@ -107,3 +107,11 @@ def idl_expr(expr: idl.Expr, lookup_sym: typing.Callable[[str], str]) -> str: case _: assert False return " ".join(ret) + + +def arg_used(arg: str) -> str: + return arg + + +def arg_unused(arg: str) -> str: + return f"LM_UNUSED({arg})" diff --git a/lib9p/protogen/c_marshal.py b/lib9p/protogen/c_marshal.py new file mode 100644 index 0000000..152206d --- /dev/null +++ b/lib9p/protogen/c_marshal.py @@ -0,0 +1,357 @@ +# lib9p/protogen/c_marshal.py - Generate C marshal functions +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + +import typing + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c_marshal"] + + +def gen_c_marshal(versions: set[str], typs: list[idl.UserType]) -> str: + ret = """ +/* marshal_* ******************************************************************/ + +""" + ret += cutil.macro( + "#define MARSHAL_BYTES_ZEROCOPY(ctx, data, len)\n" + "\tif (ctx->net_iov[ctx->net_iov_cnt-1].iov_len)\n" + "\t\tctx->net_iov_cnt++;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = data;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len = len;\n" + "\tctx->net_iov_cnt++;\n" + ) + ret += cutil.macro( + "#define MARSHAL_BYTES(ctx, data, len)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tmemcpy(&ctx->net_copied[ctx->net_copied_size], data, len);\n" + "\tctx->net_copied_size += len;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += len;\n" + ) + ret += cutil.macro( + "#define MARSHAL_U8LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tctx->net_copied[ctx->net_copied_size] = val;\n" + "\tctx->net_copied_size += 1;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 1;\n" + ) + ret += cutil.macro( + "#define MARSHAL_U16LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint16le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 2;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 2;\n" + ) + ret += cutil.macro( + "#define MARSHAL_U32LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint32le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 4;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 4;\n" + ) + ret += cutil.macro( + "#define MARSHAL_U64LE(ctx, val)\n" + "\tif (!ctx->net_iov[ctx->net_iov_cnt-1].iov_base)\n" + "\t\tctx->net_iov[ctx->net_iov_cnt-1].iov_base = &ctx->net_copied[ctx->net_copied_size];\n" + "\tuint64le_encode(&ctx->net_copied[ctx->net_copied_size], val);\n" + "\tctx->net_copied_size += 8;\n" + "\tctx->net_iov[ctx->net_iov_cnt-1].iov_len += 8;\n" + ) + + class OffsetExpr: + static: int + cond: dict[frozenset[str], "OffsetExpr"] + rep: list[tuple[idlutil.Path, "OffsetExpr"]] + + def __init__(self) -> None: + self.static = 0 + self.rep = [] + self.cond = {} + + def add(self, other: "OffsetExpr") -> None: + self.static += other.static + self.rep += other.rep + for k, v in other.cond.items(): + if k in self.cond: + self.cond[k].add(v) + else: + self.cond[k] = v + + def gen_c( + self, + dsttyp: str, + dstvar: str, + root: str, + indent_depth: int, + loop_depth: int, + ) -> str: + oneline: list[str] = [] + multiline = "" + if self.static: + oneline.append(str(self.static)) + for cnt, sub in self.rep: + if not sub.cond and not sub.rep: + if sub.static == 1: + oneline.append(cnt.c_str(root)) + else: + oneline.append(f"({cnt.c_str(root)})*{sub.static}") + continue + loopvar = chr(ord("i") + loop_depth) + multiline += f"{'\t'*indent_depth}for ({c9util.typename(cnt.elems[-1].typ)} {loopvar} = 0; {loopvar} < {cnt.c_str(root)}; {loopvar}++) {{\n" + multiline += sub.gen_c( + "", dstvar, root, indent_depth + 1, loop_depth + 1 + ) + multiline += f"{'\t'*indent_depth}}}\n" + for vers, sub in self.cond.items(): + multiline += cutil.ifdef_push(indent_depth + 1, c9util.ver_ifdef(vers)) + multiline += f"{'\t'*indent_depth}if {c9util.ver_cond(vers)} {{\n" + multiline += sub.gen_c("", dstvar, root, indent_depth + 1, loop_depth) + multiline += f"{'\t'*indent_depth}}}\n" + multiline += cutil.ifdef_pop(indent_depth) + if dsttyp: + if not oneline: + oneline.append("0") + ret = f"{'\t'*indent_depth}{dsttyp} {dstvar} = {' + '.join(oneline)};\n" + elif oneline: + ret = f"{'\t'*indent_depth}{dstvar} += {' + '.join(oneline)};\n" + ret += multiline + return ret + + type OffsetExprRecursion = typing.Callable[[idlutil.Path], idlutil.WalkCmd] + + def get_offset_expr(typ: idl.UserType, recurse: OffsetExprRecursion) -> OffsetExpr: + if not isinstance(typ, idl.Struct): + assert typ.static_size + ret = OffsetExpr() + ret.static = typ.static_size + return ret + + stack: list[tuple[idlutil.Path, OffsetExpr, typing.Callable[[], None]]] + + def pop_root() -> None: + assert False + + def pop_cond() -> None: + nonlocal stack + key = frozenset(stack[-1][0].elems[-1].in_versions) + if key in stack[-2][1].cond: + stack[-2][1].cond[key].add(stack[-1][1]) + else: + stack[-2][1].cond[key] = stack[-1][1] + stack = stack[:-1] + + def pop_rep() -> None: + nonlocal stack + member_path = stack[-1][0] + member = member_path.elems[-1] + assert member.cnt + cnt_path = member_path.parent().add(member.cnt) + stack[-2][1].rep.append((cnt_path, stack[-1][1])) + stack = stack[:-1] + + def handle( + path: idlutil.Path, + ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None] | None]: + nonlocal recurse + + ret = recurse(path) + if ret != idlutil.WalkCmd.KEEP_GOING: + return ret, None + + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + stack[-1][2]() + + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + stack.append((path, OffsetExpr(), pop_cond)) + if child.cnt: + stack.append((path, OffsetExpr(), pop_rep)) + if not isinstance(child.typ, idl.Struct): + assert child.typ.static_size + stack[-1][1].static += child.typ.static_size + return ret, pop + + stack = [(idlutil.Path(typ), OffsetExpr(), pop_root)] + idlutil.walk(typ, handle) + return stack[0][1] + + def go_to_end(path: idlutil.Path) -> idlutil.WalkCmd: + return idlutil.WalkCmd.KEEP_GOING + + def go_to_tok(name: str) -> typing.Callable[[idlutil.Path], idlutil.WalkCmd]: + def ret(path: idlutil.Path) -> idlutil.WalkCmd: + if len(path.elems) == 1 and path.elems[0].membname == name: + return idlutil.WalkCmd.ABORT + return idlutil.WalkCmd.KEEP_GOING + + return ret + + for typ in typs: + if not ( + isinstance(typ, idl.Message) or typ.typname == "stat" + ): # SPECIAL (include stat) + continue + assert isinstance(typ, idl.Struct) + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"static bool marshal_{typ.typname}(struct _marshal_ctx *ctx, {c9util.typename(typ)} *val) {{\n" + + # Pass 1 - check size + max_size = max(typ.max_size(v) for v in typ.in_versions) + + if max_size > cutil.UINT32_MAX: # SPECIAL (9P2000.e) + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint64_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > (uint64_t)(ctx->ctx->max_msg_size)) {\n" + else: + ret += get_offset_expr(typ, go_to_end).gen_c( + "uint32_t", "needed_size", "val->", 1, 0 + ) + ret += "\tif (needed_size > ctx->ctx->max_msg_size) {\n" + if isinstance(typ, idl.Message): # SPECIAL (disable for stat) + ret += '\t\tlib9p_errorf(ctx->ctx, LINUX_ERANGE, "%s message too large to marshal into %s limit (limit=%"PRIu32")",\n' + ret += f'\t\t\t"{typ.typname}",\n' + ret += f'\t\t\tctx->ctx->version ? "negotiated" : "{'client' if typ.msgid % 2 == 0 else 'server'}",\n' + ret += "\t\t\tctx->ctx->max_msg_size);\n" + ret += "\t\treturn true;\n" + ret += "\t}\n" + + # Pass 2 - write data + ifdef_depth = 1 + stack: list[tuple[idlutil.Path, bool]] = [(idlutil.Path(typ), False)] + + def handle( + path: idlutil.Path, + ) -> tuple[idlutil.WalkCmd, typing.Callable[[], None]]: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + stack_len = len(stack) + + def pop() -> None: + nonlocal ret + nonlocal ifdef_depth + nonlocal stack + nonlocal stack_len + while len(stack) > stack_len: + ret += f"{'\t'*(len(stack)-1)}}}\n" + if stack[-1][1]: + ifdef_depth -= 1 + ret += cutil.ifdef_pop(ifdef_depth) + stack = stack[:-1] + + loopdepth = sum(1 for elem in path.elems if elem.cnt) + struct = path.elems[-1].typ if path.elems else path.root + if isinstance(struct, idl.Struct): + offsets: list[str] = [] + for member in struct.members: + if not member.val: + continue + for tok in member.val.tokens: + if not isinstance(tok, idl.ExprSym): + continue + if tok.symname == "end" or tok.symname.startswith("&"): + if tok.symname not in offsets: + offsets.append(tok.symname) + for name in offsets: + name_prefix = "offsetof_" + "".join( + m.membname + "_" for m in path.elems + ) + if name == "end": + if not path.elems: + nonlocal max_size + if max_size > cutil.UINT32_MAX: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = (uint32_t)needed_size;\n" + else: + ret += f"{'\t'*len(stack)}uint32_t {name_prefix}end = needed_size;\n" + continue + recurse: OffsetExprRecursion = go_to_end + else: + assert name.startswith("&") + name = name[1:] + recurse = go_to_tok(name) + expr = get_offset_expr(struct, recurse) + expr_prefix = path.c_str("val->", loopdepth) + if not expr_prefix.endswith(">"): + expr_prefix += "." + ret += expr.gen_c( + "uint32_t", + name_prefix + name, + expr_prefix, + len(stack), + loopdepth, + ) + if path.elems: + child = path.elems[-1] + parent = path.elems[-2].typ if len(path.elems) > 1 else path.root + if child.in_versions < parent.in_versions: + ret += cutil.ifdef_push( + ifdef_depth + 1, c9util.ver_ifdef(child.in_versions) + ) + ifdef_depth += 1 + ret += f"{'\t'*len(stack)}if ({c9util.ver_cond(child.in_versions)}) {{\n" + stack.append((path, True)) + if child.cnt: + cnt_path = path.parent().add(child.cnt) + if child.typ.static_size == 1: # SPECIAL (zerocopy) + if path.root.typname == "stat": # SPECIAL (stat) + ret += f"{'\t'*len(stack)}MARSHAL_BYTES(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + else: + ret += f"{'\t'*len(stack)}MARSHAL_BYTES_ZEROCOPY(ctx, {path.c_str('val->')[:-3]}, {cnt_path.c_str('val->')});\n" + return idlutil.WalkCmd.KEEP_GOING, pop + loopvar = chr(ord("i") + loopdepth - 1) + ret += f"{'\t'*len(stack)}for ({c9util.typename(child.cnt.typ)} {loopvar} = 0; {loopvar} < {cnt_path.c_str('val->')}; {loopvar}++) {{\n" + stack.append((path, False)) + if not isinstance(child.typ, idl.Struct): + if child.val: + + def lookup_sym(sym: str) -> str: + nonlocal path + if sym.startswith("&"): + sym = sym[1:] + return ( + "offsetof_" + + "".join(m.membname + "_" for m in path.elems[:-1]) + + sym + ) + + val = c9util.idl_expr(child.val, lookup_sym) + else: + val = path.c_str("val->") + if isinstance(child.typ, idl.Bitfield): + val += f" & {child.typ.typname}_masks[ctx->ctx->version]" + ret += f"{'\t'*len(stack)}MARSHAL_U{child.typ.static_size*8}LE(ctx, {val});\n" + return idlutil.WalkCmd.KEEP_GOING, pop + + idlutil.walk(typ, handle) + del handle + del stack + del max_size + + ret += "\treturn false;\n" + ret += "}\n" + ret += cutil.ifdef_pop(0) + return ret diff --git a/lib9p/protogen/c_unmarshal.py b/lib9p/protogen/c_unmarshal.py new file mode 100644 index 0000000..e17f456 --- /dev/null +++ b/lib9p/protogen/c_unmarshal.py @@ -0,0 +1,92 @@ +# lib9p/protogen/c_unmarshal.py - Generate C unmarshal functions +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c_unmarshal"] + + +def gen_c_unmarshal(versions: set[str], typs: list[idl.UserType]) -> str: + ret = """ +/* unmarshal_* ****************************************************************/ + +LM_ALWAYS_INLINE static void unmarshal_1(struct _unmarshal_ctx *ctx, uint8_t *out) { +\t*out = ctx->net_bytes[ctx->net_offset]; +\tctx->net_offset += 1; +} + +LM_ALWAYS_INLINE static void unmarshal_2(struct _unmarshal_ctx *ctx, uint16_t *out) { +\t*out = uint16le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 2; +} + +LM_ALWAYS_INLINE static void unmarshal_4(struct _unmarshal_ctx *ctx, uint32_t *out) { +\t*out = uint32le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 4; +} + +LM_ALWAYS_INLINE static void unmarshal_8(struct _unmarshal_ctx *ctx, uint64_t *out) { +\t*out = uint64le_decode(&ctx->net_bytes[ctx->net_offset]); +\tctx->net_offset += 8; +} +""" + for typ in idlutil.topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = ( + c9util.arg_unused + if (isinstance(typ, idl.Struct) and not typ.members) + else c9util.arg_used + ) + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"{inline} static void unmarshal_{typ.typname}(struct _unmarshal_ctx *{argfn('ctx')}, {c9util.typename(typ)} *out) {{\n" + match typ: + case idl.Number(): + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" + case idl.Bitfield(): + ret += f"\tunmarshal_{typ.prim.typname}(ctx, ({c9util.typename(typ.prim)} *)out);\n" + case idl.Struct(): + ret += "\tmemset(out, 0, sizeof(*out));\n" + + for member in typ.members: + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + if member.val: + ret += f"\tctx->net_offset += {member.static_size};\n" + continue + ret += "\t" + + prefix = "\t" + if member.in_versions != typ.in_versions: + ret += "if ( " + c9util.ver_cond(member.in_versions) + " ) " + prefix = "\t\t" + if member.cnt: + if member.in_versions != typ.in_versions: + ret += "{\n" + ret += prefix + if member.typ.static_size == 1: # SPECIAL (string, zerocopy) + ret += f"out->{member.membname} = (char *)&ctx->net_bytes[ctx->net_offset];\n" + ret += f"{prefix}ctx->net_offset += out->{member.cnt.membname};\n" + else: + ret += f"out->{member.membname} = ctx->extra;\n" + ret += f"{prefix}ctx->extra += sizeof(out->{member.membname}[0]) * out->{member.cnt.membname};\n" + ret += f"{prefix}for (typeof(out->{member.cnt.membname}) i = 0; i < out->{member.cnt.membname}; i++)\n" + ret += f"{prefix}\tunmarshal_{member.typ.typname}(ctx, &out->{member.membname}[i]);\n" + if member.in_versions != typ.in_versions: + ret += "\t}\n" + else: + ret += f"unmarshal_{member.typ.typname}(ctx, &out->{member.membname});\n" + ret += cutil.ifdef_pop(1) + ret += "}\n" + ret += cutil.ifdef_pop(0) + return ret diff --git a/lib9p/protogen/c_validate.py b/lib9p/protogen/c_validate.py new file mode 100644 index 0000000..a3f4348 --- /dev/null +++ b/lib9p/protogen/c_validate.py @@ -0,0 +1,171 @@ +# lib9p/protogen/c_validate.py - Generate C validation functions +# +# Copyright (C) 2024-2025 Luke T. Shumaker <lukeshu@lukeshu.com> +# SPDX-License-Identifier: AGPL-3.0-or-later + + +import idl + +from . import c9util, cutil, idlutil + +# This strives to be "general-purpose" in that it just acts on the +# *.9p inputs; but (unfortunately?) there are a few special-cases in +# this script, marked with "SPECIAL". + + +# pylint: disable=unused-variable +__all__ = ["gen_c_validate"] + + +def should_save_value(typ: idl.Struct, member: idl.StructMember) -> bool: + return bool(member.max or member.val or any(m.cnt == member for m in typ.members)) + + +def gen_c_validate(versions: set[str], typs: list[idl.UserType]) -> str: + ret = """ +/* validate_* *****************************************************************/ + +LM_ALWAYS_INLINE static bool _validate_size_net(struct _validate_ctx *ctx, uint32_t n) { +\tif (__builtin_add_overflow(ctx->net_offset, n, &ctx->net_offset)) +\t\t/* If needed-net-size overflowed uint32_t, then +\t\t * there's no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\tif (ctx->net_offset > ctx->net_size) +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_size_host(struct _validate_ctx *ctx, size_t n) { +\tif (__builtin_add_overflow(ctx->host_extra, n, &ctx->host_extra)) +\t\t/* If needed-host-size overflowed size_t, then there's +\t\t * no way that actual-net-size will live up to +\t\t * that. */ +\t\treturn lib9p_error(ctx->ctx, LINUX_EBADMSG, "message is too short for content"); +\treturn false; +} + +LM_ALWAYS_INLINE static bool _validate_list(struct _validate_ctx *ctx, + size_t cnt, + _validate_fn_t item_fn, size_t item_host_size) { +\tfor (size_t i = 0; i < cnt; i++) +\t\tif (_validate_size_host(ctx, item_host_size) || item_fn(ctx)) +\t\t\treturn true; +\treturn false; +} + +LM_ALWAYS_INLINE static bool validate_1(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 1); } +LM_ALWAYS_INLINE static bool validate_2(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 2); } +LM_ALWAYS_INLINE static bool validate_4(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 4); } +LM_ALWAYS_INLINE static bool validate_8(struct _validate_ctx *ctx) { return _validate_size_net(ctx, 8); } +""" + + for typ in idlutil.topo_sorted(typs): + inline = "LM_FLATTEN" if isinstance(typ, idl.Message) else "LM_ALWAYS_INLINE" + argfn = ( + c9util.arg_unused + if (isinstance(typ, idl.Struct) and not typ.members) + else c9util.arg_used + ) + ret += "\n" + ret += cutil.ifdef_push(1, c9util.ver_ifdef(typ.in_versions)) + ret += f"{inline} static bool validate_{typ.typname}(struct _validate_ctx *{argfn('ctx')}) {{\n" + + match typ: + case idl.Number(): + ret += f"\treturn validate_{typ.prim.typname}(ctx);\n" + case idl.Bitfield(): + ret += f"\t if (validate_{typ.static_size}(ctx))\n" + ret += "\t\treturn true;\n" + ret += f"\t{c9util.typename(typ)} mask = {typ.typname}_masks[ctx->ctx->version];\n" + if typ.static_size == 1: + ret += f"\t{c9util.typename(typ)} val = ctx->net_bytes[ctx->net_offset-1];\n" + else: + ret += f"\t{c9util.typename(typ)} val = uint{typ.static_size*8}le_decode(&ctx->net_bytes[ctx->net_offset-{typ.static_size}]);\n" + ret += "\tif (val & ~mask)\n" + ret += f'\t\treturn lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "unknown bits in {typ.typname} bitfield: %#0{typ.static_size}"PRIx{typ.static_size*8}, val & ~mask);\n' + ret += "\treturn false;\n" + case idl.Struct(): # and idl.Message() + if len(typ.members) == 0: + ret += "\treturn false;\n" + ret += "}\n" + continue + + # Pass 1 - declare value variables + for member in typ.members: + if should_save_value(typ, member): + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t{c9util.typename(member.typ)} {member.membname};\n" + ret += cutil.ifdef_pop(1) + + # Pass 2 - declare offset variables + mark_offset: set[str] = set() + for member in typ.members: + for tok in [*member.max.tokens, *member.val.tokens]: + if isinstance(tok, idl.ExprSym) and tok.symname.startswith("&"): + if tok.symname[1:] not in mark_offset: + ret += f"\tuint32_t _{tok.symname[1:]}_offset;\n" + mark_offset.add(tok.symname[1:]) + + # Pass 3 - main pass + ret += "\treturn false\n" + for member in typ.members: + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += "\t || " + if member.in_versions != typ.in_versions: + ret += "( " + c9util.ver_cond(member.in_versions) + " && " + if member.cnt is not None: + if member.typ.static_size == 1: # SPECIAL (zerocopy) + ret += f"_validate_size_net(ctx, {member.cnt.membname})" + else: + ret += f"_validate_list(ctx, {member.cnt.membname}, validate_{member.typ.typname}, sizeof({c9util.typename(member.typ)}))" + if typ.typname == "s": # SPECIAL (string) + ret += '\n\t || ({ (!is_valid_utf8_without_nul(&ctx->net_bytes[ctx->net_offset-len], len)) && lib9p_error(ctx->ctx, LINUX_EBADMSG, "message contains invalid UTF-8"); })' + else: + if should_save_value(typ, member): + ret += "(" + if member.membname in mark_offset: + ret += f"({{ _{member.membname}_offset = ctx->net_offset; " + ret += f"validate_{member.typ.typname}(ctx)" + if member.membname in mark_offset: + ret += "; })" + if should_save_value(typ, member): + nbytes = member.static_size + assert nbytes + if nbytes == 1: + ret += f" || ({{ {member.membname} = ctx->net_bytes[ctx->net_offset-1]; false; }}))" + else: + ret += f" || ({{ {member.membname} = uint{nbytes*8}le_decode(&ctx->net_bytes[ctx->net_offset-{nbytes}]); false; }}))" + if member.in_versions != typ.in_versions: + ret += " )" + ret += "\n" + + # Pass 4 - validate ,max= and ,val= constraints + for member in typ.members: + + def lookup_sym(sym: str) -> str: + match sym: + case "end": + return "ctx->net_offset" + case _: + assert sym.startswith("&") + return f"_{sym[1:]}_offset" + + if member.max: + assert member.static_size + nbits = member.static_size * 8 + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t max = {c9util.idl_expr(member.max, lookup_sym)}; (((uint{nbits}_t){member.membname}) > max) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is too large (%"PRIu{nbits}" > %"PRIu{nbits}")", {member.membname}, max); }})\n' + if member.val: + assert member.static_size + nbits = member.static_size * 8 + ret += cutil.ifdef_push(2, c9util.ver_ifdef(member.in_versions)) + ret += f"\t || ({{ uint{nbits}_t exp = {c9util.idl_expr(member.val, lookup_sym)}; (((uint{nbits}_t){member.membname}) != exp) &&\n" + ret += f'\t lib9p_errorf(ctx->ctx, LINUX_EBADMSG, "{member.membname} value is wrong (actual:%"PRIu{nbits}" != correct:%"PRIu{nbits}")", (uint{nbits}_t){member.membname}, exp); }})\n' + + ret += cutil.ifdef_pop(1) + ret += "\t ;\n" + ret += "}\n" + ret += cutil.ifdef_pop(0) + return ret |